Compare commits
20 Commits
master
...
enhance-bu
Author | SHA1 | Date |
---|---|---|
Brian Rosenberger | dd151aec3f | |
Brian Rosenberger | 1b3338f809 | |
Brian Rosenberger | 3d949c5348 | |
Brian Rosenberger | 6930116c18 | |
Brian Rosenberger | e27fb8422f | |
Brian Rosenberger | d0342fc939 | |
Brian Rosenberger | b34b96d929 | |
Brian Rosenberger | 8f51471a31 | |
Brian Rosenberger | dc5de40620 | |
Brian Rosenberger | e834407b6e | |
Brian Rosenberger | 4dc5a116b6 | |
Brian Rosenberger | 997143b9dd | |
Brian Rosenberger | 0bed17c97f | |
Brian Rosenberger | 8d73a7a410 | |
Brian Rosenberger | c758cf918f | |
Brian Rosenberger | 2c8c6d9624 | |
Brian Rosenberger | 0ba049885f | |
Brian Rosenberger | 345f55a003 | |
Brian Rosenberger | 1c39dbee52 | |
Brian Rosenberger | ea504bff41 |
|
@ -1,4 +1,4 @@
|
||||||
FROM nvidia/cuda:11.4.0-cudnn8-devel-ubuntu20.04
|
FROM nvidia/cuda:11.4.3-cudnn8-devel-ubuntu20.04
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
|
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
|
||||||
|
|
|
@ -36,6 +36,8 @@ pom.xml.versionsBackup
|
||||||
pom.xml.next
|
pom.xml.next
|
||||||
release.properties
|
release.properties
|
||||||
*dependency-reduced-pom.xml
|
*dependency-reduced-pom.xml
|
||||||
|
**/build/*
|
||||||
|
.gradle/*
|
||||||
|
|
||||||
# Specific for Nd4j
|
# Specific for Nd4j
|
||||||
*.md5
|
*.md5
|
||||||
|
@ -83,3 +85,14 @@ bruai4j-native-common/cmake*
|
||||||
/bruai4j-native/bruai4j-native-common/blasbuild/
|
/bruai4j-native/bruai4j-native-common/blasbuild/
|
||||||
/bruai4j-native/bruai4j-native-common/build/
|
/bruai4j-native/bruai4j-native-common/build/
|
||||||
/cavis-native/cavis-native-lib/blasbuild/
|
/cavis-native/cavis-native-lib/blasbuild/
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/classes/org.deeplearning4j.gradientcheck.AttentionLayerTest.html
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/base-style.css
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/style.css
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/js/report.js
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/packages/org.deeplearning4j.gradientcheck.html
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/index.html
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/resources/main/iris.dat
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/resources/test/junit-platform.properties
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/resources/test/logback-test.xml
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/test-results/cudaTest/TEST-org.deeplearning4j.gradientcheck.AttentionLayerTest.xml
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/tmp/jar/MANIFEST.MF
|
||||||
|
|
|
@ -35,7 +35,7 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cpu') {
|
stage('build-linux-cpu') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,13 +21,15 @@
|
||||||
|
|
||||||
pipeline {
|
pipeline {
|
||||||
agent {
|
agent {
|
||||||
dockerfile {
|
/* dockerfile {
|
||||||
filename 'Dockerfile'
|
filename 'Dockerfile'
|
||||||
dir '.docker'
|
dir '.docker'
|
||||||
label 'linux && cuda'
|
label 'linux && cuda'
|
||||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||||
//args '--gpus all' --needed for test only, you can build without GPU
|
//args '--gpus all' --needed for test only, you can build without GPU
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
label 'linux && cuda'
|
||||||
}
|
}
|
||||||
|
|
||||||
stages {
|
stages {
|
||||||
|
@ -43,13 +45,13 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cuda') {
|
stage('build-linux-cuda') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
steps {
|
steps {
|
||||||
withGradle {
|
withGradle {
|
||||||
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \
|
sh 'sh ./gradlew build --stacktrace -PCAVIS_CHIP=cuda \
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cuda') {
|
stage('build-linux-cuda') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cpu') {
|
stage('build-linux-cpu') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ pipeline {
|
||||||
stages {
|
stages {
|
||||||
stage('publish-linux-cpu') {
|
stage('publish-linux-cpu') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cuda') {
|
stage('build-linux-cuda') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,5 +56,20 @@ pipeline {
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
stage('test-linux-cuda') {
|
||||||
|
environment {
|
||||||
|
MAVEN = credentials('Internal_Archiva')
|
||||||
|
OSSRH = credentials('OSSRH')
|
||||||
|
}
|
||||||
|
|
||||||
|
steps {
|
||||||
|
withGradle {
|
||||||
|
sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -Pskip-native=true -PCAVIS_CHIP=cuda \
|
||||||
|
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||||
|
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||||
|
}
|
||||||
|
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cpu') {
|
stage('build-linux-cpu') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,167 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.ai.nd4j.tests;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class ExploreParamsTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testParam() {
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.seed(12345)
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.layer(
|
||||||
|
DenseLayer.builder().nIn(4).nOut(30).name("1. Dense").activation(Activation.TANH))
|
||||||
|
.layer(DenseLayer.builder().nIn(30).nOut(10).name("2. Dense"))
|
||||||
|
// .layer(FrozenLayer.builder(DenseLayer.builder().nOut(6).build()).build())
|
||||||
|
|
||||||
|
.layer(
|
||||||
|
OutputLayer.builder()
|
||||||
|
.nOut(3)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MSE)
|
||||||
|
.activation(Activation.SOFTMAX))
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork nn = new MultiLayerNetwork(conf);
|
||||||
|
nn.init();
|
||||||
|
log.info(nn.summary());
|
||||||
|
// INDArray input = Nd4j.rand(10,4);
|
||||||
|
INDArray labels = Nd4j.zeros(9, 3);
|
||||||
|
|
||||||
|
INDArray input =
|
||||||
|
Nd4j.create(
|
||||||
|
new double[][] {
|
||||||
|
{5.15, 3.5, 1.4, 0.21}, // setosa
|
||||||
|
{4.9, 3.2, 1.4, 0.2}, // setosa
|
||||||
|
{4.7, 3.2, 1.23, 0.2}, // setosa
|
||||||
|
{7, 3.25, 4.7, 1.41}, // versicolor
|
||||||
|
{6.4, 3.2, 4.54, 1.5}, // versicolor
|
||||||
|
{6.9, 3.1, 4.92, 1.5}, // versicolor
|
||||||
|
{7.7, 3, 6.1, 2.3}, // virginica
|
||||||
|
{6.3, 3.4, 5.6, 2.45}, // virginica
|
||||||
|
{6.4, 3.12, 5.5, 1.8} // virginica
|
||||||
|
});
|
||||||
|
|
||||||
|
labels.putScalar(0, 1);
|
||||||
|
labels.putScalar(3, 1);
|
||||||
|
labels.putScalar(6, 1);
|
||||||
|
labels.putScalar(10, 1);
|
||||||
|
labels.putScalar(13, 1);
|
||||||
|
labels.putScalar(16, 1);
|
||||||
|
labels.putScalar(20, 1);
|
||||||
|
labels.putScalar(23, 1);
|
||||||
|
labels.putScalar(26, 1);
|
||||||
|
|
||||||
|
IrisDataSetIterator iter = new IrisDataSetIterator();
|
||||||
|
//Iterable<Pair<INDArray, INDArray>> it = List.of(new Pair<INDArray, INDArray>(input, labels));
|
||||||
|
List l = new ArrayList<>();
|
||||||
|
for (int i=0; i< input.rows(); i++) {
|
||||||
|
l.add(new Pair(input.getRow(i), labels.getRow(i)));
|
||||||
|
}
|
||||||
|
Iterable<Pair<INDArray, INDArray>> it = l;
|
||||||
|
INDArrayDataSetIterator diter = new INDArrayDataSetIterator(it, 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < 100; i++) {
|
||||||
|
// nn.fit(input, labels);
|
||||||
|
// nn.fit( input, labels);
|
||||||
|
nn.fit(diter);
|
||||||
|
// nn.feedForward(input);
|
||||||
|
if(i%20==0) log.info("Score: {}", nn.getScore());
|
||||||
|
}
|
||||||
|
|
||||||
|
Evaluation eval = nn.evaluate(iter, List.of("setosa", "vericolor", "virginica"));
|
||||||
|
|
||||||
|
log.info("\n{}", eval.stats());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testParam2() throws IOException {
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.seed(12345)
|
||||||
|
.layer(
|
||||||
|
DenseLayer.builder().nIn(784).nOut(20).name("1. Dense"))
|
||||||
|
.layer(DenseLayer.builder().nIn(20).nOut(10).name("2. Dense"))
|
||||||
|
.layer(
|
||||||
|
OutputLayer.builder()
|
||||||
|
.nOut(10)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MSE)
|
||||||
|
.activation(Activation.SOFTMAX))
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork nn = new MultiLayerNetwork(conf);
|
||||||
|
nn.init();
|
||||||
|
log.info(nn.summary());
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf2 =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.seed(12345)
|
||||||
|
.layer(
|
||||||
|
DenseLayer.builder().nIn(784).nOut(20).name("1. Dense").dropOut(0.7))
|
||||||
|
.layer(DenseLayer.builder().nIn(20).nOut(10).name("2. Dense"))
|
||||||
|
.layer(
|
||||||
|
OutputLayer.builder()
|
||||||
|
.nOut(10)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MSE)
|
||||||
|
.activation(Activation.SOFTMAX))
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork nn2 = new MultiLayerNetwork(conf2);
|
||||||
|
nn2.init();
|
||||||
|
log.info(nn2.summary());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
MnistDataSetIterator iter = new MnistDataSetIterator(10, 500);
|
||||||
|
MnistDataSetIterator iter2 = new MnistDataSetIterator(10, 50);
|
||||||
|
|
||||||
|
|
||||||
|
for (int i = 0; i < 200; i++) {
|
||||||
|
nn.fit(iter);
|
||||||
|
nn2.fit(iter);
|
||||||
|
if(i%20==0) log.info("Score: {} vs. {}", nn.getScore(), nn2.getScore());
|
||||||
|
}
|
||||||
|
|
||||||
|
Evaluation eval = nn.evaluate(iter2);
|
||||||
|
Evaluation eval2 = nn2.evaluate(iter2);
|
||||||
|
|
||||||
|
log.info("\n{} \n{}", eval.stats(), eval2.stats());
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,411 +1,261 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package net.brutex.gan;
|
package net.brutex.gan;
|
||||||
|
|
||||||
import java.awt.BorderLayout;
|
import static net.brutex.ai.dnn.api.NN.dense;
|
||||||
import java.awt.Dimension;
|
|
||||||
import java.awt.GridLayout;
|
import java.awt.*;
|
||||||
import java.awt.Image;
|
|
||||||
import java.awt.image.BufferedImage;
|
import java.awt.image.BufferedImage;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Random;
|
import javax.swing.*;
|
||||||
import javax.swing.ImageIcon;
|
|
||||||
import javax.swing.JFrame;
|
|
||||||
import javax.swing.JLabel;
|
|
||||||
import javax.swing.JPanel;
|
|
||||||
import javax.swing.WindowConstants;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.datavec.image.loader.NativeImageLoader;
|
|
||||||
import org.datavec.image.recordreader.ImageRecordReader;
|
|
||||||
import org.datavec.image.transform.ColorConversionTransform;
|
|
||||||
import org.datavec.image.transform.ImageTransform;
|
|
||||||
import org.datavec.image.transform.PipelineImageTransform;
|
|
||||||
import org.datavec.image.transform.ResizeImageTransform;
|
|
||||||
import org.datavec.image.transform.ShowImageTransform;
|
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||||
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
|
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
|
||||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class App {
|
public class App {
|
||||||
private static final double LEARNING_RATE = 0.000002;
|
private static final double LEARNING_RATE = 0.002;
|
||||||
private static final double GRADIENT_THRESHOLD = 100.0;
|
private static final double GRADIENT_THRESHOLD = 100.0;
|
||||||
|
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
||||||
|
private static final int BATCHSIZE = 128;
|
||||||
|
private static JFrame frame;
|
||||||
|
private static JPanel panel;
|
||||||
|
|
||||||
private static final int X_DIM = 20 ;
|
private static LayerConfiguration[] genLayers() {
|
||||||
private static final int Y_DIM = 20;
|
return new LayerConfiguration[] {
|
||||||
private static final int CHANNELS = 1;
|
dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
|
||||||
private static final int batchSize = 10;
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
private static final int INPUT = 128;
|
dense().nIn(256).nOut(512).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(512).nOut(1024).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(1024).nOut(784).activation(Activation.TANH).build()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
private static final int OUTPUT_PER_PANEL = 4;
|
/**
|
||||||
|
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
|
||||||
|
*
|
||||||
|
* @return config
|
||||||
|
*/
|
||||||
|
private static NeuralNetConfiguration generator() {
|
||||||
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(genLayers())
|
||||||
|
.name("generator")
|
||||||
|
.build();
|
||||||
|
|
||||||
private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
|
return conf;
|
||||||
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
}
|
||||||
|
|
||||||
private static JFrame frame;
|
private static LayerConfiguration[] disLayers() {
|
||||||
private static JFrame frame2;
|
return new LayerConfiguration[]{
|
||||||
private static JPanel panel;
|
dense().nIn(784).nOut(1024).build(),
|
||||||
private static JPanel panel2;
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
dense().nIn(1024).nOut(512).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
dense().nIn(512).nOut(256).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
private static LayerConfiguration[] genLayers() {
|
private static NeuralNetConfiguration discriminator() {
|
||||||
return new LayerConfiguration[] {
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
|
.seed(42)
|
||||||
ActivationLayer.builder(Activation.LEAKYRELU).build(),
|
.updater(UPDATER)
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(disLayers())
|
||||||
|
.name("discriminator")
|
||||||
|
.build();
|
||||||
|
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
return conf;
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
}
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
|
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
|
||||||
|
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build()
|
private static NeuralNetConfiguration gan() {
|
||||||
};
|
LayerConfiguration[] genLayers = genLayers();
|
||||||
}
|
LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream()
|
||||||
|
.map((layer) -> {
|
||||||
|
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
||||||
|
return FrozenLayerWithBackprop.builder(layer).build();
|
||||||
|
} else {
|
||||||
|
return layer;
|
||||||
|
}
|
||||||
|
}).toArray(LayerConfiguration[]::new);
|
||||||
|
LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
|
||||||
|
|
||||||
/**
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
* Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image.
|
.seed(42)
|
||||||
*
|
.updater(UPDATER)
|
||||||
* @return config
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
*/
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
private static NeuralNetConfiguration generator() {
|
.weightInit(WeightInit.XAVIER)
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
.activation(Activation.IDENTITY)
|
||||||
.seed(42)
|
.layersFromArray(layers)
|
||||||
.updater(UPDATER)
|
.name("GAN")
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.build();
|
||||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
|
||||||
//.weightInit(WeightInit.XAVIER)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
.activation(Activation.IDENTITY)
|
|
||||||
.layersFromArray(genLayers())
|
|
||||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
|
||||||
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
|
|
||||||
.build();
|
|
||||||
((NeuralNetConfiguration) conf).init();
|
|
||||||
|
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static LayerConfiguration[] disLayers() {
|
@Test
|
||||||
return new LayerConfiguration[]{
|
public void runTest() throws Exception {
|
||||||
DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network
|
App.main(null);
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
}
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
public static void main(String... args) throws Exception {
|
||||||
DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
|
||||||
DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
|
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
|
||||||
DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
|
||||||
|
|
||||||
OutputLayer.builder().name("dis-output").lossFunction(LossFunction.XENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
|
MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
private static NeuralNetConfiguration discriminator() {
|
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
|
||||||
|
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
|
||||||
|
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
|
||||||
|
gen.init();
|
||||||
|
dis.init();
|
||||||
|
gan.init();
|
||||||
|
|
||||||
NeuralNetConfiguration conf =
|
copyParams(gen, dis, gan);
|
||||||
NeuralNetConfiguration.builder()
|
|
||||||
.seed(42)
|
|
||||||
.updater(UPDATER)
|
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
|
||||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
|
||||||
.weightNoise(null)
|
|
||||||
// .weightInitFn(new WeightInitXavier())
|
|
||||||
// .activationFn(new ActivationIdentity())
|
|
||||||
.activation(Activation.IDENTITY)
|
|
||||||
.layersFromArray(disLayers())
|
|
||||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
|
||||||
.build();
|
|
||||||
((NeuralNetConfiguration) conf).init();
|
|
||||||
|
|
||||||
return conf;
|
gen.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
}
|
dis.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
|
gan.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
|
|
||||||
private static NeuralNetConfiguration gan() {
|
trainData.reset();
|
||||||
LayerConfiguration[] genLayers = genLayers();
|
|
||||||
LayerConfiguration[] disLayers = Arrays.stream(disLayers())
|
|
||||||
.map((layer) -> {
|
|
||||||
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
|
||||||
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
|
|
||||||
} else {
|
|
||||||
return layer;
|
|
||||||
}
|
|
||||||
}).toArray(LayerConfiguration[]::new);
|
|
||||||
LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers);
|
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
int j = 0;
|
||||||
.seed(42)
|
for (int i = 0; i < 50; i++) {
|
||||||
.updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() )
|
while (trainData.hasNext()) {
|
||||||
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
|
j++;
|
||||||
.gradientNormalizationThreshold( 100 )
|
|
||||||
//.weightInitFn( new WeightInitXavier() ) //this is internal
|
// generate data
|
||||||
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
INDArray real = trainData.next().getFeatures().muli(2).subi(1);
|
||||||
.weightInit( WeightInit.XAVIER)
|
int batchSize = (int) real.shape()[0];
|
||||||
//.activationFn( new ActivationIdentity()) //this is internal
|
|
||||||
.activation( Activation.IDENTITY )
|
INDArray fakeIn = Nd4j.rand(batchSize, 100);
|
||||||
.layersFromArray( layers )
|
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||||
.inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
|
||||||
.build();
|
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
|
||||||
((NeuralNetConfiguration) conf).init();
|
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
|
||||||
return conf;
|
|
||||||
}
|
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||||
|
|
||||||
|
dis.fit(data);
|
||||||
|
dis.fit(data);
|
||||||
|
|
||||||
|
// Update the discriminator in the GAN network
|
||||||
|
updateGan(gen, dis, gan);
|
||||||
|
|
||||||
|
gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
|
||||||
|
|
||||||
|
|
||||||
@Test
|
if (j % 10 == 1) {
|
||||||
public void runTest() throws Exception {
|
System.out.println("Epoch " + i +" Iteration " + j + " Visualizing...");
|
||||||
main();
|
INDArray[] samples = new INDArray[9];
|
||||||
}
|
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
|
||||||
|
|
||||||
public static void main(String... args) throws Exception {
|
for (int k = 0; k < 9; k++) {
|
||||||
|
INDArray input = fakeSet2.get(k).getFeatures();
|
||||||
|
//samples[k] = gen.output(input, false);
|
||||||
|
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
||||||
|
|
||||||
log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m ");
|
}
|
||||||
Nd4j.getMemoryManager().setAutoGcWindow(500);
|
visualize(samples);
|
||||||
|
}
|
||||||
// MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
|
}
|
||||||
// FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
|
trainData.reset();
|
||||||
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
|
// Copy the GANs generator to gen.
|
||||||
|
//updateGen(gen, gan);
|
||||||
|
|
||||||
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
|
||||||
|
|
||||||
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
|
||||||
ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
|
|
||||||
|
|
||||||
ImageTransform tr = new PipelineImageTransform.Builder()
|
|
||||||
.addImageTransform(transform) //convert to GREY SCALE
|
|
||||||
.addImageTransform(transform3)
|
|
||||||
//.addImageTransform(transform2)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
|
|
||||||
imageRecordReader.initialize(fileSplit, tr);
|
|
||||||
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
|
|
||||||
|
|
||||||
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
|
|
||||||
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
|
|
||||||
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
|
|
||||||
gen.init(); log.debug("Generator network: {}", gen);
|
|
||||||
dis.init(); log.debug("Discriminator network: {}", dis);
|
|
||||||
gan.init(); log.debug("Complete GAN network: {}", gan);
|
|
||||||
|
|
||||||
|
|
||||||
copyParams(gen, dis, gan);
|
|
||||||
|
|
||||||
gen.addTrainingListeners(new PerformanceListener(15, true));
|
|
||||||
//dis.addTrainingListeners(new PerformanceListener(10, true));
|
|
||||||
//gan.addTrainingListeners(new PerformanceListener(10, true));
|
|
||||||
//gan.addTrainingListeners(new ScoreToChartListener("gan"));
|
|
||||||
//dis.setListeners(new ScoreToChartListener("dis"));
|
|
||||||
|
|
||||||
System.out.println(gan.toString());
|
|
||||||
gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
|
|
||||||
|
|
||||||
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
|
|
||||||
//trainData.reset();
|
|
||||||
|
|
||||||
int j = 0;
|
|
||||||
for (int i = 0; i < 201; i++) { //epoch
|
|
||||||
while (trainData.hasNext()) {
|
|
||||||
j++;
|
|
||||||
|
|
||||||
DataSet next = trainData.next();
|
|
||||||
// generate data
|
|
||||||
INDArray real = next.getFeatures();//.div(255f);
|
|
||||||
|
|
||||||
//start next round if there are not enough images left to have a full batchsize dataset
|
|
||||||
if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
|
|
||||||
log.warn("Your total number of input images is not a multiple of {}, "
|
|
||||||
+ "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(i%20 == 0) {
|
// Copy the GANs generator to gen.
|
||||||
// frame2 = visualize(new INDArray[]{real}, batchSize,
|
updateGen(gen, gan);
|
||||||
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
|
|
||||||
|
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||||
|
int genLayerCount = gen.getLayers().length;
|
||||||
|
for (int i = 0; i < gan.getLayers().length; i++) {
|
||||||
|
if (i < genLayerCount) {
|
||||||
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
|
} else {
|
||||||
|
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
real.divi(255f);
|
}
|
||||||
|
|
||||||
// int batchSize = (int) real.shape()[0];
|
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||||
|
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||||
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
|
||||||
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
|
|
||||||
|
|
||||||
//log.info("real has {} items.", real.length());
|
|
||||||
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
|
|
||||||
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
|
|
||||||
|
|
||||||
|
|
||||||
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
|
||||||
|
|
||||||
dis.fit(data);
|
|
||||||
dis.fit(data);
|
|
||||||
|
|
||||||
// Update the discriminator in the GAN network
|
|
||||||
updateGan(gen, dis, gan);
|
|
||||||
|
|
||||||
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
|
|
||||||
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)));
|
|
||||||
|
|
||||||
|
|
||||||
if (j % 10 == 1) {
|
|
||||||
System.out.println("Iteration " + j + " Visualizing...");
|
|
||||||
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
|
|
||||||
|
|
||||||
|
|
||||||
for (int k = 0; k < samples.length; k++) {
|
|
||||||
//INDArray input = fakeSet2.get(k).getFeatures();
|
|
||||||
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
|
|
||||||
INDArray input = fakeSet2.get(k).getFeatures();
|
|
||||||
input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
|
|
||||||
|
|
||||||
//samples[k] = gen.output(input, false);
|
|
||||||
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
|
||||||
samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
|
|
||||||
//samples[k] =
|
|
||||||
samples[k].addi(1f).divi(2f).muli(255f);
|
|
||||||
|
|
||||||
}
|
|
||||||
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
trainData.reset();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the GANs generator to gen.
|
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||||
updateGen(gen, gan);
|
int genLayerCount = gen.getLayers().length;
|
||||||
|
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||||
gen.save(new File("mnist-mlp-generator.dlj"));
|
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
|
||||||
}
|
|
||||||
|
|
||||||
private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
|
||||||
int genLayerCount = gen.getLayers().length;
|
|
||||||
for (int i = 0; i < gan.getLayers().length; i++) {
|
|
||||||
if (i < genLayerCount) {
|
|
||||||
if(gan.getLayer(i).getParams() != null)
|
|
||||||
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
|
||||||
} else {
|
|
||||||
if(gan.getLayer(i).getParams() != null)
|
|
||||||
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
|
||||||
for (int i = 0; i < gen.getLayers().length; i++) {
|
|
||||||
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
|
||||||
int genLayerCount = gen.getLayers().length;
|
|
||||||
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
|
||||||
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
|
|
||||||
if (isOrig) {
|
|
||||||
frame.setTitle("Viz Original");
|
|
||||||
} else {
|
|
||||||
frame.setTitle("Generated");
|
|
||||||
}
|
|
||||||
|
|
||||||
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
|
||||||
frame.setLayout(new BorderLayout());
|
|
||||||
|
|
||||||
JPanel panelx = new JPanel();
|
|
||||||
|
|
||||||
panelx.setLayout(new GridLayout(4, 4, 8, 8));
|
|
||||||
for (INDArray sample : samples) {
|
|
||||||
for(int i = 0; i<batchElements; i++) {
|
|
||||||
panelx.add(getImage(sample, i, isOrig));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
frame.add(panelx, BorderLayout.CENTER);
|
|
||||||
frame.setVisible(true);
|
|
||||||
|
|
||||||
frame.revalidate();
|
|
||||||
frame.setMinimumSize(new Dimension(300, 20));
|
|
||||||
frame.pack();
|
|
||||||
return frame;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
|
||||||
final BufferedImage bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY);
|
|
||||||
final int imageSize = X_DIM * Y_DIM;
|
|
||||||
final int offset = batchElement * imageSize;
|
|
||||||
int pxl = offset * CHANNELS; //where to start in the INDArray
|
|
||||||
|
|
||||||
//Image in NCHW - channels first format
|
|
||||||
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
|
||||||
for (int y = 0; y < Y_DIM; y++) { // step through the columns x
|
|
||||||
for (int x = 0; x < X_DIM; x++) { //step through the rows y
|
|
||||||
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
|
|
||||||
bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
|
|
||||||
pxl++; //next item in INDArray
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ImageIcon orig = new ImageIcon(bi);
|
private static void visualize(INDArray[] samples) {
|
||||||
|
if (frame == null) {
|
||||||
|
frame = new JFrame();
|
||||||
|
frame.setTitle("Viz");
|
||||||
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
|
frame.setLayout(new BorderLayout());
|
||||||
|
|
||||||
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
|
panel = new JPanel();
|
||||||
|
|
||||||
ImageIcon scaled = new ImageIcon(imageScaled);
|
panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
|
||||||
|
frame.add(panel, BorderLayout.CENTER);
|
||||||
|
frame.setVisible(true);
|
||||||
|
}
|
||||||
|
|
||||||
return new JLabel(scaled);
|
panel.removeAll();
|
||||||
|
|
||||||
}
|
for (INDArray sample : samples) {
|
||||||
|
panel.add(getImage(sample));
|
||||||
|
}
|
||||||
|
|
||||||
|
frame.revalidate();
|
||||||
|
frame.pack();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JLabel getImage(INDArray tensor) {
|
||||||
|
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
|
||||||
|
for (int i = 0; i < 784; i++) {
|
||||||
|
int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
|
||||||
|
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
|
||||||
|
}
|
||||||
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
|
||||||
|
|
||||||
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
|
||||||
|
return new JLabel(scaled);
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,343 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import java.awt.*;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.List;
|
||||||
|
import javax.imageio.ImageIO;
|
||||||
|
import javax.swing.*;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.datavec.api.split.FileSplit;
|
||||||
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
|
import org.datavec.image.transform.*;
|
||||||
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class App2 {
|
||||||
|
|
||||||
|
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
|
||||||
|
static final float COLORSPACE = 255f;
|
||||||
|
static final int DIMENSIONS = 28;
|
||||||
|
static final int CHANNELS = 1;
|
||||||
|
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
|
||||||
|
final int OUTPUT_PER_PANEL = 10;
|
||||||
|
|
||||||
|
final boolean BIAS = true;
|
||||||
|
|
||||||
|
static final int BATCHSIZE=128;
|
||||||
|
|
||||||
|
private JFrame frame2, frame;
|
||||||
|
static final String OUTPUT_DIR = "d:/out/";
|
||||||
|
|
||||||
|
final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
|
||||||
|
final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void runTest() throws IOException {
|
||||||
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
|
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
|
||||||
|
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
|
||||||
|
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
||||||
|
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
||||||
|
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
|
||||||
|
|
||||||
|
ImageTransform tr = new PipelineImageTransform.Builder()
|
||||||
|
.addImageTransform(transform) //convert to GREY SCALE
|
||||||
|
.addImageTransform(transform3)
|
||||||
|
//.addImageTransform(transform2)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS);
|
||||||
|
imageRecordReader.initialize(fileSplit, tr);
|
||||||
|
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE );
|
||||||
|
trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
|
||||||
|
|
||||||
|
MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator());
|
||||||
|
MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator());
|
||||||
|
|
||||||
|
LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream()
|
||||||
|
.map((layer) -> {
|
||||||
|
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
||||||
|
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
|
||||||
|
} else {
|
||||||
|
return layer;
|
||||||
|
}
|
||||||
|
}).toArray(LayerConfiguration[]::new);
|
||||||
|
|
||||||
|
NeuralNetConfiguration netConfiguration =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.name("GAN")
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.updater(App2Config.UPDATER)
|
||||||
|
.innerConfigurations(new ArrayList<>(List.of(App2Config.generator())))
|
||||||
|
.layersFromList(new ArrayList<>(Arrays.asList(disLayers)))
|
||||||
|
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
// .inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||||
|
// .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.dataType(DataType.FLOAT)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration );
|
||||||
|
|
||||||
|
dis.init(); log.debug("Discriminator network: {}", dis);
|
||||||
|
gen.init(); log.debug("Generator network: {}", gen);
|
||||||
|
gan.init(); log.debug("GAN network: {}", gan);
|
||||||
|
|
||||||
|
|
||||||
|
log.info("Generator Summary:\n{}", gen.summary());
|
||||||
|
log.info("GAN Summary:\n{}", gan.summary());
|
||||||
|
dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
|
||||||
|
gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
|
||||||
|
gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
|
||||||
|
|
||||||
|
int j = 0;
|
||||||
|
for (int i = 0; i < 51; i++) { //epoch
|
||||||
|
while (trainData.hasNext()) {
|
||||||
|
j++;
|
||||||
|
DataSet next = trainData.next();
|
||||||
|
// generate data
|
||||||
|
INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f);
|
||||||
|
|
||||||
|
//start next round if there are not enough images left to have a full batchsize dataset
|
||||||
|
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
|
||||||
|
log.warn("Your total number of input images is not a multiple of {}, "
|
||||||
|
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
//if(i%20 == 0) {
|
||||||
|
|
||||||
|
// frame2 = visualize(new INDArray[]{real}, BATCHSIZE,
|
||||||
|
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
|
||||||
|
//}
|
||||||
|
//real.divi(255f);
|
||||||
|
|
||||||
|
// int batchSize = (int) real.shape()[0];
|
||||||
|
|
||||||
|
//INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS);
|
||||||
|
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
|
||||||
|
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
|
||||||
|
|
||||||
|
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||||
|
// when generator has TANH as activation - value range is -1 to 1
|
||||||
|
// when generator has SIGMOID, then range is 0 to 1
|
||||||
|
fake.addi(1f).divi(2f);
|
||||||
|
|
||||||
|
DataSet realSet = new DataSet(real, label_real);
|
||||||
|
DataSet fakeSet = new DataSet(fake, label_fake);
|
||||||
|
|
||||||
|
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||||
|
|
||||||
|
dis.fit(data);
|
||||||
|
dis.fit(data);
|
||||||
|
// Update the discriminator in the GAN network
|
||||||
|
updateGan(gen, dis, gan);
|
||||||
|
|
||||||
|
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake));
|
||||||
|
|
||||||
|
//Visualize and reporting
|
||||||
|
if (j % 10 == 1) {
|
||||||
|
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
|
||||||
|
INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE];
|
||||||
|
|
||||||
|
|
||||||
|
for (int k = 0; k < samples.length; k++) {
|
||||||
|
DataSet fakeSet2 = new DataSet(fakeIn, label_fake);
|
||||||
|
INDArray input = fakeSet2.get(k).getFeatures();
|
||||||
|
|
||||||
|
//input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images
|
||||||
|
input = input.reshape(1, App2Config.INPUT);
|
||||||
|
|
||||||
|
//samples[k] = gen.output(input, false);
|
||||||
|
samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
||||||
|
samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS);
|
||||||
|
//samples[k] =
|
||||||
|
//samples[k].muli(255f);
|
||||||
|
|
||||||
|
}
|
||||||
|
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (trainData.resetSupported()) {
|
||||||
|
trainData.reset();
|
||||||
|
} else {
|
||||||
|
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
||||||
|
}
|
||||||
|
// Copy the GANs generator to gen.
|
||||||
|
updateGen(gen, gan);
|
||||||
|
log.info("Updated GAN's generator from gen.");
|
||||||
|
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
|
||||||
|
if (isOrig) {
|
||||||
|
frame.setTitle("Viz Original");
|
||||||
|
} else {
|
||||||
|
frame.setTitle("Generated");
|
||||||
|
}
|
||||||
|
|
||||||
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
|
frame.setLayout(new BorderLayout());
|
||||||
|
|
||||||
|
JPanel panelx = new JPanel();
|
||||||
|
|
||||||
|
panelx.setLayout(new GridLayout(4, 4, 8, 8));
|
||||||
|
for (INDArray sample : samples) {
|
||||||
|
for(int i = 0; i<batchElements; i++) {
|
||||||
|
panelx.add(getImage(sample, i, isOrig));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
frame.add(panelx, BorderLayout.CENTER);
|
||||||
|
frame.setVisible(true);
|
||||||
|
|
||||||
|
frame.revalidate();
|
||||||
|
frame.setMinimumSize(new Dimension(300, 20));
|
||||||
|
frame.pack();
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
||||||
|
final BufferedImage bi;
|
||||||
|
if(CHANNELS >1) {
|
||||||
|
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
|
||||||
|
} else {
|
||||||
|
bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
|
||||||
|
}
|
||||||
|
final int imageSize = DIMENSIONS * DIMENSIONS;
|
||||||
|
final int offset = batchElement * imageSize;
|
||||||
|
int pxl = offset * CHANNELS; //where to start in the INDArray
|
||||||
|
|
||||||
|
//Image in NCHW - channels first format
|
||||||
|
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
||||||
|
for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
|
||||||
|
for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
|
||||||
|
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
|
||||||
|
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
|
||||||
|
bi.getRaster().setSample(x, y, c, f_pxl);
|
||||||
|
pxl++; //next item in INDArray
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
|
||||||
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
|
||||||
|
return new JLabel(scaled);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void saveImage(Image image, int batchElement, boolean isOrig) {
|
||||||
|
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Save the images to disk
|
||||||
|
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
|
||||||
|
|
||||||
|
log.debug("Images saved successfully.");
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("Error saving the images: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
|
||||||
|
File directory = new File(outputDirectory);
|
||||||
|
if (!directory.exists()) {
|
||||||
|
directory.mkdir();
|
||||||
|
}
|
||||||
|
|
||||||
|
File outputFile = new File(directory, fileName);
|
||||||
|
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static BufferedImage imageToBufferedImage(Image image) {
|
||||||
|
if (image instanceof BufferedImage) {
|
||||||
|
return (BufferedImage) image;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a buffered image with the same dimensions and transparency as the original image
|
||||||
|
BufferedImage bufferedImage;
|
||||||
|
if (CHANNELS > 1) {
|
||||||
|
bufferedImage =
|
||||||
|
new BufferedImage(
|
||||||
|
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
|
||||||
|
} else {
|
||||||
|
bufferedImage =
|
||||||
|
new BufferedImage(
|
||||||
|
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw the original image onto the buffered image
|
||||||
|
Graphics2D g2d = bufferedImage.createGraphics();
|
||||||
|
g2d.drawImage(image, 0, 0, null);
|
||||||
|
g2d.dispose();
|
||||||
|
|
||||||
|
return bufferedImage;
|
||||||
|
}
|
||||||
|
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||||
|
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||||
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||||
|
int genLayerCount = gen.getLayers().length;
|
||||||
|
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||||
|
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,176 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import static net.brutex.ai.dnn.api.NN.*;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
public class App2Config {
|
||||||
|
|
||||||
|
public static final int INPUT = 100;
|
||||||
|
public static final int X_DIM = 28;
|
||||||
|
public static final int y_DIM = 28;
|
||||||
|
public static final int CHANNELS = 1;
|
||||||
|
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
|
||||||
|
|
||||||
|
static LayerConfiguration[] genLayerConfig() {
|
||||||
|
return new LayerConfiguration[] {
|
||||||
|
/*
|
||||||
|
DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(), /*
|
||||||
|
Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS)
|
||||||
|
.kernelSize(2,2)
|
||||||
|
.stride(1,1)
|
||||||
|
.padding(0,0)
|
||||||
|
.convolutionMode(ConvolutionMode.Truncate)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.hasBias(BIAS).build(),
|
||||||
|
//BatchNormalization.builder().nOut(CHANNELS).build(),
|
||||||
|
Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS)
|
||||||
|
.kernelSize(2,2)
|
||||||
|
.stride(2,2)
|
||||||
|
.padding(0,0)
|
||||||
|
.convolutionMode(ConvolutionMode.Truncate)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.hasBias(BIAS).build(),
|
||||||
|
//BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(),
|
||||||
|
|
||||||
|
|
||||||
|
DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||||
|
DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||||
|
DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||||
|
// DropoutLayer.builder(0.001).build(),
|
||||||
|
DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build() */
|
||||||
|
|
||||||
|
dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(256).nOut(512).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(512).nOut(1024).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(1024).nOut(784).activation(Activation.TANH).build(),
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static LayerConfiguration[] disLayerConfig() {
|
||||||
|
return new LayerConfiguration[] {/*
|
||||||
|
Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS)
|
||||||
|
.build(),
|
||||||
|
Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS)
|
||||||
|
.build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.LEAKYRELU).build(),
|
||||||
|
BatchNormalization.builder().build(),
|
||||||
|
OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SIGMOID)
|
||||||
|
.build()
|
||||||
|
|
||||||
|
|
||||||
|
dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||||
|
DropoutLayer.builder(0.5).build(),
|
||||||
|
|
||||||
|
DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||||
|
DropoutLayer.builder(0.5).build(),
|
||||||
|
|
||||||
|
DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||||
|
DropoutLayer.builder(0.5).build(),
|
||||||
|
|
||||||
|
OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT)
|
||||||
|
.activation(Activation.SIGMOID)
|
||||||
|
.build() */
|
||||||
|
dense().nIn(784).nOut(1024).hasBias(true).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
dense().nIn(1024).nOut(512).hasBias(true).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
dense().nIn(512).nOut(256).hasBias(true).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static NeuralNetConfiguration generator() {
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.name("generator")
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||||
|
.weightNoise(null)
|
||||||
|
// .weightInitFn(new WeightInitXavier())
|
||||||
|
// .activationFn(new ActivationIdentity())
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(App2Config.genLayerConfig())
|
||||||
|
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
|
||||||
|
|
||||||
|
.build();
|
||||||
|
conf.init();
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
static NeuralNetConfiguration discriminator() {
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.name("discriminator")
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||||
|
.weightNoise(null)
|
||||||
|
// .weightInitFn(new WeightInitXavier())
|
||||||
|
// .activationFn(new ActivationIdentity())
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(disLayerConfig())
|
||||||
|
//.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.dataType(DataType.FLOAT)
|
||||||
|
.build();
|
||||||
|
conf.init();
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
}
|
|
@ -24,12 +24,14 @@ package net.brutex.gan;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -98,7 +100,10 @@ public class MnistSimpleGAN {
|
||||||
|
|
||||||
return new MultiLayerNetwork(discConf);
|
return new MultiLayerNetwork(discConf);
|
||||||
}
|
}
|
||||||
|
@Test
|
||||||
|
public void runTest() throws Exception {
|
||||||
|
main(null);
|
||||||
|
}
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
GAN gan = new GAN.Builder()
|
GAN gan = new GAN.Builder()
|
||||||
.generator(MnistSimpleGAN::getGenerator)
|
.generator(MnistSimpleGAN::getGenerator)
|
||||||
|
@ -108,6 +113,7 @@ public class MnistSimpleGAN {
|
||||||
.updater(UPDATER)
|
.updater(UPDATER)
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.gradientNormalizationThreshold(100)
|
.gradientNormalizationThreshold(100)
|
||||||
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
|
@ -88,7 +88,7 @@ public class CNN1DTestCases {
|
||||||
.convolutionMode(ConvolutionMode.Same))
|
.convolutionMode(ConvolutionMode.Same))
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("0", Convolution1DLayer.builder().nOut(32).activation(Activation.TANH).kernelSize(3).stride(1).build(), "in")
|
.layer("0", Convolution1D.builder().nOut(32).activation(Activation.TANH).kernelSize(3).stride(1).build(), "in")
|
||||||
.layer("1", Subsampling1DLayer.builder().kernelSize(2).stride(1).poolingType(SubsamplingLayer.PoolingType.MAX.toPoolingType()).build(), "0")
|
.layer("1", Subsampling1DLayer.builder().kernelSize(2).stride(1).poolingType(SubsamplingLayer.PoolingType.MAX.toPoolingType()).build(), "0")
|
||||||
.layer("2", Cropping1D.builder(1).build(), "1")
|
.layer("2", Cropping1D.builder(1).build(), "1")
|
||||||
.layer("3", ZeroPadding1DLayer.builder(1).build(), "2")
|
.layer("3", ZeroPadding1DLayer.builder(1).build(), "2")
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
# Default logging detail level for all instances of SimpleLogger.
|
# Default logging detail level for all instances of SimpleLogger.
|
||||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
||||||
# If not specified, defaults to "info".
|
# If not specified, defaults to "info".
|
||||||
org.slf4j.simpleLogger.defaultLogLevel=trace
|
org.slf4j.simpleLogger.defaultLogLevel=debug
|
||||||
|
|
||||||
# Logging detail level for a SimpleLogger instance named "xxxxx".
|
# Logging detail level for a SimpleLogger instance named "xxxxx".
|
||||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
||||||
|
@ -42,8 +42,8 @@ org.slf4j.simpleLogger.defaultLogLevel=trace
|
||||||
# If the format is not specified or is invalid, the default format is used.
|
# If the format is not specified or is invalid, the default format is used.
|
||||||
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
|
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
|
||||||
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
|
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
|
||||||
org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
|
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
|
||||||
|
|
||||||
# Set to true if you want to output the current thread name.
|
# Set to true if you want to output the current thread name.
|
||||||
# Defaults to true.
|
# Defaults to true.
|
||||||
org.slf4j.simpleLogger.showThreadName=true
|
#org.slf4j.simpleLogger.showThreadName=true
|
|
@ -71,7 +71,7 @@ dependencies {
|
||||||
// api "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}"
|
// api "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}"
|
||||||
|
|
||||||
|
|
||||||
api "org.projectlombok:lombok:1.18.26"
|
api "org.projectlombok:lombok:1.18.28"
|
||||||
|
|
||||||
/*Logging*/
|
/*Logging*/
|
||||||
api 'org.slf4j:slf4j-api:2.0.3'
|
api 'org.slf4j:slf4j-api:2.0.3'
|
||||||
|
|
|
@ -2385,11 +2385,15 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
long[] stride();
|
long[] stride();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the ordering (fortran or c 'f' and 'c' respectively) of this ndarray
|
* Return the ordering (fortran or c 'f' and 'c' respectively) of this ndarray <br/><br/>
|
||||||
* @return the ordering of this ndarray
|
* C Is Contiguous layout. Mathematically speaking, row major.<br/>
|
||||||
*/
|
* F Is Fortran contiguous layout. Mathematically speaking, column major.<br/>
|
||||||
char ordering();
|
* {@see https://en.wikipedia.org/wiki/Row-_and_column-major_order}<br/>
|
||||||
|
*
|
||||||
|
* @return the ordering of this ndarray
|
||||||
|
*/
|
||||||
|
char ordering();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the size along a specified dimension
|
* Returns the size along a specified dimension
|
||||||
|
|
|
@ -334,6 +334,7 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet {
|
||||||
public void save(File to) {
|
public void save(File to) {
|
||||||
try (FileOutputStream fos = new FileOutputStream(to, false);
|
try (FileOutputStream fos = new FileOutputStream(to, false);
|
||||||
BufferedOutputStream bos = new BufferedOutputStream(fos)) {
|
BufferedOutputStream bos = new BufferedOutputStream(fos)) {
|
||||||
|
to.mkdirs();
|
||||||
save(bos);
|
save(bos);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
|
|
@ -5121,7 +5121,7 @@ public class Nd4j {
|
||||||
Nd4j.backend = backend;
|
Nd4j.backend = backend;
|
||||||
updateNd4jContext();
|
updateNd4jContext();
|
||||||
props = Nd4jContext.getInstance().getConf();
|
props = Nd4jContext.getInstance().getConf();
|
||||||
logger.info("Properties for Nd4jContext " + props);
|
log.debug("Properties for Nd4jContext {}", props);
|
||||||
PropertyParser pp = new PropertyParser(props);
|
PropertyParser pp = new PropertyParser(props);
|
||||||
|
|
||||||
String otherDtype = pp.toString(ND4JSystemProperties.DTYPE);
|
String otherDtype = pp.toString(ND4JSystemProperties.DTYPE);
|
||||||
|
|
|
@ -166,10 +166,10 @@ public class DataSetIteratorTest extends BaseDL4JTest {
|
||||||
int seed = 123;
|
int seed = 123;
|
||||||
int listenerFreq = 1;
|
int listenerFreq = 1;
|
||||||
|
|
||||||
LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples,
|
final LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples,
|
||||||
new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed));
|
new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed));
|
||||||
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed)
|
final var builder = NeuralNetConfiguration.builder().seed(seed)
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.layer(0, ConvolutionLayer.builder(5, 5).nIn(numChannels).nOut(6)
|
.layer(0, ConvolutionLayer.builder(5, 5).nIn(numChannels).nOut(6)
|
||||||
|
@ -181,7 +181,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.inputType(InputType.convolutionalFlat(numRows, numColumns, numChannels));
|
.inputType(InputType.convolutionalFlat(numRows, numColumns, numChannels));
|
||||||
|
|
||||||
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
|
final MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
|
||||||
model.init();
|
model.init();
|
||||||
|
|
||||||
model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
|
model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
|
||||||
|
|
|
@ -45,6 +45,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
|
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
||||||
|
@ -924,8 +925,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
};
|
};
|
||||||
|
|
||||||
for(EpochTerminationCondition e : etc ){
|
for(EpochTerminationCondition e : etc ){
|
||||||
String s = NeuralNetConfiguration.mapper().writeValueAsString(e);
|
String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(e);
|
||||||
EpochTerminationCondition c = NeuralNetConfiguration.mapper().readValue(s, EpochTerminationCondition.class);
|
EpochTerminationCondition c = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, EpochTerminationCondition.class);
|
||||||
assertEquals(e, c);
|
assertEquals(e, c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -936,8 +937,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
};
|
};
|
||||||
|
|
||||||
for(IterationTerminationCondition i : itc ){
|
for(IterationTerminationCondition i : itc ){
|
||||||
String s = NeuralNetConfiguration.mapper().writeValueAsString(i);
|
String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(i);
|
||||||
IterationTerminationCondition c = NeuralNetConfiguration.mapper().readValue(s, IterationTerminationCondition.class);
|
IterationTerminationCondition c = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, IterationTerminationCondition.class);
|
||||||
assertEquals(i, c);
|
assertEquals(i, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -309,7 +309,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict)
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict)
|
||||||
.list()
|
|
||||||
.layer(0, ConvolutionLayer.builder().kernelSize(2, 3).stride(2, 2).padding(0, 0).nOut(5)
|
.layer(0, ConvolutionLayer.builder().kernelSize(2, 3).stride(2, 2).padding(0, 0).nOut(5)
|
||||||
.build())
|
.build())
|
||||||
.layer(1, OutputLayer.builder().nOut(10).build())
|
.layer(1, OutputLayer.builder().nOut(10).build())
|
||||||
|
|
|
@ -77,7 +77,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration.builder().updater(new NoOp())
|
NeuralNetConfiguration.builder().updater(new NoOp())
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.seed(12345L)
|
.seed(12345L)
|
||||||
.dist(new NormalDistribution(0, 1)).list()
|
.weightInit(new NormalDistribution(0, 1))
|
||||||
.layer(0, DenseLayer.builder().nIn(4).nOut(3)
|
.layer(0, DenseLayer.builder().nIn(4).nOut(3)
|
||||||
.activation(Activation.IDENTITY).build())
|
.activation(Activation.IDENTITY).build())
|
||||||
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).nOut(3).build())
|
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).nOut(3).build())
|
||||||
|
@ -122,7 +122,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.updater(new NoOp()).seed(12345L)
|
.updater(new NoOp()).seed(12345L)
|
||||||
.dist(new NormalDistribution(0, 2)).list()
|
.dist(new NormalDistribution(0, 2)).list()
|
||||||
.layer(0, ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
|
.layer(0, Convolution2D.builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
|
||||||
.activation(Activation.IDENTITY).build())
|
.activation(Activation.IDENTITY).build())
|
||||||
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).build())
|
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).build())
|
||||||
.layer(2, ActivationLayer.builder().activation(Activation.TANH).build())
|
.layer(2, ActivationLayer.builder().activation(Activation.TANH).build())
|
||||||
|
|
|
@ -91,9 +91,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.list()
|
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -202,7 +201,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -211,7 +210,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.layer(Cropping1D.builder(cropping).build())
|
.layer(Cropping1D.builder(cropping).build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -317,7 +316,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -326,7 +325,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.layer(ZeroPadding1DLayer.builder(zeroPadding).build())
|
.layer(ZeroPadding1DLayer.builder(zeroPadding).build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -435,10 +434,9 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.list()
|
|
||||||
.layer(
|
.layer(
|
||||||
0,
|
0,
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -447,7 +445,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
1,
|
1,
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -461,6 +459,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
.padding(padding)
|
.padding(padding)
|
||||||
.pnorm(pnorm)
|
.pnorm(pnorm)
|
||||||
|
.name("SubsamplingLayer")
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
3,
|
3,
|
||||||
|
@ -548,7 +547,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(2)
|
.kernelSize(2)
|
||||||
.rnnDataFormat(RNNFormat.NCW)
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -562,7 +561,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.pnorm(pnorm)
|
.pnorm(pnorm)
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(2)
|
.kernelSize(2)
|
||||||
.rnnDataFormat(RNNFormat.NCW)
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -655,7 +654,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(k)
|
.kernelSize(k)
|
||||||
.dilation(d)
|
.dilation(d)
|
||||||
.hasBias(hasBias)
|
.hasBias(hasBias)
|
||||||
|
@ -664,7 +663,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.nOut(convNOut1)
|
.nOut(convNOut1)
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(k)
|
.kernelSize(k)
|
||||||
.dilation(d)
|
.dilation(d)
|
||||||
.convolutionMode(ConvolutionMode.Causal)
|
.convolutionMode(ConvolutionMode.Causal)
|
||||||
|
|
|
@ -0,0 +1,811 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.gradientcheck;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.TestUtils;
|
||||||
|
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.util.Convolution1DUtils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
import org.nd4j.linalg.learning.config.NoOp;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class CNN1DNewGradientCheckTest extends BaseDL4JTest {
|
||||||
|
private static final boolean PRINT_RESULTS = true;
|
||||||
|
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
||||||
|
private static final double DEFAULT_EPS = 1e-6;
|
||||||
|
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
|
||||||
|
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
||||||
|
|
||||||
|
static {
|
||||||
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1D() {
|
||||||
|
int minibatchSize = 4;
|
||||||
|
int[] dataChannels = {4, 10}; //the input
|
||||||
|
int[] kernels = {2,4,5,8};
|
||||||
|
int stride = 2;
|
||||||
|
int padding = 3;
|
||||||
|
int seriesLength = 300;
|
||||||
|
|
||||||
|
for (int kernel : kernels) {
|
||||||
|
for (int dChannels : dataChannels) {
|
||||||
|
int numLabels = ((seriesLength + (2 * padding) - kernel) / stride) + 1;
|
||||||
|
final NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nIn(dChannels) // channels
|
||||||
|
.nOut(3)
|
||||||
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(4)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(dChannels, seriesLength))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
INDArray input = Nd4j.rand(minibatchSize, dChannels, seriesLength);
|
||||||
|
INDArray labels = Nd4j.zeros(minibatchSize, 4, numLabels);
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
for (int j = 0; j < numLabels; j++) {
|
||||||
|
labels.putScalar(new int[] {i, i % 4, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
String msg =
|
||||||
|
"Minibatch="
|
||||||
|
+ minibatchSize
|
||||||
|
+ ", activationFn="
|
||||||
|
+ Activation.RELU
|
||||||
|
+ ", kernel = "
|
||||||
|
+ kernel;
|
||||||
|
|
||||||
|
System.out.println(msg);
|
||||||
|
for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
|
/**
|
||||||
|
List<Pair<INDArray, INDArray>> iter = new java.util.ArrayList<>(Collections.emptyList());
|
||||||
|
iter.add(new Pair<>(input, labels));
|
||||||
|
for(int x=0;x<100; x++) net.fit(input, labels);
|
||||||
|
Evaluation eval = net.evaluate(new INDArrayDataSetIterator(iter,2), Arrays.asList(new String[]{"One", "Two", "Three", "Four"}));
|
||||||
|
// net.fit(input, labels);
|
||||||
|
eval.eval(labels, net.output(input));
|
||||||
|
|
||||||
|
**/
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
net,
|
||||||
|
DEFAULT_EPS,
|
||||||
|
DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR,
|
||||||
|
PRINT_RESULTS,
|
||||||
|
RETURN_ON_FIRST_FAILURE,
|
||||||
|
input,
|
||||||
|
labels);
|
||||||
|
|
||||||
|
assertTrue(gradOK, msg);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1DWithLocallyConnected1D() {
|
||||||
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
||||||
|
int[] minibatchSizes = {2, 3};
|
||||||
|
int length = 25;
|
||||||
|
int convNIn = 18;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 4;
|
||||||
|
|
||||||
|
int[] kernels = {1,2,4};
|
||||||
|
int stride = 1;
|
||||||
|
int padding = 0;
|
||||||
|
|
||||||
|
Activation[] activations = {Activation.SIGMOID};
|
||||||
|
|
||||||
|
for (Activation afn : activations) {
|
||||||
|
for (int minibatchSize : minibatchSizes) {
|
||||||
|
for (int kernel : kernels) {
|
||||||
|
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
|
||||||
|
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
for (int j = 0; j < length; j++) {
|
||||||
|
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nIn(convNIn)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
LocallyConnected1D.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nIn(convNOut1)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.hasBias(false)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String json = conf.toJson();
|
||||||
|
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
|
||||||
|
assertEquals(conf, c2);
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
String msg =
|
||||||
|
"Minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
|
||||||
|
|
||||||
|
if (PRINT_RESULTS) {
|
||||||
|
System.out.println(msg);
|
||||||
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
// System.out.println("ILayer " + j + " # params: " +
|
||||||
|
// net.getLayer(j).numParams());
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
net,
|
||||||
|
DEFAULT_EPS,
|
||||||
|
DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR,
|
||||||
|
PRINT_RESULTS,
|
||||||
|
RETURN_ON_FIRST_FAILURE,
|
||||||
|
input,
|
||||||
|
labels);
|
||||||
|
|
||||||
|
assertTrue(gradOK, msg);
|
||||||
|
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1DWithCropping1D() {
|
||||||
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
||||||
|
int[] minibatchSizes = {1, 3};
|
||||||
|
int length = 7;
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 4;
|
||||||
|
|
||||||
|
int[] kernels = {1, 2, 4};
|
||||||
|
int stride = 1;
|
||||||
|
|
||||||
|
int padding = 0;
|
||||||
|
int cropping = 1;
|
||||||
|
int croppedLength = length - 2 * cropping;
|
||||||
|
|
||||||
|
Activation[] activations = {Activation.SIGMOID};
|
||||||
|
SubsamplingLayer.PoolingType[] poolingTypes =
|
||||||
|
new SubsamplingLayer.PoolingType[] {
|
||||||
|
SubsamplingLayer.PoolingType.MAX,
|
||||||
|
SubsamplingLayer.PoolingType.AVG,
|
||||||
|
SubsamplingLayer.PoolingType.PNORM
|
||||||
|
};
|
||||||
|
|
||||||
|
for (Activation afn : activations) {
|
||||||
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
|
for (int minibatchSize : minibatchSizes) {
|
||||||
|
for (int kernel : kernels) {
|
||||||
|
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
|
||||||
|
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, croppedLength);
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
for (int j = 0; j < croppedLength; j++) {
|
||||||
|
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(Cropping1D.builder(cropping).build())
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String json = conf.toJson();
|
||||||
|
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
|
||||||
|
assertEquals(conf, c2);
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
String msg =
|
||||||
|
"PoolingType="
|
||||||
|
+ poolingType
|
||||||
|
+ ", minibatch="
|
||||||
|
+ minibatchSize
|
||||||
|
+ ", activationFn="
|
||||||
|
+ afn
|
||||||
|
+ ", kernel = "
|
||||||
|
+ kernel;
|
||||||
|
|
||||||
|
if (PRINT_RESULTS) {
|
||||||
|
System.out.println(msg);
|
||||||
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
// System.out.println("ILayer " + j + " # params: " +
|
||||||
|
// net.getLayer(j).numParams());
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
net,
|
||||||
|
DEFAULT_EPS,
|
||||||
|
DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR,
|
||||||
|
PRINT_RESULTS,
|
||||||
|
RETURN_ON_FIRST_FAILURE,
|
||||||
|
input,
|
||||||
|
labels);
|
||||||
|
|
||||||
|
assertTrue(gradOK, msg);
|
||||||
|
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1DWithZeroPadding1D() {
|
||||||
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
||||||
|
int[] minibatchSizes = {1, 3};
|
||||||
|
int length = 7;
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 4;
|
||||||
|
|
||||||
|
int[] kernels = {1, 2, 4};
|
||||||
|
int stride = 1;
|
||||||
|
int pnorm = 2;
|
||||||
|
|
||||||
|
int padding = 0;
|
||||||
|
int zeroPadding = 2;
|
||||||
|
int paddedLength = length + 2 * zeroPadding;
|
||||||
|
|
||||||
|
Activation[] activations = {Activation.SIGMOID};
|
||||||
|
SubsamplingLayer.PoolingType[] poolingTypes =
|
||||||
|
new SubsamplingLayer.PoolingType[] {
|
||||||
|
SubsamplingLayer.PoolingType.MAX,
|
||||||
|
SubsamplingLayer.PoolingType.AVG,
|
||||||
|
SubsamplingLayer.PoolingType.PNORM
|
||||||
|
};
|
||||||
|
|
||||||
|
for (Activation afn : activations) {
|
||||||
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
|
for (int minibatchSize : minibatchSizes) {
|
||||||
|
for (int kernel : kernels) {
|
||||||
|
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
|
||||||
|
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, paddedLength);
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
for (int j = 0; j < paddedLength; j++) {
|
||||||
|
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(2, kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(ZeroPadding1DLayer.builder(zeroPadding).build())
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(ZeroPadding1DLayer.builder(0).build())
|
||||||
|
.layer(
|
||||||
|
Subsampling1DLayer.builder(poolingType)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.pnorm(pnorm)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String json = conf.toJson();
|
||||||
|
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
|
||||||
|
assertEquals(conf, c2);
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
String msg =
|
||||||
|
"PoolingType="
|
||||||
|
+ poolingType
|
||||||
|
+ ", minibatch="
|
||||||
|
+ minibatchSize
|
||||||
|
+ ", activationFn="
|
||||||
|
+ afn
|
||||||
|
+ ", kernel = "
|
||||||
|
+ kernel;
|
||||||
|
|
||||||
|
if (PRINT_RESULTS) {
|
||||||
|
System.out.println(msg);
|
||||||
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
// System.out.println("ILayer " + j + " # params: " +
|
||||||
|
// net.getLayer(j).numParams());
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
net,
|
||||||
|
DEFAULT_EPS,
|
||||||
|
DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR,
|
||||||
|
PRINT_RESULTS,
|
||||||
|
RETURN_ON_FIRST_FAILURE,
|
||||||
|
input,
|
||||||
|
labels);
|
||||||
|
|
||||||
|
assertTrue(gradOK, msg);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1DWithSubsampling1D() {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
int[] minibatchSizes = {1, 3};
|
||||||
|
int length = 7;
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 4;
|
||||||
|
|
||||||
|
int[] kernels = {1, 2, 4};
|
||||||
|
int stride = 1;
|
||||||
|
int padding = 0;
|
||||||
|
int pnorm = 2;
|
||||||
|
|
||||||
|
Activation[] activations = {Activation.SIGMOID, Activation.TANH};
|
||||||
|
SubsamplingLayer.PoolingType[] poolingTypes =
|
||||||
|
new SubsamplingLayer.PoolingType[] {
|
||||||
|
SubsamplingLayer.PoolingType.MAX,
|
||||||
|
SubsamplingLayer.PoolingType.AVG,
|
||||||
|
SubsamplingLayer.PoolingType.PNORM
|
||||||
|
};
|
||||||
|
|
||||||
|
for (Activation afn : activations) {
|
||||||
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
|
for (int minibatchSize : minibatchSizes) {
|
||||||
|
for (int kernel : kernels) {
|
||||||
|
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
|
||||||
|
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
for (int j = 0; j < length; j++) {
|
||||||
|
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.layer(
|
||||||
|
0,
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
1,
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
2,
|
||||||
|
Subsampling1DLayer.builder(poolingType)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.pnorm(pnorm)
|
||||||
|
.name("SubsamplingLayer")
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
3,
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String json = conf.toJson();
|
||||||
|
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
|
||||||
|
assertEquals(conf, c2);
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
String msg =
|
||||||
|
"PoolingType="
|
||||||
|
+ poolingType
|
||||||
|
+ ", minibatch="
|
||||||
|
+ minibatchSize
|
||||||
|
+ ", activationFn="
|
||||||
|
+ afn
|
||||||
|
+ ", kernel = "
|
||||||
|
+ kernel;
|
||||||
|
|
||||||
|
if (PRINT_RESULTS) {
|
||||||
|
System.out.println(msg);
|
||||||
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
// System.out.println("ILayer " + j + " # params: " +
|
||||||
|
// net.getLayer(j).numParams());
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
net,
|
||||||
|
DEFAULT_EPS,
|
||||||
|
DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR,
|
||||||
|
PRINT_RESULTS,
|
||||||
|
RETURN_ON_FIRST_FAILURE,
|
||||||
|
input,
|
||||||
|
labels);
|
||||||
|
|
||||||
|
assertTrue(gradOK, msg);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1dWithMasking() {
|
||||||
|
int length = 12;
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 3;
|
||||||
|
|
||||||
|
int pnorm = 2;
|
||||||
|
|
||||||
|
SubsamplingLayer.PoolingType[] poolingTypes =
|
||||||
|
new SubsamplingLayer.PoolingType[] {
|
||||||
|
SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG
|
||||||
|
};
|
||||||
|
|
||||||
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
|
for (ConvolutionMode cm :
|
||||||
|
new ConvolutionMode[] {ConvolutionMode.Same, ConvolutionMode.Truncate}) {
|
||||||
|
for (int stride : new int[] {1, 2}) {
|
||||||
|
String s = cm + ", stride=" + stride + ", pooling=" + poolingType;
|
||||||
|
log.info("Starting test: " + s);
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(cm)
|
||||||
|
.seed(12345)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.kernelSize(2)
|
||||||
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
|
.stride(stride)
|
||||||
|
.nIn(convNIn)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
Subsampling1DLayer.builder(poolingType)
|
||||||
|
.kernelSize(2)
|
||||||
|
.stride(stride)
|
||||||
|
.pnorm(pnorm)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.kernelSize(2)
|
||||||
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
|
.stride(stride)
|
||||||
|
.nIn(convNOut1)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
||||||
|
.layer(
|
||||||
|
OutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
INDArray f = Nd4j.rand(2, convNIn, length);
|
||||||
|
INDArray fm = Nd4j.create(2, length);
|
||||||
|
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
|
||||||
|
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, 6)).assign(1);
|
||||||
|
|
||||||
|
INDArray label = TestUtils.randomOneHot(2, finalNOut);
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
|
||||||
|
|
||||||
|
assertTrue(gradOK, s);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
|
||||||
|
// TODO also check that masked step values don't impact forward pass, score or gradients
|
||||||
|
|
||||||
|
DataSet ds = new DataSet(f, label, fm, null);
|
||||||
|
double scoreBefore = net.score(ds);
|
||||||
|
net.setInput(f);
|
||||||
|
net.setLabels(label);
|
||||||
|
net.setLayerMaskArrays(fm, null);
|
||||||
|
net.computeGradientAndScore();
|
||||||
|
INDArray gradBefore = net.getFlattenedGradients().dup();
|
||||||
|
f.putScalar(1, 0, 10, 10.0);
|
||||||
|
f.putScalar(1, 1, 11, 20.0);
|
||||||
|
double scoreAfter = net.score(ds);
|
||||||
|
net.setInput(f);
|
||||||
|
net.setLabels(label);
|
||||||
|
net.setLayerMaskArrays(fm, null);
|
||||||
|
net.computeGradientAndScore();
|
||||||
|
INDArray gradAfter = net.getFlattenedGradients().dup();
|
||||||
|
|
||||||
|
assertEquals(scoreBefore, scoreAfter, 1e-6);
|
||||||
|
assertEquals(gradBefore, gradAfter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1Causal() throws Exception {
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 3;
|
||||||
|
|
||||||
|
int[] lengths = {11, 12, 13, 9, 10, 11};
|
||||||
|
int[] kernels = {2, 3, 2, 4, 2, 3};
|
||||||
|
int[] dilations = {1, 1, 2, 1, 2, 1};
|
||||||
|
int[] strides = {1, 2, 1, 2, 1, 1};
|
||||||
|
boolean[] masks = {false, true, false, true, false, true};
|
||||||
|
boolean[] hasB = {true, false, true, false, true, true};
|
||||||
|
for (int i = 0; i < lengths.length; i++) {
|
||||||
|
int length = lengths[i];
|
||||||
|
int k = kernels[i];
|
||||||
|
int d = dilations[i];
|
||||||
|
int st = strides[i];
|
||||||
|
boolean mask = masks[i];
|
||||||
|
boolean hasBias = hasB[i];
|
||||||
|
// TODO has bias
|
||||||
|
String s = "k=" + k + ", s=" + st + " d=" + d + ", seqLen=" + length;
|
||||||
|
log.info("Starting test: " + s);
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.weightInit(new NormalDistribution(0, 1))
|
||||||
|
.seed(12345)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.kernelSize(k)
|
||||||
|
.dilation(d)
|
||||||
|
.hasBias(hasBias)
|
||||||
|
.convolutionMode(ConvolutionMode.Causal)
|
||||||
|
.stride(st)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.kernelSize(k)
|
||||||
|
.dilation(d)
|
||||||
|
.convolutionMode(ConvolutionMode.Causal)
|
||||||
|
.stride(st)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length);
|
||||||
|
INDArray fm = null;
|
||||||
|
if (mask) {
|
||||||
|
fm = Nd4j.create(2, length);
|
||||||
|
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
|
||||||
|
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length - 2)).assign(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d);
|
||||||
|
long outSize2 =
|
||||||
|
Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d);
|
||||||
|
|
||||||
|
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int) outSize2);
|
||||||
|
|
||||||
|
String msg =
|
||||||
|
"Minibatch="
|
||||||
|
+ 1
|
||||||
|
+ ", activationFn="
|
||||||
|
+ Activation.RELU
|
||||||
|
+ ", kernel = "
|
||||||
|
+ k;
|
||||||
|
|
||||||
|
System.out.println(msg);
|
||||||
|
for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
|
||||||
|
|
||||||
|
assertTrue(gradOK, s);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -112,9 +112,8 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
|
.updater(new NoOp())
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.list()
|
|
||||||
.layer(0, Convolution3D.builder().activation(afn).kernelSize(kernel)
|
.layer(0, Convolution3D.builder().activation(afn).kernelSize(kernel)
|
||||||
.stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false)
|
.stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false)
|
||||||
.convolutionMode(mode).dataFormat(df)
|
.convolutionMode(mode).dataFormat(df)
|
||||||
|
@ -400,7 +399,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
|
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
|
||||||
.layer(0, Convolution3D.builder().activation(afn).kernelSize(1, 1, 1)
|
.layer(0, Convolution3D.builder().activation(afn).kernelSize(1, 1, 1)
|
||||||
.nIn(convNIn).nOut(convNOut).hasBias(false)
|
.nIn(convNIn).nOut(convNOut).hasBias(false)
|
||||||
.convolutionMode(mode).dataFormat(df)
|
.convolutionMode(mode).dataFormat(df)
|
||||||
|
|
|
@ -108,8 +108,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.seed(12345L)
|
.seed(12345L)
|
||||||
.list()
|
|
||||||
.layer(0, ConvolutionLayer.builder(1, 1).nOut(6).activation(afn).build())
|
.layer(0, Convolution2D.builder().kernelSize(1).stride(1).nOut(6).activation(afn).build())
|
||||||
.layer(1, OutputLayer.builder(lf).activation(outputActivation).nOut(3).build())
|
.layer(1, OutputLayer.builder(lf).activation(outputActivation).nOut(3).build())
|
||||||
.inputType(InputType.convolutionalFlat(1, 4, 1));
|
.inputType(InputType.convolutionalFlat(1, 4, 1));
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -336,7 +337,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
// to ensure that we carry the parameters through
|
// to ensure that we carry the parameters through
|
||||||
// the serializer.
|
// the serializer.
|
||||||
try{
|
try{
|
||||||
ObjectMapper m = NeuralNetConfiguration.mapper();
|
ObjectMapper m = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
String s = m.writeValueAsString(lossFunctions[i]);
|
String s = m.writeValueAsString(lossFunctions[i]);
|
||||||
ILossFunction lf2 = m.readValue(s, lossFunctions[i].getClass());
|
ILossFunction lf2 = m.readValue(s, lossFunctions[i].getClass());
|
||||||
lossFunctions[i] = lf2;
|
lossFunctions[i] = lf2;
|
||||||
|
|
|
@ -180,7 +180,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
Pooling2D.class, //Alias for SubsamplingLayer
|
Pooling2D.class, //Alias for SubsamplingLayer
|
||||||
Convolution2D.class, //Alias for ConvolutionLayer
|
Convolution2D.class, //Alias for ConvolutionLayer
|
||||||
Pooling1D.class, //Alias for Subsampling1D
|
Pooling1D.class, //Alias for Subsampling1D
|
||||||
Convolution1D.class, //Alias for Convolution1DLayer
|
Convolution1D.class, //Alias for Convolution1D
|
||||||
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
||||||
));
|
));
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.Timeout;
|
import org.junit.jupiter.api.Timeout;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -1026,7 +1026,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
} catch (DL4JInvalidInputException e) {
|
} catch (DL4JInvalidInputException e) {
|
||||||
// e.printStackTrace();
|
// e.printStackTrace();
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration"), msg);
|
assertTrue(msg.contains(Convolution2DUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration"), msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
@ -921,7 +921,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration.builder()
|
NeuralNetConfiguration.builder()
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.kernelSize(2)
|
.kernelSize(2)
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
|
@ -975,7 +975,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConv1dCausalAllowed() {
|
public void testConv1dCausalAllowed() {
|
||||||
Convolution1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
Convolution1D.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||||
Subsampling1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
Subsampling1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -346,7 +346,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(2, it.getHeight());
|
assertEquals(2, it.getHeight());
|
||||||
assertEquals(2, it.getWidth());
|
assertEquals(2, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
int[] outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
int[] outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
||||||
assertEquals(2, outSize[0]);
|
assertEquals(2, outSize[0]);
|
||||||
assertEquals(2, outSize[1]);
|
assertEquals(2, outSize[1]);
|
||||||
|
|
||||||
|
@ -357,7 +357,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(2, it.getHeight());
|
assertEquals(2, it.getHeight());
|
||||||
assertEquals(2, it.getWidth());
|
assertEquals(2, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
||||||
assertEquals(2, outSize[0]);
|
assertEquals(2, outSize[0]);
|
||||||
assertEquals(2, outSize[1]);
|
assertEquals(2, outSize[1]);
|
||||||
|
|
||||||
|
@ -367,7 +367,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(3, it.getHeight());
|
assertEquals(3, it.getHeight());
|
||||||
assertEquals(3, it.getWidth());
|
assertEquals(3, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
||||||
assertEquals(3, outSize[0]);
|
assertEquals(3, outSize[0]);
|
||||||
assertEquals(3, outSize[1]);
|
assertEquals(3, outSize[1]);
|
||||||
|
|
||||||
|
@ -397,7 +397,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
System.out.println(e.getMessage());
|
System.out.println(e.getMessage());
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
||||||
fail("Exception expected");
|
fail("Exception expected");
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println(e.getMessage());
|
System.out.println(e.getMessage());
|
||||||
|
@ -409,7 +409,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(1, it.getHeight());
|
assertEquals(1, it.getHeight());
|
||||||
assertEquals(1, it.getWidth());
|
assertEquals(1, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
||||||
assertEquals(1, outSize[0]);
|
assertEquals(1, outSize[0]);
|
||||||
assertEquals(1, outSize[1]);
|
assertEquals(1, outSize[1]);
|
||||||
|
|
||||||
|
@ -419,7 +419,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(2, it.getHeight());
|
assertEquals(2, it.getHeight());
|
||||||
assertEquals(2, it.getWidth());
|
assertEquals(2, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
||||||
assertEquals(2, outSize[0]);
|
assertEquals(2, outSize[0]);
|
||||||
assertEquals(2, outSize[1]);
|
assertEquals(2, outSize[1]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -732,7 +732,7 @@ public class BatchNormalizationTest extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(rnn ? LSTM.builder().nOut(3).build() :
|
.layer(rnn ? LSTM.builder().nOut(3).build() :
|
||||||
Convolution1DLayer.builder().kernelSize(3).stride(1).nOut(3).build())
|
Convolution1D.builder().kernelSize(3).stride(1).nOut(3).build())
|
||||||
.layer(BatchNormalization.builder().build())
|
.layer(BatchNormalization.builder().build())
|
||||||
.layer(RnnOutputLayer.builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
.layer(RnnOutputLayer.builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||||
.inputType(InputType.recurrent(3))
|
.inputType(InputType.recurrent(3))
|
||||||
|
|
|
@ -52,7 +52,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest {
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs(inputName)
|
.addInputs(inputName)
|
||||||
.setOutputs(output)
|
.setOutputs(output)
|
||||||
.layer(conv, Convolution1DLayer.builder(7)
|
.layer(conv, Convolution1D.builder(7)
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.nOut(input.size(1))
|
.nOut(input.size(1))
|
||||||
.weightInit(new WeightInitIdentity())
|
.weightInit(new WeightInitIdentity())
|
||||||
|
|
|
@ -23,6 +23,7 @@ package org.deeplearning4j.regressiontest;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.distribution.*;
|
import org.deeplearning4j.nn.conf.distribution.*;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ public class TestDistributionDeserializer extends BaseDL4JTest {
|
||||||
new Distribution[] {new NormalDistribution(3, 0.5), new UniformDistribution(-2, 1),
|
new Distribution[] {new NormalDistribution(3, 0.5), new UniformDistribution(-2, 1),
|
||||||
new GaussianDistribution(2, 1.0), new BinomialDistribution(10, 0.3)};
|
new GaussianDistribution(2, 1.0), new BinomialDistribution(10, 0.3)};
|
||||||
|
|
||||||
ObjectMapper om = NeuralNetConfiguration.mapper();
|
ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
|
|
||||||
for (Distribution d : distributions) {
|
for (Distribution d : distributions) {
|
||||||
String json = om.writeValueAsString(d);
|
String json = om.writeValueAsString(d);
|
||||||
|
@ -50,7 +51,7 @@ public class TestDistributionDeserializer extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDistributionDeserializerLegacyFormat() throws Exception {
|
public void testDistributionDeserializerLegacyFormat() throws Exception {
|
||||||
ObjectMapper om = NeuralNetConfiguration.mapper();
|
ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
|
|
||||||
String normalJson = "{\n" + " \"normal\" : {\n" + " \"mean\" : 0.1,\n"
|
String normalJson = "{\n" + " \"normal\" : {\n" + " \"mean\" : 0.1,\n"
|
||||||
+ " \"std\" : 1.2\n" + " }\n" + " }";
|
+ " \"std\" : 1.2\n" + " }\n" + " }";
|
||||||
|
|
|
@ -38,7 +38,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
||||||
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
|
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
|
||||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.nd4j.jita.allocator.Allocator;
|
import org.nd4j.jita.allocator.Allocator;
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
import org.nd4j.jita.conf.CudaEnvironment;
|
import org.nd4j.jita.conf.CudaEnvironment;
|
||||||
|
@ -681,9 +681,9 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
|
|
||||||
int[] outSize;
|
int[] outSize;
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
|
outSize = Convolution2DUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
|
||||||
padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
padding = Convolution2DUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
||||||
int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
int[] padBottomRight = Convolution2DUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
||||||
if(!Arrays.equals(padding, padBottomRight)){
|
if(!Arrays.equals(padding, padBottomRight)){
|
||||||
/*
|
/*
|
||||||
CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
|
CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
|
||||||
|
@ -731,7 +731,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
// CuDNN handle
|
// CuDNN handle
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
|
outSize = Convolution2DUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
|
||||||
}
|
}
|
||||||
|
|
||||||
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
|
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
|
||||||
|
|
|
@ -42,7 +42,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.nd4j.common.primitives.Counter;
|
import org.nd4j.common.primitives.Counter;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
@ -442,8 +442,8 @@ public class KerasModel {
|
||||||
KerasInput kerasInput = (KerasInput) layer;
|
KerasInput kerasInput = (KerasInput) layer;
|
||||||
LayerConfiguration layer1 = layersOrdered.get(kerasLayerIdx + 1).layer;
|
LayerConfiguration layer1 = layersOrdered.get(kerasLayerIdx + 1).layer;
|
||||||
//no dim order, try to pull it from the next layer if there is one
|
//no dim order, try to pull it from the next layer if there is one
|
||||||
if(ConvolutionUtils.layerHasConvolutionLayout(layer1)) {
|
if(Convolution2DUtils.layerHasConvolutionLayout(layer1)) {
|
||||||
CNN2DFormat formatForLayer = ConvolutionUtils.getFormatForLayer(layer1);
|
CNN2DFormat formatForLayer = Convolution2DUtils.getFormatForLayer(layer1);
|
||||||
if(formatForLayer == CNN2DFormat.NCHW) {
|
if(formatForLayer == CNN2DFormat.NCHW) {
|
||||||
dimOrder = KerasLayer.DimOrder.THEANO;
|
dimOrder = KerasLayer.DimOrder.THEANO;
|
||||||
} else if(formatForLayer == CNN2DFormat.NHWC) {
|
} else if(formatForLayer == CNN2DFormat.NHWC) {
|
||||||
|
|
|
@ -43,223 +43,255 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils.impo
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class KerasSequentialModel extends KerasModel {
|
public class KerasSequentialModel extends KerasModel {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* (Recommended) Builder-pattern constructor for Sequential model.
|
* (Recommended) Builder-pattern constructor for Sequential model.
|
||||||
*
|
*
|
||||||
* @param modelBuilder builder object
|
* @param modelBuilder builder object
|
||||||
* @throws IOException I/O exception
|
* @throws IOException I/O exception
|
||||||
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
||||||
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
|
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
|
||||||
*/
|
*/
|
||||||
public KerasSequentialModel(KerasModelBuilder modelBuilder)
|
public KerasSequentialModel(KerasModelBuilder modelBuilder)
|
||||||
throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
|
throws UnsupportedKerasConfigurationException,
|
||||||
this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(),
|
IOException,
|
||||||
modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(),
|
InvalidKerasConfigurationException {
|
||||||
modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape());
|
this(
|
||||||
|
modelBuilder.getModelJson(),
|
||||||
|
modelBuilder.getModelYaml(),
|
||||||
|
modelBuilder.getWeightsArchive(),
|
||||||
|
modelBuilder.getWeightsRoot(),
|
||||||
|
modelBuilder.getTrainingJson(),
|
||||||
|
modelBuilder.getTrainingArchive(),
|
||||||
|
modelBuilder.isEnforceTrainingConfig(),
|
||||||
|
modelBuilder.getInputShape());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML),
|
||||||
|
* training configuration (JSON), weights, and "training mode" boolean indicator. When built in
|
||||||
|
* training mode, certain unsupported configurations (e.g., unknown regularizers) will throw
|
||||||
|
* Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be
|
||||||
|
* otherwise ignored.
|
||||||
|
*
|
||||||
|
* @param modelJson model configuration JSON string
|
||||||
|
* @param modelYaml model configuration YAML string
|
||||||
|
* @param trainingJson training configuration JSON string
|
||||||
|
* @throws IOException I/O exception
|
||||||
|
*/
|
||||||
|
public KerasSequentialModel(
|
||||||
|
String modelJson,
|
||||||
|
String modelYaml,
|
||||||
|
Hdf5Archive weightsArchive,
|
||||||
|
String weightsRoot,
|
||||||
|
String trainingJson,
|
||||||
|
Hdf5Archive trainingArchive,
|
||||||
|
boolean enforceTrainingConfig,
|
||||||
|
int[] inputShape)
|
||||||
|
throws IOException,
|
||||||
|
InvalidKerasConfigurationException,
|
||||||
|
UnsupportedKerasConfigurationException {
|
||||||
|
|
||||||
|
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
|
||||||
|
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
|
||||||
|
this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
|
||||||
|
this.enforceTrainingConfig = enforceTrainingConfig;
|
||||||
|
|
||||||
|
/* Determine model configuration type. */
|
||||||
|
if (!modelConfig.containsKey(config.getFieldClassName()))
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"Could not determine Keras model class (no "
|
||||||
|
+ config.getFieldClassName()
|
||||||
|
+ " field found)");
|
||||||
|
this.className = (String) modelConfig.get(config.getFieldClassName());
|
||||||
|
if (!this.className.equals(config.getFieldClassNameSequential()))
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"Model class name must be "
|
||||||
|
+ config.getFieldClassNameSequential()
|
||||||
|
+ " (found "
|
||||||
|
+ this.className
|
||||||
|
+ ")");
|
||||||
|
|
||||||
|
/* Process layer configurations. */
|
||||||
|
if (!modelConfig.containsKey(config.getModelFieldConfig()))
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"Could not find layer configurations (no "
|
||||||
|
+ config.getModelFieldConfig()
|
||||||
|
+ " field found)");
|
||||||
|
|
||||||
|
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations.
|
||||||
|
// For consistency
|
||||||
|
// "config" is now an object containing a "name" and "layers", the latter contain the same data
|
||||||
|
// as before.
|
||||||
|
// This change only affects Sequential models.
|
||||||
|
List<Object> layerList;
|
||||||
|
try {
|
||||||
|
layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
|
||||||
|
} catch (Exception e) {
|
||||||
|
HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
|
||||||
|
layerList = (List<Object>) layerMap.get("layers");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = prepareLayers(layerList);
|
||||||
* (Not recommended) Constructor for Sequential model from model configuration
|
this.layers = layerPair.getFirst();
|
||||||
* (JSON or YAML), training configuration (JSON), weights, and "training mode"
|
this.layersOrdered = layerPair.getSecond();
|
||||||
* boolean indicator. When built in training mode, certain unsupported configurations
|
|
||||||
* (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these
|
|
||||||
* will generate warnings but will be otherwise ignored.
|
|
||||||
*
|
|
||||||
* @param modelJson model configuration JSON string
|
|
||||||
* @param modelYaml model configuration YAML string
|
|
||||||
* @param trainingJson training configuration JSON string
|
|
||||||
* @throws IOException I/O exception
|
|
||||||
*/
|
|
||||||
public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot,
|
|
||||||
String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig,
|
|
||||||
int[] inputShape)
|
|
||||||
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
|
||||||
|
|
||||||
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
|
KerasLayer inputLayer;
|
||||||
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
|
if (this.layersOrdered.get(0) instanceof KerasInput) {
|
||||||
this.kerasBackend = KerasModelUtils.determineKerasBackend(modelConfig, config);
|
inputLayer = this.layersOrdered.get(0);
|
||||||
this.enforceTrainingConfig = enforceTrainingConfig;
|
} else {
|
||||||
|
/* Add placeholder input layer and update lists of input and output layers. */
|
||||||
|
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
|
||||||
|
Preconditions.checkState(
|
||||||
|
ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!");
|
||||||
|
inputLayer = new KerasInput("input1", firstLayerInputShape);
|
||||||
|
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
|
||||||
|
this.layers.put(inputLayer.getName(), inputLayer);
|
||||||
|
this.layersOrdered.add(0, inputLayer);
|
||||||
|
}
|
||||||
|
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
|
||||||
|
this.outputLayerNames =
|
||||||
|
new ArrayList<>(
|
||||||
|
Collections.singletonList(
|
||||||
|
this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
|
||||||
|
|
||||||
/* Determine model configuration type. */
|
/* Update each layer's inbound layer list to include (only) previous layer. */
|
||||||
if (!modelConfig.containsKey(config.getFieldClassName()))
|
KerasLayer prevLayer = null;
|
||||||
throw new InvalidKerasConfigurationException(
|
for (KerasLayer layer : this.layersOrdered) {
|
||||||
"Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
|
if (prevLayer != null)
|
||||||
this.className = (String) modelConfig.get(config.getFieldClassName());
|
layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
|
||||||
if (!this.className.equals(config.getFieldClassNameSequential()))
|
prevLayer = layer;
|
||||||
throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential()
|
|
||||||
+ " (found " + this.className + ")");
|
|
||||||
|
|
||||||
/* Process layer configurations. */
|
|
||||||
if (!modelConfig.containsKey(config.getModelFieldConfig()))
|
|
||||||
throw new InvalidKerasConfigurationException(
|
|
||||||
"Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)");
|
|
||||||
|
|
||||||
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency
|
|
||||||
// "config" is now an object containing a "name" and "layers", the latter contain the same data as before.
|
|
||||||
// This change only affects Sequential models.
|
|
||||||
List<Object> layerList;
|
|
||||||
try {
|
|
||||||
layerList = (List<Object>) modelConfig.get(config.getModelFieldConfig());
|
|
||||||
} catch (Exception e) {
|
|
||||||
HashMap layerMap = (HashMap<String, Object>) modelConfig.get(config.getModelFieldConfig());
|
|
||||||
layerList = (List<Object>) layerMap.get("layers");
|
|
||||||
}
|
|
||||||
|
|
||||||
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair =
|
|
||||||
prepareLayers(layerList);
|
|
||||||
this.layers = layerPair.getFirst();
|
|
||||||
this.layersOrdered = layerPair.getSecond();
|
|
||||||
|
|
||||||
KerasLayer inputLayer;
|
|
||||||
if (this.layersOrdered.get(0) instanceof KerasInput) {
|
|
||||||
inputLayer = this.layersOrdered.get(0);
|
|
||||||
} else {
|
|
||||||
/* Add placeholder input layer and update lists of input and output layers. */
|
|
||||||
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
|
|
||||||
Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!");
|
|
||||||
inputLayer = new KerasInput("input1", firstLayerInputShape);
|
|
||||||
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
|
|
||||||
this.layers.put(inputLayer.getName(), inputLayer);
|
|
||||||
this.layersOrdered.add(0, inputLayer);
|
|
||||||
}
|
|
||||||
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
|
|
||||||
this.outputLayerNames = new ArrayList<>(
|
|
||||||
Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
|
|
||||||
|
|
||||||
/* Update each layer's inbound layer list to include (only) previous layer. */
|
|
||||||
KerasLayer prevLayer = null;
|
|
||||||
for (KerasLayer layer : this.layersOrdered) {
|
|
||||||
if (prevLayer != null)
|
|
||||||
layer.setInboundLayerNames(Collections.singletonList(prevLayer.getName()));
|
|
||||||
prevLayer = layer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Import training configuration. */
|
|
||||||
if (enforceTrainingConfig) {
|
|
||||||
if (trainingJson != null)
|
|
||||||
importTrainingConfiguration(trainingJson);
|
|
||||||
else log.warn("If enforceTrainingConfig is true, a training " +
|
|
||||||
"configuration object has to be provided. Usually the only practical way to do this is to store" +
|
|
||||||
" your keras model with `model.save('model_path.h5'. If you store model config and weights" +
|
|
||||||
" separately no training configuration is attached.");
|
|
||||||
}
|
|
||||||
|
|
||||||
this.outputTypes = inferOutputTypes(inputShape);
|
|
||||||
|
|
||||||
if (weightsArchive != null)
|
|
||||||
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/* Import training configuration. */
|
||||||
* Default constructor
|
if (enforceTrainingConfig) {
|
||||||
*/
|
if (trainingJson != null) importTrainingConfiguration(trainingJson);
|
||||||
public KerasSequentialModel() {
|
else
|
||||||
super();
|
log.warn(
|
||||||
|
"If enforceTrainingConfig is true, a training "
|
||||||
|
+ "configuration object has to be provided. Usually the only practical way to do this is to store"
|
||||||
|
+ " your keras model with `model.save('model_path.h5'. If you store model config and weights"
|
||||||
|
+ " separately no training configuration is attached.");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
this.outputTypes = inferOutputTypes(inputShape);
|
||||||
* Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
|
|
||||||
*
|
|
||||||
* @return NeuralNetConfiguration
|
|
||||||
*/
|
|
||||||
public NeuralNetConfiguration getNeuralNetConfiguration()
|
|
||||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
|
||||||
if (!this.className.equals(config.getFieldClassNameSequential()))
|
|
||||||
throw new InvalidKerasConfigurationException(
|
|
||||||
"Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
|
|
||||||
if (this.inputLayerNames.size() != 1)
|
|
||||||
throw new InvalidKerasConfigurationException(
|
|
||||||
"MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
|
|
||||||
if (this.outputLayerNames.size() != 1)
|
|
||||||
throw new InvalidKerasConfigurationException(
|
|
||||||
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
|
|
||||||
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder();
|
if (weightsArchive != null)
|
||||||
|
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
|
||||||
|
}
|
||||||
|
|
||||||
if (optimizer != null) {
|
/** Default constructor */
|
||||||
modelBuilder.updater(optimizer);
|
public KerasSequentialModel() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure a NeuralNetConfiguration from this Keras Sequential model configuration.
|
||||||
|
*
|
||||||
|
* @return NeuralNetConfiguration
|
||||||
|
*/
|
||||||
|
public NeuralNetConfiguration getNeuralNetConfiguration()
|
||||||
|
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||||
|
if (!this.className.equals(config.getFieldClassNameSequential()))
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
|
||||||
|
if (this.inputLayerNames.size() != 1)
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"MultiLayerNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
|
||||||
|
if (this.outputLayerNames.size() != 1)
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
|
||||||
|
|
||||||
|
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder =
|
||||||
|
NeuralNetConfiguration.builder();
|
||||||
|
|
||||||
|
if (optimizer != null) {
|
||||||
|
modelBuilder.updater(optimizer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't forcibly override for keras import
|
||||||
|
modelBuilder.overrideNinUponBuild(false);
|
||||||
|
/* Add layers one at a time. */
|
||||||
|
KerasLayer prevLayer = null;
|
||||||
|
int layerIndex = 0;
|
||||||
|
for (KerasLayer layer : this.layersOrdered) {
|
||||||
|
if (layer.isLayer()) {
|
||||||
|
int nbInbound = layer.getInboundLayerNames().size();
|
||||||
|
if (nbInbound != 1)
|
||||||
|
throw new InvalidKerasConfigurationException(
|
||||||
|
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
|
||||||
|
+ nbInbound
|
||||||
|
+ " for layer "
|
||||||
|
+ layer.getName()
|
||||||
|
+ ")");
|
||||||
|
if (prevLayer != null) {
|
||||||
|
InputType[] inputTypes = new InputType[1];
|
||||||
|
InputPreProcessor preprocessor;
|
||||||
|
if (prevLayer.isInputPreProcessor()) {
|
||||||
|
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
|
||||||
|
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
|
||||||
|
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
||||||
|
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
|
||||||
|
} else {
|
||||||
|
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
|
||||||
|
preprocessor = layer.getInputPreprocessor(inputTypes);
|
||||||
|
if (preprocessor != null) {
|
||||||
|
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
||||||
|
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
|
||||||
|
} else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild());
|
||||||
|
}
|
||||||
|
if (preprocessor != null) {
|
||||||
|
|
||||||
|
Map<Integer, InputPreProcessor> map = new HashMap<>();
|
||||||
|
map.put(layerIndex, preprocessor);
|
||||||
|
modelBuilder.inputPreProcessors(map);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
modelBuilder.layer(layerIndex++, layer.getLayer());
|
||||||
//don't forcibly override for keras import
|
} else if (layer.getVertex() != null)
|
||||||
modelBuilder.overrideNinUponBuild(false);
|
throw new InvalidKerasConfigurationException(
|
||||||
/* Add layers one at a time. */
|
"Cannot add vertex to NeuralNetConfiguration (class name "
|
||||||
KerasLayer prevLayer = null;
|
+ layer.getClassName()
|
||||||
int layerIndex = 0;
|
+ ", layer name "
|
||||||
for (KerasLayer layer : this.layersOrdered) {
|
+ layer.getName()
|
||||||
if (layer.isLayer()) {
|
+ ")");
|
||||||
int nbInbound = layer.getInboundLayerNames().size();
|
prevLayer = layer;
|
||||||
if (nbInbound != 1)
|
|
||||||
throw new InvalidKerasConfigurationException(
|
|
||||||
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
|
|
||||||
+ nbInbound + " for layer " + layer.getName() + ")");
|
|
||||||
if (prevLayer != null) {
|
|
||||||
InputType[] inputTypes = new InputType[1];
|
|
||||||
InputPreProcessor preprocessor;
|
|
||||||
if (prevLayer.isInputPreProcessor()) {
|
|
||||||
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
|
|
||||||
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
|
|
||||||
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
|
||||||
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
|
|
||||||
} else {
|
|
||||||
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
|
|
||||||
preprocessor = layer.getInputPreprocessor(inputTypes);
|
|
||||||
if(preprocessor != null) {
|
|
||||||
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
|
||||||
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
|
|
||||||
}
|
|
||||||
else
|
|
||||||
layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild());
|
|
||||||
|
|
||||||
}
|
|
||||||
if (preprocessor != null)
|
|
||||||
modelBuilder.inputPreProcessor(layerIndex, preprocessor);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
modelBuilder.layer(layerIndex++, layer.getLayer());
|
|
||||||
} else if (layer.getVertex() != null)
|
|
||||||
throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name "
|
|
||||||
+ layer.getClassName() + ", layer name " + layer.getName() + ")");
|
|
||||||
prevLayer = layer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
|
|
||||||
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
|
|
||||||
modelBuilder.backpropType(BackpropType.TruncatedBPTT)
|
|
||||||
.tbpttFwdLength(truncatedBPTT)
|
|
||||||
.tbpttBackLength(truncatedBPTT);
|
|
||||||
else
|
|
||||||
modelBuilder.backpropType(BackpropType.Standard);
|
|
||||||
|
|
||||||
NeuralNetConfiguration build = modelBuilder.build();
|
|
||||||
|
|
||||||
|
|
||||||
return build;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
|
||||||
* Build a MultiLayerNetwork from this Keras Sequential model configuration.
|
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
|
||||||
*
|
modelBuilder
|
||||||
* @return MultiLayerNetwork
|
.backpropType(BackpropType.TruncatedBPTT)
|
||||||
*/
|
.tbpttFwdLength(truncatedBPTT)
|
||||||
public MultiLayerNetwork getMultiLayerNetwork()
|
.tbpttBackLength(truncatedBPTT);
|
||||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
else modelBuilder.backpropType(BackpropType.Standard);
|
||||||
return getMultiLayerNetwork(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
NeuralNetConfiguration build = modelBuilder.build();
|
||||||
* Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
|
|
||||||
*
|
return build;
|
||||||
* @return MultiLayerNetwork
|
}
|
||||||
*/
|
|
||||||
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
|
/**
|
||||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
* Build a MultiLayerNetwork from this Keras Sequential model configuration.
|
||||||
MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration());
|
*
|
||||||
model.init();
|
* @return MultiLayerNetwork
|
||||||
if (importWeights)
|
*/
|
||||||
model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers);
|
public MultiLayerNetwork getMultiLayerNetwork()
|
||||||
return model;
|
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||||
}
|
return getMultiLayerNetwork(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a MultiLayerNetwork from this Keras Sequential model configuration and import weights.
|
||||||
|
*
|
||||||
|
* @return MultiLayerNetwork
|
||||||
|
*/
|
||||||
|
public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights)
|
||||||
|
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||||
|
MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration());
|
||||||
|
model.init();
|
||||||
|
if (importWeights)
|
||||||
|
model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers);
|
||||||
|
return model;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -84,29 +84,29 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
|
||||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||||
|
|
||||||
ConvolutionLayer.ConvolutionLayerBuilder builder = Convolution1DLayer.builder().name(this.name)
|
var builder = Convolution1D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(init)
|
.weightInit(init)
|
||||||
.dilation(getDilationRate(layerConfig, 1, conf, true)[0])
|
.dilation(getDilationRate(layerConfig, 1, conf, true)[0])
|
||||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion))
|
||||||
.hasBias(hasBias)
|
.hasBias(hasBias)
|
||||||
.rnnDataFormat(dimOrder == DimOrder.TENSORFLOW ? RNNFormat.NWC : RNNFormat.NCW)
|
.rnnDataFormat(dimOrder == DimOrder.TENSORFLOW ? RNNFormat.NWC : RNNFormat.NCW)
|
||||||
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]);
|
.stride(getStrideFromConfig(layerConfig, 1, conf));
|
||||||
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
|
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
|
||||||
if (hasBias)
|
if (hasBias)
|
||||||
builder.biasInit(0.0);
|
builder.biasInit(0.0);
|
||||||
if (padding != null)
|
if (padding != null)
|
||||||
builder.padding(padding[0]);
|
builder.padding(padding);
|
||||||
if (biasConstraint != null)
|
if (biasConstraint != null)
|
||||||
builder.constrainBias(biasConstraint);
|
builder.constrainBias(biasConstraint);
|
||||||
if (weightConstraint != null)
|
if (weightConstraint != null)
|
||||||
builder.constrainWeights(weightConstraint);
|
builder.constrainWeights(weightConstraint);
|
||||||
this.layer = builder.build();
|
this.layer = builder.build();
|
||||||
Convolution1DLayer convolution1DLayer = (Convolution1DLayer) layer;
|
Convolution1D convolution1D = (Convolution1D) layer;
|
||||||
convolution1DLayer.setDefaultValueOverriden(true);
|
convolution1D.setDefaultValueOverriden(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -114,8 +114,8 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
|
||||||
*
|
*
|
||||||
* @return ConvolutionLayer
|
* @return ConvolutionLayer
|
||||||
*/
|
*/
|
||||||
public Convolution1DLayer getAtrousConvolution1D() {
|
public Convolution1D getAtrousConvolution1D() {
|
||||||
return (Convolution1DLayer) this.layer;
|
return (Convolution1D) this.layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -24,6 +24,7 @@ import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Convolution2D;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -85,7 +86,7 @@ public class KerasAtrousConvolution2D extends KerasConvolution {
|
||||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||||
|
|
||||||
val builder = ConvolutionLayer.builder().name(this.name)
|
val builder = Convolution2D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(init)
|
.weightInit(init)
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -93,7 +93,7 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
|
|
||||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||||
Convolution1DLayer.Convolution1DLayerBuilder builder = Convolution1DLayer.builder().name(this.name)
|
var builder = Convolution1D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(init)
|
.weightInit(init)
|
||||||
|
@ -125,9 +125,9 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
|
|
||||||
this.layer = builder.build();
|
this.layer = builder.build();
|
||||||
//set this in order to infer the dimensional format
|
//set this in order to infer the dimensional format
|
||||||
Convolution1DLayer convolution1DLayer = (Convolution1DLayer) this.layer;
|
Convolution1D convolution1D = (Convolution1D) this.layer;
|
||||||
convolution1DLayer.setDataFormat(dimOrder == DimOrder.TENSORFLOW ? CNN2DFormat.NHWC : CNN2DFormat.NCHW);
|
convolution1D.setDataFormat(dimOrder == DimOrder.TENSORFLOW ? CNN2DFormat.NHWC : CNN2DFormat.NCHW);
|
||||||
convolution1DLayer.setDefaultValueOverriden(true);
|
convolution1D.setDefaultValueOverriden(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -135,8 +135,8 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
*
|
*
|
||||||
* @return ConvolutionLayer
|
* @return ConvolutionLayer
|
||||||
*/
|
*/
|
||||||
public Convolution1DLayer getConvolution1DLayer() {
|
public Convolution1D getConvolution1DLayer() {
|
||||||
return (Convolution1DLayer) this.layer;
|
return (Convolution1D) this.layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Convolution2D;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -95,7 +96,7 @@ public class KerasConvolution2D extends KerasConvolution {
|
||||||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||||
|
|
||||||
final var builder = ConvolutionLayer.builder().name(this.name)
|
final var builder = Convolution2D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(init)
|
.weightInit(init)
|
||||||
|
|
|
@ -23,6 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||||
|
@ -41,8 +42,8 @@ public class JsonTest extends BaseDL4JTest {
|
||||||
|
|
||||||
};
|
};
|
||||||
for(InputPreProcessor p : pp ){
|
for(InputPreProcessor p : pp ){
|
||||||
String s = NeuralNetConfiguration.mapper().writeValueAsString(p);
|
String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(p);
|
||||||
InputPreProcessor p2 = NeuralNetConfiguration.mapper().readValue(s, InputPreProcessor.class);
|
InputPreProcessor p2 = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, InputPreProcessor.class);
|
||||||
assertEquals(p, p2);
|
assertEquals(p, p2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,11 +29,8 @@ import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
||||||
|
@ -656,7 +653,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
|
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
|
||||||
true, true, false, null, null);
|
true, true, false, null, null);
|
||||||
Layer l = net.getLayer(0);
|
Layer l = net.getLayer(0);
|
||||||
Convolution1DLayer c1d = (Convolution1DLayer) l.getTrainingConfig();
|
Convolution1D c1d = (Convolution1D) l.getTrainingConfig();
|
||||||
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
|
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||||
|
@ -97,7 +97,7 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest {
|
||||||
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
||||||
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
||||||
|
|
||||||
Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
|
Convolution1D layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getName());
|
assertEquals(LAYER_NAME, layer.getName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInit());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
|
|
|
@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||||
|
@ -119,7 +119,7 @@ public class KerasConvolution1DTest extends BaseDL4JTest {
|
||||||
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
||||||
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
||||||
|
|
||||||
Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
|
Convolution1D layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getName());
|
assertEquals(LAYER_NAME, layer.getName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInit());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
|
|
|
@ -22,8 +22,6 @@
|
||||||
package net.brutex.ai.dnn.api;
|
package net.brutex.ai.dnn.api;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
|
|
||||||
public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
|
public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
|
||||||
|
|
||||||
|
|
|
@ -31,9 +31,11 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
public class NN {
|
public class NN {
|
||||||
|
|
||||||
|
|
||||||
public static NeuralNetConfigurationBuilder<?, ?> net() {
|
public static NeuralNetConfigurationBuilder<?, ?> nn() {
|
||||||
return NeuralNetConfiguration.builder();
|
return NeuralNetConfiguration.builder();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static DenseLayer.DenseLayerBuilder<?,?> dense() { return DenseLayer.builder(); }
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,6 @@ package net.brutex.ai.dnn.networks;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
@ -33,7 +32,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Artificial Neural Network An artificial neural network (1) takes some input data, and (2)
|
* Artificial Neural Network An artificial neural network (1) takes some input data, and (2)
|
||||||
* transforms this input data by calculating a weighted sum over the inputs and (3) applies a
|
* transforms this input data by calculating a weighted sum over the inputs and (3) applies a
|
||||||
|
|
|
@ -20,6 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping;
|
package org.deeplearning4j.earlystopping;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
@ -30,11 +34,6 @@ import org.deeplearning4j.earlystopping.termination.IterationTerminationConditio
|
||||||
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
||||||
import org.nd4j.common.function.Supplier;
|
import org.nd4j.common.function.Supplier;
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class EarlyStoppingConfiguration<T extends IModel> implements Serializable {
|
public class EarlyStoppingConfiguration<T extends IModel> implements Serializable {
|
||||||
|
|
|
@ -20,16 +20,15 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping;
|
package org.deeplearning4j.earlystopping;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
|
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
|
||||||
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
|
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
|
||||||
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
|
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = InMemoryModelSaver.class, name = "InMemoryModelSaver"),
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = InMemoryModelSaver.class, name = "InMemoryModelSaver"),
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping;
|
package org.deeplearning4j.earlystopping;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import lombok.Data;
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class EarlyStoppingResult<T extends IModel> implements Serializable {
|
public class EarlyStoppingResult<T extends IModel> implements Serializable {
|
||||||
|
|
|
@ -20,10 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.saver;
|
package org.deeplearning4j.earlystopping.saver;
|
||||||
|
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
||||||
|
|
||||||
public class InMemoryModelSaver<T extends IModel> implements EarlyStoppingModelSaver<T> {
|
public class InMemoryModelSaver<T extends IModel> implements EarlyStoppingModelSaver<T> {
|
||||||
|
|
||||||
|
|
|
@ -20,15 +20,14 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.saver;
|
package org.deeplearning4j.earlystopping.saver;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.Charset;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.charset.Charset;
|
|
||||||
|
|
||||||
public class LocalFileGraphSaver implements EarlyStoppingModelSaver<ComputationGraph> {
|
public class LocalFileGraphSaver implements EarlyStoppingModelSaver<ComputationGraph> {
|
||||||
|
|
||||||
private static final String BEST_GRAPH_BIN = "bestGraph.bin";
|
private static final String BEST_GRAPH_BIN = "bestGraph.bin";
|
||||||
|
|
|
@ -20,15 +20,14 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.saver;
|
package org.deeplearning4j.earlystopping.saver;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.Charset;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.charset.Charset;
|
|
||||||
|
|
||||||
public class LocalFileModelSaver implements EarlyStoppingModelSaver<MultiLayerNetwork> {
|
public class LocalFileModelSaver implements EarlyStoppingModelSaver<MultiLayerNetwork> {
|
||||||
|
|
||||||
private static final String BEST_MODEL_BIN = "bestModel.bin";
|
private static final String BEST_MODEL_BIN = "bestModel.bin";
|
||||||
|
|
|
@ -26,11 +26,11 @@ import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
|
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
||||||
|
|
||||||
public class AutoencoderScoreCalculator extends BaseScoreCalculator<IModel> {
|
public class AutoencoderScoreCalculator extends BaseScoreCalculator<IModel> {
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc;
|
package org.deeplearning4j.earlystopping.scorecalc;
|
||||||
|
|
||||||
import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -29,7 +30,6 @@ import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
public class DataSetLossCalculator extends BaseScoreCalculator<IModel> {
|
public class DataSetLossCalculator extends BaseScoreCalculator<IModel> {
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc;
|
package org.deeplearning4j.earlystopping.scorecalc;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
@ -27,8 +29,6 @@ import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Deprecated
|
@Deprecated
|
||||||
|
|
|
@ -20,12 +20,11 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc;
|
package org.deeplearning4j.earlystopping.scorecalc;
|
||||||
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
|
|
@ -26,11 +26,11 @@ import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
||||||
|
|
||||||
public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<IModel> {
|
public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<IModel> {
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc.base;
|
package org.deeplearning4j.earlystopping.scorecalc.base;
|
||||||
|
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||||
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
package org.deeplearning4j.earlystopping.scorecalc.base;
|
package org.deeplearning4j.earlystopping.scorecalc.base;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class BestScoreEpochTerminationCondition implements EpochTerminationCondition {
|
public class BestScoreEpochTerminationCondition implements EpochTerminationCondition {
|
||||||
|
|
|
@ -22,9 +22,7 @@ package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
|
|
@ -22,7 +22,6 @@ package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class MaxScoreIterationTerminationCondition implements IterationTerminationCondition {
|
public class MaxScoreIterationTerminationCondition implements IterationTerminationCondition {
|
||||||
|
|
|
@ -20,10 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
/**Terminate training based on max time.
|
/**Terminate training based on max time.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,6 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.trainer;
|
package org.deeplearning4j.earlystopping.trainer;
|
||||||
|
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
|
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
|
||||||
|
@ -40,13 +46,6 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.FileNotFoundException;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public abstract class BaseEarlyStoppingTrainer<T extends IModel> implements IEarlyStoppingTrainer<T> {
|
public abstract class BaseEarlyStoppingTrainer<T extends IModel> implements IEarlyStoppingTrainer<T> {
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class);
|
private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class);
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.trainer;
|
package org.deeplearning4j.earlystopping.trainer;
|
||||||
|
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
|
||||||
import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
||||||
|
|
|
@ -20,6 +20,13 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonAutoDetect;
|
||||||
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.MapperFeature;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.SerializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.module.SimpleModule;
|
||||||
|
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.common.primitives.AtomicBoolean;
|
import org.nd4j.common.primitives.AtomicBoolean;
|
||||||
|
@ -28,14 +35,6 @@ import org.nd4j.common.primitives.serde.JsonDeserializerAtomicBoolean;
|
||||||
import org.nd4j.common.primitives.serde.JsonDeserializerAtomicDouble;
|
import org.nd4j.common.primitives.serde.JsonDeserializerAtomicDouble;
|
||||||
import org.nd4j.common.primitives.serde.JsonSerializerAtomicBoolean;
|
import org.nd4j.common.primitives.serde.JsonSerializerAtomicBoolean;
|
||||||
import org.nd4j.common.primitives.serde.JsonSerializerAtomicDouble;
|
import org.nd4j.common.primitives.serde.JsonSerializerAtomicDouble;
|
||||||
import com.fasterxml.jackson.annotation.JsonAutoDetect;
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
||||||
import com.fasterxml.jackson.databind.DeserializationFeature;
|
|
||||||
import com.fasterxml.jackson.databind.MapperFeature;
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import com.fasterxml.jackson.databind.SerializationFeature;
|
|
||||||
import com.fasterxml.jackson.databind.module.SimpleModule;
|
|
||||||
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
|
|
@ -20,15 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
import com.google.common.collect.HashMultiset;
|
|
||||||
import com.google.common.collect.Multiset;
|
|
||||||
import lombok.Getter;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public class ConfusionMatrix<T extends Comparable<? super T>> extends org.nd4j.evaluation.classification.ConfusionMatrix<T> {
|
public class ConfusionMatrix<T extends Comparable<? super T>> extends org.nd4j.evaluation.classification.ConfusionMatrix<T> {
|
||||||
|
|
|
@ -20,14 +20,11 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import lombok.NonNull;
|
|
||||||
import org.nd4j.evaluation.EvaluationAveraging;
|
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Deprecated
|
@Deprecated
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Getter
|
@Getter
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.evaluation.curves.BaseHistogram;
|
import org.nd4j.evaluation.curves.BaseHistogram;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,13 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
import lombok.NonNull;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.NonNull;
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public class ReliabilityDiagram extends org.nd4j.evaluation.curves.ReliabilityDiagram {
|
public class ReliabilityDiagram extends org.nd4j.evaluation.curves.ReliabilityDiagram {
|
||||||
|
|
|
@ -20,10 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.meta;
|
package org.deeplearning4j.eval.meta;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.adapters;
|
package org.deeplearning4j.nn.adapters;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
@ -32,8 +33,6 @@ import org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
|
|
||||||
|
|
|
@ -20,14 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
public interface Classifier extends IModel {
|
public interface Classifier extends IModel {
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,13 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface ITraininableLayerConfiguration {
|
public interface ITraininableLayerConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
|
||||||
import java.util.Map;
|
import java.io.Serializable;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.nn.conf.CacheMode;
|
import org.deeplearning4j.nn.conf.CacheMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -29,10 +29,8 @@ import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.LayerHelper;
|
import org.deeplearning4j.nn.layers.LayerHelper;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A layer is the highest-level building block in deep learning. A layer is a container that usually
|
* A layer is the highest-level building block in deep learning. A layer is a container that usually
|
||||||
|
|
|
@ -20,13 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Param initializer for a layer
|
* Param initializer for a layer
|
||||||
*
|
*
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update the model
|
* Update the model
|
||||||
|
|
|
@ -22,8 +22,8 @@ package org.deeplearning4j.nn.api.layers;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.Classifier;
|
import org.deeplearning4j.nn.api.Classifier;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public interface IOutputLayer extends Layer, Classifier {
|
public interface IOutputLayer extends Layer, Classifier {
|
||||||
|
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api.layers;
|
package org.deeplearning4j.nn.api.layers;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
public interface LayerConstraint extends Cloneable, Serializable {
|
public interface LayerConstraint extends Cloneable, Serializable {
|
||||||
|
|
|
@ -20,13 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api.layers;
|
package org.deeplearning4j.nn.api.layers;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
import java.util.Map;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public interface RecurrentLayer extends Layer {
|
public interface RecurrentLayer extends Layer {
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.*;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
|
@ -34,6 +40,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
@ -42,16 +49,9 @@ import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"})
|
@EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"})
|
||||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
@ -110,7 +110,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
* @return YAML representation of configuration
|
* @return YAML representation of configuration
|
||||||
*/
|
*/
|
||||||
public String toYaml() {
|
public String toYaml() {
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
|
||||||
synchronized (mapper) {
|
synchronized (mapper) {
|
||||||
try {
|
try {
|
||||||
return mapper.writeValueAsString(this);
|
return mapper.writeValueAsString(this);
|
||||||
|
@ -127,7 +127,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
* @return {@link ComputationGraphConfiguration}
|
* @return {@link ComputationGraphConfiguration}
|
||||||
*/
|
*/
|
||||||
public static ComputationGraphConfiguration fromYaml(String json) {
|
public static ComputationGraphConfiguration fromYaml(String json) {
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
|
||||||
try {
|
try {
|
||||||
return mapper.readValue(json, ComputationGraphConfiguration.class);
|
return mapper.readValue(json, ComputationGraphConfiguration.class);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
|
@ -140,7 +140,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
*/
|
*/
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
//As per NeuralNetConfiguration.toJson()
|
//As per NeuralNetConfiguration.toJson()
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
ObjectMapper mapper =CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
synchronized (mapper) {
|
synchronized (mapper) {
|
||||||
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
|
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
|
||||||
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
|
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
|
||||||
|
@ -160,7 +160,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
*/
|
*/
|
||||||
public static ComputationGraphConfiguration fromJson(String json) {
|
public static ComputationGraphConfiguration fromJson(String json) {
|
||||||
//As per NeuralNetConfiguration.fromJson()
|
//As per NeuralNetConfiguration.fromJson()
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
ComputationGraphConfiguration conf;
|
ComputationGraphConfiguration conf;
|
||||||
try {
|
try {
|
||||||
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
|
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
|
||||||
|
|
|
@ -19,10 +19,10 @@
|
||||||
*/
|
*/
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
|
|
||||||
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
|
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
|
||||||
|
|
||||||
@JsonSerialize(using = DataFormatSerializer.class)
|
@JsonSerialize(using = DataFormatSerializer.class)
|
||||||
@JsonDeserialize(using = DataFormatDeserializer.class)
|
@JsonDeserialize(using = DataFormatDeserializer.class)
|
||||||
|
|
|
@ -21,14 +21,13 @@
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import java.io.Serializable;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
public interface InputPreProcessor extends Serializable, Cloneable {
|
public interface InputPreProcessor extends Serializable, Cloneable {
|
||||||
|
|
|
@ -21,10 +21,9 @@
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import java.util.*;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -35,10 +34,8 @@ import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
||||||
import org.deeplearning4j.nn.conf.dropout.IDropout;
|
import org.deeplearning4j.nn.conf.dropout.IDropout;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
|
||||||
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
|
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
|
||||||
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
@ -47,7 +44,6 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
import org.deeplearning4j.util.NetworkUtils;
|
import org.deeplearning4j.util.NetworkUtils;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
@ -57,9 +53,6 @@ import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
|
* Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
|
||||||
* multiple layers. Everything starts with a NeuralNetConfiguration, which organizes those layers
|
* multiple layers. Everything starts with a NeuralNetConfiguration, which organizes those layers
|
||||||
|
@ -159,7 +152,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
@Getter @Setter @NonNull @lombok.Builder.Default
|
@Getter @Setter @NonNull @lombok.Builder.Default
|
||||||
protected BackpropType backpropType = BackpropType.Standard;
|
protected BackpropType backpropType = BackpropType.Standard;
|
||||||
|
|
||||||
@Getter @lombok.Builder.Default
|
@Getter @Setter @Singular
|
||||||
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
|
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
|
||||||
/**
|
/**
|
||||||
* When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
|
* When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
|
||||||
|
@ -331,7 +324,6 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
*/
|
*/
|
||||||
@Getter @Setter @lombok.Builder.Default private IUpdater biasUpdater = null;
|
@Getter @Setter @lombok.Builder.Default private IUpdater biasUpdater = null;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Weight initialization scheme to use, for initial weight values Note: values set by this method
|
* Weight initialization scheme to use, for initial weight values Note: values set by this method
|
||||||
* will be applied to all applicable layers in the network, unless a different value is explicitly
|
* will be applied to all applicable layers in the network, unless a different value is explicitly
|
||||||
|
@ -339,6 +331,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
* and can be overridden on a per-layer basis.
|
* and can be overridden on a per-layer basis.
|
||||||
*/
|
*/
|
||||||
@Getter @Setter @lombok.Builder.Default private IWeightInit weightInit = new WeightInitXavier();
|
@Getter @Setter @lombok.Builder.Default private IWeightInit weightInit = new WeightInitXavier();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the convolution mode for convolutional layers, which impacts padding and output sizes. See
|
* Sets the convolution mode for convolutional layers, which impacts padding and output sizes. See
|
||||||
* {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE<br>
|
* {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE<br>
|
||||||
|
@ -416,113 +409,6 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
@Getter @Setter @lombok.Builder.Default private double biasInit = 0.0;
|
@Getter @Setter @lombok.Builder.Default private double biasInit = 0.0;
|
||||||
@Getter @Setter @lombok.Builder.Default private double gainInit = 1.0;
|
@Getter @Setter @lombok.Builder.Default private double gainInit = 1.0;
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied
|
|
||||||
* from handling of {@link Activation} above.
|
|
||||||
*
|
|
||||||
* @return True if all is well and layer iteration shall continue. False else-wise.
|
|
||||||
*/
|
|
||||||
private static boolean handleLegacyWeightInitFromJson(
|
|
||||||
String json, LayerConfiguration l, ObjectMapper mapper, JsonNode confs, int layerCount) {
|
|
||||||
if ((l instanceof BaseLayerConfiguration)
|
|
||||||
&& ((BaseLayerConfiguration) l).getWeightInit() == null) {
|
|
||||||
try {
|
|
||||||
JsonNode jsonNode = mapper.readTree(json);
|
|
||||||
if (confs == null) {
|
|
||||||
confs = jsonNode.get("confs");
|
|
||||||
}
|
|
||||||
if (confs instanceof ArrayNode) {
|
|
||||||
ArrayNode layerConfs = (ArrayNode) confs;
|
|
||||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
|
||||||
if (outputLayerNNCNode == null) {
|
|
||||||
return false; // Should never happen...
|
|
||||||
}
|
|
||||||
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
|
||||||
|
|
||||||
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
JsonNode layerNode = layerWrapperNode.elements().next();
|
|
||||||
JsonNode weightInit =
|
|
||||||
layerNode.get("weightInit"); // Should only have 1 element: "dense", "output", etc
|
|
||||||
JsonNode distribution = layerNode.get("dist");
|
|
||||||
|
|
||||||
Distribution dist = null;
|
|
||||||
if (distribution != null) {
|
|
||||||
dist = mapper.treeToValue(distribution, Distribution.class);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (weightInit != null) {
|
|
||||||
final IWeightInit wi =
|
|
||||||
WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist);
|
|
||||||
((BaseLayerConfiguration) l).setWeightInit(wi);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.warn(
|
|
||||||
"ILayer with null WeightInit detected: " + l.getName() + ", could not parse JSON",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Object mapper for serialization of configurations
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ObjectMapper mapperYaml() {
|
|
||||||
return JsonMappers.getMapperYaml();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Object mapper for serialization of configurations
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ObjectMapper mapper() {
|
|
||||||
return JsonMappers.getMapper();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static NeuralNetBaseBuilderConfiguration fromYaml(String input) {
|
|
||||||
throw new RuntimeException("Needs fixing - not supported."); // TODO
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return JSON representation of NN configuration
|
|
||||||
*/
|
|
||||||
public String toYaml() {
|
|
||||||
ObjectMapper mapper = NeuralNetBaseBuilderConfiguration.mapperYaml();
|
|
||||||
synchronized (mapper) {
|
|
||||||
try {
|
|
||||||
return mapper.writeValueAsString(this);
|
|
||||||
} catch (com.fasterxml.jackson.core.JsonProcessingException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return JSON representation of NN configuration
|
|
||||||
*/
|
|
||||||
public String toJson() {
|
|
||||||
ObjectMapper mapper = NeuralNetBaseBuilderConfiguration.mapper();
|
|
||||||
synchronized (mapper) {
|
|
||||||
// JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields
|
|
||||||
// occasionally
|
|
||||||
// when writeValueAsString is used by multiple threads. This results in invalid JSON. See
|
|
||||||
// issue #3243
|
|
||||||
try {
|
|
||||||
return mapper.writeValueAsString(this);
|
|
||||||
} catch (com.fasterxml.jackson.core.JsonProcessingException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNetBaseBuilderConfiguration clone() {
|
public NeuralNetBaseBuilderConfiguration clone() {
|
||||||
NeuralNetBaseBuilderConfiguration clone;
|
NeuralNetBaseBuilderConfiguration clone;
|
||||||
|
@ -561,14 +447,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
|
|
||||||
List<Object> innerConfigurations$value = new ArrayList<>(); // initialize with an empty list
|
List<Object> innerConfigurations$value = new ArrayList<>(); // initialize with an empty list
|
||||||
|
|
||||||
public B activation(Activation activation) {
|
|
||||||
this.activation = activation;
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
public B activation(IActivation activation) {
|
|
||||||
this.activation = activation;
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
/**
|
/**
|
||||||
* Set constraints to be applied to all layers. Default: no constraints.<br>
|
* Set constraints to be applied to all layers. Default: no constraints.<br>
|
||||||
* Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
|
* Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
|
||||||
|
@ -583,7 +462,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
public B constrainWeights(LayerConstraint... constraints) {
|
public B constrainWeights(LayerConstraint... constraints) {
|
||||||
constrainWeights$value = Arrays.asList(constraints);
|
constrainWeights$value = Arrays.asList(constraints);
|
||||||
constrainWeights$set = true;
|
constrainWeights$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -618,7 +497,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
public B constrainAllParameters(LayerConstraint... constraints) {
|
public B constrainAllParameters(LayerConstraint... constraints) {
|
||||||
allParamConstraints$value = Arrays.asList(constraints);
|
allParamConstraints$value = Arrays.asList(constraints);
|
||||||
allParamConstraints$set = true;
|
allParamConstraints$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -635,7 +514,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
public B constrainBias(LayerConstraint... constraints) {
|
public B constrainBias(LayerConstraint... constraints) {
|
||||||
biasConstraints$value = Arrays.asList(constraints);
|
biasConstraints$value = Arrays.asList(constraints);
|
||||||
biasConstraints$set = true;
|
biasConstraints$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -645,11 +524,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
* @param processor what to use to preProcess the data.
|
* @param processor what to use to preProcess the data.
|
||||||
* @return builder pattern
|
* @return builder pattern
|
||||||
*/
|
*/
|
||||||
public B inputPreProcessor(Integer layer, InputPreProcessor processor) {
|
//public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
|
||||||
inputPreProcessors$value.put(layer, processor);
|
// inputPreProcessors$value.put(layer, processor);
|
||||||
inputPreProcessors$set = true;
|
// inputPreProcessors$set = true;
|
||||||
return (B) this;
|
// return self();
|
||||||
}
|
// }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set layer at index
|
* Set layer at index
|
||||||
|
@ -658,7 +537,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
* @param layer the layer
|
* @param layer the layer
|
||||||
* @return builder
|
* @return builder
|
||||||
*/
|
*/
|
||||||
public B layer(Integer index, @NonNull LayerConfiguration layer) {
|
public B layer(@NonNull Integer index, @NonNull LayerConfiguration layer) {
|
||||||
innerConfigurations$value.add(index, layer);
|
innerConfigurations$value.add(index, layer);
|
||||||
innerConfigurations$set = true;
|
innerConfigurations$set = true;
|
||||||
return self();
|
return self();
|
||||||
|
@ -680,10 +559,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
* @param layer the layer
|
* @param layer the layer
|
||||||
* @return builder
|
* @return builder
|
||||||
*/
|
*/
|
||||||
|
@JsonIgnore
|
||||||
public B layer(@NonNull LayerConfiguration layer) {
|
public B layer(@NonNull LayerConfiguration layer) {
|
||||||
innerConfigurations$value.add(layer);
|
innerConfigurations$value.add(layer);
|
||||||
innerConfigurations$set = true;
|
innerConfigurations$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
public B layer(@NonNull LayerConfiguration.LayerConfigurationBuilder<?, ?> layer) {
|
public B layer(@NonNull LayerConfiguration.LayerConfigurationBuilder<?, ?> layer) {
|
||||||
return this.layer(layer.build());
|
return this.layer(layer.build());
|
||||||
|
@ -699,7 +579,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
public B layersFromArray(@NonNull LayerConfiguration[] arrLayers) {
|
public B layersFromArray(@NonNull LayerConfiguration[] arrLayers) {
|
||||||
innerConfigurations$value.addAll(List.of(arrLayers));
|
innerConfigurations$value.addAll(List.of(arrLayers));
|
||||||
innerConfigurations$set = true;
|
innerConfigurations$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Specify additional layer configurations */
|
/** Specify additional layer configurations */
|
||||||
|
@ -707,7 +587,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
public B layersFromList(@NonNull List<LayerConfiguration> listLayers) {
|
public B layersFromList(@NonNull List<LayerConfiguration> listLayers) {
|
||||||
innerConfigurations$value.addAll(listLayers);
|
innerConfigurations$value.addAll(listLayers);
|
||||||
innerConfigurations$set = true;
|
innerConfigurations$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -723,7 +603,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
regularization$value.add(new L1Regularization(l1));
|
regularization$value.add(new L1Regularization(l1));
|
||||||
}
|
}
|
||||||
regularization$set = true;
|
regularization$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -751,7 +631,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
regularization$value.add(new L2Regularization(l2));
|
regularization$value.add(new L2Regularization(l2));
|
||||||
}
|
}
|
||||||
regularization$set = true;
|
regularization$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -766,7 +646,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
regularizationBias$value.add(new L1Regularization(l1Bias));
|
regularizationBias$value.add(new L1Regularization(l1Bias));
|
||||||
}
|
}
|
||||||
regularizationBias$set = true;
|
regularizationBias$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -791,7 +671,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
"L2 bias regularization removed: incompatible with added WeightDecay regularization");
|
"L2 bias regularization removed: incompatible with added WeightDecay regularization");
|
||||||
regularizationBias$value.add(new L2Regularization(l2Bias));
|
regularizationBias$value.add(new L2Regularization(l2Bias));
|
||||||
}
|
}
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -833,7 +713,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
regularization$value.add(new WeightDecay(coefficient, applyLR));
|
regularization$value.add(new WeightDecay(coefficient, applyLR));
|
||||||
}
|
}
|
||||||
regularization$set = true;
|
regularization$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -870,7 +750,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
regularizationBias$value.add(new WeightDecay(coefficient, applyLR));
|
regularizationBias$value.add(new WeightDecay(coefficient, applyLR));
|
||||||
}
|
}
|
||||||
regularization$set = true;
|
regularization$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -881,7 +761,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public B list() {
|
public B list() {
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -893,23 +773,24 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
*
|
*
|
||||||
* @param distribution Distribution to use for weight initialization
|
* @param distribution Distribution to use for weight initialization
|
||||||
*/
|
*/
|
||||||
@JsonIgnore
|
@JsonIgnore @Deprecated
|
||||||
public B weightInit(Distribution distribution) {
|
public B weightInit(Distribution distribution) {
|
||||||
this.weightInit$value = new WeightInitDistribution(distribution);
|
this.weightInit$value = new WeightInitDistribution(distribution);
|
||||||
this.weightInit$set = true;
|
this.weightInit$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
public B weightInit(WeightInit weightInit) {
|
public B weightInit(WeightInit weightInit) {
|
||||||
this.weightInit$value = weightInit.getWeightInitFunction();
|
this.weightInit$value = weightInit.getWeightInitFunction();
|
||||||
this.weightInit$set = true;
|
this.weightInit$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@JsonProperty("weightInit") //this is needed for Jackson < 2.4, otherwise JsonIgnore on the other setters will ignore this also
|
||||||
public B weightInit(IWeightInit iWeightInit) {
|
public B weightInit(IWeightInit iWeightInit) {
|
||||||
this.weightInit$value = iWeightInit;
|
this.weightInit$value = iWeightInit;
|
||||||
this.weightInit$set = true;
|
this.weightInit$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -918,12 +799,13 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
* @param distribution
|
* @param distribution
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@JsonIgnore
|
||||||
public B dist(@NonNull Distribution distribution) {
|
public B dist(@NonNull Distribution distribution) {
|
||||||
return (B) weightInit(distribution);
|
return weightInit(distribution);
|
||||||
}
|
}
|
||||||
|
|
||||||
public B dropOut(@NonNull IDropout dropout) {
|
public B dropOut(@NonNull IDropout dropout) {
|
||||||
return (B) idropOut(dropout);
|
return idropOut(dropout);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -933,7 +815,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
* @return builder
|
* @return builder
|
||||||
*/
|
*/
|
||||||
public B dropOut(double dropout) {
|
public B dropOut(double dropout) {
|
||||||
return (B) idropOut(new Dropout(dropout));
|
return idropOut(new Dropout(dropout));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -946,7 +828,8 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
public B confs(@NonNull List<NeuralNetConfiguration> confs) {
|
public B confs(@NonNull List<NeuralNetConfiguration> confs) {
|
||||||
innerConfigurations$value.addAll(confs);
|
innerConfigurations$value.addAll(confs);
|
||||||
innerConfigurations$set = true;
|
innerConfigurations$set = true;
|
||||||
return (B) this;
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,39 +22,27 @@ package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.*;
|
||||||
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
|
import java.util.*;
|
||||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import lombok.extern.jackson.Jacksonized;
|
import lombok.extern.jackson.Jacksonized;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
|
||||||
import org.deeplearning4j.util.OutputLayerUtil;
|
import org.deeplearning4j.util.OutputLayerUtil;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.config.Sgd;
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
|
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
|
* Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
|
||||||
|
@ -62,71 +50,50 @@ import java.util.stream.Collectors;
|
||||||
* and their hyperparameters. Hyperparameters are variables that determine how a neural network
|
* and their hyperparameters. Hyperparameters are variables that determine how a neural network
|
||||||
* learns. They include how many times to update the weights of the model, how to initialize those
|
* learns. They include how many times to update the weights of the model, how to initialize those
|
||||||
* weights, which activation function to attach to the nodes, which optimization algorithm to use,
|
* weights, which activation function to attach to the nodes, which optimization algorithm to use,
|
||||||
* and how fast the model should learn. This is what one configuration would look like:
|
* and how fast the model should learn. This is what one configuration would look like: <br>
|
||||||
* <br/><br/>
|
* <br>
|
||||||
*
|
* NeuralNetConfiguration conf = NeuralNetConfiguration.builder()<br>
|
||||||
* NeuralNetConfiguration conf = NeuralNetConfiguration.builder()<br/>
|
* .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)<br>
|
||||||
* .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)<br/>
|
* .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)<br>
|
||||||
* .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)<br/>
|
* .updater(new Sgd(0.05)) //... other hyperparameters <br>
|
||||||
* .updater(new Sgd(0.05)) //... other hyperparameters <br/>
|
* .backprop(true)<br>
|
||||||
* .backprop(true)<br/>
|
* .build();<br>
|
||||||
* .build();<br/><br/>
|
* <br>
|
||||||
*
|
* With Deeplearning4j, you add a layer by calling layer on the
|
||||||
* With Deeplearning4j, you add a layer
|
* NeuralNetConfiguration.NeuralNetConfigurationBuilder(), specifying its place in the order of
|
||||||
* by calling layer on the NeuralNetConfiguration.NeuralNetConfigurationBuilder(), specifying its place in the order of
|
|
||||||
* layers (the zero-indexed layer below is the input layer), the number of input and output nodes,
|
* layers (the zero-indexed layer below is the input layer), the number of input and output nodes,
|
||||||
* nIn and nOut, as well as the type: DenseLayer.<br/><br/>
|
* nIn and nOut, as well as the type: DenseLayer.<br>
|
||||||
*
|
* <br>
|
||||||
* .layer(0, DenseLayer.builder().nIn(784).nOut(250)<br/>
|
* .layer(0, DenseLayer.builder().nIn(784).nOut(250)<br>
|
||||||
* .build())<br/><br/>
|
* .build())<br>
|
||||||
*
|
* <br>
|
||||||
* Once you've configured your net, you train the
|
* Once you've configured your net, you train the model with model.fit.
|
||||||
* model with model.fit.
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Jacksonized
|
@JsonIgnoreProperties(value = {"net"})
|
||||||
@JsonIgnoreProperties(value={"net"}, ignoreUnknown = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@EqualsAndHashCode(exclude = {"net"}, callSuper = true)
|
// @JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id")
|
||||||
//@JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id")
|
|
||||||
|
|
||||||
//The inner builder, that we can then extend ...
|
// The inner builder, that we can then extend ...
|
||||||
@SuperBuilder //TODO fix access
|
@Jacksonized
|
||||||
|
@SuperBuilder // TODO fix access
|
||||||
public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
|
|
||||||
|
|
||||||
private IModel net;
|
|
||||||
private static final int DEFAULT_TBPTT_LENGTH = 20;
|
private static final int DEFAULT_TBPTT_LENGTH = 20;
|
||||||
private boolean initCalled = false;
|
|
||||||
|
|
||||||
|
@Getter @Setter @NonNull @lombok.Builder.Default @Deprecated
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@NonNull
|
|
||||||
@lombok.Builder.Default
|
|
||||||
@Deprecated
|
|
||||||
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
|
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
|
||||||
@Getter
|
|
||||||
@Setter
|
@Getter @Setter @NonNull @lombok.Builder.Default @Deprecated
|
||||||
@NonNull
|
|
||||||
@lombok.Builder.Default
|
|
||||||
@Deprecated
|
|
||||||
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
|
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
|
||||||
|
|
||||||
|
@Getter @Setter @lombok.Builder.Default protected int iterationCount = 0;
|
||||||
@Getter
|
// Counter for the number of epochs completed so far. Used for per-epoch schedules
|
||||||
@Setter
|
@Getter @Setter @lombok.Builder.Default protected int epochCount = 0;
|
||||||
@lombok.Builder.Default
|
@lombok.Builder.Default protected double dampingFactor = 100;
|
||||||
protected int iterationCount = 0;
|
@EqualsAndHashCode.Exclude private IModel net;
|
||||||
//Counter for the number of epochs completed so far. Used for per-epoch schedules
|
private boolean initCalled = false;
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@lombok.Builder.Default
|
|
||||||
protected int epochCount = 0;
|
|
||||||
@lombok.Builder.Default
|
|
||||||
protected double dampingFactor = 100;
|
|
||||||
// gradient keys used for ensuring order when getting and setting the gradient
|
// gradient keys used for ensuring order when getting and setting the gradient
|
||||||
@lombok.Builder.Default private LinkedHashSet<String> netWideVariables = new LinkedHashSet<>();
|
@lombok.Builder.Default private LinkedHashSet<String> netWideVariables = new LinkedHashSet<>();
|
||||||
|
|
||||||
|
@ -141,22 +108,19 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
*/
|
*/
|
||||||
@Getter @Setter @Builder.Default private IUpdater updater = new Sgd();
|
@Getter @Setter @Builder.Default private IUpdater updater = new Sgd();
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage of cuDNN.
|
* Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage
|
||||||
* See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory.
|
* of cuDNN. See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but
|
||||||
* <br>
|
* "NO_WORKSPACE" uses less memory. <br>
|
||||||
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
* Note: values set by this method will be applied to all applicable layers in the network, unless
|
||||||
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
* a different value is explicitly set on a given layer. In other words: values set via this
|
||||||
* value, and can be overridden on a per-layer basis.
|
* method are used as the default value, and can be overridden on a per-layer basis.
|
||||||
|
*
|
||||||
* @param cudnnAlgoMode cuDNN algo mode to use
|
* @param cudnnAlgoMode cuDNN algo mode to use
|
||||||
*/
|
*/
|
||||||
@Getter
|
@Getter @Setter @lombok.Builder.Default
|
||||||
@Setter
|
|
||||||
@lombok.Builder.Default
|
|
||||||
private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST;
|
private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a neural net configuration from json
|
* Create a neural net configuration from json
|
||||||
*
|
*
|
||||||
|
@ -164,260 +128,23 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
* @return {@link NeuralNetConfiguration}
|
* @return {@link NeuralNetConfiguration}
|
||||||
*/
|
*/
|
||||||
public static NeuralNetConfiguration fromJson(String json) {
|
public static NeuralNetConfiguration fromJson(String json) {
|
||||||
NeuralNetConfiguration conf;
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
|
||||||
try {
|
try {
|
||||||
conf = mapper.readValue(json, NeuralNetConfiguration.class);
|
return mapper.readValue(json, NeuralNetConfiguration.class);
|
||||||
} catch (InvalidTypeIdException e) {
|
} catch (JsonProcessingException e) {
|
||||||
if (e.getMessage().contains("@class")) {
|
|
||||||
try {
|
|
||||||
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
|
||||||
return JsonMappers.getLegacyMapper().readValue(json, NeuralNetConfiguration.class);
|
|
||||||
} catch (InvalidTypeIdException e2) {
|
|
||||||
//Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.ILayer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
|
|
||||||
//1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
|
|
||||||
String msg = e2.getMessage();
|
|
||||||
if (msg != null && msg.contains("Could not resolve type id")) {
|
|
||||||
throw new RuntimeException(
|
|
||||||
"Error deserializing NeuralNetConfiguration - configuration may have a custom " +
|
|
||||||
"layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom"
|
|
||||||
+
|
|
||||||
" layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
throw new RuntimeException(e2);
|
|
||||||
} catch (IOException e2) {
|
|
||||||
throw new RuntimeException(e2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
} catch (IOException e) {
|
|
||||||
//Check if this exception came from legacy deserializer...
|
|
||||||
String msg = e.getMessage();
|
|
||||||
if (msg != null && msg.contains("legacy")) {
|
|
||||||
throw new RuntimeException(
|
|
||||||
"Error deserializing NeuralNetConfiguration - configuration may have a custom " +
|
|
||||||
"layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be "
|
|
||||||
+
|
|
||||||
"deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
//To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
|
|
||||||
// Previously: enumeration used for loss functions. Now: use classes
|
|
||||||
// IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
|
|
||||||
int layerCount = 0;
|
|
||||||
JsonNode confs = null;
|
|
||||||
for (LayerConfiguration nnc : conf.getFlattenedLayerConfigurations()) {
|
|
||||||
LayerConfiguration l = nnc;
|
|
||||||
if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFunction() == null) {
|
|
||||||
//lossFn field null -> may be an old config format, with lossFunction field being for the enum
|
|
||||||
//if so, try walking the JSON graph to extract out the appropriate enum value
|
|
||||||
|
|
||||||
BaseOutputLayer ol = (BaseOutputLayer) l;
|
|
||||||
try {
|
|
||||||
JsonNode jsonNode = mapper.readTree(json);
|
|
||||||
if (confs == null) {
|
|
||||||
confs = jsonNode.get("confs");
|
|
||||||
}
|
|
||||||
if (confs instanceof ArrayNode) {
|
|
||||||
ArrayNode layerConfs = (ArrayNode) confs;
|
|
||||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
|
||||||
if (outputLayerNNCNode == null) {
|
|
||||||
throw new RuntimeException(
|
|
||||||
"should never happen"); //return conf; //Should never happen...
|
|
||||||
}
|
|
||||||
JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
|
|
||||||
|
|
||||||
JsonNode lossFunctionNode = null;
|
|
||||||
if (outputLayerNode.has("output")) {
|
|
||||||
lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
|
|
||||||
} else if (outputLayerNode.has("rnnoutput")) {
|
|
||||||
lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lossFunctionNode != null) {
|
|
||||||
String lossFunctionEnumStr = lossFunctionNode.asText();
|
|
||||||
LossFunctions.LossFunction lossFunction = null;
|
|
||||||
try {
|
|
||||||
lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.warn(
|
|
||||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lossFunction != null) {
|
|
||||||
switch (lossFunction) {
|
|
||||||
case MSE:
|
|
||||||
ol.setLossFunction(new LossMSE());
|
|
||||||
break;
|
|
||||||
case XENT:
|
|
||||||
ol.setLossFunction(new LossBinaryXENT());
|
|
||||||
break;
|
|
||||||
case NEGATIVELOGLIKELIHOOD:
|
|
||||||
ol.setLossFunction(new LossNegativeLogLikelihood());
|
|
||||||
break;
|
|
||||||
case MCXENT:
|
|
||||||
ol.setLossFunction(new LossMCXENT());
|
|
||||||
break;
|
|
||||||
|
|
||||||
//Remaining: TODO
|
|
||||||
case SQUARED_LOSS:
|
|
||||||
case RECONSTRUCTION_CROSSENTROPY:
|
|
||||||
default:
|
|
||||||
log.warn(
|
|
||||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}",
|
|
||||||
lossFunction);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
log.warn(
|
|
||||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})",
|
|
||||||
(confs != null ? confs.getClass() : null));
|
|
||||||
}
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.warn(
|
|
||||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
|
|
||||||
e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn")
|
|
||||||
//Try to load the old format if necessary, and create the appropriate IActivation instance
|
|
||||||
if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getActivationFn() == null) {
|
|
||||||
try {
|
|
||||||
JsonNode jsonNode = mapper.readTree(json);
|
|
||||||
if (confs == null) {
|
|
||||||
confs = jsonNode.get("confs");
|
|
||||||
}
|
|
||||||
if (confs instanceof ArrayNode) {
|
|
||||||
ArrayNode layerConfs = (ArrayNode) confs;
|
|
||||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
|
||||||
if (outputLayerNNCNode == null) {
|
|
||||||
throw new RuntimeException(
|
|
||||||
"Should never happen"); //return conf; //Should never happen...
|
|
||||||
}
|
|
||||||
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
|
||||||
|
|
||||||
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
JsonNode layerNode = layerWrapperNode.elements().next();
|
|
||||||
JsonNode activationFunction = layerNode.get(
|
|
||||||
"activationFunction"); //Should only have 1 element: "dense", "output", etc
|
|
||||||
|
|
||||||
if (activationFunction != null) {
|
|
||||||
Activation ia = Activation.fromString(activationFunction.asText());
|
|
||||||
((BaseLayerConfiguration) l).setActivation(ia.getActivationFunction());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.warn(
|
|
||||||
"ILayer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) {
|
|
||||||
return conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
layerCount++;
|
|
||||||
}
|
|
||||||
return conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied
|
|
||||||
* from handling of {@link Activation} above.
|
|
||||||
*
|
|
||||||
* @return True if all is well and layer iteration shall continue. False else-wise.
|
|
||||||
*/
|
|
||||||
private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l,
|
|
||||||
ObjectMapper mapper,
|
|
||||||
JsonNode confs, int layerCount) {
|
|
||||||
if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInit() == null) {
|
|
||||||
try {
|
|
||||||
JsonNode jsonNode = mapper.readTree(json);
|
|
||||||
if (confs == null) {
|
|
||||||
confs = jsonNode.get("confs");
|
|
||||||
}
|
|
||||||
if (confs instanceof ArrayNode) {
|
|
||||||
ArrayNode layerConfs = (ArrayNode) confs;
|
|
||||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
|
||||||
if (outputLayerNNCNode == null) {
|
|
||||||
return false; //Should never happen...
|
|
||||||
}
|
|
||||||
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
|
||||||
|
|
||||||
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
JsonNode layerNode = layerWrapperNode.elements().next();
|
|
||||||
JsonNode weightInit = layerNode.get(
|
|
||||||
"weightInit"); //Should only have 1 element: "dense", "output", etc
|
|
||||||
JsonNode distribution = layerNode.get("dist");
|
|
||||||
|
|
||||||
Distribution dist = null;
|
|
||||||
if (distribution != null) {
|
|
||||||
dist = mapper.treeToValue(distribution, Distribution.class);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (weightInit != null) {
|
|
||||||
final IWeightInit wi = WeightInit.valueOf(weightInit.asText())
|
|
||||||
.getWeightInitFunction(dist);
|
|
||||||
((BaseLayerConfiguration) l).setWeightInit(wi);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.warn(
|
|
||||||
"ILayer with null WeightInit detected: " + l.getName() + ", could not parse JSON",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Object mapper for serialization of configurations
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ObjectMapper mapperYaml() {
|
|
||||||
return JsonMappers.getMapperYaml();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Object mapper for serialization of configurations
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ObjectMapper mapper() {
|
|
||||||
return JsonMappers.getMapper();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static NeuralNetConfiguration fromYaml(String input) {
|
public static NeuralNetConfiguration fromYaml(String input) {
|
||||||
throw new RuntimeException("Needs fixing - not supported."); //TODO
|
throw new RuntimeException("Needs fixing - not supported."); // TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return JSON representation of NN configuration
|
* @return JSON representation of NN configuration
|
||||||
*/
|
*/
|
||||||
public String toYaml() {
|
public String toYaml() {
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
|
||||||
synchronized (mapper) {
|
synchronized (mapper) {
|
||||||
try {
|
try {
|
||||||
return mapper.writeValueAsString(this);
|
return mapper.writeValueAsString(this);
|
||||||
|
@ -431,10 +158,12 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
* @return JSON representation of NN configuration
|
* @return JSON representation of NN configuration
|
||||||
*/
|
*/
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
synchronized (mapper) {
|
synchronized (mapper) {
|
||||||
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
|
// JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields
|
||||||
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
|
// occasionally
|
||||||
|
// when writeValueAsString is used by multiple threads. This results in invalid JSON. See
|
||||||
|
// issue #3243
|
||||||
try {
|
try {
|
||||||
return mapper.writeValueAsString(this);
|
return mapper.writeValueAsString(this);
|
||||||
} catch (com.fasterxml.jackson.core.JsonProcessingException e) {
|
} catch (com.fasterxml.jackson.core.JsonProcessingException e) {
|
||||||
|
@ -453,7 +182,9 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
public NeuralNetConfiguration clone() {
|
public NeuralNetConfiguration clone() {
|
||||||
NeuralNetConfiguration clone;
|
NeuralNetConfiguration clone;
|
||||||
clone = (NeuralNetConfiguration) super.clone();
|
clone = (NeuralNetConfiguration) super.clone();
|
||||||
if(getStepFunction() != null) { clone.setStepFunction(getStepFunction().clone()); }
|
if (getStepFunction() != null) {
|
||||||
|
clone.setStepFunction(getStepFunction().clone());
|
||||||
|
}
|
||||||
clone.netWideVariables = new LinkedHashSet<>(netWideVariables);
|
clone.netWideVariables = new LinkedHashSet<>(netWideVariables);
|
||||||
clone.setInnerConfigurations(new ArrayList<>(innerConfigurations));
|
clone.setInnerConfigurations(new ArrayList<>(innerConfigurations));
|
||||||
|
|
||||||
|
@ -473,98 +204,109 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
clone.setDataType(this.getDataType());
|
clone.setDataType(this.getDataType());
|
||||||
|
|
||||||
return clone;
|
return clone;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** */
|
||||||
*
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public void init() {
|
public void init() {
|
||||||
if(initCalled) return;
|
if (initCalled) return;
|
||||||
initCalled=true;
|
initCalled = true;
|
||||||
|
|
||||||
/**
|
/** Run init() for each layer */
|
||||||
* Run init() for each layer
|
for (NeuralNetConfiguration nconf : getNetConfigurations()) {
|
||||||
*/
|
|
||||||
for( NeuralNetConfiguration nconf : getNetConfigurations() ) {
|
|
||||||
nconf.init();
|
nconf.init();
|
||||||
}
|
}
|
||||||
//getNetConfigurations().stream().forEach( conf -> {
|
// getNetConfigurations().stream().forEach( conf -> {
|
||||||
// conf.init(); //do not call on self
|
// conf.init(); //do not call on self
|
||||||
//}); //call init on all embedded net configurations
|
// }); //call init on all embedded net configurations
|
||||||
|
|
||||||
//TODO do not put inside self to avoid serialization issues
|
// TODO do not put inside self to avoid serialization issues
|
||||||
// innerConfigurations.add(0, this); //put this configuration at first place
|
// innerConfigurations.add(0, this); //put this configuration at first place
|
||||||
|
|
||||||
|
|
||||||
|
getLayerConfigurations().stream()
|
||||||
|
.forEach(
|
||||||
|
lconf ->
|
||||||
|
lconf.setNetConfiguration(
|
||||||
|
this)); // set this as net config for all layers (defined in here, not stacked
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Inherit network wide configuration setting to those layer configurations
|
* Inherit network wide configuration setting to those layer configurations that do not have an
|
||||||
* that do not have an individual setting (nor a default)
|
* individual setting (nor a default)
|
||||||
*/
|
*/
|
||||||
for(LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) {
|
for (LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) {
|
||||||
lconf.runInheritance();
|
lconf.runInheritance();
|
||||||
}
|
}
|
||||||
|
|
||||||
getLayerConfigurations().stream().forEach( lconf -> lconf.setNetConfiguration(this)); //set this as net config for all layers (defined in here, not stacked
|
|
||||||
|
|
||||||
|
// Validate BackpropType setting
|
||||||
//Validate BackpropType setting
|
|
||||||
if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH)
|
if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH)
|
||||||
&& backpropType != BackpropType.TruncatedBPTT) {
|
&& backpropType != BackpropType.TruncatedBPTT) {
|
||||||
log.warn("Truncated backpropagation through time lengths have been configured with values "
|
log.warn(
|
||||||
+ tbpttFwdLength
|
"Truncated backpropagation through time lengths have been configured with values "
|
||||||
+ " and " + tbpttBackLength + " but backprop type is set to " + backpropType
|
+ tbpttFwdLength
|
||||||
+ ". TBPTT configuration" +
|
+ " and "
|
||||||
" settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
|
+ tbpttBackLength
|
||||||
|
+ " but backprop type is set to "
|
||||||
|
+ backpropType
|
||||||
|
+ ". TBPTT configuration"
|
||||||
|
+ " settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (backpropType == BackpropType.TruncatedBPTT && isValidateTbpttConfig()) {
|
if (backpropType == BackpropType.TruncatedBPTT && isValidateTbpttConfig()) {
|
||||||
//Check for invalid combination - tbptt plus LastTimeStepLayer or
|
// Check for invalid combination - tbptt plus LastTimeStepLayer or
|
||||||
for (int i = 0; i < getFlattenedLayerConfigurations().size(); i++) {
|
for (int i = 0; i < getFlattenedLayerConfigurations().size(); i++) {
|
||||||
LayerConfiguration l = getFlattenedLayerConfigurations().get(i);
|
LayerConfiguration l = getFlattenedLayerConfigurations().get(i);
|
||||||
if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) {
|
if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"Invalid network configuration detected: Truncated backpropagation through time (TBPTT)"
|
"Invalid network configuration detected: Truncated backpropagation through time (TBPTT)"
|
||||||
+
|
+ " cannot be used with layer "
|
||||||
" cannot be used with layer " + i + " of type " + l.getClass().getName()
|
+ i
|
||||||
+ ": TBPTT is incompatible with this layer type (which is designed " +
|
+ " of type "
|
||||||
"to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n"
|
+ l.getClass().getName()
|
||||||
+
|
+ ": TBPTT is incompatible with this layer type (which is designed "
|
||||||
"This check can be disabled using validateTbpttConfig(false) but this is not recommended.");
|
+ "to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n"
|
||||||
|
+ "This check can be disabled using validateTbpttConfig(false) but this is not recommended.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (getInputType() == null && inputPreProcessors.get(0) == null) {
|
if (getInputType() == null && inputPreProcessors.get(0) == null) {
|
||||||
//User hasn't set the InputType. Sometimes we can infer it...
|
// User hasn't set the InputType. Sometimes we can infer it...
|
||||||
// For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to feed in
|
// For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to
|
||||||
|
// feed in
|
||||||
// standard feedforward or RNN data
|
// standard feedforward or RNN data
|
||||||
//This isn't the most elegant implementation, but should avoid breaking backward compatibility here
|
// This isn't the most elegant implementation, but should avoid breaking backward
|
||||||
//Can't infer InputType for CNN layers, however (don't know image dimensions/depth)
|
// compatibility here
|
||||||
|
// Can't infer InputType for CNN layers, however (don't know image dimensions/depth)
|
||||||
LayerConfiguration firstLayer = getFlattenedLayerConfigurations().get(0);
|
LayerConfiguration firstLayer = getFlattenedLayerConfigurations().get(0);
|
||||||
if (firstLayer instanceof BaseRecurrentLayer) {
|
if (firstLayer instanceof BaseRecurrentLayer) {
|
||||||
BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer;
|
BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer;
|
||||||
val nIn = brl.getNIn();
|
val nIn = brl.getNIn();
|
||||||
if (nIn > 0) {
|
if (nIn > 0) {
|
||||||
setInputType( InputType.recurrent(nIn, brl.getDataFormat()));
|
setInputType(InputType.recurrent(nIn, brl.getDataFormat()));
|
||||||
}
|
}
|
||||||
} else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer
|
} else if (firstLayer instanceof DenseLayer
|
||||||
|
|| firstLayer instanceof EmbeddingLayer
|
||||||
|| firstLayer instanceof OutputLayer) {
|
|| firstLayer instanceof OutputLayer) {
|
||||||
//Can't just use "instanceof FeedForwardLayer" here. ConvolutionLayer is also a FeedForwardLayer
|
// Can't just use "instanceof FeedForwardLayer" here. ConvolutionLayer is also a
|
||||||
|
// FeedForwardLayer
|
||||||
FeedForwardLayer ffl = (FeedForwardLayer) firstLayer;
|
FeedForwardLayer ffl = (FeedForwardLayer) firstLayer;
|
||||||
val nIn = ffl.getNIn();
|
val nIn = ffl.getNIn();
|
||||||
if (nIn > 0) {
|
if (nIn > 0) {
|
||||||
setInputType( InputType.feedForward(nIn));
|
setInputType(InputType.feedForward(nIn));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Add preprocessors and set nIns, if InputType has been set
|
// Add preprocessors and set nIns, if InputType has been set
|
||||||
// Builder.inputType field can be set in 1 of 4 ways:
|
// Builder.inputType field can be set in 1 of 4 ways:
|
||||||
// 1. User calls setInputType directly
|
// 1. User calls setInputType directly
|
||||||
// 2. Via ConvolutionLayerSetup -> internally calls setInputType(InputType.convolutional(...))
|
// 2. Via ConvolutionLayerSetup -> internally calls setInputType(InputType.convolutional(...))
|
||||||
// 3. Via the above code: i.e., assume input is as expected by the RNN or dense layer -> sets the inputType field
|
// 3. Via the above code: i.e., assume input is as expected by the RNN or dense layer -> sets
|
||||||
if(inputPreProcessors == null) {
|
// the inputType field
|
||||||
|
if (inputPreProcessors == null) {
|
||||||
inputPreProcessors = new HashMap<>();
|
inputPreProcessors = new HashMap<>();
|
||||||
}
|
}
|
||||||
if (getInputType() != null) {
|
if (getInputType() != null) {
|
||||||
|
@ -572,10 +314,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
for (int i = 0; i < getFlattenedLayerConfigurations().size(); i++) {
|
for (int i = 0; i < getFlattenedLayerConfigurations().size(); i++) {
|
||||||
LayerConfiguration l = getFlattenedLayerConfigurations().get(i);
|
LayerConfiguration l = getFlattenedLayerConfigurations().get(i);
|
||||||
if (inputPreProcessors.get(i) == null) {
|
if (inputPreProcessors.get(i) == null) {
|
||||||
//Don't override preprocessor setting, but set preprocessor if required...
|
// Don't override preprocessor setting, but set preprocessor if required...
|
||||||
@NonNull
|
@NonNull
|
||||||
InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
|
InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
|
||||||
if (inputPreProcessor != null) {
|
if (inputPreProcessor != null) {
|
||||||
|
inputPreProcessors = new HashMap<>(inputPreProcessors);
|
||||||
inputPreProcessors.put(i, inputPreProcessor);
|
inputPreProcessors.put(i, inputPreProcessor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -586,41 +329,47 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
}
|
}
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
LayerConfiguration layer = getFlattenedLayerConfigurations().get(i - 1);
|
LayerConfiguration layer = getFlattenedLayerConfigurations().get(i - 1);
|
||||||
//convolution 1d is an edge case where it has rnn input type but the filters
|
// convolution 1d is an edge case where it has rnn input type but the filters
|
||||||
//should be the output
|
// should be the output
|
||||||
if (layer instanceof Convolution1DLayer) {
|
if (layer instanceof Convolution1D || layer instanceof Convolution1DNew) {
|
||||||
if (l instanceof DenseLayer && getInputType() instanceof InputType.InputTypeRecurrent) {
|
if (l instanceof DenseLayer && getInputType() instanceof InputType.InputTypeRecurrent) {
|
||||||
FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l;
|
FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l;
|
||||||
if (getInputType() instanceof InputType.InputTypeRecurrent) {
|
if (getInputType() instanceof InputType.InputTypeRecurrent) {
|
||||||
InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) getInputType();
|
InputType.InputTypeRecurrent recurrent =
|
||||||
|
(InputType.InputTypeRecurrent) getInputType();
|
||||||
feedForwardLayer.setNIn(recurrent.getTimeSeriesLength());
|
feedForwardLayer.setNIn(recurrent.getTimeSeriesLength());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
l.setNIn(currentInputType,
|
l.setNIn(
|
||||||
isOverrideNinUponBuild()); //Don't override the nIn setting, if it's manually set by the user
|
currentInputType,
|
||||||
|
isOverrideNinUponBuild()); // Don't override the nIn setting, if it's manually set
|
||||||
|
// by the user
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
l.setNIn(currentInputType,
|
l.setNIn(
|
||||||
isOverrideNinUponBuild()); //Don't override the nIn setting, if it's manually set by the user
|
currentInputType,
|
||||||
|
isOverrideNinUponBuild()); // Don't override the nIn setting, if it's manually set
|
||||||
|
// by the user
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
l.setNIn(currentInputType,
|
l.setNIn(
|
||||||
isOverrideNinUponBuild()); //Don't override the nIn setting, if it's manually set by the user
|
currentInputType,
|
||||||
|
isOverrideNinUponBuild()); // Don't override the nIn setting, if it's manually set by
|
||||||
|
// the user
|
||||||
}
|
}
|
||||||
|
|
||||||
currentInputType = l.getOutputType(i, currentInputType);
|
currentInputType = l.getOutputType(i, currentInputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(getSeed());
|
Nd4j.getRandom().setSeed(getSeed());
|
||||||
|
|
||||||
//Validate output layer configuration
|
// Validate output layer configuration
|
||||||
if (isValidateOutputLayerConfig()) {
|
if (isValidateOutputLayerConfig()) {
|
||||||
//Validate output layer configurations...
|
// Validate output layer configurations...
|
||||||
for (LayerConfiguration n : getFlattenedLayerConfigurations()) {
|
for (LayerConfiguration n : getFlattenedLayerConfigurations()) {
|
||||||
OutputLayerUtil.validateOutputLayer(n.getName(), n); //No-op for non output/loss layers
|
OutputLayerUtil.validateOutputLayer(n.getName(), n); // No-op for non output/loss layers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -646,26 +395,28 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
layerName = String.valueOf(i);
|
layerName = String.valueOf(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Pass input type through preprocessor, if necessary
|
// Pass input type through preprocessor, if necessary
|
||||||
InputPreProcessor preproc = getInputPreProcess(i);
|
InputPreProcessor preproc = getInputPreProcess(i);
|
||||||
//TODO memory requirements for preprocessor
|
// TODO memory requirements for preprocessor
|
||||||
if (preproc != null) {
|
if (preproc != null) {
|
||||||
inputType = preproc.getOutputType(inputType);
|
inputType = preproc.getOutputType(inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
LayerMemoryReport report = getFlattenedLayerConfigurations().get(i).getMemoryReport(inputType);
|
LayerMemoryReport report =
|
||||||
|
getFlattenedLayerConfigurations().get(i).getMemoryReport(inputType);
|
||||||
memoryReportMap.put(layerName, report);
|
memoryReportMap.put(layerName, report);
|
||||||
|
|
||||||
inputType = getFlattenedLayerConfigurations().get(i).getOutputType(i, inputType);
|
inputType = getFlattenedLayerConfigurations().get(i).getOutputType(i, inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new NetworkMemoryReport(memoryReportMap, NeuralNetConfiguration.class,
|
return new NetworkMemoryReport(
|
||||||
"MultiLayerNetwork", inputType);
|
memoryReportMap, NeuralNetConfiguration.class, "MultiLayerNetwork", inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For the given input shape/type for the network, return a list of activation sizes for each
|
* For the given input shape/type for the network, return a list of activation sizes for each
|
||||||
* layer in the network.<br> i.e., list.get(i) is the output activation sizes for layer i
|
* layer in the network.<br>
|
||||||
|
* i.e., list.get(i) is the output activation sizes for layer i
|
||||||
*
|
*
|
||||||
* @param inputType Input type for the network
|
* @param inputType Input type for the network
|
||||||
* @return A lits of activation types for the network, indexed by layer number
|
* @return A lits of activation types for the network, indexed by layer number
|
||||||
|
@ -699,38 +450,47 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
public void addNetWideVariable(String variable) {
|
public void addNetWideVariable(String variable) {
|
||||||
if (!netWideVariables.contains(variable)) {
|
if (!netWideVariables.contains(variable)) {
|
||||||
netWideVariables.add(variable);
|
netWideVariables.add(variable);
|
||||||
log.trace("Adding neural network wide variable '{}' to the list of variables. New length is {}.", variable, netWideVariables.size());
|
log.trace(
|
||||||
|
"Adding neural network wide variable '{}' to the list of variables. New length is {}.",
|
||||||
|
variable,
|
||||||
|
netWideVariables.size());
|
||||||
}
|
}
|
||||||
log.trace("Skipped adding neural network wide variable '{}' to the list of variables. It was already present. Length remains {}.", variable, netWideVariables.size());
|
log.trace(
|
||||||
|
"Skipped adding neural network wide variable '{}' to the list of variables. It was already present. Length remains {}.",
|
||||||
|
variable,
|
||||||
|
netWideVariables.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void clearNetWideVariable() {
|
public void clearNetWideVariable() {
|
||||||
|
|
||||||
netWideVariables.clear();
|
netWideVariables.clear();
|
||||||
log.trace("Adding neural network wide variables have been cleared. New length is {}.", netWideVariables.size());
|
log.trace(
|
||||||
|
"Adding neural network wide variables have been cleared. New length is {}.",
|
||||||
|
netWideVariables.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* From the list of layers and neural net configurations, only return the Layer Configurations that
|
* From the list of layers and neural net configurations, only return the Layer Configurations
|
||||||
* are defined in this neural network (it does not include embedded neural network configuration
|
* that are defined in this neural network (it does not include embedded neural network
|
||||||
* layers)
|
* configuration layers)
|
||||||
|
*
|
||||||
* @return list with layer configurations
|
* @return list with layer configurations
|
||||||
*/
|
*/
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
public List<LayerConfiguration> getLayerConfigurations() {
|
public List<LayerConfiguration> getLayerConfigurations() {
|
||||||
return innerConfigurations.stream()
|
return innerConfigurations.stream()
|
||||||
.filter(obj -> (obj instanceof LayerConfiguration))
|
.filter(obj -> (obj instanceof LayerConfiguration))
|
||||||
.map( obj -> (LayerConfiguration)obj )
|
.map(obj -> (LayerConfiguration) obj)
|
||||||
.collect( Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* From the list of layers and neural net configurations, only return the neural net configurations
|
* From the list of layers and neural net configurations, only return the neural net
|
||||||
|
* configurations
|
||||||
|
*
|
||||||
* @return list with neural net configurations
|
* @return list with neural net configurations
|
||||||
*/
|
*/
|
||||||
//@Synchronized("innerConfigurationsLock")
|
// @Synchronized("innerConfigurationsLock")
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
public List<NeuralNetConfiguration> getNetConfigurations() {
|
public List<NeuralNetConfiguration> getNetConfigurations() {
|
||||||
List<NeuralNetConfiguration> list;
|
List<NeuralNetConfiguration> list;
|
||||||
|
@ -751,35 +511,47 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
* @return list of layer configurations
|
* @return list of layer configurations
|
||||||
*/
|
*/
|
||||||
public List<LayerConfiguration> getFlattenedLayerConfigurations(NeuralNetConfiguration conf) {
|
public List<LayerConfiguration> getFlattenedLayerConfigurations(NeuralNetConfiguration conf) {
|
||||||
List<LayerConfiguration> ret = new ArrayList<>(); //create the final return list
|
List<LayerConfiguration> ret = new ArrayList<>(); // create the final return list
|
||||||
//When properly initialized, _this_ configuration is set first in the list, however we
|
// When properly initialized, _this_ configuration is set first in the list, however we
|
||||||
//can find cases where this is not true, thus the first configuration is another net or layer configuration
|
// can find cases where this is not true, thus the first configuration is another net or layer
|
||||||
//and should not be skipped. In essence, skip first configuration if that is "this".
|
// configuration
|
||||||
//TODO: skipping not needed anymore as we removed _this_ from innerConfigurations
|
// and should not be skipped. In essence, skip first configuration if that is "this".
|
||||||
|
// TODO: skipping not needed anymore as we removed _this_ from innerConfigurations
|
||||||
int iSkip = 0;
|
int iSkip = 0;
|
||||||
if(conf.getInnerConfigurations().size()>0 && conf.getInnerConfigurations().get(0).equals(this)) { iSkip=1;}
|
if (conf.getInnerConfigurations().size() > 0
|
||||||
conf.getInnerConfigurations().stream().skip(iSkip)
|
&& conf.getInnerConfigurations().get(0).equals(this)) {
|
||||||
.forEach(obj -> {
|
iSkip = 1;
|
||||||
//if Layer Config, include in list and inherit parameters from this conf
|
}
|
||||||
//else if neural net configuration, call self recursively to resolve layer configurations
|
conf.getInnerConfigurations().stream()
|
||||||
|
.skip(iSkip)
|
||||||
|
.forEach(
|
||||||
|
obj -> {
|
||||||
|
// if Layer Config, include in list and inherit parameters from this conf
|
||||||
|
// else if neural net configuration, call self recursively to resolve layer
|
||||||
|
// configurations
|
||||||
if (obj instanceof LayerConfiguration) {
|
if (obj instanceof LayerConfiguration) {
|
||||||
((LayerConfiguration) obj).setNetConfiguration(conf);
|
((LayerConfiguration) obj).setNetConfiguration(conf);
|
||||||
ret.add((LayerConfiguration) obj);
|
ret.add((LayerConfiguration) obj);
|
||||||
} else if (obj instanceof NeuralNetConfiguration)
|
} else if (obj instanceof NeuralNetConfiguration)
|
||||||
ret.addAll(getFlattenedLayerConfigurations(
|
ret.addAll(getFlattenedLayerConfigurations((NeuralNetConfiguration) obj));
|
||||||
(NeuralNetConfiguration) obj));
|
|
||||||
else {
|
else {
|
||||||
log.error(
|
log.error(
|
||||||
"The list of layers and neural network configurations does contain an object of {}. Element will be ignored.",
|
"The list of layers and neural network configurations does contain an object of {}. Element will be ignored.",
|
||||||
obj.getClass().getSimpleName());
|
obj.getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
// make sure the indexes are sequenced properly
|
||||||
|
AtomicInteger i = new AtomicInteger();
|
||||||
|
ret.forEach(obj -> {
|
||||||
|
obj.setIndex(i.getAndIncrement());
|
||||||
|
});
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sames as {@link #getFlattenedLayerConfigurations(NeuralNetConfiguration)}, but uses this configurations
|
* Sames as {@link #getFlattenedLayerConfigurations(NeuralNetConfiguration)}, but uses this
|
||||||
* list of configurations
|
* configurations list of configurations
|
||||||
|
*
|
||||||
* @return list of layer configurations
|
* @return list of layer configurations
|
||||||
*/
|
*/
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
|
@ -789,6 +561,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a new layer to the first position
|
* Add a new layer to the first position
|
||||||
|
*
|
||||||
* @param layer configuration
|
* @param layer configuration
|
||||||
*/
|
*/
|
||||||
public void setLayer(@NonNull LayerConfiguration layer) {
|
public void setLayer(@NonNull LayerConfiguration layer) {
|
||||||
|
@ -801,26 +574,28 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deprecated, do not use. Workaround for old tests
|
* Deprecated, do not use. Workaround for old tests and getFlattenedLayerConfigurations().get(0);
|
||||||
* and getFlattenedLayerConfigurations().get(0);
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
@Deprecated @JsonIgnore
|
@Deprecated
|
||||||
|
@JsonIgnore
|
||||||
public LayerConfiguration getFirstLayer() {
|
public LayerConfiguration getFirstLayer() {
|
||||||
log.warn("This getFirstLayer method is an ugly workaround and will be removed.");
|
log.warn("This getFirstLayer method is an ugly workaround and will be removed.");
|
||||||
return getFlattenedLayerConfigurations().get(0);
|
return getFlattenedLayerConfigurations().get(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
protected boolean canEqual(final Object other) {
|
||||||
|
return other instanceof NeuralNetConfiguration;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
protected boolean canEqual(final Object other) {
|
public abstract static class NeuralNetConfigurationBuilder<
|
||||||
return other instanceof NeuralNetConfiguration;
|
C extends NeuralNetConfiguration,
|
||||||
}
|
B extends NeuralNetConfiguration.NeuralNetConfigurationBuilder<C, B>>
|
||||||
|
extends NeuralNetBaseBuilderConfigurationBuilder<C, B> {
|
||||||
|
|
||||||
public static abstract class NeuralNetConfigurationBuilder<C extends NeuralNetConfiguration,
|
|
||||||
B extends NeuralNetConfiguration.NeuralNetConfigurationBuilder<C, B>> extends
|
|
||||||
NeuralNetBaseBuilderConfigurationBuilder<C, B> {
|
|
||||||
|
|
||||||
public ComputationGraphConfiguration.GraphBuilder graphBuilder() {
|
public ComputationGraphConfiguration.GraphBuilder graphBuilder() {
|
||||||
return new ComputationGraphConfiguration.GraphBuilder(this);
|
return new ComputationGraphConfiguration.GraphBuilder(this);
|
||||||
|
@ -829,10 +604,9 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
public NeuralNetConfigurationBuilder clone() {
|
public NeuralNetConfigurationBuilder clone() {
|
||||||
try {
|
try {
|
||||||
return (NeuralNetConfigurationBuilder) super.clone();
|
return (NeuralNetConfigurationBuilder) super.clone();
|
||||||
} catch(CloneNotSupportedException ex) {
|
} catch (CloneNotSupportedException ex) {
|
||||||
throw new RuntimeException(ex);
|
throw new RuntimeException(ex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,13 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* N is the batch size<br/>
|
||||||
|
* C is the number of feature maps (that is,, number of channels)<br/>
|
||||||
|
* H is the image height (not used for 1D conv as this is an RNN format<br/>
|
||||||
|
* W is the image width<br/>
|
||||||
|
* **/
|
||||||
public enum RNNFormat implements DataFormat {
|
public enum RNNFormat implements DataFormat {
|
||||||
NCW,
|
/** n=batch size; c=channels/ features; w=width **/ NCW,
|
||||||
NWC
|
/** n=batch size; w=width; c=channels/ features **/ NWC
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.constraint;
|
package org.deeplearning4j.nn.conf.constraint;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
@ -27,11 +30,6 @@ import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.constraint;
|
package org.deeplearning4j.nn.conf.constraint;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Set;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -27,9 +29,6 @@ import org.nd4j.linalg.factory.Broadcast;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
public class MaxNormConstraint extends BaseConstraint {
|
public class MaxNormConstraint extends BaseConstraint {
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue