Compare commits

...

26 Commits

Author SHA1 Message Date
Brian Rosenberger 1c1ec071ef Allow docker container to use GPU
Signed-off-by: brian <brian@brutex.de>
2023-08-15 23:14:12 +02:00
Brian Rosenberger 74ad5087c1 Allo docker container to use GPU
Signed-off-by: brian <brian@brutex.de>
2023-08-15 20:58:43 +02:00
Brian Rosenberger acae3944ec Allo docker container to use GPU
Signed-off-by: brian <brian@brutex.de>
2023-08-15 20:32:46 +02:00
Brian Rosenberger be7cd6b930 Allo docker container to use GPU
Signed-off-by: brian <brian@brutex.de>
2023-08-15 20:25:35 +02:00
Brian Rosenberger 99aed71ffa Merge remote-tracking branch 'brutex-origin/master'
# Conflicts:
#	cavis-common-platform/build.gradle
2023-08-15 16:59:12 +02:00
Brian Rosenberger 2df8ea06e0 Allo docker container to use GPU
Signed-off-by: brian <brian@brutex.de>
2023-08-15 16:58:55 +02:00
Brian Rosenberger 090c5ab2eb Updating extended GAN tests
Signed-off-by: brian <brian@brutex.de>
2023-08-14 14:07:03 +02:00
Brian Rosenberger a40d5aa7cf Update docker image to 12.1.0-cudnn8-devel-ubuntu22.04
Signed-off-by: brian <brian@brutex.de>
2023-08-12 20:25:46 +02:00
Brian Rosenberger d2972e4f24 Adding collection of junit test results for jenkins CPU pipelines
Signed-off-by: brian <brian@brutex.de>
2023-08-11 09:04:22 +02:00
Brian Rosenberger 704f4860d5 Update vulnerable hadoop-common from 3.2.0 to 3.2.4
Signed-off-by: brian <brian@brutex.de>
2023-08-11 08:03:29 +02:00
Brian Rosenberger d5728cbd8e Fix internal archiva credential parameter
Signed-off-by: brian <brian@brutex.de>
2023-08-10 21:21:09 +02:00
Brian Rosenberger d40c044df4 Jenkins pipeline for CUDA builds to include junit post build action
Signed-off-by: brian <brian@brutex.de>
2023-08-10 14:19:09 +02:00
Brian Rosenberger a6c4a16d9a Update underlying library versions for javacpp 1.5.9
Signed-off-by: brian <brian@brutex.de>
2023-08-10 10:55:46 +02:00
Brian Rosenberger 0e4be5c4d2 Tag some junit tests with @Tag("long-running") to skip them during normal build
Signed-off-by: brian <brian@brutex.de>
2023-08-09 12:13:31 +02:00
Brian Rosenberger f7be1e324f Update leptonica library to 1.83.0
Signed-off-by: brian <brian@brutex.de>
2023-08-09 11:45:48 +02:00
Brian Rosenberger 1c3496ad84 gan example
Signed-off-by: brian <brian@brutex.de>
2023-08-07 10:39:16 +02:00
Brian Rosenberger 3ea555b645 add test stage to linux cuda on docker build
Signed-off-by: brian <brian@brutex.de>
2023-08-07 10:39:16 +02:00
Brian Rosenberger e11568605d Update lombok
Signed-off-by: brian <brian@brutex.de>
2023-08-07 10:39:16 +02:00
Brian Rosenberger 9f0682eb75 Downgrade gradle wrapper to 7.4.2 and upgrade javacpp-gradle plugin to 1.5.9
Signed-off-by: brian <brian@brutex.de>
2023-08-07 10:39:16 +02:00
Brian Rosenberger ca127d8b88 Fixed missing imports
Signed-off-by: brian <brian@brutex.de>
2023-08-07 10:39:16 +02:00
Brian Rosenberger deb436036b Change jenkins pipeline credentials id for MAVEN
Signed-off-by: brian <brian@brutex.de>
2023-08-07 10:39:16 +02:00
Brian Rosenberger 1f2bfb36a5 Change jenkins pipeline credentials id for MAVEN
Signed-off-by: brian <brian@brutex.de>
2023-08-07 10:39:16 +02:00
Brian Rosenberger b477b71325 Change jenkins pipeline credentials id for MAVEN
Signed-off-by: brian <brian@brutex.de>
2023-08-07 10:39:16 +02:00
Brian Rosenberger d75e0be506 Fix build docker image to use CUDA 11.4.3 (was 11.4.0)
Signed-off-by: brian <brian@brutex.de>
2023-08-07 10:39:16 +02:00
Brian Rosenberger 318cafb6f0 Fix build docker image to use CUDA 11.4.3 (was 11.4.0)
Signed-off-by: brian <brian@brutex.de>
2023-08-07 10:39:16 +02:00
Brian Rosenberger 24466a8fd4 Fixing Tests 2023-08-07 10:39:16 +02:00
465 changed files with 7763 additions and 4338 deletions

View File

@ -1,4 +1,4 @@
FROM nvidia/cuda:11.4.0-cudnn8-devel-ubuntu20.04 FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04
RUN apt-get update && \ RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
@ -11,5 +11,10 @@ RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.
rm cmake-3.24.2-linux-x86_64.sh rm cmake-3.24.2-linux-x86_64.sh
RUN echo "/usr/local/cuda/compat/" >> /etc/ld.so.conf.d/cuda-driver.conf
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
RUN ldconfig -p | grep cuda

13
.gitignore vendored
View File

@ -36,6 +36,8 @@ pom.xml.versionsBackup
pom.xml.next pom.xml.next
release.properties release.properties
*dependency-reduced-pom.xml *dependency-reduced-pom.xml
**/build/*
.gradle/*
# Specific for Nd4j # Specific for Nd4j
*.md5 *.md5
@ -83,3 +85,14 @@ bruai4j-native-common/cmake*
/bruai4j-native/bruai4j-native-common/blasbuild/ /bruai4j-native/bruai4j-native-common/blasbuild/
/bruai4j-native/bruai4j-native-common/build/ /bruai4j-native/bruai4j-native-common/build/
/cavis-native/cavis-native-lib/blasbuild/ /cavis-native/cavis-native-lib/blasbuild/
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/classes/org.deeplearning4j.gradientcheck.AttentionLayerTest.html
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/base-style.css
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/style.css
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/js/report.js
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/packages/org.deeplearning4j.gradientcheck.html
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/index.html
/cavis-dnn/cavis-dnn-core/build/resources/main/iris.dat
/cavis-dnn/cavis-dnn-core/build/resources/test/junit-platform.properties
/cavis-dnn/cavis-dnn-core/build/resources/test/logback-test.xml
/cavis-dnn/cavis-dnn-core/build/test-results/cudaTest/TEST-org.deeplearning4j.gradientcheck.AttentionLayerTest.xml
/cavis-dnn/cavis-dnn-core/build/tmp/jar/MANIFEST.MF

View File

@ -35,7 +35,7 @@ pipeline {
} }
stage('build-linux-cpu') { stage('build-linux-cpu') {
environment { environment {
MAVEN = credentials('Internal Archiva') MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH') OSSRH = credentials('OSSRH')
} }
@ -65,7 +65,7 @@ pipeline {
}*/ }*/
stage('publish-linux-cpu') { stage('publish-linux-cpu') {
environment { environment {
MAVEN = credentials('Internal Archiva') MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH') OSSRH = credentials('OSSRH')
} }
@ -79,4 +79,9 @@ pipeline {
} }
} }
} }
post {
always {
junit '**/build/test-results/**/*.xml'
}
}
} }

View File

@ -21,13 +21,15 @@
pipeline { pipeline {
agent { agent {
dockerfile { /* dockerfile {
filename 'Dockerfile' filename 'Dockerfile'
dir '.docker' dir '.docker'
label 'linux && cuda' label 'linux && cuda'
//additionalBuildArgs '--build-arg version=1.0.2' //additionalBuildArgs '--build-arg version=1.0.2'
//args '--gpus all' --needed for test only, you can build without GPU //args '--gpus all' --needed for test only, you can build without GPU
} }
*/
label 'linux && cuda'
} }
stages { stages {
@ -43,13 +45,13 @@ pipeline {
} }
stage('build-linux-cuda') { stage('build-linux-cuda') {
environment { environment {
MAVEN = credentials('Internal Archiva') MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH') OSSRH = credentials('OSSRH')
} }
steps { steps {
withGradle { withGradle {
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \ sh 'sh ./gradlew build --stacktrace -PCAVIS_CHIP=cuda \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
} }
@ -57,4 +59,10 @@ pipeline {
} }
} }
} }
post {
always {
junit '**/build/test-results/**/*.xml'
}
}
} }

View File

@ -47,7 +47,7 @@ pipeline {
} }
stage('build-linux-cuda') { stage('build-linux-cuda') {
environment { environment {
MAVEN = credentials('Internal Archiva') MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH') OSSRH = credentials('OSSRH')
} }

View File

@ -41,7 +41,7 @@ pipeline {
} }
stage('build-linux-cpu') { stage('build-linux-cpu') {
environment { environment {
MAVEN = credentials('Internal Archiva') MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH') OSSRH = credentials('OSSRH')
} }
@ -85,4 +85,9 @@ pipeline {
} }
} }
} }
post {
always {
junit '**/build/test-results/**/*.xml'
}
}
} }

View File

@ -33,7 +33,7 @@ pipeline {
stages { stages {
stage('publish-linux-cpu') { stage('publish-linux-cpu') {
environment { environment {
MAVEN = credentials('Internal Archiva') MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH') OSSRH = credentials('OSSRH')
} }

View File

@ -26,7 +26,7 @@ pipeline {
dir '.docker' dir '.docker'
label 'linux && docker && cuda' label 'linux && docker && cuda'
//additionalBuildArgs '--build-arg version=1.0.2' //additionalBuildArgs '--build-arg version=1.0.2'
//args '--gpus all' --needed for test only, you can build without GPU args '--gpus all' //needed for test only, you can build without GPU
} }
} }
@ -43,7 +43,7 @@ pipeline {
} }
stage('build-linux-cuda') { stage('build-linux-cuda') {
environment { environment {
MAVEN = credentials('Internal Archiva') MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH') OSSRH = credentials('OSSRH')
} }
@ -56,5 +56,26 @@ pipeline {
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
} }
} }
stage('test-linux-cuda') {
environment {
MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -Pskip-native=true -PCAVIS_CHIP=cuda \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
}
}
}
post {
always {
junit '**/build/test-results/**/*.xml'
}
} }
} }

View File

@ -41,7 +41,7 @@ pipeline {
} }
stage('build-linux-cpu') { stage('build-linux-cpu') {
environment { environment {
MAVEN = credentials('Internal Archiva') MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH') OSSRH = credentials('OSSRH')
} }

View File

@ -0,0 +1,167 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.nd4j.tests;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.nd4j.common.primitives.Pair;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
@Slf4j
public class ExploreParamsTest {
@Test
public void testParam() {
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.seed(12345)
.dataType(DataType.DOUBLE)
.layer(
DenseLayer.builder().nIn(4).nOut(30).name("1. Dense").activation(Activation.TANH))
.layer(DenseLayer.builder().nIn(30).nOut(10).name("2. Dense"))
// .layer(FrozenLayer.builder(DenseLayer.builder().nOut(6).build()).build())
.layer(
OutputLayer.builder()
.nOut(3)
.lossFunction(LossFunctions.LossFunction.MSE)
.activation(Activation.SOFTMAX))
.build();
MultiLayerNetwork nn = new MultiLayerNetwork(conf);
nn.init();
log.info(nn.summary());
// INDArray input = Nd4j.rand(10,4);
INDArray labels = Nd4j.zeros(9, 3);
INDArray input =
Nd4j.create(
new double[][] {
{5.15, 3.5, 1.4, 0.21}, // setosa
{4.9, 3.2, 1.4, 0.2}, // setosa
{4.7, 3.2, 1.23, 0.2}, // setosa
{7, 3.25, 4.7, 1.41}, // versicolor
{6.4, 3.2, 4.54, 1.5}, // versicolor
{6.9, 3.1, 4.92, 1.5}, // versicolor
{7.7, 3, 6.1, 2.3}, // virginica
{6.3, 3.4, 5.6, 2.45}, // virginica
{6.4, 3.12, 5.5, 1.8} // virginica
});
labels.putScalar(0, 1);
labels.putScalar(3, 1);
labels.putScalar(6, 1);
labels.putScalar(10, 1);
labels.putScalar(13, 1);
labels.putScalar(16, 1);
labels.putScalar(20, 1);
labels.putScalar(23, 1);
labels.putScalar(26, 1);
IrisDataSetIterator iter = new IrisDataSetIterator();
//Iterable<Pair<INDArray, INDArray>> it = List.of(new Pair<INDArray, INDArray>(input, labels));
List l = new ArrayList<>();
for (int i=0; i< input.rows(); i++) {
l.add(new Pair(input.getRow(i), labels.getRow(i)));
}
Iterable<Pair<INDArray, INDArray>> it = l;
INDArrayDataSetIterator diter = new INDArrayDataSetIterator(it, 1);
for (int i = 0; i < 100; i++) {
// nn.fit(input, labels);
// nn.fit( input, labels);
nn.fit(diter);
// nn.feedForward(input);
if(i%20==0) log.info("Score: {}", nn.getScore());
}
Evaluation eval = nn.evaluate(iter, List.of("setosa", "vericolor", "virginica"));
log.info("\n{}", eval.stats());
}
@Test
public void testParam2() throws IOException {
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.seed(12345)
.layer(
DenseLayer.builder().nIn(784).nOut(20).name("1. Dense"))
.layer(DenseLayer.builder().nIn(20).nOut(10).name("2. Dense"))
.layer(
OutputLayer.builder()
.nOut(10)
.lossFunction(LossFunctions.LossFunction.MSE)
.activation(Activation.SOFTMAX))
.build();
MultiLayerNetwork nn = new MultiLayerNetwork(conf);
nn.init();
log.info(nn.summary());
NeuralNetConfiguration conf2 =
NeuralNetConfiguration.builder()
.seed(12345)
.layer(
DenseLayer.builder().nIn(784).nOut(20).name("1. Dense").dropOut(0.7))
.layer(DenseLayer.builder().nIn(20).nOut(10).name("2. Dense"))
.layer(
OutputLayer.builder()
.nOut(10)
.lossFunction(LossFunctions.LossFunction.MSE)
.activation(Activation.SOFTMAX))
.build();
MultiLayerNetwork nn2 = new MultiLayerNetwork(conf2);
nn2.init();
log.info(nn2.summary());
MnistDataSetIterator iter = new MnistDataSetIterator(10, 500);
MnistDataSetIterator iter2 = new MnistDataSetIterator(10, 50);
for (int i = 0; i < 200; i++) {
nn.fit(iter);
nn2.fit(iter);
if(i%20==0) log.info("Score: {} vs. {}", nn.getScore(), nn2.getScore());
}
Evaluation eval = nn.evaluate(iter2);
Evaluation eval2 = nn2.evaluate(iter2);
log.info("\n{} \n{}", eval.stats(), eval2.stats());
}
}

View File

@ -36,9 +36,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class LoadBackendTests { public class LoadBackendTests {
@Test @Test
public void loadBackend() throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException { public void loadBackend() throws NoSuchFieldException, IllegalAccessException {
// check if Nd4j is there // check if Nd4j is there
//Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path")); Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path"));
final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths"); final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
sysPathsField.setAccessible(true); sysPathsField.setAccessible(true);
sysPathsField.set(null, null); sysPathsField.set(null, null);

View File

@ -1,110 +1,49 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan; package net.brutex.gan;
import java.awt.BorderLayout; import static net.brutex.ai.dnn.api.NN.dense;
import java.awt.Dimension;
import java.awt.GridLayout; import java.awt.*;
import java.awt.Image;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
import java.io.File; import java.io.File;
import java.util.Arrays; import java.util.Arrays;
import java.util.Random; import javax.swing.*;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.WindowConstants;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.datavec.api.split.FileSplit; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.deeplearning4j.optimize.listeners.PerformanceListener; import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreToChartListener; import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions;
@Slf4j
public class App { public class App {
private static final double LEARNING_RATE = 0.000002; private static final double LEARNING_RATE = 0.002;
private static final double GRADIENT_THRESHOLD = 100.0; private static final double GRADIENT_THRESHOLD = 100.0;
private static final int X_DIM = 20 ;
private static final int Y_DIM = 20;
private static final int CHANNELS = 1;
private static final int batchSize = 10;
private static final int INPUT = 128;
private static final int OUTPUT_PER_PANEL = 4;
private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build(); private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
private static final int BATCHSIZE = 128;
private static JFrame frame; private static JFrame frame;
private static JFrame frame2;
private static JPanel panel; private static JPanel panel;
private static JPanel panel2;
private static LayerConfiguration[] genLayers() { private static LayerConfiguration[] genLayers() {
return new LayerConfiguration[] { return new LayerConfiguration[] {
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
ActivationLayer.builder(Activation.LEAKYRELU).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(), dense().nIn(256).nOut(512).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(512).nOut(1024).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build() ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(1024).nOut(784).activation(Activation.TANH).build()
}; };
} }
@ -119,65 +58,51 @@ public class App {
.updater(UPDATER) .updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD) .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
//.weightInit(WeightInit.XAVIER)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
.layersFromArray(genLayers()) .layersFromArray(genLayers())
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .name("generator")
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
.build(); .build();
((NeuralNetConfiguration) conf).init();
return conf; return conf;
} }
private static LayerConfiguration[] disLayers() { private static LayerConfiguration[] disLayers() {
return new LayerConfiguration[]{ return new LayerConfiguration[]{
DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network dense().nIn(784).nOut(1024).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(), DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC dense().nIn(1024).nOut(512).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(), DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(), dense().nIn(512).nOut(256).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(), DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
OutputLayer.builder().name("dis-output").lossFunction(LossFunction.XENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
}; };
} }
private static NeuralNetConfiguration discriminator() { private static NeuralNetConfiguration discriminator() {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.seed(42) .seed(42)
.updater(UPDATER) .updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD) .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
.layersFromArray(disLayers()) .layersFromArray(disLayers())
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .name("discriminator")
.build(); .build();
((NeuralNetConfiguration) conf).init();
return conf; return conf;
} }
private static NeuralNetConfiguration gan() { private static NeuralNetConfiguration gan() {
LayerConfiguration[] genLayers = genLayers(); LayerConfiguration[] genLayers = genLayers();
LayerConfiguration[] disLayers = Arrays.stream(disLayers()) LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream()
.map((layer) -> { .map((layer) -> {
if (layer instanceof DenseLayer || layer instanceof OutputLayer) { if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build(); return FrozenLayerWithBackprop.builder(layer).build();
} else { } else {
return layer; return layer;
} }
@ -186,107 +111,57 @@ public class App {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(42) .seed(42)
.updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() ) .updater(UPDATER)
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold( 100 ) .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
//.weightInitFn( new WeightInitXavier() ) //this is internal .weightInit(WeightInit.XAVIER)
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5))) .activation(Activation.IDENTITY)
.weightInit( WeightInit.XAVIER) .layersFromArray(layers)
//.activationFn( new ActivationIdentity()) //this is internal .name("GAN")
.activation( Activation.IDENTITY )
.layersFromArray( layers )
.inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
.build(); .build();
((NeuralNetConfiguration) conf).init();
return conf; return conf;
} }
@Test @Tag("long-running")
@Test
public void runTest() throws Exception { public void runTest() throws Exception {
main(); App.main(null);
} }
public static void main(String... args) throws Exception { public static void main(String... args) throws Exception {
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m "); MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
Nd4j.getMemoryManager().setAutoGcWindow(500);
// MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
// FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
ImageTransform tr = new PipelineImageTransform.Builder()
.addImageTransform(transform) //convert to GREY SCALE
.addImageTransform(transform3)
//.addImageTransform(transform2)
.build();
ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
imageRecordReader.initialize(fileSplit, tr);
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
MultiLayerNetwork gen = new MultiLayerNetwork(generator()); MultiLayerNetwork gen = new MultiLayerNetwork(generator());
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator()); MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
MultiLayerNetwork gan = new MultiLayerNetwork(gan()); MultiLayerNetwork gan = new MultiLayerNetwork(gan());
gen.init(); log.debug("Generator network: {}", gen); gen.init();
dis.init(); log.debug("Discriminator network: {}", dis); dis.init();
gan.init(); log.debug("Complete GAN network: {}", gan); gan.init();
copyParams(gen, dis, gan); copyParams(gen, dis, gan);
gen.addTrainingListeners(new PerformanceListener(15, true)); gen.addTrainingListeners(new PerformanceListener(10, true));
//dis.addTrainingListeners(new PerformanceListener(10, true)); dis.addTrainingListeners(new PerformanceListener(10, true));
//gan.addTrainingListeners(new PerformanceListener(10, true)); gan.addTrainingListeners(new PerformanceListener(10, true));
//gan.addTrainingListeners(new ScoreToChartListener("gan"));
//dis.setListeners(new ScoreToChartListener("dis"));
System.out.println(gan.toString()); trainData.reset();
gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
//trainData.reset();
int j = 0; int j = 0;
for (int i = 0; i < 201; i++) { //epoch for (int i = 0; i < 50; i++) {
while (trainData.hasNext()) { while (trainData.hasNext()) {
j++; j++;
DataSet next = trainData.next();
// generate data // generate data
INDArray real = next.getFeatures();//.div(255f); INDArray real = trainData.next().getFeatures().muli(2).subi(1);
int batchSize = (int) real.shape()[0];
//start next round if there are not enough images left to have a full batchsize dataset INDArray fakeIn = Nd4j.rand(batchSize, 100);
if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
log.warn("Your total number of input images is not a multiple of {}, "
+ "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
break;
}
if(i%20 == 0) {
// frame2 = visualize(new INDArray[]{real}, batchSize,
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
}
real.divi(255f);
// int batchSize = (int) real.shape()[0];
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn); INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
//log.info("real has {} items.", real.length());
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1)); DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1)); DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet)); DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
dis.fit(data); dis.fit(data);
@ -295,32 +170,26 @@ public class App {
// Update the discriminator in the GAN network // Update the discriminator in the GAN network
updateGan(gen, dis, gan); updateGan(gen, dis, gan);
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1))); gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)));
if (j % 10 == 1) { if (j % 10 == 1) {
System.out.println("Iteration " + j + " Visualizing..."); System.out.println("Epoch " + i +" Iteration " + j + " Visualizing...");
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize]; INDArray[] samples = new INDArray[9];
for (int k = 0; k < samples.length; k++) {
//INDArray input = fakeSet2.get(k).getFeatures();
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1)); DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
INDArray input = fakeSet2.get(k).getFeatures();
input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
for (int k = 0; k < 9; k++) {
INDArray input = fakeSet2.get(k).getFeatures();
//samples[k] = gen.output(input, false); //samples[k] = gen.output(input, false);
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input); samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
//samples[k] =
samples[k].addi(1f).divi(2f).muli(255f);
} }
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1 visualize(samples);
} }
} }
trainData.reset(); trainData.reset();
// Copy the GANs generator to gen.
//updateGen(gen, gan);
} }
// Copy the GANs generator to gen. // Copy the GANs generator to gen.
@ -333,10 +202,8 @@ public class App {
int genLayerCount = gen.getLayers().length; int genLayerCount = gen.getLayers().length;
for (int i = 0; i < gan.getLayers().length; i++) { for (int i = 0; i < gan.getLayers().length; i++) {
if (i < genLayerCount) { if (i < genLayerCount) {
if(gan.getLayer(i).getParams() != null)
gen.getLayer(i).setParams(gan.getLayer(i).getParams()); gen.getLayer(i).setParams(gan.getLayer(i).getParams());
} else { } else {
if(gan.getLayer(i).getParams() != null)
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams()); dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
} }
} }
@ -355,57 +222,41 @@ public class App {
} }
} }
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) { private static void visualize(INDArray[] samples) {
if (isOrig) { if (frame == null) {
frame.setTitle("Viz Original"); frame = new JFrame();
} else { frame.setTitle("Viz");
frame.setTitle("Generated");
}
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout()); frame.setLayout(new BorderLayout());
JPanel panelx = new JPanel(); panel = new JPanel();
panelx.setLayout(new GridLayout(4, 4, 8, 8)); panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
for (INDArray sample : samples) { frame.add(panel, BorderLayout.CENTER);
for(int i = 0; i<batchElements; i++) {
panelx.add(getImage(sample, i, isOrig));
}
}
frame.add(panelx, BorderLayout.CENTER);
frame.setVisible(true); frame.setVisible(true);
}
panel.removeAll();
for (INDArray sample : samples) {
panel.add(getImage(sample));
}
frame.revalidate(); frame.revalidate();
frame.setMinimumSize(new Dimension(300, 20));
frame.pack(); frame.pack();
return frame;
} }
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) { private static JLabel getImage(INDArray tensor) {
final BufferedImage bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
final int imageSize = X_DIM * Y_DIM; for (int i = 0; i < 784; i++) {
final int offset = batchElement * imageSize; int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
int pxl = offset * CHANNELS; //where to start in the INDArray bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
//Image in NCHW - channels first format
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
for (int y = 0; y < Y_DIM; y++) { // step through the columns x
for (int x = 0; x < X_DIM; x++) { //step through the rows y
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
pxl++; //next item in INDArray
} }
}
}
ImageIcon orig = new ImageIcon(bi); ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled); ImageIcon scaled = new ImageIcon(imageScaled);
return new JLabel(scaled); return new JLabel(scaled);
} }
} }

View File

@ -0,0 +1,371 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.List;
import javax.imageio.ImageIO;
import javax.swing.*;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.*;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator;
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import static net.brutex.gan.App2Config.BATCHSIZE;
@Slf4j
public class App2 {
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
static final int DIMENSIONS = 28;
static final int CHANNELS = 1;
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
final boolean BIAS = true;
private JFrame frame2, frame;
static final String OUTPUT_DIR = "d:/out/";
final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
@Test @Tag("long-running")
void runTest() throws IOException {
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans3"), NativeImageLoader.getALLOWED_FORMATS());
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
ImageTransform tr = new PipelineImageTransform.Builder()
.addImageTransform(transform) //convert to GREY SCALE
.addImageTransform(transform3)
//.addImageTransform(transform2)
.build();
ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS);
imageRecordReader.initialize(fileSplit, tr);
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE );
trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator());
MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator());
LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream()
.map((layer) -> {
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
} else {
return layer;
}
}).toArray(LayerConfiguration[]::new);
NeuralNetConfiguration netConfiguration =
NeuralNetConfiguration.builder()
.name("GAN")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.updater(App2Config.UPDATER)
.innerConfigurations(new ArrayList<>(List.of(App2Config.generator())))
.layersFromList(new ArrayList<>(Arrays.asList(disLayers)))
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
// .inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
// .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
//.dataType(DataType.FLOAT)
.build();
MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration );
dis.init(); log.debug("Discriminator network: {}", dis);
gen.init(); log.debug("Generator network: {}", gen);
gan.init(); log.debug("GAN network: {}", gan);
log.info("Generator Summary:\n{}", gen.summary());
log.info("GAN Summary:\n{}", gan.summary());
dis.addTrainingListeners(new PerformanceListener(3, true, "DIS"));
//gen.addTrainingListeners(new PerformanceListener(3, true, "GEN")); //is never trained separately from GAN
gan.addTrainingListeners(new PerformanceListener(3, true, "GAN"));
/*
Thread vt =
new Thread(
new Runnable() {
@Override
public void run() {
while (true) {
visualize(0, 0, gen);
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
});
vt.start();
*/
App2Display display = new App2Display();
//Repack training data with new fake/real label. Original MNist has 10 labels, one for each digit
DataSet data = null;
int j =0;
for(int i=0;i<App2Config.EPOCHS;i++) {
log.info("Epoch {}", i);
data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
while (trainData.hasNext()) {
j++;
INDArray real = trainData.next().getFeatures();
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1,
Nd4j.rand(BATCHSIZE, App2Config.INPUT));
//sigmoid output is -1 to 1
fake.addi(1f).divi(2f);
if (j % 50 == 1) {
display.visualize(new INDArray[] {fake}, App2Config.OUTPUT_PER_PANEL, false);
display.visualize(new INDArray[] {real}, App2Config.OUTPUT_PER_PANEL, true);
}
DataSet realSet = new DataSet(real, label_real);
DataSet fakeSet = new DataSet(fake, label_fake);
//start next round if there are not enough images left to have a full batchsize dataset
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
log.warn("Your total number of input images is not a multiple of {}, "
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
break;
}
//if(real.length()/BATCHSIZE!=784) break;
data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
}
//fit the discriminator
dis.fit(data);
dis.fit(data);
// Update the discriminator in the GAN network
updateGan(gen, dis, gan);
//reset the training data and fit the complete GAN
if (trainData.resetSupported()) {
trainData.reset();
} else {
log.error("Trainingdata {} does not support reset.", trainData.toString());
}
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_real));
if (trainData.resetSupported()) {
trainData.reset();
} else {
log.error("Trainingdata {} does not support reset.", trainData.toString());
}
log.info("Updated GAN's generator from gen.");
updateGen(gen, gan);
gen.save(new File("mnist-mlp-generator.dlj"));
}
//vt.stop();
/*
int j;
for (int i = 0; i < App2Config.EPOCHS; i++) { //epoch
j=0;
while (trainData.hasNext()) {
j++;
DataSet next = trainData.next();
// generate data
INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f);
//start next round if there are not enough images left to have a full batchsize dataset
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
log.warn("Your total number of input images is not a multiple of {}, "
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
break;
}
//if(i%20 == 0) {
// frame2 = visualize(new INDArray[]{real}, BATCHSIZE,
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
//}
//real.divi(255f);
// int batchSize = (int) real.shape()[0];
//INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS);
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
// when generator has TANH as activation - value range is -1 to 1
// when generator has SIGMOID, then range is 0 to 1
fake.addi(1f).divi(2f);
DataSet realSet = new DataSet(real, label_real);
DataSet fakeSet = new DataSet(fake, label_fake);
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
dis.fit(data);
dis.fit(data);
// Update the discriminator in the GAN network
updateGan(gen, dis, gan);
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake));
//Visualize and reporting
if (j % 10 == 1) {
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE];
for (int k = 0; k < samples.length; k++) {
DataSet fakeSet2 = new DataSet(fakeIn, label_fake);
INDArray input = fakeSet2.get(k).getFeatures();
//input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images
input = input.reshape(1, App2Config.INPUT);
//samples[k] = gen.output(input, false);
samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input);
samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS);
//samples[k] =
//samples[k].muli(255f);
}
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
}
}
if (trainData.resetSupported()) {
trainData.reset();
} else {
log.error("Trainingdata {} does not support reset.", trainData.toString());
}
// Copy the GANs generator to gen.
updateGen(gen, gan);
log.info("Updated GAN's generator from gen.");
gen.save(new File("mnist-mlp-generator.dlj"));
}
*/
}
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
for (int i = 0; i < gen.getLayers().length; i++) {
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
}
}
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
}
}
@Test
void testDiskriminator() throws IOException {
MultiLayerNetwork net = new MultiLayerNetwork(App2Config.discriminator());
net.init();
net.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
DataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
DataSet data = null;
for(int i=0;i<App2Config.EPOCHS;i++) {
log.info("Epoch {}", i);
data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
while (trainData.hasNext()) {
INDArray real = trainData.next().getFeatures();
long[] l = new long[]{BATCHSIZE, real.length() / BATCHSIZE};
INDArray fake = Nd4j.rand(l );
DataSet realSet = new DataSet(real, label_real);
DataSet fakeSet = new DataSet(fake, label_fake);
if(real.length()/BATCHSIZE!=784) break;
data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
}
net.fit(data);
trainData.reset();
}
long[] l = new long[]{BATCHSIZE, 784};
INDArray fake = Nd4j.rand(l );
DataSet fakeSet = new DataSet(fake, label_fake);
data = DataSet.merge(Arrays.asList(data, fakeSet));
ExistingDataSetIterator iter = new ExistingDataSetIterator(data);
Evaluation eval = net.evaluate(iter);
log.info( "\n" + eval.confusionMatrix());
}
}

View File

@ -0,0 +1,183 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import static net.brutex.ai.dnn.api.NN.*;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class App2Config {
public static final int INPUT = 100;
public static final int BATCHSIZE=150;
public static final int X_DIM = 28;
public static final int Y_DIM = 28;
public static final int CHANNELS = 1;
public static final int EPOCHS = 50;
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
public static final IUpdater UPDATER_DIS = Adam.builder().learningRate(0.02).beta1(0.5).build();
public static final boolean SHOW_GENERATED = true;
public static final float COLORSPACE = 255f;
final static int OUTPUT_PER_PANEL = 10;
static LayerConfiguration[] genLayerConfig() {
return new LayerConfiguration[] {
/*
DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(),
ActivationLayer.builder().activation(Activation.RELU).build(), /*
Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS)
.kernelSize(2,2)
.stride(1,1)
.padding(0,0)
.convolutionMode(ConvolutionMode.Truncate)
.activation(Activation.RELU)
.hasBias(BIAS).build(),
//BatchNormalization.builder().nOut(CHANNELS).build(),
Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS)
.kernelSize(2,2)
.stride(2,2)
.padding(0,0)
.convolutionMode(ConvolutionMode.Truncate)
.activation(Activation.RELU)
.hasBias(BIAS).build(),
//BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(),
DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
// DropoutLayer.builder(0.001).build(),
DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build() */
dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(256).nOut(512).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(512).nOut(1024).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
dense().nIn(1024).nOut(784).activation(Activation.TANH).build(),
};
}
static LayerConfiguration[] disLayerConfig() {
return new LayerConfiguration[] {/*
Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS)
.build(),
Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS)
.build(),
ActivationLayer.builder().activation(Activation.LEAKYRELU).build(),
BatchNormalization.builder().build(),
OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SIGMOID)
.build()
dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(),
ActivationLayer.builder().activation(Activation.RELU).build(),
DropoutLayer.builder(0.5).build(),
OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT)
.activation(Activation.SIGMOID)
.build() */
dense().nIn(784).nOut(1024).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(1024).nOut(512).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
dense().nIn(512).nOut(256).hasBias(true).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(),
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
};
}
static NeuralNetConfiguration generator() {
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.name("generator")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.seed(42)
.updater(UPDATER)
.weightInit(WeightInit.XAVIER)
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY)
.layersFromArray(App2Config.genLayerConfig())
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
//.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
.build();
conf.init();
return conf;
}
static NeuralNetConfiguration discriminator() {
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.name("discriminator")
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.seed(42)
.updater(UPDATER_DIS)
.weightInit(WeightInit.XAVIER)
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY)
.layersFromArray(disLayerConfig())
//.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
//.dataType(DataType.FLOAT)
.build();
conf.init();
return conf;
}
}

View File

@ -0,0 +1,160 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.gan;
import com.google.inject.Singleton;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.ndarray.INDArray;
import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.color.ColorSpace;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.UUID;
import static net.brutex.gan.App2.OUTPUT_DIR;
import static net.brutex.gan.App2Config.*;
@Slf4j
@Singleton
public class App2Display {
private final JFrame frame = new JFrame();
private final App2GUI display = new App2GUI();
private final JPanel real_panel;
private final JPanel fake_panel;
public App2Display() {
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setContentPane(display.getOverall_panel());
frame.setMinimumSize(new Dimension(300, 20));
frame.pack();
frame.setVisible(true);
real_panel = display.getReal_panel();
fake_panel = display.getGen_panel();
real_panel.setLayout(new GridLayout(4, 4, 8, 8));
fake_panel.setLayout(new GridLayout(4, 4, 8, 8));
}
public void visualize(INDArray[] samples, int batchElements, boolean isOrig) {
for (INDArray sample : samples) {
for(int i = 0; i<batchElements; i++) {
final Image img = this.getImage(sample, i, isOrig);
final ImageIcon icon = new ImageIcon(img);
if(isOrig) {
if(real_panel.getComponents().length>=OUTPUT_PER_PANEL) {
real_panel.remove(0);
}
real_panel.add(new JLabel(icon));
} else {
if(fake_panel.getComponents().length>=OUTPUT_PER_PANEL) {
fake_panel.remove(0);
}
fake_panel.add(new JLabel(icon));
}
}
}
frame.pack();
frame.repaint();
}
public Image getImage(INDArray tensor, int batchElement, boolean isOrig) {
final BufferedImage bi;
if(CHANNELS >1) {
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
} else {
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
}
final int imageSize = X_DIM * Y_DIM;
final int offset = batchElement * imageSize;
int pxl = offset * CHANNELS; //where to start in the INDArray
//Image in NCHW - channels first format
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
for (int y = 0; y < X_DIM; y++) { // step through the columns x
for (int x = 0; x < Y_DIM; x++) { //step through the rows y
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
if(isOrig) log.trace("'{}.'{} Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, isOrig ? "Real" : "Fake", x, y, c, pxl, f_pxl);
bi.getRaster().setSample(x, y, c, f_pxl);
pxl++; //next item in INDArray
}
}
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled);
//if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
return imageScaled;
}
private static void saveImage(Image image, int batchElement, boolean isOrig) {
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
try {
// Save the images to disk
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
log.debug("Images saved successfully.");
} catch (IOException e) {
log.error("Error saving the images: {}", e.getMessage());
}
}
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
File directory = new File(outputDirectory);
if (!directory.exists()) {
directory.mkdir();
}
File outputFile = new File(directory, fileName);
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
}
public static BufferedImage imageToBufferedImage(Image image) {
if (image instanceof BufferedImage) {
return (BufferedImage) image;
}
// Create a buffered image with the same dimensions and transparency as the original image
BufferedImage bufferedImage;
if (CHANNELS > 1) {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
} else {
bufferedImage =
new BufferedImage(
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
}
// Draw the original image onto the buffered image
Graphics2D g2d = bufferedImage.createGraphics();
g2d.drawImage(image, 0, 0, null);
g2d.dispose();
return bufferedImage;
}
}

View File

@ -0,0 +1,61 @@
package net.brutex.gan;
import javax.swing.JPanel;
import javax.swing.JSplitPane;
import javax.swing.JLabel;
import java.awt.BorderLayout;
public class App2GUI extends JPanel {
/**
*
*/
private static final long serialVersionUID = 1L;
private JPanel overall_panel;
private JPanel real_panel;
private JPanel gen_panel;
/**
* Create the panel.
*/
public App2GUI() {
overall_panel = new JPanel();
add(overall_panel);
JSplitPane splitPane = new JSplitPane();
overall_panel.add(splitPane);
JPanel p1 = new JPanel();
splitPane.setLeftComponent(p1);
p1.setLayout(new BorderLayout(0, 0));
JLabel lblNewLabel = new JLabel("Generator");
p1.add(lblNewLabel, BorderLayout.NORTH);
gen_panel = new JPanel();
p1.add(gen_panel, BorderLayout.SOUTH);
JPanel p2 = new JPanel();
splitPane.setRightComponent(p2);
p2.setLayout(new BorderLayout(0, 0));
JLabel lblNewLabel_1 = new JLabel("Real");
p2.add(lblNewLabel_1, BorderLayout.NORTH);
real_panel = new JPanel();
p2.add(real_panel, BorderLayout.SOUTH);
}
public JPanel getOverall_panel() {
return overall_panel;
}
public JPanel getReal_panel() {
return real_panel;
}
public JPanel getGen_panel() {
return gen_panel;
}
}

View File

@ -24,12 +24,15 @@ package net.brutex.gan;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.DropoutLayer; import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -98,7 +101,10 @@ public class MnistSimpleGAN {
return new MultiLayerNetwork(discConf); return new MultiLayerNetwork(discConf);
} }
@Test @Tag("long-running")
public void runTest() throws Exception {
main(null);
}
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
GAN gan = new GAN.Builder() GAN gan = new GAN.Builder()
.generator(MnistSimpleGAN::getGenerator) .generator(MnistSimpleGAN::getGenerator)
@ -108,6 +114,7 @@ public class MnistSimpleGAN {
.updater(UPDATER) .updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100) .gradientNormalizationThreshold(100)
.build(); .build();
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);

View File

@ -25,7 +25,7 @@
# Default logging detail level for all instances of SimpleLogger. # Default logging detail level for all instances of SimpleLogger.
# Must be one of ("trace", "debug", "info", "warn", or "error"). # Must be one of ("trace", "debug", "info", "warn", or "error").
# If not specified, defaults to "info". # If not specified, defaults to "info".
org.slf4j.simpleLogger.defaultLogLevel=trace org.slf4j.simpleLogger.defaultLogLevel=debug
# Logging detail level for a SimpleLogger instance named "xxxxx". # Logging detail level for a SimpleLogger instance named "xxxxx".
# Must be one of ("trace", "debug", "info", "warn", or "error"). # Must be one of ("trace", "debug", "info", "warn", or "error").
@ -42,8 +42,8 @@ org.slf4j.simpleLogger.defaultLogLevel=trace
# If the format is not specified or is invalid, the default format is used. # If the format is not specified or is invalid, the default format is used.
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z. # The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z #org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss #org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
# Set to true if you want to output the current thread name. # Set to true if you want to output the current thread name.
# Defaults to true. # Defaults to true.
org.slf4j.simpleLogger.showThreadName=true #org.slf4j.simpleLogger.showThreadName=true

View File

@ -88,7 +88,7 @@ public class CNN1DTestCases {
.convolutionMode(ConvolutionMode.Same)) .convolutionMode(ConvolutionMode.Same))
.graphBuilder() .graphBuilder()
.addInputs("in") .addInputs("in")
.layer("0", Convolution1DLayer.builder().nOut(32).activation(Activation.TANH).kernelSize(3).stride(1).build(), "in") .layer("0", Convolution1D.builder().nOut(32).activation(Activation.TANH).kernelSize(3).stride(1).build(), "in")
.layer("1", Subsampling1DLayer.builder().kernelSize(2).stride(1).poolingType(SubsamplingLayer.PoolingType.MAX.toPoolingType()).build(), "0") .layer("1", Subsampling1DLayer.builder().kernelSize(2).stride(1).poolingType(SubsamplingLayer.PoolingType.MAX.toPoolingType()).build(), "0")
.layer("2", Cropping1D.builder(1).build(), "1") .layer("2", Cropping1D.builder(1).build(), "1")
.layer("3", ZeroPadding1DLayer.builder(1).build(), "2") .layer("3", ZeroPadding1DLayer.builder(1).build(), "2")

View File

@ -20,7 +20,7 @@ ext {
def javacv = [version:"1.5.7"] def javacv = [version:"1.5.7"]
def opencv = [version: "4.5.5"] def opencv = [version: "4.5.5"]
def leptonica = [version: "1.82.0"] def leptonica = [version: "1.83.0"] //fix, only in javacpp 1.5.9
def junit = [version: "5.9.1"] def junit = [version: "5.9.1"]
def flatbuffers = [version: "1.10.0"] def flatbuffers = [version: "1.10.0"]
@ -71,7 +71,7 @@ dependencies {
// api "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}" // api "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}"
api "org.projectlombok:lombok:1.18.26" api "org.projectlombok:lombok:1.18.28"
/*Logging*/ /*Logging*/
api 'org.slf4j:slf4j-api:2.0.3' api 'org.slf4j:slf4j-api:2.0.3'
@ -118,7 +118,8 @@ dependencies {
api "org.bytedeco:javacv:${javacv.version}" api "org.bytedeco:javacv:${javacv.version}"
api "org.bytedeco:opencv:${opencv.version}-${javacpp.presetsVersion}" api "org.bytedeco:opencv:${opencv.version}-${javacpp.presetsVersion}"
api "org.bytedeco:openblas:${openblas.version}-${javacpp.presetsVersion}" api "org.bytedeco:openblas:${openblas.version}-${javacpp.presetsVersion}"
api "org.bytedeco:leptonica-platform:${leptonica.version}-${javacpp.presetsVersion}" api "org.bytedeco:leptonica-platform:${leptonica.version}-1.5.9"
api "org.bytedeco:leptonica:${leptonica.version}-1.5.9"
api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}" api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}"
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}" api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}"
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}" api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}"
@ -129,6 +130,7 @@ dependencies {
api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}" api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}" api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
api "org.bytedeco:mkl-dnn:0.21.5-${javacpp.presetsVersion}" api "org.bytedeco:mkl-dnn:0.21.5-${javacpp.presetsVersion}"
api "org.bytedeco:mkl:2022.0-${javacpp.presetsVersion}"
api "org.bytedeco:tensorflow:${tensorflow.version}-${javacpp.presetsVersion}" api "org.bytedeco:tensorflow:${tensorflow.version}-${javacpp.presetsVersion}"
api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}" api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}"
api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}" api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}"

View File

@ -28,7 +28,8 @@ dependencies {
implementation "org.bytedeco:javacv" implementation "org.bytedeco:javacv"
implementation "org.bytedeco:opencv" implementation "org.bytedeco:opencv"
implementation group: "org.bytedeco", name: "opencv", classifier: buildTarget implementation group: "org.bytedeco", name: "opencv", classifier: buildTarget
implementation "org.bytedeco:leptonica-platform" //implementation "org.bytedeco:leptonica-platform"
implementation group: "org.bytedeco", name: "leptonica", classifier: buildTarget
implementation "org.bytedeco:hdf5-platform" implementation "org.bytedeco:hdf5-platform"
implementation "commons-io:commons-io" implementation "commons-io:commons-io"

View File

@ -46,7 +46,7 @@ import java.nio.ByteOrder;
import org.bytedeco.leptonica.*; import org.bytedeco.leptonica.*;
import org.bytedeco.opencv.opencv_core.*; import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.leptonica.global.lept.*; import static org.bytedeco.leptonica.global.leptonica.*;
import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgcodecs.*; import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*; import static org.bytedeco.opencv.global.opencv_imgproc.*;

View File

@ -52,10 +52,9 @@ import java.io.InputStream;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.Random; import java.util.Random;
import java.util.stream.IntStream;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.bytedeco.leptonica.global.lept.*; import static org.bytedeco.leptonica.global.leptonica.*;
import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_core.*;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;

View File

@ -2386,7 +2386,11 @@ public interface INDArray extends Serializable, AutoCloseable {
long[] stride(); long[] stride();
/** /**
* Return the ordering (fortran or c 'f' and 'c' respectively) of this ndarray * Return the ordering (fortran or c 'f' and 'c' respectively) of this ndarray <br/><br/>
* C Is Contiguous layout. Mathematically speaking, row major.<br/>
* F Is Fortran contiguous layout. Mathematically speaking, column major.<br/>
* {@see https://en.wikipedia.org/wiki/Row-_and_column-major_order}<br/>
*
* @return the ordering of this ndarray * @return the ordering of this ndarray
*/ */
char ordering(); char ordering();

View File

@ -334,6 +334,7 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet {
public void save(File to) { public void save(File to) {
try (FileOutputStream fos = new FileOutputStream(to, false); try (FileOutputStream fos = new FileOutputStream(to, false);
BufferedOutputStream bos = new BufferedOutputStream(fos)) { BufferedOutputStream bos = new BufferedOutputStream(fos)) {
to.mkdirs();
save(bos); save(bos);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);

View File

@ -5121,7 +5121,7 @@ public class Nd4j {
Nd4j.backend = backend; Nd4j.backend = backend;
updateNd4jContext(); updateNd4jContext();
props = Nd4jContext.getInstance().getConf(); props = Nd4jContext.getInstance().getConf();
logger.info("Properties for Nd4jContext " + props); log.debug("Properties for Nd4jContext {}", props);
PropertyParser pp = new PropertyParser(props); PropertyParser pp = new PropertyParser(props);
String otherDtype = pp.toString(ND4JSystemProperties.DTYPE); String otherDtype = pp.toString(ND4JSystemProperties.DTYPE);

View File

@ -166,10 +166,10 @@ public class DataSetIteratorTest extends BaseDL4JTest {
int seed = 123; int seed = 123;
int listenerFreq = 1; int listenerFreq = 1;
LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples, final LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples,
new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed)); new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed));
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) final var builder = NeuralNetConfiguration.builder().seed(seed)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.layer(0, ConvolutionLayer.builder(5, 5).nIn(numChannels).nOut(6) .layer(0, ConvolutionLayer.builder(5, 5).nIn(numChannels).nOut(6)
@ -181,7 +181,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
.build()) .build())
.inputType(InputType.convolutionalFlat(numRows, numColumns, numChannels)); .inputType(InputType.convolutionalFlat(numRows, numColumns, numChannels));
MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); final MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
model.init(); model.init();
model.addTrainingListeners(new ScoreIterationListener(listenerFreq)); model.addTrainingListeners(new ScoreIterationListener(listenerFreq));

View File

@ -45,6 +45,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution; import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.serde.CavisMapper;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.deeplearning4j.optimize.api.BaseTrainingListener;
@ -924,8 +925,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
}; };
for(EpochTerminationCondition e : etc ){ for(EpochTerminationCondition e : etc ){
String s = NeuralNetConfiguration.mapper().writeValueAsString(e); String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(e);
EpochTerminationCondition c = NeuralNetConfiguration.mapper().readValue(s, EpochTerminationCondition.class); EpochTerminationCondition c = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, EpochTerminationCondition.class);
assertEquals(e, c); assertEquals(e, c);
} }
@ -936,8 +937,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
}; };
for(IterationTerminationCondition i : itc ){ for(IterationTerminationCondition i : itc ){
String s = NeuralNetConfiguration.mapper().writeValueAsString(i); String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(i);
IterationTerminationCondition c = NeuralNetConfiguration.mapper().readValue(s, IterationTerminationCondition.class); IterationTerminationCondition c = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, IterationTerminationCondition.class);
assertEquals(i, c); assertEquals(i, c);
} }
} }

View File

@ -309,7 +309,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
try { try {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict) NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict)
.list()
.layer(0, ConvolutionLayer.builder().kernelSize(2, 3).stride(2, 2).padding(0, 0).nOut(5) .layer(0, ConvolutionLayer.builder().kernelSize(2, 3).stride(2, 2).padding(0, 0).nOut(5)
.build()) .build())
.layer(1, OutputLayer.builder().nOut(10).build()) .layer(1, OutputLayer.builder().nOut(10).build())

View File

@ -77,7 +77,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
NeuralNetConfiguration.builder().updater(new NoOp()) NeuralNetConfiguration.builder().updater(new NoOp())
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
.seed(12345L) .seed(12345L)
.dist(new NormalDistribution(0, 1)).list() .weightInit(new NormalDistribution(0, 1))
.layer(0, DenseLayer.builder().nIn(4).nOut(3) .layer(0, DenseLayer.builder().nIn(4).nOut(3)
.activation(Activation.IDENTITY).build()) .activation(Activation.IDENTITY).build())
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).nOut(3).build()) .layer(1,BatchNormalization.builder().useLogStd(useLogStd).nOut(3).build())
@ -122,7 +122,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
.updater(new NoOp()).seed(12345L) .updater(new NoOp()).seed(12345L)
.dist(new NormalDistribution(0, 2)).list() .dist(new NormalDistribution(0, 2)).list()
.layer(0, ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2) .layer(0, Convolution2D.builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
.activation(Activation.IDENTITY).build()) .activation(Activation.IDENTITY).build())
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).build()) .layer(1,BatchNormalization.builder().useLogStd(useLogStd).build())
.layer(2, ActivationLayer.builder().activation(Activation.TANH).build()) .layer(2, ActivationLayer.builder().activation(Activation.TANH).build())

View File

@ -91,9 +91,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.updater(new NoOp()) .updater(new NoOp())
.dist(new NormalDistribution(0, 1)) .dist(new NormalDistribution(0, 1))
.convolutionMode(ConvolutionMode.Same) .convolutionMode(ConvolutionMode.Same)
.list()
.layer( .layer(
Convolution1DLayer.builder() Convolution1D.builder()
.activation(afn) .activation(afn)
.kernelSize(kernel) .kernelSize(kernel)
.stride(stride) .stride(stride)
@ -202,7 +201,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.dist(new NormalDistribution(0, 1)) .dist(new NormalDistribution(0, 1))
.convolutionMode(ConvolutionMode.Same) .convolutionMode(ConvolutionMode.Same)
.layer( .layer(
Convolution1DLayer.builder() Convolution1D.builder()
.activation(afn) .activation(afn)
.kernelSize(kernel) .kernelSize(kernel)
.stride(stride) .stride(stride)
@ -211,7 +210,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.build()) .build())
.layer(Cropping1D.builder(cropping).build()) .layer(Cropping1D.builder(cropping).build())
.layer( .layer(
Convolution1DLayer.builder() Convolution1D.builder()
.activation(afn) .activation(afn)
.kernelSize(kernel) .kernelSize(kernel)
.stride(stride) .stride(stride)
@ -317,7 +316,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.dist(new NormalDistribution(0, 1)) .dist(new NormalDistribution(0, 1))
.convolutionMode(ConvolutionMode.Same) .convolutionMode(ConvolutionMode.Same)
.layer( .layer(
Convolution1DLayer.builder() Convolution1D.builder()
.activation(afn) .activation(afn)
.kernelSize(kernel) .kernelSize(kernel)
.stride(stride) .stride(stride)
@ -326,7 +325,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.build()) .build())
.layer(ZeroPadding1DLayer.builder(zeroPadding).build()) .layer(ZeroPadding1DLayer.builder(zeroPadding).build())
.layer( .layer(
Convolution1DLayer.builder() Convolution1D.builder()
.activation(afn) .activation(afn)
.kernelSize(kernel) .kernelSize(kernel)
.stride(stride) .stride(stride)
@ -435,10 +434,9 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.updater(new NoOp()) .updater(new NoOp())
.dist(new NormalDistribution(0, 1)) .dist(new NormalDistribution(0, 1))
.convolutionMode(ConvolutionMode.Same) .convolutionMode(ConvolutionMode.Same)
.list()
.layer( .layer(
0, 0,
Convolution1DLayer.builder() Convolution1D.builder()
.activation(afn) .activation(afn)
.kernelSize(kernel) .kernelSize(kernel)
.stride(stride) .stride(stride)
@ -447,7 +445,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.build()) .build())
.layer( .layer(
1, 1,
Convolution1DLayer.builder() Convolution1D.builder()
.activation(afn) .activation(afn)
.kernelSize(kernel) .kernelSize(kernel)
.stride(stride) .stride(stride)
@ -461,6 +459,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.stride(stride) .stride(stride)
.padding(padding) .padding(padding)
.pnorm(pnorm) .pnorm(pnorm)
.name("SubsamplingLayer")
.build()) .build())
.layer( .layer(
3, 3,
@ -548,7 +547,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.seed(12345) .seed(12345)
.list() .list()
.layer( .layer(
Convolution1DLayer.builder() Convolution1D.builder()
.kernelSize(2) .kernelSize(2)
.rnnDataFormat(RNNFormat.NCW) .rnnDataFormat(RNNFormat.NCW)
.stride(stride) .stride(stride)
@ -562,7 +561,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.pnorm(pnorm) .pnorm(pnorm)
.build()) .build())
.layer( .layer(
Convolution1DLayer.builder() Convolution1D.builder()
.kernelSize(2) .kernelSize(2)
.rnnDataFormat(RNNFormat.NCW) .rnnDataFormat(RNNFormat.NCW)
.stride(stride) .stride(stride)
@ -655,7 +654,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.seed(12345) .seed(12345)
.list() .list()
.layer( .layer(
Convolution1DLayer.builder() Convolution1D.builder()
.kernelSize(k) .kernelSize(k)
.dilation(d) .dilation(d)
.hasBias(hasBias) .hasBias(hasBias)
@ -664,7 +663,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
.nOut(convNOut1) .nOut(convNOut1)
.build()) .build())
.layer( .layer(
Convolution1DLayer.builder() Convolution1D.builder()
.kernelSize(k) .kernelSize(k)
.dilation(d) .dilation(d)
.convolutionMode(ConvolutionMode.Causal) .convolutionMode(ConvolutionMode.Causal)

View File

@ -0,0 +1,811 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.gradientcheck;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.Convolution1DUtils;
import org.junit.jupiter.api.Test;
import org.nd4j.common.primitives.Pair;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@Slf4j
public class CNN1DNewGradientCheckTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
static {
Nd4j.setDataType(DataType.DOUBLE);
}
@Test
public void testCnn1D() {
int minibatchSize = 4;
int[] dataChannels = {4, 10}; //the input
int[] kernels = {2,4,5,8};
int stride = 2;
int padding = 3;
int seriesLength = 300;
for (int kernel : kernels) {
for (int dChannels : dataChannels) {
int numLabels = ((seriesLength + (2 * padding) - kernel) / stride) + 1;
final NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1))
.convolutionMode(ConvolutionMode.Same)
.layer(
Convolution1DNew.builder()
.activation(Activation.RELU)
.kernelSize(kernel)
.stride(stride)
.padding(padding)
.nIn(dChannels) // channels
.nOut(3)
.rnnDataFormat(RNNFormat.NCW)
.build())
.layer(
RnnOutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nOut(4)
.build())
.inputType(InputType.recurrent(dChannels, seriesLength))
.build();
INDArray input = Nd4j.rand(minibatchSize, dChannels, seriesLength);
INDArray labels = Nd4j.zeros(minibatchSize, 4, numLabels);
for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < numLabels; j++) {
labels.putScalar(new int[] {i, i % 4, j}, 1.0);
}
}
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg =
"Minibatch="
+ minibatchSize
+ ", activationFn="
+ Activation.RELU
+ ", kernel = "
+ kernel;
System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++)
System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
/**
List<Pair<INDArray, INDArray>> iter = new java.util.ArrayList<>(Collections.emptyList());
iter.add(new Pair<>(input, labels));
for(int x=0;x<100; x++) net.fit(input, labels);
Evaluation eval = net.evaluate(new INDArrayDataSetIterator(iter,2), Arrays.asList(new String[]{"One", "Two", "Three", "Four"}));
// net.fit(input, labels);
eval.eval(labels, net.output(input));
**/
boolean gradOK =
GradientCheckUtil.checkGradients(
net,
DEFAULT_EPS,
DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR,
PRINT_RESULTS,
RETURN_ON_FIRST_FAILURE,
input,
labels);
assertTrue(gradOK, msg);
TestUtils.testModelSerialization(net);
}
}
}
@Test
public void testCnn1DWithLocallyConnected1D() {
Nd4j.getRandom().setSeed(1337);
int[] minibatchSizes = {2, 3};
int length = 25;
int convNIn = 18;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 4;
int[] kernels = {1,2,4};
int stride = 1;
int padding = 0;
Activation[] activations = {Activation.SIGMOID};
for (Activation afn : activations) {
for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) {
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < length; j++) {
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
}
}
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1))
.convolutionMode(ConvolutionMode.Same)
.layer(
Convolution1DNew.builder()
.activation(afn)
.kernelSize(kernel)
.stride(stride)
.padding(padding)
.nIn(convNIn)
.nOut(convNOut1)
.rnnDataFormat(RNNFormat.NCW)
.build())
.layer(
LocallyConnected1D.builder()
.activation(afn)
.kernelSize(kernel)
.stride(stride)
.padding(padding)
.nIn(convNOut1)
.nOut(convNOut2)
.hasBias(false)
.build())
.layer(
RnnOutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nOut(finalNOut)
.build())
.inputType(InputType.recurrent(convNIn, length))
.build();
String json = conf.toJson();
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg =
"Minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("ILayer " + j + " # params: " +
// net.getLayer(j).numParams());
}
boolean gradOK =
GradientCheckUtil.checkGradients(
net,
DEFAULT_EPS,
DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR,
PRINT_RESULTS,
RETURN_ON_FIRST_FAILURE,
input,
labels);
assertTrue(gradOK, msg);
TestUtils.testModelSerialization(net);
}
}
}
}
@Test
public void testCnn1DWithCropping1D() {
Nd4j.getRandom().setSeed(1337);
int[] minibatchSizes = {1, 3};
int length = 7;
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 4;
int[] kernels = {1, 2, 4};
int stride = 1;
int padding = 0;
int cropping = 1;
int croppedLength = length - 2 * cropping;
Activation[] activations = {Activation.SIGMOID};
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[] {
SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG,
SubsamplingLayer.PoolingType.PNORM
};
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) {
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, croppedLength);
for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < croppedLength; j++) {
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
}
}
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1))
.convolutionMode(ConvolutionMode.Same)
.layer(
Convolution1DNew.builder()
.activation(afn)
.kernelSize(kernel)
.stride(stride)
.padding(padding)
.nOut(convNOut1)
.build())
.layer(Cropping1D.builder(cropping).build())
.layer(
Convolution1DNew.builder()
.activation(afn)
.kernelSize(kernel)
.stride(stride)
.padding(padding)
.nOut(convNOut2)
.build())
.layer(
RnnOutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nOut(finalNOut)
.build())
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
.build();
String json = conf.toJson();
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg =
"PoolingType="
+ poolingType
+ ", minibatch="
+ minibatchSize
+ ", activationFn="
+ afn
+ ", kernel = "
+ kernel;
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("ILayer " + j + " # params: " +
// net.getLayer(j).numParams());
}
boolean gradOK =
GradientCheckUtil.checkGradients(
net,
DEFAULT_EPS,
DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR,
PRINT_RESULTS,
RETURN_ON_FIRST_FAILURE,
input,
labels);
assertTrue(gradOK, msg);
TestUtils.testModelSerialization(net);
}
}
}
}
}
@Test
public void testCnn1DWithZeroPadding1D() {
Nd4j.getRandom().setSeed(1337);
int[] minibatchSizes = {1, 3};
int length = 7;
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 4;
int[] kernels = {1, 2, 4};
int stride = 1;
int pnorm = 2;
int padding = 0;
int zeroPadding = 2;
int paddedLength = length + 2 * zeroPadding;
Activation[] activations = {Activation.SIGMOID};
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[] {
SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG,
SubsamplingLayer.PoolingType.PNORM
};
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) {
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, paddedLength);
for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < paddedLength; j++) {
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
}
}
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1))
.convolutionMode(ConvolutionMode.Same)
.layer(
Convolution1DNew.builder()
.activation(afn)
.kernelSize(2, kernel)
.stride(stride)
.padding(padding)
.nOut(convNOut1)
.build())
.layer(ZeroPadding1DLayer.builder(zeroPadding).build())
.layer(
Convolution1DNew.builder()
.activation(afn)
.kernelSize(kernel)
.stride(stride)
.padding(padding)
.nOut(convNOut2)
.build())
.layer(ZeroPadding1DLayer.builder(0).build())
.layer(
Subsampling1DLayer.builder(poolingType)
.kernelSize(kernel)
.stride(stride)
.padding(padding)
.pnorm(pnorm)
.build())
.layer(
RnnOutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nOut(finalNOut)
.build())
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
.build();
String json = conf.toJson();
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg =
"PoolingType="
+ poolingType
+ ", minibatch="
+ minibatchSize
+ ", activationFn="
+ afn
+ ", kernel = "
+ kernel;
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("ILayer " + j + " # params: " +
// net.getLayer(j).numParams());
}
boolean gradOK =
GradientCheckUtil.checkGradients(
net,
DEFAULT_EPS,
DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR,
PRINT_RESULTS,
RETURN_ON_FIRST_FAILURE,
input,
labels);
assertTrue(gradOK, msg);
TestUtils.testModelSerialization(net);
}
}
}
}
}
@Test
public void testCnn1DWithSubsampling1D() {
Nd4j.getRandom().setSeed(12345);
int[] minibatchSizes = {1, 3};
int length = 7;
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 4;
int[] kernels = {1, 2, 4};
int stride = 1;
int padding = 0;
int pnorm = 2;
Activation[] activations = {Activation.SIGMOID, Activation.TANH};
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[] {
SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG,
SubsamplingLayer.PoolingType.PNORM
};
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) {
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < length; j++) {
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
}
}
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1))
.convolutionMode(ConvolutionMode.Same)
.layer(
0,
Convolution1DNew.builder()
.activation(afn)
.kernelSize(kernel)
.stride(stride)
.padding(padding)
.nOut(convNOut1)
.build())
.layer(
1,
Convolution1DNew.builder()
.activation(afn)
.kernelSize(kernel)
.stride(stride)
.padding(padding)
.nOut(convNOut2)
.build())
.layer(
2,
Subsampling1DLayer.builder(poolingType)
.kernelSize(kernel)
.stride(stride)
.padding(padding)
.pnorm(pnorm)
.name("SubsamplingLayer")
.build())
.layer(
3,
RnnOutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nOut(finalNOut)
.build())
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
.build();
String json = conf.toJson();
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg =
"PoolingType="
+ poolingType
+ ", minibatch="
+ minibatchSize
+ ", activationFn="
+ afn
+ ", kernel = "
+ kernel;
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("ILayer " + j + " # params: " +
// net.getLayer(j).numParams());
}
boolean gradOK =
GradientCheckUtil.checkGradients(
net,
DEFAULT_EPS,
DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR,
PRINT_RESULTS,
RETURN_ON_FIRST_FAILURE,
input,
labels);
assertTrue(gradOK, msg);
TestUtils.testModelSerialization(net);
}
}
}
}
}
@Test
public void testCnn1dWithMasking() {
int length = 12;
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 3;
int pnorm = 2;
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[] {
SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG
};
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (ConvolutionMode cm :
new ConvolutionMode[] {ConvolutionMode.Same, ConvolutionMode.Truncate}) {
for (int stride : new int[] {1, 2}) {
String s = cm + ", stride=" + stride + ", pooling=" + poolingType;
log.info("Starting test: " + s);
Nd4j.getRandom().setSeed(12345);
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 1))
.convolutionMode(cm)
.seed(12345)
.layer(
Convolution1DNew.builder()
.kernelSize(2)
.rnnDataFormat(RNNFormat.NCW)
.stride(stride)
.nIn(convNIn)
.nOut(convNOut1)
.build())
.layer(
Subsampling1DLayer.builder(poolingType)
.kernelSize(2)
.stride(stride)
.pnorm(pnorm)
.build())
.layer(
Convolution1DNew.builder()
.kernelSize(2)
.rnnDataFormat(RNNFormat.NCW)
.stride(stride)
.nIn(convNOut1)
.nOut(convNOut2)
.build())
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
.layer(
OutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nOut(finalNOut)
.build())
.inputType(InputType.recurrent(convNIn, length))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray f = Nd4j.rand(2, convNIn, length);
INDArray fm = Nd4j.create(2, length);
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, 6)).assign(1);
INDArray label = TestUtils.randomOneHot(2, finalNOut);
boolean gradOK =
GradientCheckUtil.checkGradients(
new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
assertTrue(gradOK, s);
TestUtils.testModelSerialization(net);
// TODO also check that masked step values don't impact forward pass, score or gradients
DataSet ds = new DataSet(f, label, fm, null);
double scoreBefore = net.score(ds);
net.setInput(f);
net.setLabels(label);
net.setLayerMaskArrays(fm, null);
net.computeGradientAndScore();
INDArray gradBefore = net.getFlattenedGradients().dup();
f.putScalar(1, 0, 10, 10.0);
f.putScalar(1, 1, 11, 20.0);
double scoreAfter = net.score(ds);
net.setInput(f);
net.setLabels(label);
net.setLayerMaskArrays(fm, null);
net.computeGradientAndScore();
INDArray gradAfter = net.getFlattenedGradients().dup();
assertEquals(scoreBefore, scoreAfter, 1e-6);
assertEquals(gradBefore, gradAfter);
}
}
}
}
@Test
public void testCnn1Causal() throws Exception {
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 3;
int[] lengths = {11, 12, 13, 9, 10, 11};
int[] kernels = {2, 3, 2, 4, 2, 3};
int[] dilations = {1, 1, 2, 1, 2, 1};
int[] strides = {1, 2, 1, 2, 1, 1};
boolean[] masks = {false, true, false, true, false, true};
boolean[] hasB = {true, false, true, false, true, true};
for (int i = 0; i < lengths.length; i++) {
int length = lengths[i];
int k = kernels[i];
int d = dilations[i];
int st = strides[i];
boolean mask = masks[i];
boolean hasBias = hasB[i];
// TODO has bias
String s = "k=" + k + ", s=" + st + " d=" + d + ", seqLen=" + length;
log.info("Starting test: " + s);
Nd4j.getRandom().setSeed(12345);
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.activation(Activation.TANH)
.weightInit(new NormalDistribution(0, 1))
.seed(12345)
.layer(
Convolution1DNew.builder()
.kernelSize(k)
.dilation(d)
.hasBias(hasBias)
.convolutionMode(ConvolutionMode.Causal)
.stride(st)
.nOut(convNOut1)
.build())
.layer(
Convolution1DNew.builder()
.kernelSize(k)
.dilation(d)
.convolutionMode(ConvolutionMode.Causal)
.stride(st)
.nOut(convNOut2)
.build())
.layer(
RnnOutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nOut(finalNOut)
.build())
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length);
INDArray fm = null;
if (mask) {
fm = Nd4j.create(2, length);
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length - 2)).assign(1);
}
long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d);
long outSize2 =
Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d);
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int) outSize2);
String msg =
"Minibatch="
+ 1
+ ", activationFn="
+ Activation.RELU
+ ", kernel = "
+ k;
System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++)
System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
boolean gradOK =
GradientCheckUtil.checkGradients(
new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
assertTrue(gradOK, s);
TestUtils.testModelSerialization(net);
}
}
}

View File

@ -112,9 +112,8 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) .updater(new NoOp())
.dist(new NormalDistribution(0, 1)) .dist(new NormalDistribution(0, 1))
.list()
.layer(0, Convolution3D.builder().activation(afn).kernelSize(kernel) .layer(0, Convolution3D.builder().activation(afn).kernelSize(kernel)
.stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false) .stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false)
.convolutionMode(mode).dataFormat(df) .convolutionMode(mode).dataFormat(df)
@ -400,7 +399,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
.dist(new NormalDistribution(0, 1)) .dist(new NormalDistribution(0, 1))
.seed(12345) .seed(12345)
.list()
.layer(0, Convolution3D.builder().activation(afn).kernelSize(1, 1, 1) .layer(0, Convolution3D.builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNIn).nOut(convNOut).hasBias(false) .nIn(convNIn).nOut(convNOut).hasBias(false)
.convolutionMode(mode).dataFormat(df) .convolutionMode(mode).dataFormat(df)

View File

@ -108,8 +108,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.updater(new NoOp()) .updater(new NoOp())
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.seed(12345L) .seed(12345L)
.list()
.layer(0, ConvolutionLayer.builder(1, 1).nOut(6).activation(afn).build()) .layer(0, Convolution2D.builder().kernelSize(1).stride(1).nOut(6).activation(afn).build())
.layer(1, OutputLayer.builder(lf).activation(outputActivation).nOut(3).build()) .layer(1, OutputLayer.builder(lf).activation(outputActivation).nOut(3).build())
.inputType(InputType.convolutionalFlat(1, 4, 1)); .inputType(InputType.convolutionalFlat(1, 4, 1));

View File

@ -32,6 +32,7 @@ import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.serde.CavisMapper;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -336,7 +337,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
// to ensure that we carry the parameters through // to ensure that we carry the parameters through
// the serializer. // the serializer.
try{ try{
ObjectMapper m = NeuralNetConfiguration.mapper(); ObjectMapper m = CavisMapper.getMapper(CavisMapper.Type.JSON);
String s = m.writeValueAsString(lossFunctions[i]); String s = m.writeValueAsString(lossFunctions[i]);
ILossFunction lf2 = m.readValue(s, lossFunctions[i].getClass()); ILossFunction lf2 = m.readValue(s, lossFunctions[i].getClass());
lossFunctions[i] = lf2; lossFunctions[i] = lf2;

View File

@ -180,7 +180,7 @@ public class DTypeTests extends BaseDL4JTest {
Pooling2D.class, //Alias for SubsamplingLayer Pooling2D.class, //Alias for SubsamplingLayer
Convolution2D.class, //Alias for ConvolutionLayer Convolution2D.class, //Alias for ConvolutionLayer
Pooling1D.class, //Alias for Subsampling1D Pooling1D.class, //Alias for Subsampling1D
Convolution1D.class, //Alias for Convolution1DLayer Convolution1D.class, //Alias for Convolution1D
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
)); ));

View File

@ -37,7 +37,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.Convolution2DUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.Timeout;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -1026,7 +1026,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} catch (DL4JInvalidInputException e) { } catch (DL4JInvalidInputException e) {
// e.printStackTrace(); // e.printStackTrace();
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration"), msg); assertTrue(msg.contains(Convolution2DUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration"), msg);
} }
} }
} }

View File

@ -36,7 +36,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.Convolution1D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
@ -921,7 +921,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
NeuralNetConfiguration.builder() NeuralNetConfiguration.builder()
.convolutionMode(ConvolutionMode.Same) .convolutionMode(ConvolutionMode.Same)
.layer( .layer(
Convolution1DLayer.builder() Convolution1D.builder()
.nOut(3) .nOut(3)
.kernelSize(2) .kernelSize(2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -975,7 +975,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
@Test @Test
public void testConv1dCausalAllowed() { public void testConv1dCausalAllowed() {
Convolution1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build(); Convolution1D.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
Subsampling1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build(); Subsampling1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
} }

View File

@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.Convolution2DUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -346,7 +346,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
assertEquals(2, it.getHeight()); assertEquals(2, it.getHeight());
assertEquals(2, it.getWidth()); assertEquals(2, it.getWidth());
assertEquals(dOut, it.getChannels()); assertEquals(dOut, it.getChannels());
int[] outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict); int[] outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
assertEquals(2, outSize[0]); assertEquals(2, outSize[0]);
assertEquals(2, outSize[1]); assertEquals(2, outSize[1]);
@ -357,7 +357,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
assertEquals(2, it.getHeight()); assertEquals(2, it.getHeight());
assertEquals(2, it.getWidth()); assertEquals(2, it.getWidth());
assertEquals(dOut, it.getChannels()); assertEquals(dOut, it.getChannels());
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate); outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
assertEquals(2, outSize[0]); assertEquals(2, outSize[0]);
assertEquals(2, outSize[1]); assertEquals(2, outSize[1]);
@ -367,7 +367,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
assertEquals(3, it.getHeight()); assertEquals(3, it.getHeight());
assertEquals(3, it.getWidth()); assertEquals(3, it.getWidth());
assertEquals(dOut, it.getChannels()); assertEquals(dOut, it.getChannels());
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same); outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
assertEquals(3, outSize[0]); assertEquals(3, outSize[0]);
assertEquals(3, outSize[1]); assertEquals(3, outSize[1]);
@ -397,7 +397,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
System.out.println(e.getMessage()); System.out.println(e.getMessage());
} }
try { try {
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict); outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
fail("Exception expected"); fail("Exception expected");
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println(e.getMessage()); System.out.println(e.getMessage());
@ -409,7 +409,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
assertEquals(1, it.getHeight()); assertEquals(1, it.getHeight());
assertEquals(1, it.getWidth()); assertEquals(1, it.getWidth());
assertEquals(dOut, it.getChannels()); assertEquals(dOut, it.getChannels());
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate); outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
assertEquals(1, outSize[0]); assertEquals(1, outSize[0]);
assertEquals(1, outSize[1]); assertEquals(1, outSize[1]);
@ -419,7 +419,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
assertEquals(2, it.getHeight()); assertEquals(2, it.getHeight());
assertEquals(2, it.getWidth()); assertEquals(2, it.getWidth());
assertEquals(dOut, it.getChannels()); assertEquals(dOut, it.getChannels());
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same); outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
assertEquals(2, outSize[0]); assertEquals(2, outSize[0]);
assertEquals(2, outSize[1]); assertEquals(2, outSize[1]);
} }

View File

@ -732,7 +732,7 @@ public class BatchNormalizationTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.convolutionMode(ConvolutionMode.Same) .convolutionMode(ConvolutionMode.Same)
.layer(rnn ? LSTM.builder().nOut(3).build() : .layer(rnn ? LSTM.builder().nOut(3).build() :
Convolution1DLayer.builder().kernelSize(3).stride(1).nOut(3).build()) Convolution1D.builder().kernelSize(3).stride(1).nOut(3).build())
.layer(BatchNormalization.builder().build()) .layer(BatchNormalization.builder().build())
.layer(RnnOutputLayer.builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build()) .layer(RnnOutputLayer.builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
.inputType(InputType.recurrent(3)) .inputType(InputType.recurrent(3))

View File

@ -52,7 +52,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest {
.graphBuilder() .graphBuilder()
.addInputs(inputName) .addInputs(inputName)
.setOutputs(output) .setOutputs(output)
.layer(conv, Convolution1DLayer.builder(7) .layer(conv, Convolution1D.builder(7)
.convolutionMode(ConvolutionMode.Same) .convolutionMode(ConvolutionMode.Same)
.nOut(input.size(1)) .nOut(input.size(1))
.weightInit(new WeightInitIdentity()) .weightInit(new WeightInitIdentity())

View File

@ -23,6 +23,7 @@ package org.deeplearning4j.regressiontest;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.*; import org.deeplearning4j.nn.conf.distribution.*;
import org.deeplearning4j.nn.conf.serde.CavisMapper;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
@ -38,7 +39,7 @@ public class TestDistributionDeserializer extends BaseDL4JTest {
new Distribution[] {new NormalDistribution(3, 0.5), new UniformDistribution(-2, 1), new Distribution[] {new NormalDistribution(3, 0.5), new UniformDistribution(-2, 1),
new GaussianDistribution(2, 1.0), new BinomialDistribution(10, 0.3)}; new GaussianDistribution(2, 1.0), new BinomialDistribution(10, 0.3)};
ObjectMapper om = NeuralNetConfiguration.mapper(); ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
for (Distribution d : distributions) { for (Distribution d : distributions) {
String json = om.writeValueAsString(d); String json = om.writeValueAsString(d);
@ -50,7 +51,7 @@ public class TestDistributionDeserializer extends BaseDL4JTest {
@Test @Test
public void testDistributionDeserializerLegacyFormat() throws Exception { public void testDistributionDeserializerLegacyFormat() throws Exception {
ObjectMapper om = NeuralNetConfiguration.mapper(); ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
String normalJson = "{\n" + " \"normal\" : {\n" + " \"mean\" : 0.1,\n" String normalJson = "{\n" + " \"normal\" : {\n" + " \"mean\" : 0.1,\n"
+ " \"std\" : 1.2\n" + " }\n" + " }"; + " \"std\" : 1.2\n" + " }\n" + " }";

View File

@ -38,7 +38,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.cuda.BaseCudnnHelper; import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper; import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.Convolution2DUtils;
import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.jita.conf.CudaEnvironment;
@ -681,9 +681,9 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
int[] outSize; int[] outSize;
if (convolutionMode == ConvolutionMode.Same) { if (convolutionMode == ConvolutionMode.Same) {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation outSize = Convolution2DUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation); padding = Convolution2DUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation); int[] padBottomRight = Convolution2DUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
if(!Arrays.equals(padding, padBottomRight)){ if(!Arrays.equals(padding, padBottomRight)){
/* /*
CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
@ -731,7 +731,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
// CuDNN handle // CuDNN handle
} }
} else { } else {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation outSize = Convolution2DUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
} }
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize); return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);

View File

@ -42,7 +42,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils;
import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.Convolution2DUtils;
import org.nd4j.common.primitives.Counter; import org.nd4j.common.primitives.Counter;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
@ -442,8 +442,8 @@ public class KerasModel {
KerasInput kerasInput = (KerasInput) layer; KerasInput kerasInput = (KerasInput) layer;
LayerConfiguration layer1 = layersOrdered.get(kerasLayerIdx + 1).layer; LayerConfiguration layer1 = layersOrdered.get(kerasLayerIdx + 1).layer;
//no dim order, try to pull it from the next layer if there is one //no dim order, try to pull it from the next layer if there is one
if(ConvolutionUtils.layerHasConvolutionLayout(layer1)) { if(Convolution2DUtils.layerHasConvolutionLayout(layer1)) {
CNN2DFormat formatForLayer = ConvolutionUtils.getFormatForLayer(layer1); CNN2DFormat formatForLayer = Convolution2DUtils.getFormatForLayer(layer1);
if(formatForLayer == CNN2DFormat.NCHW) { if(formatForLayer == CNN2DFormat.NCHW) {
dimOrder = KerasLayer.DimOrder.THEANO; dimOrder = KerasLayer.DimOrder.THEANO;
} else if(formatForLayer == CNN2DFormat.NHWC) { } else if(formatForLayer == CNN2DFormat.NHWC) {

View File

@ -52,28 +52,44 @@ public class KerasSequentialModel extends KerasModel {
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration * @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
*/ */
public KerasSequentialModel(KerasModelBuilder modelBuilder) public KerasSequentialModel(KerasModelBuilder modelBuilder)
throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException { throws UnsupportedKerasConfigurationException,
this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(), IOException,
modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(), InvalidKerasConfigurationException {
modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape()); this(
modelBuilder.getModelJson(),
modelBuilder.getModelYaml(),
modelBuilder.getWeightsArchive(),
modelBuilder.getWeightsRoot(),
modelBuilder.getTrainingJson(),
modelBuilder.getTrainingArchive(),
modelBuilder.isEnforceTrainingConfig(),
modelBuilder.getInputShape());
} }
/** /**
* (Not recommended) Constructor for Sequential model from model configuration * (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML),
* (JSON or YAML), training configuration (JSON), weights, and "training mode" * training configuration (JSON), weights, and "training mode" boolean indicator. When built in
* boolean indicator. When built in training mode, certain unsupported configurations * training mode, certain unsupported configurations (e.g., unknown regularizers) will throw
* (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these * Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be
* will generate warnings but will be otherwise ignored. * otherwise ignored.
* *
* @param modelJson model configuration JSON string * @param modelJson model configuration JSON string
* @param modelYaml model configuration YAML string * @param modelYaml model configuration YAML string
* @param trainingJson training configuration JSON string * @param trainingJson training configuration JSON string
* @throws IOException I/O exception * @throws IOException I/O exception
*/ */
public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot, public KerasSequentialModel(
String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig, String modelJson,
String modelYaml,
Hdf5Archive weightsArchive,
String weightsRoot,
String trainingJson,
Hdf5Archive trainingArchive,
boolean enforceTrainingConfig,
int[] inputShape) int[] inputShape)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { throws IOException,
InvalidKerasConfigurationException,
UnsupportedKerasConfigurationException {
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml); Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config); this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
@ -83,19 +99,29 @@ public class KerasSequentialModel extends KerasModel {
/* Determine model configuration type. */ /* Determine model configuration type. */
if (!modelConfig.containsKey(config.getFieldClassName())) if (!modelConfig.containsKey(config.getFieldClassName()))
throw new InvalidKerasConfigurationException( throw new InvalidKerasConfigurationException(
"Could not determine Keras model class (no " + config.getFieldClassName() + " field found)"); "Could not determine Keras model class (no "
+ config.getFieldClassName()
+ " field found)");
this.className = (String) modelConfig.get(config.getFieldClassName()); this.className = (String) modelConfig.get(config.getFieldClassName());
if (!this.className.equals(config.getFieldClassNameSequential())) if (!this.className.equals(config.getFieldClassNameSequential()))
throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential() throw new InvalidKerasConfigurationException(
+ " (found " + this.className + ")"); "Model class name must be "
+ config.getFieldClassNameSequential()
+ " (found "
+ this.className
+ ")");
/* Process layer configurations. */ /* Process layer configurations. */
if (!modelConfig.containsKey(config.getModelFieldConfig())) if (!modelConfig.containsKey(config.getModelFieldConfig()))
throw new InvalidKerasConfigurationException( throw new InvalidKerasConfigurationException(
"Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)"); "Could not find layer configurations (no "
+ config.getModelFieldConfig()
+ " field found)");
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency // Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations.
// "config" is now an object containing a "name" and "layers", the latter contain the same data as before. // For consistency
// "config" is now an object containing a "name" and "layers", the latter contain the same data
// as before.
// This change only affects Sequential models. // This change only affects Sequential models.
List<Object> layerList; List<Object> layerList;
try { try {
@ -105,8 +131,7 @@ public class KerasSequentialModel extends KerasModel {
layerList = (List<Object>) layerMap.get("layers"); layerList = (List<Object>) layerMap.get("layers");
} }
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = prepareLayers(layerList);
prepareLayers(layerList);
this.layers = layerPair.getFirst(); this.layers = layerPair.getFirst();
this.layersOrdered = layerPair.getSecond(); this.layersOrdered = layerPair.getSecond();
@ -116,15 +141,18 @@ public class KerasSequentialModel extends KerasModel {
} else { } else {
/* Add placeholder input layer and update lists of input and output layers. */ /* Add placeholder input layer and update lists of input and output layers. */
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape(); int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!"); Preconditions.checkState(
ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!");
inputLayer = new KerasInput("input1", firstLayerInputShape); inputLayer = new KerasInput("input1", firstLayerInputShape);
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder()); inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
this.layers.put(inputLayer.getName(), inputLayer); this.layers.put(inputLayer.getName(), inputLayer);
this.layersOrdered.add(0, inputLayer); this.layersOrdered.add(0, inputLayer);
} }
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName())); this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
this.outputLayerNames = new ArrayList<>( this.outputLayerNames =
Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName())); new ArrayList<>(
Collections.singletonList(
this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
/* Update each layer's inbound layer list to include (only) previous layer. */ /* Update each layer's inbound layer list to include (only) previous layer. */
KerasLayer prevLayer = null; KerasLayer prevLayer = null;
@ -136,12 +164,13 @@ public class KerasSequentialModel extends KerasModel {
/* Import training configuration. */ /* Import training configuration. */
if (enforceTrainingConfig) { if (enforceTrainingConfig) {
if (trainingJson != null) if (trainingJson != null) importTrainingConfiguration(trainingJson);
importTrainingConfiguration(trainingJson); else
else log.warn("If enforceTrainingConfig is true, a training " + log.warn(
"configuration object has to be provided. Usually the only practical way to do this is to store" + "If enforceTrainingConfig is true, a training "
" your keras model with `model.save('model_path.h5'. If you store model config and weights" + + "configuration object has to be provided. Usually the only practical way to do this is to store"
" separately no training configuration is attached."); + " your keras model with `model.save('model_path.h5'. If you store model config and weights"
+ " separately no training configuration is attached.");
} }
this.outputTypes = inferOutputTypes(inputShape); this.outputTypes = inferOutputTypes(inputShape);
@ -150,9 +179,7 @@ public class KerasSequentialModel extends KerasModel {
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend); importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
} }
/** /** Default constructor */
* Default constructor
*/
public KerasSequentialModel() { public KerasSequentialModel() {
super(); super();
} }
@ -174,14 +201,14 @@ public class KerasSequentialModel extends KerasModel {
throw new InvalidKerasConfigurationException( throw new InvalidKerasConfigurationException(
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")"); "MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder(); NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder =
NeuralNetConfiguration.builder();
if (optimizer != null) { if (optimizer != null) {
modelBuilder.updater(optimizer); modelBuilder.updater(optimizer);
} }
// don't forcibly override for keras import
//don't forcibly override for keras import
modelBuilder.overrideNinUponBuild(false); modelBuilder.overrideNinUponBuild(false);
/* Add layers one at a time. */ /* Add layers one at a time. */
KerasLayer prevLayer = null; KerasLayer prevLayer = null;
@ -192,7 +219,10 @@ public class KerasSequentialModel extends KerasModel {
if (nbInbound != 1) if (nbInbound != 1)
throw new InvalidKerasConfigurationException( throw new InvalidKerasConfigurationException(
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found " "Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
+ nbInbound + " for layer " + layer.getName() + ")"); + nbInbound
+ " for layer "
+ layer.getName()
+ ")");
if (prevLayer != null) { if (prevLayer != null) {
InputType[] inputTypes = new InputType[1]; InputType[] inputTypes = new InputType[1];
InputPreProcessor preprocessor; InputPreProcessor preprocessor;
@ -200,42 +230,44 @@ public class KerasSequentialModel extends KerasModel {
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0)); inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
preprocessor = prevLayer.getInputPreprocessor(inputTypes); preprocessor = prevLayer.getInputPreprocessor(inputTypes);
InputType outputType = preprocessor.getOutputType(inputTypes[0]); InputType outputType = preprocessor.getOutputType(inputTypes[0]);
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild()); layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
} else { } else {
inputTypes[0] = this.outputTypes.get(prevLayer.getName()); inputTypes[0] = this.outputTypes.get(prevLayer.getName());
preprocessor = layer.getInputPreprocessor(inputTypes); preprocessor = layer.getInputPreprocessor(inputTypes);
if(preprocessor != null) { if (preprocessor != null) {
InputType outputType = preprocessor.getOutputType(inputTypes[0]); InputType outputType = preprocessor.getOutputType(inputTypes[0]);
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild()); layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
} else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild());
} }
else if (preprocessor != null) {
layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild());
Map<Integer, InputPreProcessor> map = new HashMap<>();
map.put(layerIndex, preprocessor);
modelBuilder.inputPreProcessors(map);
} }
if (preprocessor != null)
modelBuilder.inputPreProcessor(layerIndex, preprocessor);
} }
modelBuilder.layer(layerIndex++, layer.getLayer()); modelBuilder.layer(layerIndex++, layer.getLayer());
} else if (layer.getVertex() != null) } else if (layer.getVertex() != null)
throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name " throw new InvalidKerasConfigurationException(
+ layer.getClassName() + ", layer name " + layer.getName() + ")"); "Cannot add vertex to NeuralNetConfiguration (class name "
+ layer.getClassName()
+ ", layer name "
+ layer.getName()
+ ")");
prevLayer = layer; prevLayer = layer;
} }
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */ /* Whether to use standard backprop (or BPTT) or truncated BPTT. */
if (this.useTruncatedBPTT && this.truncatedBPTT > 0) if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
modelBuilder.backpropType(BackpropType.TruncatedBPTT) modelBuilder
.backpropType(BackpropType.TruncatedBPTT)
.tbpttFwdLength(truncatedBPTT) .tbpttFwdLength(truncatedBPTT)
.tbpttBackLength(truncatedBPTT); .tbpttBackLength(truncatedBPTT);
else else modelBuilder.backpropType(BackpropType.Standard);
modelBuilder.backpropType(BackpropType.Standard);
NeuralNetConfiguration build = modelBuilder.build(); NeuralNetConfiguration build = modelBuilder.build();
return build; return build;
} }

View File

@ -23,7 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.Convolution1D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
@ -84,29 +84,29 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
ConvolutionLayer.ConvolutionLayerBuilder builder = Convolution1DLayer.builder().name(this.name) var builder = Convolution1D.builder().name(this.name)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(init) .weightInit(init)
.dilation(getDilationRate(layerConfig, 1, conf, true)[0]) .dilation(getDilationRate(layerConfig, 1, conf, true)[0])
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion))
.hasBias(hasBias) .hasBias(hasBias)
.rnnDataFormat(dimOrder == DimOrder.TENSORFLOW ? RNNFormat.NWC : RNNFormat.NCW) .rnnDataFormat(dimOrder == DimOrder.TENSORFLOW ? RNNFormat.NWC : RNNFormat.NCW)
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]); .stride(getStrideFromConfig(layerConfig, 1, conf));
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
if (hasBias) if (hasBias)
builder.biasInit(0.0); builder.biasInit(0.0);
if (padding != null) if (padding != null)
builder.padding(padding[0]); builder.padding(padding);
if (biasConstraint != null) if (biasConstraint != null)
builder.constrainBias(biasConstraint); builder.constrainBias(biasConstraint);
if (weightConstraint != null) if (weightConstraint != null)
builder.constrainWeights(weightConstraint); builder.constrainWeights(weightConstraint);
this.layer = builder.build(); this.layer = builder.build();
Convolution1DLayer convolution1DLayer = (Convolution1DLayer) layer; Convolution1D convolution1D = (Convolution1D) layer;
convolution1DLayer.setDefaultValueOverriden(true); convolution1D.setDefaultValueOverriden(true);
} }
/** /**
@ -114,8 +114,8 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
* *
* @return ConvolutionLayer * @return ConvolutionLayer
*/ */
public Convolution1DLayer getAtrousConvolution1D() { public Convolution1D getAtrousConvolution1D() {
return (Convolution1DLayer) this.layer; return (Convolution1D) this.layer;
} }
/** /**

View File

@ -24,6 +24,7 @@ import lombok.val;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution2D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
@ -85,7 +86,7 @@ public class KerasAtrousConvolution2D extends KerasConvolution {
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
val builder = ConvolutionLayer.builder().name(this.name) val builder = Convolution2D.builder().name(this.name)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(init) .weightInit(init)

View File

@ -28,7 +28,7 @@ import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.Convolution1D;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
@ -93,7 +93,7 @@ public class KerasConvolution1D extends KerasConvolution {
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
Convolution1DLayer.Convolution1DLayerBuilder builder = Convolution1DLayer.builder().name(this.name) var builder = Convolution1D.builder().name(this.name)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(init) .weightInit(init)
@ -125,9 +125,9 @@ public class KerasConvolution1D extends KerasConvolution {
this.layer = builder.build(); this.layer = builder.build();
//set this in order to infer the dimensional format //set this in order to infer the dimensional format
Convolution1DLayer convolution1DLayer = (Convolution1DLayer) this.layer; Convolution1D convolution1D = (Convolution1D) this.layer;
convolution1DLayer.setDataFormat(dimOrder == DimOrder.TENSORFLOW ? CNN2DFormat.NHWC : CNN2DFormat.NCHW); convolution1D.setDataFormat(dimOrder == DimOrder.TENSORFLOW ? CNN2DFormat.NHWC : CNN2DFormat.NCHW);
convolution1DLayer.setDefaultValueOverriden(true); convolution1D.setDefaultValueOverriden(true);
} }
/** /**
@ -135,8 +135,8 @@ public class KerasConvolution1D extends KerasConvolution {
* *
* @return ConvolutionLayer * @return ConvolutionLayer
*/ */
public Convolution1DLayer getConvolution1DLayer() { public Convolution1D getConvolution1DLayer() {
return (Convolution1DLayer) this.layer; return (Convolution1D) this.layer;
} }

View File

@ -28,6 +28,7 @@ import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution2D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
@ -95,7 +96,7 @@ public class KerasConvolution2D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
final var builder = ConvolutionLayer.builder().name(this.name) final var builder = Convolution2D.builder().name(this.name)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
.weightInit(init) .weightInit(init)

View File

@ -23,6 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.serde.CavisMapper;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
@ -41,8 +42,8 @@ public class JsonTest extends BaseDL4JTest {
}; };
for(InputPreProcessor p : pp ){ for(InputPreProcessor p : pp ){
String s = NeuralNetConfiguration.mapper().writeValueAsString(p); String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(p);
InputPreProcessor p2 = NeuralNetConfiguration.mapper().readValue(s, InputPreProcessor.class); InputPreProcessor p2 = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, InputPreProcessor.class);
assertEquals(p, p2); assertEquals(p, p2);
} }

View File

@ -29,11 +29,8 @@ import org.deeplearning4j.gradientcheck.GradientCheckUtil;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.Convolution1D;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive; import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
@ -656,7 +653,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, null, null); true, true, false, null, null);
Layer l = net.getLayer(0); Layer l = net.getLayer(0);
Convolution1DLayer c1d = (Convolution1DLayer) l.getTrainingConfig(); Convolution1D c1d = (Convolution1D) l.getTrainingConfig();
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode()); assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
} }
} }

View File

@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.Convolution1D;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
@ -97,7 +97,7 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest {
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D(); Convolution1D layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getName()); assertEquals(LAYER_NAME, layer.getName());
assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(INIT_DL4J, layer.getWeightInit());

View File

@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.Convolution1D;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils; import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
@ -119,7 +119,7 @@ public class KerasConvolution1DTest extends BaseDL4JTest {
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer(); Convolution1D layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getName()); assertEquals(LAYER_NAME, layer.getName());
assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(INIT_DL4J, layer.getWeightInit());

View File

@ -22,8 +22,6 @@
package net.brutex.ai.dnn.api; package net.brutex.ai.dnn.api;
import java.io.Serializable; import java.io.Serializable;
import java.util.List;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
public interface INeuralNetworkConfiguration extends Serializable, Cloneable { public interface INeuralNetworkConfiguration extends Serializable, Cloneable {

View File

@ -31,9 +31,11 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
public class NN { public class NN {
public static NeuralNetConfigurationBuilder<?, ?> net() { public static NeuralNetConfigurationBuilder<?, ?> nn() {
return NeuralNetConfiguration.builder(); return NeuralNetConfiguration.builder();
} }
public static DenseLayer.DenseLayerBuilder<?,?> dense() { return DenseLayer.builder(); }
} }

View File

@ -23,7 +23,6 @@ package net.brutex.ai.dnn.networks;
import java.io.Serializable; import java.io.Serializable;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
@ -33,7 +32,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
/** /**
* Artificial Neural Network An artificial neural network (1) takes some input data, and (2) * Artificial Neural Network An artificial neural network (1) takes some input data, and (2)
* transforms this input data by calculating a weighted sum over the inputs and (3) applies a * transforms this input data by calculating a weighted sum over the inputs and (3) applies a

View File

@ -20,6 +20,10 @@
package org.deeplearning4j.earlystopping; package org.deeplearning4j.earlystopping;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.api.IModel;
@ -30,11 +34,6 @@ import org.deeplearning4j.earlystopping.termination.IterationTerminationConditio
import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.nd4j.common.function.Supplier; import org.nd4j.common.function.Supplier;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
public class EarlyStoppingConfiguration<T extends IModel> implements Serializable { public class EarlyStoppingConfiguration<T extends IModel> implements Serializable {

View File

@ -20,16 +20,15 @@
package org.deeplearning4j.earlystopping; package org.deeplearning4j.earlystopping;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.io.IOException;
import java.io.Serializable;
import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver; import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver; import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver; import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.io.IOException;
import java.io.Serializable;
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = InMemoryModelSaver.class, name = "InMemoryModelSaver"), @JsonSubTypes(value = {@JsonSubTypes.Type(value = InMemoryModelSaver.class, name = "InMemoryModelSaver"),

View File

@ -20,11 +20,10 @@
package org.deeplearning4j.earlystopping; package org.deeplearning4j.earlystopping;
import lombok.Data;
import net.brutex.ai.dnn.api.IModel;
import java.io.Serializable; import java.io.Serializable;
import java.util.Map; import java.util.Map;
import lombok.Data;
import net.brutex.ai.dnn.api.IModel;
@Data @Data
public class EarlyStoppingResult<T extends IModel> implements Serializable { public class EarlyStoppingResult<T extends IModel> implements Serializable {

View File

@ -20,10 +20,9 @@
package org.deeplearning4j.earlystopping.saver; package org.deeplearning4j.earlystopping.saver;
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
import net.brutex.ai.dnn.api.IModel;
import java.io.IOException; import java.io.IOException;
import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
public class InMemoryModelSaver<T extends IModel> implements EarlyStoppingModelSaver<T> { public class InMemoryModelSaver<T extends IModel> implements EarlyStoppingModelSaver<T> {

View File

@ -20,15 +20,14 @@
package org.deeplearning4j.earlystopping.saver; package org.deeplearning4j.earlystopping.saver;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver; import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
public class LocalFileGraphSaver implements EarlyStoppingModelSaver<ComputationGraph> { public class LocalFileGraphSaver implements EarlyStoppingModelSaver<ComputationGraph> {
private static final String BEST_GRAPH_BIN = "bestGraph.bin"; private static final String BEST_GRAPH_BIN = "bestGraph.bin";

View File

@ -20,15 +20,14 @@
package org.deeplearning4j.earlystopping.saver; package org.deeplearning4j.earlystopping.saver;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver; import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
public class LocalFileModelSaver implements EarlyStoppingModelSaver<MultiLayerNetwork> { public class LocalFileModelSaver implements EarlyStoppingModelSaver<MultiLayerNetwork> {
private static final String BEST_MODEL_BIN = "bestModel.bin"; private static final String BEST_MODEL_BIN = "bestModel.bin";

View File

@ -26,11 +26,11 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder; import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
public class AutoencoderScoreCalculator extends BaseScoreCalculator<IModel> { public class AutoencoderScoreCalculator extends BaseScoreCalculator<IModel> {

View File

@ -20,8 +20,9 @@
package org.deeplearning4j.earlystopping.scorecalc; package org.deeplearning4j.earlystopping.scorecalc;
import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator; import com.fasterxml.jackson.annotation.JsonProperty;
import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -29,7 +30,6 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import com.fasterxml.jackson.annotation.JsonProperty;
public class DataSetLossCalculator extends BaseScoreCalculator<IModel> { public class DataSetLossCalculator extends BaseScoreCalculator<IModel> {

View File

@ -20,6 +20,8 @@
package org.deeplearning4j.earlystopping.scorecalc; package org.deeplearning4j.earlystopping.scorecalc;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
@ -27,8 +29,6 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
@NoArgsConstructor @NoArgsConstructor
@Deprecated @Deprecated

View File

@ -20,12 +20,11 @@
package org.deeplearning4j.earlystopping.scorecalc; package org.deeplearning4j.earlystopping.scorecalc;
import net.brutex.ai.dnn.api.IModel;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.io.Serializable; import java.io.Serializable;
import net.brutex.ai.dnn.api.IModel;
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)

View File

@ -26,11 +26,11 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<IModel> { public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<IModel> {

View File

@ -20,9 +20,9 @@
package org.deeplearning4j.earlystopping.scorecalc.base; package org.deeplearning4j.earlystopping.scorecalc.base;
import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator; import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;

View File

@ -21,8 +21,8 @@
package org.deeplearning4j.earlystopping.scorecalc.base; package org.deeplearning4j.earlystopping.scorecalc.base;
import lombok.NonNull; import lombok.NonNull;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;

View File

@ -20,8 +20,8 @@
package org.deeplearning4j.earlystopping.termination; package org.deeplearning4j.earlystopping.termination;
import lombok.Data;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
@Data @Data
public class BestScoreEpochTerminationCondition implements EpochTerminationCondition { public class BestScoreEpochTerminationCondition implements EpochTerminationCondition {

View File

@ -22,9 +22,7 @@ package org.deeplearning4j.earlystopping.termination;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.io.Serializable; import java.io.Serializable;
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")

View File

@ -22,7 +22,6 @@ package org.deeplearning4j.earlystopping.termination;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.io.Serializable; import java.io.Serializable;
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")

View File

@ -20,10 +20,10 @@
package org.deeplearning4j.earlystopping.termination; package org.deeplearning4j.earlystopping.termination;
import lombok.Data;
import lombok.NoArgsConstructor;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import lombok.NoArgsConstructor;
@NoArgsConstructor @NoArgsConstructor
@Data @Data

View File

@ -20,8 +20,8 @@
package org.deeplearning4j.earlystopping.termination; package org.deeplearning4j.earlystopping.termination;
import lombok.Data;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
@Data @Data
public class MaxScoreIterationTerminationCondition implements IterationTerminationCondition { public class MaxScoreIterationTerminationCondition implements IterationTerminationCondition {

View File

@ -20,10 +20,9 @@
package org.deeplearning4j.earlystopping.termination; package org.deeplearning4j.earlystopping.termination;
import lombok.Data;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import lombok.Data;
/**Terminate training based on max time. /**Terminate training based on max time.
*/ */

View File

@ -20,9 +20,9 @@
package org.deeplearning4j.earlystopping.termination; package org.deeplearning4j.earlystopping.termination;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import com.fasterxml.jackson.annotation.JsonProperty;
@Slf4j @Slf4j
@Data @Data

View File

@ -20,6 +20,12 @@
package org.deeplearning4j.earlystopping.trainer; package org.deeplearning4j.earlystopping.trainer;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult; import org.deeplearning4j.earlystopping.EarlyStoppingResult;
@ -40,13 +46,6 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
public abstract class BaseEarlyStoppingTrainer<T extends IModel> implements IEarlyStoppingTrainer<T> { public abstract class BaseEarlyStoppingTrainer<T extends IModel> implements IEarlyStoppingTrainer<T> {
private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class); private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class);

View File

@ -20,7 +20,6 @@
package org.deeplearning4j.earlystopping.trainer; package org.deeplearning4j.earlystopping.trainer;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;

View File

@ -20,6 +20,13 @@
package org.deeplearning4j.eval; package org.deeplearning4j.eval;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.common.primitives.AtomicBoolean; import org.nd4j.common.primitives.AtomicBoolean;
@ -28,14 +35,6 @@ import org.nd4j.common.primitives.serde.JsonDeserializerAtomicBoolean;
import org.nd4j.common.primitives.serde.JsonDeserializerAtomicDouble; import org.nd4j.common.primitives.serde.JsonDeserializerAtomicDouble;
import org.nd4j.common.primitives.serde.JsonSerializerAtomicBoolean; import org.nd4j.common.primitives.serde.JsonSerializerAtomicBoolean;
import org.nd4j.common.primitives.serde.JsonSerializerAtomicDouble; import org.nd4j.common.primitives.serde.JsonSerializerAtomicDouble;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
@Deprecated @Deprecated
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)

View File

@ -20,15 +20,8 @@
package org.deeplearning4j.eval; package org.deeplearning4j.eval;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import lombok.Getter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Deprecated @Deprecated
public class ConfusionMatrix<T extends Comparable<? super T>> extends org.nd4j.evaluation.classification.ConfusionMatrix<T> { public class ConfusionMatrix<T extends Comparable<? super T>> extends org.nd4j.evaluation.classification.ConfusionMatrix<T> {

View File

@ -20,14 +20,11 @@
package org.deeplearning4j.eval; package org.deeplearning4j.eval;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import org.nd4j.evaluation.EvaluationAveraging;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Deprecated @Deprecated

View File

@ -20,9 +20,9 @@
package org.deeplearning4j.eval; package org.deeplearning4j.eval;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import com.fasterxml.jackson.annotation.JsonProperty;
@Deprecated @Deprecated
@Getter @Getter

View File

@ -20,11 +20,10 @@
package org.deeplearning4j.eval; package org.deeplearning4j.eval;
import java.util.List;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import java.util.List;
@Deprecated @Deprecated
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)

View File

@ -20,10 +20,10 @@
package org.deeplearning4j.eval.curves; package org.deeplearning4j.eval.curves;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import org.nd4j.evaluation.curves.BaseHistogram; import org.nd4j.evaluation.curves.BaseHistogram;
import com.fasterxml.jackson.annotation.JsonProperty;
@Deprecated @Deprecated
@Data @Data

View File

@ -20,13 +20,9 @@
package org.deeplearning4j.eval.curves; package org.deeplearning4j.eval.curves;
import com.google.common.base.Preconditions; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays;
@Deprecated @Deprecated
@Data @Data

View File

@ -20,8 +20,8 @@
package org.deeplearning4j.eval.curves; package org.deeplearning4j.eval.curves;
import lombok.NonNull;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.NonNull;
@Deprecated @Deprecated
public class ReliabilityDiagram extends org.nd4j.evaluation.curves.ReliabilityDiagram { public class ReliabilityDiagram extends org.nd4j.evaluation.curves.ReliabilityDiagram {

View File

@ -20,10 +20,9 @@
package org.deeplearning4j.eval.curves; package org.deeplearning4j.eval.curves;
import com.google.common.base.Preconditions; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import com.fasterxml.jackson.annotation.JsonProperty;
@Deprecated @Deprecated
@Data @Data

View File

@ -20,7 +20,6 @@
package org.deeplearning4j.eval.meta; package org.deeplearning4j.eval.meta;
import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
@Data @Data

View File

@ -20,6 +20,7 @@
package org.deeplearning4j.nn.adapters; package org.deeplearning4j.nn.adapters;
import java.util.List;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
@ -32,8 +33,6 @@ import org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import java.util.List;
@Builder @Builder
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor

View File

@ -21,7 +21,6 @@
package org.deeplearning4j.nn.api; package org.deeplearning4j.nn.api;
import lombok.Getter;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;

View File

@ -20,14 +20,12 @@
package org.deeplearning4j.nn.api; package org.deeplearning4j.nn.api;
import java.util.List;
import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.api.IModel;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.util.List;
public interface Classifier extends IModel { public interface Classifier extends IModel {

View File

@ -20,13 +20,12 @@
package org.deeplearning4j.nn.api; package org.deeplearning4j.nn.api;
import java.util.List;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.List;
public interface ITraininableLayerConfiguration { public interface ITraininableLayerConfiguration {
/** /**

View File

@ -21,7 +21,7 @@
package org.deeplearning4j.nn.api; package org.deeplearning4j.nn.api;
import java.util.Map; import java.io.Serializable;
import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -29,10 +29,8 @@ import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.Serializable;
/** /**
* A layer is the highest-level building block in deep learning. A layer is a container that usually * A layer is the highest-level building block in deep learning. A layer is a container that usually

View File

@ -20,13 +20,12 @@
package org.deeplearning4j.nn.api; package org.deeplearning4j.nn.api;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.List;
import java.util.Map;
/** /**
* Param initializer for a layer * Param initializer for a layer
* *

View File

@ -20,11 +20,10 @@
package org.deeplearning4j.nn.api; package org.deeplearning4j.nn.api;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.io.Serializable; import java.io.Serializable;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
/** /**
* Update the model * Update the model

View File

@ -22,8 +22,8 @@ package org.deeplearning4j.nn.api.layers;
import org.deeplearning4j.nn.api.Classifier; import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
public interface IOutputLayer extends Layer, Classifier { public interface IOutputLayer extends Layer, Classifier {

View File

@ -20,11 +20,10 @@
package org.deeplearning4j.nn.api.layers; package org.deeplearning4j.nn.api.layers;
import org.deeplearning4j.nn.api.Layer;
import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.io.Serializable; import java.io.Serializable;
import java.util.Set; import java.util.Set;
import org.deeplearning4j.nn.api.Layer;
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface LayerConstraint extends Cloneable, Serializable { public interface LayerConstraint extends Cloneable, Serializable {

View File

@ -20,13 +20,12 @@
package org.deeplearning4j.nn.api.layers; package org.deeplearning4j.nn.api.layers;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import java.util.Map; import org.nd4j.linalg.api.ndarray.INDArray;
public interface RecurrentLayer extends Layer { public interface RecurrentLayer extends Layer {

View File

@ -20,6 +20,12 @@
package org.deeplearning4j.nn.conf; package org.deeplearning4j.nn.conf;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
import java.io.IOException;
import java.io.Serializable;
import java.util.*;
import lombok.*; import lombok.*;
import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.GraphVertex;
@ -34,6 +40,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
import org.deeplearning4j.nn.conf.serde.CavisMapper;
import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
@ -42,16 +49,9 @@ import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.Serializable;
import java.util.*;
@Data @Data
@EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"}) @EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"})
@AllArgsConstructor(access = AccessLevel.PRIVATE) @AllArgsConstructor(access = AccessLevel.PRIVATE)
@ -110,7 +110,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
* @return YAML representation of configuration * @return YAML representation of configuration
*/ */
public String toYaml() { public String toYaml() {
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
synchronized (mapper) { synchronized (mapper) {
try { try {
return mapper.writeValueAsString(this); return mapper.writeValueAsString(this);
@ -127,7 +127,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
* @return {@link ComputationGraphConfiguration} * @return {@link ComputationGraphConfiguration}
*/ */
public static ComputationGraphConfiguration fromYaml(String json) { public static ComputationGraphConfiguration fromYaml(String json) {
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
try { try {
return mapper.readValue(json, ComputationGraphConfiguration.class); return mapper.readValue(json, ComputationGraphConfiguration.class);
} catch (IOException e) { } catch (IOException e) {
@ -140,7 +140,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
*/ */
public String toJson() { public String toJson() {
//As per NeuralNetConfiguration.toJson() //As per NeuralNetConfiguration.toJson()
ObjectMapper mapper = NeuralNetConfiguration.mapper(); ObjectMapper mapper =CavisMapper.getMapper(CavisMapper.Type.JSON);
synchronized (mapper) { synchronized (mapper) {
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243 //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
@ -160,7 +160,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
*/ */
public static ComputationGraphConfiguration fromJson(String json) { public static ComputationGraphConfiguration fromJson(String json) {
//As per NeuralNetConfiguration.fromJson() //As per NeuralNetConfiguration.fromJson()
ObjectMapper mapper = NeuralNetConfiguration.mapper(); ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.JSON);
ComputationGraphConfiguration conf; ComputationGraphConfiguration conf;
try { try {
conf = mapper.readValue(json, ComputationGraphConfiguration.class); conf = mapper.readValue(json, ComputationGraphConfiguration.class);

View File

@ -19,10 +19,10 @@
*/ */
package org.deeplearning4j.nn.conf; package org.deeplearning4j.nn.conf;
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
@JsonSerialize(using = DataFormatSerializer.class) @JsonSerialize(using = DataFormatSerializer.class)
@JsonDeserialize(using = DataFormatDeserializer.class) @JsonDeserialize(using = DataFormatDeserializer.class)

Some files were not shown because too many files have changed in this diff Show More