Compare commits
26 Commits
enhance-bu
...
master
Author | SHA1 | Date |
---|---|---|
Brian Rosenberger | 1c1ec071ef | |
Brian Rosenberger | 74ad5087c1 | |
Brian Rosenberger | acae3944ec | |
Brian Rosenberger | be7cd6b930 | |
Brian Rosenberger | 99aed71ffa | |
Brian Rosenberger | 2df8ea06e0 | |
Brian Rosenberger | 090c5ab2eb | |
Brian Rosenberger | a40d5aa7cf | |
Brian Rosenberger | d2972e4f24 | |
Brian Rosenberger | 704f4860d5 | |
Brian Rosenberger | d5728cbd8e | |
Brian Rosenberger | d40c044df4 | |
Brian Rosenberger | a6c4a16d9a | |
Brian Rosenberger | 0e4be5c4d2 | |
Brian Rosenberger | f7be1e324f | |
Brian Rosenberger | 1c3496ad84 | |
Brian Rosenberger | 3ea555b645 | |
Brian Rosenberger | e11568605d | |
Brian Rosenberger | 9f0682eb75 | |
Brian Rosenberger | ca127d8b88 | |
Brian Rosenberger | deb436036b | |
Brian Rosenberger | 1f2bfb36a5 | |
Brian Rosenberger | b477b71325 | |
Brian Rosenberger | d75e0be506 | |
Brian Rosenberger | 318cafb6f0 | |
Brian Rosenberger | 24466a8fd4 |
|
@ -1,4 +1,4 @@
|
||||||
FROM nvidia/cuda:11.4.0-cudnn8-devel-ubuntu20.04
|
FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
|
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
|
||||||
|
@ -11,5 +11,10 @@ RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.
|
||||||
rm cmake-3.24.2-linux-x86_64.sh
|
rm cmake-3.24.2-linux-x86_64.sh
|
||||||
|
|
||||||
|
|
||||||
|
RUN echo "/usr/local/cuda/compat/" >> /etc/ld.so.conf.d/cuda-driver.conf
|
||||||
|
|
||||||
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
|
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
|
||||||
|
|
||||||
|
RUN ldconfig -p | grep cuda
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,8 @@ pom.xml.versionsBackup
|
||||||
pom.xml.next
|
pom.xml.next
|
||||||
release.properties
|
release.properties
|
||||||
*dependency-reduced-pom.xml
|
*dependency-reduced-pom.xml
|
||||||
|
**/build/*
|
||||||
|
.gradle/*
|
||||||
|
|
||||||
# Specific for Nd4j
|
# Specific for Nd4j
|
||||||
*.md5
|
*.md5
|
||||||
|
@ -83,3 +85,14 @@ bruai4j-native-common/cmake*
|
||||||
/bruai4j-native/bruai4j-native-common/blasbuild/
|
/bruai4j-native/bruai4j-native-common/blasbuild/
|
||||||
/bruai4j-native/bruai4j-native-common/build/
|
/bruai4j-native/bruai4j-native-common/build/
|
||||||
/cavis-native/cavis-native-lib/blasbuild/
|
/cavis-native/cavis-native-lib/blasbuild/
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/classes/org.deeplearning4j.gradientcheck.AttentionLayerTest.html
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/base-style.css
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/style.css
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/js/report.js
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/packages/org.deeplearning4j.gradientcheck.html
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/index.html
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/resources/main/iris.dat
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/resources/test/junit-platform.properties
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/resources/test/logback-test.xml
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/test-results/cudaTest/TEST-org.deeplearning4j.gradientcheck.AttentionLayerTest.xml
|
||||||
|
/cavis-dnn/cavis-dnn-core/build/tmp/jar/MANIFEST.MF
|
||||||
|
|
|
@ -35,7 +35,7 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cpu') {
|
stage('build-linux-cpu') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ pipeline {
|
||||||
}*/
|
}*/
|
||||||
stage('publish-linux-cpu') {
|
stage('publish-linux-cpu') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,4 +79,9 @@ pipeline {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
post {
|
||||||
|
always {
|
||||||
|
junit '**/build/test-results/**/*.xml'
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,13 +21,15 @@
|
||||||
|
|
||||||
pipeline {
|
pipeline {
|
||||||
agent {
|
agent {
|
||||||
dockerfile {
|
/* dockerfile {
|
||||||
filename 'Dockerfile'
|
filename 'Dockerfile'
|
||||||
dir '.docker'
|
dir '.docker'
|
||||||
label 'linux && cuda'
|
label 'linux && cuda'
|
||||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||||
//args '--gpus all' --needed for test only, you can build without GPU
|
//args '--gpus all' --needed for test only, you can build without GPU
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
label 'linux && cuda'
|
||||||
}
|
}
|
||||||
|
|
||||||
stages {
|
stages {
|
||||||
|
@ -43,13 +45,13 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cuda') {
|
stage('build-linux-cuda') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
steps {
|
steps {
|
||||||
withGradle {
|
withGradle {
|
||||||
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \
|
sh 'sh ./gradlew build --stacktrace -PCAVIS_CHIP=cuda \
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||||
}
|
}
|
||||||
|
@ -57,4 +59,10 @@ pipeline {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
post {
|
||||||
|
always {
|
||||||
|
junit '**/build/test-results/**/*.xml'
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cuda') {
|
stage('build-linux-cuda') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cpu') {
|
stage('build-linux-cpu') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,4 +85,9 @@ pipeline {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
post {
|
||||||
|
always {
|
||||||
|
junit '**/build/test-results/**/*.xml'
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ pipeline {
|
||||||
stages {
|
stages {
|
||||||
stage('publish-linux-cpu') {
|
stage('publish-linux-cpu') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ pipeline {
|
||||||
dir '.docker'
|
dir '.docker'
|
||||||
label 'linux && docker && cuda'
|
label 'linux && docker && cuda'
|
||||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||||
//args '--gpus all' --needed for test only, you can build without GPU
|
args '--gpus all' //needed for test only, you can build without GPU
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cuda') {
|
stage('build-linux-cuda') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,5 +56,26 @@ pipeline {
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
stage('test-linux-cuda') {
|
||||||
|
environment {
|
||||||
|
MAVEN = credentials('Internal_Archiva')
|
||||||
|
OSSRH = credentials('OSSRH')
|
||||||
|
}
|
||||||
|
|
||||||
|
steps {
|
||||||
|
withGradle {
|
||||||
|
sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -Pskip-native=true -PCAVIS_CHIP=cuda \
|
||||||
|
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||||
|
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||||
|
}
|
||||||
|
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
post {
|
||||||
|
always {
|
||||||
|
junit '**/build/test-results/**/*.xml'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cpu') {
|
stage('build-linux-cpu') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal Archiva')
|
MAVEN = credentials('Internal_Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,167 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.ai.nd4j.tests;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class ExploreParamsTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testParam() {
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.seed(12345)
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.layer(
|
||||||
|
DenseLayer.builder().nIn(4).nOut(30).name("1. Dense").activation(Activation.TANH))
|
||||||
|
.layer(DenseLayer.builder().nIn(30).nOut(10).name("2. Dense"))
|
||||||
|
// .layer(FrozenLayer.builder(DenseLayer.builder().nOut(6).build()).build())
|
||||||
|
|
||||||
|
.layer(
|
||||||
|
OutputLayer.builder()
|
||||||
|
.nOut(3)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MSE)
|
||||||
|
.activation(Activation.SOFTMAX))
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork nn = new MultiLayerNetwork(conf);
|
||||||
|
nn.init();
|
||||||
|
log.info(nn.summary());
|
||||||
|
// INDArray input = Nd4j.rand(10,4);
|
||||||
|
INDArray labels = Nd4j.zeros(9, 3);
|
||||||
|
|
||||||
|
INDArray input =
|
||||||
|
Nd4j.create(
|
||||||
|
new double[][] {
|
||||||
|
{5.15, 3.5, 1.4, 0.21}, // setosa
|
||||||
|
{4.9, 3.2, 1.4, 0.2}, // setosa
|
||||||
|
{4.7, 3.2, 1.23, 0.2}, // setosa
|
||||||
|
{7, 3.25, 4.7, 1.41}, // versicolor
|
||||||
|
{6.4, 3.2, 4.54, 1.5}, // versicolor
|
||||||
|
{6.9, 3.1, 4.92, 1.5}, // versicolor
|
||||||
|
{7.7, 3, 6.1, 2.3}, // virginica
|
||||||
|
{6.3, 3.4, 5.6, 2.45}, // virginica
|
||||||
|
{6.4, 3.12, 5.5, 1.8} // virginica
|
||||||
|
});
|
||||||
|
|
||||||
|
labels.putScalar(0, 1);
|
||||||
|
labels.putScalar(3, 1);
|
||||||
|
labels.putScalar(6, 1);
|
||||||
|
labels.putScalar(10, 1);
|
||||||
|
labels.putScalar(13, 1);
|
||||||
|
labels.putScalar(16, 1);
|
||||||
|
labels.putScalar(20, 1);
|
||||||
|
labels.putScalar(23, 1);
|
||||||
|
labels.putScalar(26, 1);
|
||||||
|
|
||||||
|
IrisDataSetIterator iter = new IrisDataSetIterator();
|
||||||
|
//Iterable<Pair<INDArray, INDArray>> it = List.of(new Pair<INDArray, INDArray>(input, labels));
|
||||||
|
List l = new ArrayList<>();
|
||||||
|
for (int i=0; i< input.rows(); i++) {
|
||||||
|
l.add(new Pair(input.getRow(i), labels.getRow(i)));
|
||||||
|
}
|
||||||
|
Iterable<Pair<INDArray, INDArray>> it = l;
|
||||||
|
INDArrayDataSetIterator diter = new INDArrayDataSetIterator(it, 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < 100; i++) {
|
||||||
|
// nn.fit(input, labels);
|
||||||
|
// nn.fit( input, labels);
|
||||||
|
nn.fit(diter);
|
||||||
|
// nn.feedForward(input);
|
||||||
|
if(i%20==0) log.info("Score: {}", nn.getScore());
|
||||||
|
}
|
||||||
|
|
||||||
|
Evaluation eval = nn.evaluate(iter, List.of("setosa", "vericolor", "virginica"));
|
||||||
|
|
||||||
|
log.info("\n{}", eval.stats());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testParam2() throws IOException {
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.seed(12345)
|
||||||
|
.layer(
|
||||||
|
DenseLayer.builder().nIn(784).nOut(20).name("1. Dense"))
|
||||||
|
.layer(DenseLayer.builder().nIn(20).nOut(10).name("2. Dense"))
|
||||||
|
.layer(
|
||||||
|
OutputLayer.builder()
|
||||||
|
.nOut(10)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MSE)
|
||||||
|
.activation(Activation.SOFTMAX))
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork nn = new MultiLayerNetwork(conf);
|
||||||
|
nn.init();
|
||||||
|
log.info(nn.summary());
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf2 =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.seed(12345)
|
||||||
|
.layer(
|
||||||
|
DenseLayer.builder().nIn(784).nOut(20).name("1. Dense").dropOut(0.7))
|
||||||
|
.layer(DenseLayer.builder().nIn(20).nOut(10).name("2. Dense"))
|
||||||
|
.layer(
|
||||||
|
OutputLayer.builder()
|
||||||
|
.nOut(10)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MSE)
|
||||||
|
.activation(Activation.SOFTMAX))
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork nn2 = new MultiLayerNetwork(conf2);
|
||||||
|
nn2.init();
|
||||||
|
log.info(nn2.summary());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
MnistDataSetIterator iter = new MnistDataSetIterator(10, 500);
|
||||||
|
MnistDataSetIterator iter2 = new MnistDataSetIterator(10, 50);
|
||||||
|
|
||||||
|
|
||||||
|
for (int i = 0; i < 200; i++) {
|
||||||
|
nn.fit(iter);
|
||||||
|
nn2.fit(iter);
|
||||||
|
if(i%20==0) log.info("Score: {} vs. {}", nn.getScore(), nn2.getScore());
|
||||||
|
}
|
||||||
|
|
||||||
|
Evaluation eval = nn.evaluate(iter2);
|
||||||
|
Evaluation eval2 = nn2.evaluate(iter2);
|
||||||
|
|
||||||
|
log.info("\n{} \n{}", eval.stats(), eval2.stats());
|
||||||
|
}
|
||||||
|
}
|
|
@ -36,9 +36,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
public class LoadBackendTests {
|
public class LoadBackendTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadBackend() throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
|
public void loadBackend() throws NoSuchFieldException, IllegalAccessException {
|
||||||
// check if Nd4j is there
|
// check if Nd4j is there
|
||||||
//Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path"));
|
Logger.getLogger(LoadBackendTests.class.getName()).info("System java.library.path: " + System.getProperty("java.library.path"));
|
||||||
final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
|
final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
|
||||||
sysPathsField.setAccessible(true);
|
sysPathsField.setAccessible(true);
|
||||||
sysPathsField.set(null, null);
|
sysPathsField.set(null, null);
|
||||||
|
|
|
@ -1,110 +1,49 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package net.brutex.gan;
|
package net.brutex.gan;
|
||||||
|
|
||||||
import java.awt.BorderLayout;
|
import static net.brutex.ai.dnn.api.NN.dense;
|
||||||
import java.awt.Dimension;
|
|
||||||
import java.awt.GridLayout;
|
import java.awt.*;
|
||||||
import java.awt.Image;
|
|
||||||
import java.awt.image.BufferedImage;
|
import java.awt.image.BufferedImage;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Random;
|
import javax.swing.*;
|
||||||
import javax.swing.ImageIcon;
|
|
||||||
import javax.swing.JFrame;
|
|
||||||
import javax.swing.JLabel;
|
|
||||||
import javax.swing.JPanel;
|
|
||||||
import javax.swing.WindowConstants;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.datavec.image.loader.NativeImageLoader;
|
|
||||||
import org.datavec.image.recordreader.ImageRecordReader;
|
|
||||||
import org.datavec.image.transform.ColorConversionTransform;
|
|
||||||
import org.datavec.image.transform.ImageTransform;
|
|
||||||
import org.datavec.image.transform.PipelineImageTransform;
|
|
||||||
import org.datavec.image.transform.ResizeImageTransform;
|
|
||||||
import org.datavec.image.transform.ShowImageTransform;
|
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||||
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
|
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
|
||||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class App {
|
public class App {
|
||||||
private static final double LEARNING_RATE = 0.000002;
|
private static final double LEARNING_RATE = 0.002;
|
||||||
private static final double GRADIENT_THRESHOLD = 100.0;
|
private static final double GRADIENT_THRESHOLD = 100.0;
|
||||||
|
|
||||||
private static final int X_DIM = 20 ;
|
|
||||||
private static final int Y_DIM = 20;
|
|
||||||
private static final int CHANNELS = 1;
|
|
||||||
private static final int batchSize = 10;
|
|
||||||
private static final int INPUT = 128;
|
|
||||||
|
|
||||||
private static final int OUTPUT_PER_PANEL = 4;
|
|
||||||
|
|
||||||
private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
|
|
||||||
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
|
||||||
|
private static final int BATCHSIZE = 128;
|
||||||
private static JFrame frame;
|
private static JFrame frame;
|
||||||
private static JFrame frame2;
|
|
||||||
private static JPanel panel;
|
private static JPanel panel;
|
||||||
private static JPanel panel2;
|
|
||||||
|
|
||||||
private static LayerConfiguration[] genLayers() {
|
private static LayerConfiguration[] genLayers() {
|
||||||
return new LayerConfiguration[] {
|
return new LayerConfiguration[] {
|
||||||
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
|
dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
|
||||||
ActivationLayer.builder(Activation.LEAKYRELU).build(),
|
|
||||||
|
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
|
dense().nIn(256).nOut(512).build(),
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(512).nOut(1024).build(),
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build()
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(1024).nOut(784).activation(Activation.TANH).build()
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,65 +58,51 @@ public class App {
|
||||||
.updater(UPDATER)
|
.updater(UPDATER)
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
//.weightInit(WeightInit.XAVIER)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.activation(Activation.IDENTITY)
|
.activation(Activation.IDENTITY)
|
||||||
.layersFromArray(genLayers())
|
.layersFromArray(genLayers())
|
||||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
.name("generator")
|
||||||
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
|
|
||||||
.build();
|
.build();
|
||||||
((NeuralNetConfiguration) conf).init();
|
|
||||||
|
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static LayerConfiguration[] disLayers() {
|
private static LayerConfiguration[] disLayers() {
|
||||||
return new LayerConfiguration[]{
|
return new LayerConfiguration[]{
|
||||||
DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network
|
dense().nIn(784).nOut(1024).build(),
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
|
dense().nIn(1024).nOut(512).build(),
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
|
dense().nIn(512).nOut(256).build(),
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
|
||||||
|
|
||||||
OutputLayer.builder().name("dis-output").lossFunction(LossFunction.XENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
private static NeuralNetConfiguration discriminator() {
|
private static NeuralNetConfiguration discriminator() {
|
||||||
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
NeuralNetConfiguration conf =
|
|
||||||
NeuralNetConfiguration.builder()
|
|
||||||
.seed(42)
|
.seed(42)
|
||||||
.updater(UPDATER)
|
.updater(UPDATER)
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
|
||||||
.weightNoise(null)
|
|
||||||
// .weightInitFn(new WeightInitXavier())
|
|
||||||
// .activationFn(new ActivationIdentity())
|
|
||||||
.activation(Activation.IDENTITY)
|
.activation(Activation.IDENTITY)
|
||||||
.layersFromArray(disLayers())
|
.layersFromArray(disLayers())
|
||||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
.name("discriminator")
|
||||||
.build();
|
.build();
|
||||||
((NeuralNetConfiguration) conf).init();
|
|
||||||
|
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static NeuralNetConfiguration gan() {
|
private static NeuralNetConfiguration gan() {
|
||||||
LayerConfiguration[] genLayers = genLayers();
|
LayerConfiguration[] genLayers = genLayers();
|
||||||
LayerConfiguration[] disLayers = Arrays.stream(disLayers())
|
LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream()
|
||||||
.map((layer) -> {
|
.map((layer) -> {
|
||||||
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
||||||
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
|
return FrozenLayerWithBackprop.builder(layer).build();
|
||||||
} else {
|
} else {
|
||||||
return layer;
|
return layer;
|
||||||
}
|
}
|
||||||
|
@ -186,107 +111,57 @@ public class App {
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.seed(42)
|
.seed(42)
|
||||||
.updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() )
|
.updater(UPDATER)
|
||||||
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.gradientNormalizationThreshold( 100 )
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
//.weightInitFn( new WeightInitXavier() ) //this is internal
|
.weightInit(WeightInit.XAVIER)
|
||||||
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
.activation(Activation.IDENTITY)
|
||||||
.weightInit( WeightInit.XAVIER)
|
.layersFromArray(layers)
|
||||||
//.activationFn( new ActivationIdentity()) //this is internal
|
.name("GAN")
|
||||||
.activation( Activation.IDENTITY )
|
|
||||||
.layersFromArray( layers )
|
|
||||||
.inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
|
||||||
.build();
|
.build();
|
||||||
((NeuralNetConfiguration) conf).init();
|
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test @Tag("long-running")
|
||||||
@Test
|
|
||||||
public void runTest() throws Exception {
|
public void runTest() throws Exception {
|
||||||
main();
|
App.main(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String... args) throws Exception {
|
public static void main(String... args) throws Exception {
|
||||||
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m ");
|
MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
|
||||||
Nd4j.getMemoryManager().setAutoGcWindow(500);
|
|
||||||
|
|
||||||
// MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
|
|
||||||
// FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
|
|
||||||
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
|
|
||||||
|
|
||||||
|
|
||||||
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
|
||||||
|
|
||||||
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
|
||||||
ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
|
|
||||||
|
|
||||||
ImageTransform tr = new PipelineImageTransform.Builder()
|
|
||||||
.addImageTransform(transform) //convert to GREY SCALE
|
|
||||||
.addImageTransform(transform3)
|
|
||||||
//.addImageTransform(transform2)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
|
|
||||||
imageRecordReader.initialize(fileSplit, tr);
|
|
||||||
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
|
|
||||||
|
|
||||||
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
|
MultiLayerNetwork gen = new MultiLayerNetwork(generator());
|
||||||
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
|
MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
|
||||||
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
|
MultiLayerNetwork gan = new MultiLayerNetwork(gan());
|
||||||
gen.init(); log.debug("Generator network: {}", gen);
|
gen.init();
|
||||||
dis.init(); log.debug("Discriminator network: {}", dis);
|
dis.init();
|
||||||
gan.init(); log.debug("Complete GAN network: {}", gan);
|
gan.init();
|
||||||
|
|
||||||
|
|
||||||
copyParams(gen, dis, gan);
|
copyParams(gen, dis, gan);
|
||||||
|
|
||||||
gen.addTrainingListeners(new PerformanceListener(15, true));
|
gen.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
//dis.addTrainingListeners(new PerformanceListener(10, true));
|
dis.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
//gan.addTrainingListeners(new PerformanceListener(10, true));
|
gan.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
//gan.addTrainingListeners(new ScoreToChartListener("gan"));
|
|
||||||
//dis.setListeners(new ScoreToChartListener("dis"));
|
|
||||||
|
|
||||||
System.out.println(gan.toString());
|
trainData.reset();
|
||||||
gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
|
|
||||||
|
|
||||||
//gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
|
|
||||||
//trainData.reset();
|
|
||||||
|
|
||||||
int j = 0;
|
int j = 0;
|
||||||
for (int i = 0; i < 201; i++) { //epoch
|
for (int i = 0; i < 50; i++) {
|
||||||
while (trainData.hasNext()) {
|
while (trainData.hasNext()) {
|
||||||
j++;
|
j++;
|
||||||
|
|
||||||
DataSet next = trainData.next();
|
|
||||||
// generate data
|
// generate data
|
||||||
INDArray real = next.getFeatures();//.div(255f);
|
INDArray real = trainData.next().getFeatures().muli(2).subi(1);
|
||||||
|
int batchSize = (int) real.shape()[0];
|
||||||
|
|
||||||
//start next round if there are not enough images left to have a full batchsize dataset
|
INDArray fakeIn = Nd4j.rand(batchSize, 100);
|
||||||
if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
|
|
||||||
log.warn("Your total number of input images is not a multiple of {}, "
|
|
||||||
+ "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(i%20 == 0) {
|
|
||||||
// frame2 = visualize(new INDArray[]{real}, batchSize,
|
|
||||||
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
|
|
||||||
}
|
|
||||||
real.divi(255f);
|
|
||||||
|
|
||||||
// int batchSize = (int) real.shape()[0];
|
|
||||||
|
|
||||||
INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
|
|
||||||
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||||
fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
|
|
||||||
|
|
||||||
//log.info("real has {} items.", real.length());
|
|
||||||
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
|
DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
|
||||||
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
|
DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
|
||||||
|
|
||||||
|
|
||||||
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||||
|
|
||||||
dis.fit(data);
|
dis.fit(data);
|
||||||
|
@ -295,32 +170,26 @@ public class App {
|
||||||
// Update the discriminator in the GAN network
|
// Update the discriminator in the GAN network
|
||||||
updateGan(gen, dis, gan);
|
updateGan(gen, dis, gan);
|
||||||
|
|
||||||
//gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
|
gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
|
||||||
gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)));
|
|
||||||
|
|
||||||
|
|
||||||
if (j % 10 == 1) {
|
if (j % 10 == 1) {
|
||||||
System.out.println("Iteration " + j + " Visualizing...");
|
System.out.println("Epoch " + i +" Iteration " + j + " Visualizing...");
|
||||||
INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
|
INDArray[] samples = new INDArray[9];
|
||||||
|
|
||||||
|
|
||||||
for (int k = 0; k < samples.length; k++) {
|
|
||||||
//INDArray input = fakeSet2.get(k).getFeatures();
|
|
||||||
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
|
DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
|
||||||
INDArray input = fakeSet2.get(k).getFeatures();
|
|
||||||
input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
|
|
||||||
|
|
||||||
|
for (int k = 0; k < 9; k++) {
|
||||||
|
INDArray input = fakeSet2.get(k).getFeatures();
|
||||||
//samples[k] = gen.output(input, false);
|
//samples[k] = gen.output(input, false);
|
||||||
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
||||||
samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
|
|
||||||
//samples[k] =
|
|
||||||
samples[k].addi(1f).divi(2f).muli(255f);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
|
visualize(samples);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
trainData.reset();
|
trainData.reset();
|
||||||
|
// Copy the GANs generator to gen.
|
||||||
|
//updateGen(gen, gan);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the GANs generator to gen.
|
// Copy the GANs generator to gen.
|
||||||
|
@ -333,10 +202,8 @@ public class App {
|
||||||
int genLayerCount = gen.getLayers().length;
|
int genLayerCount = gen.getLayers().length;
|
||||||
for (int i = 0; i < gan.getLayers().length; i++) {
|
for (int i = 0; i < gan.getLayers().length; i++) {
|
||||||
if (i < genLayerCount) {
|
if (i < genLayerCount) {
|
||||||
if(gan.getLayer(i).getParams() != null)
|
|
||||||
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
} else {
|
} else {
|
||||||
if(gan.getLayer(i).getParams() != null)
|
|
||||||
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
|
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -355,57 +222,41 @@ public class App {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
|
private static void visualize(INDArray[] samples) {
|
||||||
if (isOrig) {
|
if (frame == null) {
|
||||||
frame.setTitle("Viz Original");
|
frame = new JFrame();
|
||||||
} else {
|
frame.setTitle("Viz");
|
||||||
frame.setTitle("Generated");
|
|
||||||
}
|
|
||||||
|
|
||||||
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
frame.setLayout(new BorderLayout());
|
frame.setLayout(new BorderLayout());
|
||||||
|
|
||||||
JPanel panelx = new JPanel();
|
panel = new JPanel();
|
||||||
|
|
||||||
panelx.setLayout(new GridLayout(4, 4, 8, 8));
|
panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
|
||||||
for (INDArray sample : samples) {
|
frame.add(panel, BorderLayout.CENTER);
|
||||||
for(int i = 0; i<batchElements; i++) {
|
|
||||||
panelx.add(getImage(sample, i, isOrig));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
frame.add(panelx, BorderLayout.CENTER);
|
|
||||||
frame.setVisible(true);
|
frame.setVisible(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
panel.removeAll();
|
||||||
|
|
||||||
|
for (INDArray sample : samples) {
|
||||||
|
panel.add(getImage(sample));
|
||||||
|
}
|
||||||
|
|
||||||
frame.revalidate();
|
frame.revalidate();
|
||||||
frame.setMinimumSize(new Dimension(300, 20));
|
|
||||||
frame.pack();
|
frame.pack();
|
||||||
return frame;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
private static JLabel getImage(INDArray tensor) {
|
||||||
final BufferedImage bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY);
|
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
|
||||||
final int imageSize = X_DIM * Y_DIM;
|
for (int i = 0; i < 784; i++) {
|
||||||
final int offset = batchElement * imageSize;
|
int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
|
||||||
int pxl = offset * CHANNELS; //where to start in the INDArray
|
bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
|
||||||
|
|
||||||
//Image in NCHW - channels first format
|
|
||||||
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
|
||||||
for (int y = 0; y < Y_DIM; y++) { // step through the columns x
|
|
||||||
for (int x = 0; x < X_DIM; x++) { //step through the rows y
|
|
||||||
if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
|
|
||||||
bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
|
|
||||||
pxl++; //next item in INDArray
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ImageIcon orig = new ImageIcon(bi);
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
|
||||||
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
|
|
||||||
|
|
||||||
ImageIcon scaled = new ImageIcon(imageScaled);
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
|
||||||
return new JLabel(scaled);
|
return new JLabel(scaled);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,371 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import java.awt.*;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.List;
|
||||||
|
import javax.imageio.ImageIO;
|
||||||
|
import javax.swing.*;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.datavec.api.split.FileSplit;
|
||||||
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
|
import org.datavec.image.transform.*;
|
||||||
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
|
import org.junit.jupiter.api.Tag;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static net.brutex.gan.App2Config.BATCHSIZE;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class App2 {
|
||||||
|
|
||||||
|
final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
|
||||||
|
|
||||||
|
static final int DIMENSIONS = 28;
|
||||||
|
static final int CHANNELS = 1;
|
||||||
|
final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
|
||||||
|
|
||||||
|
|
||||||
|
final boolean BIAS = true;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
private JFrame frame2, frame;
|
||||||
|
static final String OUTPUT_DIR = "d:/out/";
|
||||||
|
|
||||||
|
final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
|
||||||
|
final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
|
||||||
|
|
||||||
|
@Test @Tag("long-running")
|
||||||
|
void runTest() throws IOException {
|
||||||
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
|
MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
|
||||||
|
FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans3"), NativeImageLoader.getALLOWED_FORMATS());
|
||||||
|
ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
|
||||||
|
ImageTransform transform2 = new ShowImageTransform("Tester", 30);
|
||||||
|
ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
|
||||||
|
|
||||||
|
ImageTransform tr = new PipelineImageTransform.Builder()
|
||||||
|
.addImageTransform(transform) //convert to GREY SCALE
|
||||||
|
.addImageTransform(transform3)
|
||||||
|
//.addImageTransform(transform2)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS);
|
||||||
|
imageRecordReader.initialize(fileSplit, tr);
|
||||||
|
DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE );
|
||||||
|
trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
|
||||||
|
|
||||||
|
MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator());
|
||||||
|
MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator());
|
||||||
|
|
||||||
|
LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream()
|
||||||
|
.map((layer) -> {
|
||||||
|
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
||||||
|
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
|
||||||
|
} else {
|
||||||
|
return layer;
|
||||||
|
}
|
||||||
|
}).toArray(LayerConfiguration[]::new);
|
||||||
|
|
||||||
|
NeuralNetConfiguration netConfiguration =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.name("GAN")
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.updater(App2Config.UPDATER)
|
||||||
|
.innerConfigurations(new ArrayList<>(List.of(App2Config.generator())))
|
||||||
|
.layersFromList(new ArrayList<>(Arrays.asList(disLayers)))
|
||||||
|
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
// .inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||||
|
// .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.dataType(DataType.FLOAT)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration );
|
||||||
|
|
||||||
|
dis.init(); log.debug("Discriminator network: {}", dis);
|
||||||
|
gen.init(); log.debug("Generator network: {}", gen);
|
||||||
|
gan.init(); log.debug("GAN network: {}", gan);
|
||||||
|
|
||||||
|
|
||||||
|
log.info("Generator Summary:\n{}", gen.summary());
|
||||||
|
log.info("GAN Summary:\n{}", gan.summary());
|
||||||
|
dis.addTrainingListeners(new PerformanceListener(3, true, "DIS"));
|
||||||
|
//gen.addTrainingListeners(new PerformanceListener(3, true, "GEN")); //is never trained separately from GAN
|
||||||
|
gan.addTrainingListeners(new PerformanceListener(3, true, "GAN"));
|
||||||
|
/*
|
||||||
|
Thread vt =
|
||||||
|
new Thread(
|
||||||
|
new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
while (true) {
|
||||||
|
visualize(0, 0, gen);
|
||||||
|
try {
|
||||||
|
Thread.sleep(10000);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
vt.start();
|
||||||
|
*/
|
||||||
|
|
||||||
|
App2Display display = new App2Display();
|
||||||
|
//Repack training data with new fake/real label. Original MNist has 10 labels, one for each digit
|
||||||
|
DataSet data = null;
|
||||||
|
int j =0;
|
||||||
|
for(int i=0;i<App2Config.EPOCHS;i++) {
|
||||||
|
log.info("Epoch {}", i);
|
||||||
|
data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
|
||||||
|
while (trainData.hasNext()) {
|
||||||
|
j++;
|
||||||
|
INDArray real = trainData.next().getFeatures();
|
||||||
|
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
|
||||||
|
|
||||||
|
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1,
|
||||||
|
Nd4j.rand(BATCHSIZE, App2Config.INPUT));
|
||||||
|
//sigmoid output is -1 to 1
|
||||||
|
fake.addi(1f).divi(2f);
|
||||||
|
|
||||||
|
if (j % 50 == 1) {
|
||||||
|
display.visualize(new INDArray[] {fake}, App2Config.OUTPUT_PER_PANEL, false);
|
||||||
|
display.visualize(new INDArray[] {real}, App2Config.OUTPUT_PER_PANEL, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DataSet realSet = new DataSet(real, label_real);
|
||||||
|
DataSet fakeSet = new DataSet(fake, label_fake);
|
||||||
|
|
||||||
|
//start next round if there are not enough images left to have a full batchsize dataset
|
||||||
|
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
|
||||||
|
log.warn("Your total number of input images is not a multiple of {}, "
|
||||||
|
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
//if(real.length()/BATCHSIZE!=784) break;
|
||||||
|
data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
|
||||||
|
|
||||||
|
}
|
||||||
|
//fit the discriminator
|
||||||
|
dis.fit(data);
|
||||||
|
dis.fit(data);
|
||||||
|
// Update the discriminator in the GAN network
|
||||||
|
updateGan(gen, dis, gan);
|
||||||
|
|
||||||
|
//reset the training data and fit the complete GAN
|
||||||
|
if (trainData.resetSupported()) {
|
||||||
|
trainData.reset();
|
||||||
|
} else {
|
||||||
|
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
||||||
|
}
|
||||||
|
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_real));
|
||||||
|
|
||||||
|
if (trainData.resetSupported()) {
|
||||||
|
trainData.reset();
|
||||||
|
} else {
|
||||||
|
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("Updated GAN's generator from gen.");
|
||||||
|
updateGen(gen, gan);
|
||||||
|
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||||
|
}
|
||||||
|
//vt.stop();
|
||||||
|
|
||||||
|
/*
|
||||||
|
int j;
|
||||||
|
for (int i = 0; i < App2Config.EPOCHS; i++) { //epoch
|
||||||
|
j=0;
|
||||||
|
while (trainData.hasNext()) {
|
||||||
|
j++;
|
||||||
|
DataSet next = trainData.next();
|
||||||
|
// generate data
|
||||||
|
INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f);
|
||||||
|
|
||||||
|
//start next round if there are not enough images left to have a full batchsize dataset
|
||||||
|
if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
|
||||||
|
log.warn("Your total number of input images is not a multiple of {}, "
|
||||||
|
+ "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
//if(i%20 == 0) {
|
||||||
|
|
||||||
|
// frame2 = visualize(new INDArray[]{real}, BATCHSIZE,
|
||||||
|
// frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
|
||||||
|
//}
|
||||||
|
//real.divi(255f);
|
||||||
|
|
||||||
|
// int batchSize = (int) real.shape()[0];
|
||||||
|
|
||||||
|
//INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS);
|
||||||
|
//INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
|
||||||
|
INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
|
||||||
|
|
||||||
|
INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
|
||||||
|
// when generator has TANH as activation - value range is -1 to 1
|
||||||
|
// when generator has SIGMOID, then range is 0 to 1
|
||||||
|
fake.addi(1f).divi(2f);
|
||||||
|
|
||||||
|
DataSet realSet = new DataSet(real, label_real);
|
||||||
|
DataSet fakeSet = new DataSet(fake, label_fake);
|
||||||
|
|
||||||
|
DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
|
||||||
|
|
||||||
|
dis.fit(data);
|
||||||
|
dis.fit(data);
|
||||||
|
// Update the discriminator in the GAN network
|
||||||
|
updateGan(gen, dis, gan);
|
||||||
|
|
||||||
|
gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake));
|
||||||
|
|
||||||
|
//Visualize and reporting
|
||||||
|
if (j % 10 == 1) {
|
||||||
|
System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
|
||||||
|
INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE];
|
||||||
|
|
||||||
|
|
||||||
|
for (int k = 0; k < samples.length; k++) {
|
||||||
|
DataSet fakeSet2 = new DataSet(fakeIn, label_fake);
|
||||||
|
INDArray input = fakeSet2.get(k).getFeatures();
|
||||||
|
|
||||||
|
//input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images
|
||||||
|
input = input.reshape(1, App2Config.INPUT);
|
||||||
|
|
||||||
|
//samples[k] = gen.output(input, false);
|
||||||
|
samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input);
|
||||||
|
samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS);
|
||||||
|
//samples[k] =
|
||||||
|
//samples[k].muli(255f);
|
||||||
|
|
||||||
|
}
|
||||||
|
frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (trainData.resetSupported()) {
|
||||||
|
trainData.reset();
|
||||||
|
} else {
|
||||||
|
log.error("Trainingdata {} does not support reset.", trainData.toString());
|
||||||
|
}
|
||||||
|
// Copy the GANs generator to gen.
|
||||||
|
updateGen(gen, gan);
|
||||||
|
log.info("Updated GAN's generator from gen.");
|
||||||
|
gen.save(new File("mnist-mlp-generator.dlj"));
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||||
|
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||||
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||||
|
int genLayerCount = gen.getLayers().length;
|
||||||
|
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||||
|
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testDiskriminator() throws IOException {
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(App2Config.discriminator());
|
||||||
|
net.init();
|
||||||
|
net.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
|
||||||
|
DataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
|
||||||
|
|
||||||
|
DataSet data = null;
|
||||||
|
for(int i=0;i<App2Config.EPOCHS;i++) {
|
||||||
|
log.info("Epoch {}", i);
|
||||||
|
data = new DataSet(Nd4j.rand(BATCHSIZE, 784), label_fake);
|
||||||
|
while (trainData.hasNext()) {
|
||||||
|
INDArray real = trainData.next().getFeatures();
|
||||||
|
long[] l = new long[]{BATCHSIZE, real.length() / BATCHSIZE};
|
||||||
|
INDArray fake = Nd4j.rand(l );
|
||||||
|
|
||||||
|
DataSet realSet = new DataSet(real, label_real);
|
||||||
|
DataSet fakeSet = new DataSet(fake, label_fake);
|
||||||
|
if(real.length()/BATCHSIZE!=784) break;
|
||||||
|
data = DataSet.merge(Arrays.asList(data, realSet, fakeSet));
|
||||||
|
|
||||||
|
}
|
||||||
|
net.fit(data);
|
||||||
|
trainData.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
long[] l = new long[]{BATCHSIZE, 784};
|
||||||
|
INDArray fake = Nd4j.rand(l );
|
||||||
|
DataSet fakeSet = new DataSet(fake, label_fake);
|
||||||
|
data = DataSet.merge(Arrays.asList(data, fakeSet));
|
||||||
|
ExistingDataSetIterator iter = new ExistingDataSetIterator(data);
|
||||||
|
Evaluation eval = net.evaluate(iter);
|
||||||
|
log.info( "\n" + eval.confusionMatrix());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,183 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import static net.brutex.ai.dnn.api.NN.*;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
public class App2Config {
|
||||||
|
|
||||||
|
public static final int INPUT = 100;
|
||||||
|
public static final int BATCHSIZE=150;
|
||||||
|
public static final int X_DIM = 28;
|
||||||
|
public static final int Y_DIM = 28;
|
||||||
|
public static final int CHANNELS = 1;
|
||||||
|
public static final int EPOCHS = 50;
|
||||||
|
public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
|
||||||
|
public static final IUpdater UPDATER_DIS = Adam.builder().learningRate(0.02).beta1(0.5).build();
|
||||||
|
public static final boolean SHOW_GENERATED = true;
|
||||||
|
public static final float COLORSPACE = 255f;
|
||||||
|
|
||||||
|
final static int OUTPUT_PER_PANEL = 10;
|
||||||
|
|
||||||
|
static LayerConfiguration[] genLayerConfig() {
|
||||||
|
return new LayerConfiguration[] {
|
||||||
|
/*
|
||||||
|
DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(), /*
|
||||||
|
Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS)
|
||||||
|
.kernelSize(2,2)
|
||||||
|
.stride(1,1)
|
||||||
|
.padding(0,0)
|
||||||
|
.convolutionMode(ConvolutionMode.Truncate)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.hasBias(BIAS).build(),
|
||||||
|
//BatchNormalization.builder().nOut(CHANNELS).build(),
|
||||||
|
Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS)
|
||||||
|
.kernelSize(2,2)
|
||||||
|
.stride(2,2)
|
||||||
|
.padding(0,0)
|
||||||
|
.convolutionMode(ConvolutionMode.Truncate)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.hasBias(BIAS).build(),
|
||||||
|
//BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(),
|
||||||
|
|
||||||
|
|
||||||
|
DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||||
|
DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||||
|
DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
|
||||||
|
// DropoutLayer.builder(0.001).build(),
|
||||||
|
DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build() */
|
||||||
|
|
||||||
|
dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(256).nOut(512).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(512).nOut(1024).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
dense().nIn(1024).nOut(784).activation(Activation.TANH).build(),
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static LayerConfiguration[] disLayerConfig() {
|
||||||
|
return new LayerConfiguration[] {/*
|
||||||
|
Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS)
|
||||||
|
.build(),
|
||||||
|
Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS)
|
||||||
|
.build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.LEAKYRELU).build(),
|
||||||
|
BatchNormalization.builder().build(),
|
||||||
|
OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SIGMOID)
|
||||||
|
.build()
|
||||||
|
|
||||||
|
|
||||||
|
dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||||
|
DropoutLayer.builder(0.5).build(),
|
||||||
|
|
||||||
|
DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||||
|
DropoutLayer.builder(0.5).build(),
|
||||||
|
|
||||||
|
DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(),
|
||||||
|
ActivationLayer.builder().activation(Activation.RELU).build(),
|
||||||
|
DropoutLayer.builder(0.5).build(),
|
||||||
|
|
||||||
|
OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT)
|
||||||
|
.activation(Activation.SIGMOID)
|
||||||
|
.build() */
|
||||||
|
dense().nIn(784).nOut(1024).hasBias(true).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
dense().nIn(1024).nOut(512).hasBias(true).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
dense().nIn(512).nOut(256).hasBias(true).build(),
|
||||||
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
|
OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static NeuralNetConfiguration generator() {
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.name("generator")
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||||
|
.weightNoise(null)
|
||||||
|
// .weightInitFn(new WeightInitXavier())
|
||||||
|
// .activationFn(new ActivationIdentity())
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(App2Config.genLayerConfig())
|
||||||
|
// .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
|
||||||
|
|
||||||
|
.build();
|
||||||
|
conf.init();
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
static NeuralNetConfiguration discriminator() {
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.name("discriminator")
|
||||||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
.gradientNormalizationThreshold(100)
|
||||||
|
.seed(42)
|
||||||
|
.updater(UPDATER_DIS)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
// .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||||
|
.weightNoise(null)
|
||||||
|
// .weightInitFn(new WeightInitXavier())
|
||||||
|
// .activationFn(new ActivationIdentity())
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(disLayerConfig())
|
||||||
|
//.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
|
||||||
|
//.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
|
||||||
|
//.dataType(DataType.FLOAT)
|
||||||
|
.build();
|
||||||
|
conf.init();
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,160 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import com.google.inject.Singleton;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import javax.imageio.ImageIO;
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
|
import java.awt.color.ColorSpace;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
import static net.brutex.gan.App2.OUTPUT_DIR;
|
||||||
|
import static net.brutex.gan.App2Config.*;
|
||||||
|
@Slf4j
|
||||||
|
@Singleton
|
||||||
|
public class App2Display {
|
||||||
|
|
||||||
|
private final JFrame frame = new JFrame();
|
||||||
|
private final App2GUI display = new App2GUI();
|
||||||
|
|
||||||
|
private final JPanel real_panel;
|
||||||
|
private final JPanel fake_panel;
|
||||||
|
|
||||||
|
|
||||||
|
public App2Display() {
|
||||||
|
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
|
frame.setContentPane(display.getOverall_panel());
|
||||||
|
frame.setMinimumSize(new Dimension(300, 20));
|
||||||
|
frame.pack();
|
||||||
|
frame.setVisible(true);
|
||||||
|
real_panel = display.getReal_panel();
|
||||||
|
fake_panel = display.getGen_panel();
|
||||||
|
real_panel.setLayout(new GridLayout(4, 4, 8, 8));
|
||||||
|
fake_panel.setLayout(new GridLayout(4, 4, 8, 8));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void visualize(INDArray[] samples, int batchElements, boolean isOrig) {
|
||||||
|
for (INDArray sample : samples) {
|
||||||
|
for(int i = 0; i<batchElements; i++) {
|
||||||
|
final Image img = this.getImage(sample, i, isOrig);
|
||||||
|
final ImageIcon icon = new ImageIcon(img);
|
||||||
|
if(isOrig) {
|
||||||
|
if(real_panel.getComponents().length>=OUTPUT_PER_PANEL) {
|
||||||
|
real_panel.remove(0);
|
||||||
|
}
|
||||||
|
real_panel.add(new JLabel(icon));
|
||||||
|
} else {
|
||||||
|
if(fake_panel.getComponents().length>=OUTPUT_PER_PANEL) {
|
||||||
|
fake_panel.remove(0);
|
||||||
|
}
|
||||||
|
fake_panel.add(new JLabel(icon));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
frame.pack();
|
||||||
|
frame.repaint();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Image getImage(INDArray tensor, int batchElement, boolean isOrig) {
|
||||||
|
final BufferedImage bi;
|
||||||
|
if(CHANNELS >1) {
|
||||||
|
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
|
||||||
|
} else {
|
||||||
|
bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
|
||||||
|
}
|
||||||
|
final int imageSize = X_DIM * Y_DIM;
|
||||||
|
final int offset = batchElement * imageSize;
|
||||||
|
int pxl = offset * CHANNELS; //where to start in the INDArray
|
||||||
|
|
||||||
|
//Image in NCHW - channels first format
|
||||||
|
for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
|
||||||
|
for (int y = 0; y < X_DIM; y++) { // step through the columns x
|
||||||
|
for (int x = 0; x < Y_DIM; x++) { //step through the rows y
|
||||||
|
float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
|
||||||
|
if(isOrig) log.trace("'{}.'{} Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, isOrig ? "Real" : "Fake", x, y, c, pxl, f_pxl);
|
||||||
|
bi.getRaster().setSample(x, y, c, f_pxl);
|
||||||
|
pxl++; //next item in INDArray
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ImageIcon orig = new ImageIcon(bi);
|
||||||
|
Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
|
||||||
|
ImageIcon scaled = new ImageIcon(imageScaled);
|
||||||
|
//if(! isOrig) saveImage(imageScaled, batchElement, isOrig);
|
||||||
|
return imageScaled;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void saveImage(Image image, int batchElement, boolean isOrig) {
|
||||||
|
String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Save the images to disk
|
||||||
|
saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
|
||||||
|
|
||||||
|
log.debug("Images saved successfully.");
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("Error saving the images: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
|
||||||
|
File directory = new File(outputDirectory);
|
||||||
|
if (!directory.exists()) {
|
||||||
|
directory.mkdir();
|
||||||
|
}
|
||||||
|
|
||||||
|
File outputFile = new File(directory, fileName);
|
||||||
|
ImageIO.write(imageToBufferedImage(image), "png", outputFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static BufferedImage imageToBufferedImage(Image image) {
|
||||||
|
if (image instanceof BufferedImage) {
|
||||||
|
return (BufferedImage) image;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a buffered image with the same dimensions and transparency as the original image
|
||||||
|
BufferedImage bufferedImage;
|
||||||
|
if (CHANNELS > 1) {
|
||||||
|
bufferedImage =
|
||||||
|
new BufferedImage(
|
||||||
|
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
|
||||||
|
} else {
|
||||||
|
bufferedImage =
|
||||||
|
new BufferedImage(
|
||||||
|
image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw the original image onto the buffered image
|
||||||
|
Graphics2D g2d = bufferedImage.createGraphics();
|
||||||
|
g2d.drawImage(image, 0, 0, null);
|
||||||
|
g2d.dispose();
|
||||||
|
|
||||||
|
return bufferedImage;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
package net.brutex.gan;
|
||||||
|
|
||||||
|
import javax.swing.JPanel;
|
||||||
|
import javax.swing.JSplitPane;
|
||||||
|
import javax.swing.JLabel;
|
||||||
|
import java.awt.BorderLayout;
|
||||||
|
|
||||||
|
public class App2GUI extends JPanel {
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
private JPanel overall_panel;
|
||||||
|
private JPanel real_panel;
|
||||||
|
private JPanel gen_panel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create the panel.
|
||||||
|
*/
|
||||||
|
public App2GUI() {
|
||||||
|
|
||||||
|
overall_panel = new JPanel();
|
||||||
|
add(overall_panel);
|
||||||
|
|
||||||
|
JSplitPane splitPane = new JSplitPane();
|
||||||
|
overall_panel.add(splitPane);
|
||||||
|
|
||||||
|
JPanel p1 = new JPanel();
|
||||||
|
splitPane.setLeftComponent(p1);
|
||||||
|
p1.setLayout(new BorderLayout(0, 0));
|
||||||
|
|
||||||
|
JLabel lblNewLabel = new JLabel("Generator");
|
||||||
|
p1.add(lblNewLabel, BorderLayout.NORTH);
|
||||||
|
|
||||||
|
gen_panel = new JPanel();
|
||||||
|
p1.add(gen_panel, BorderLayout.SOUTH);
|
||||||
|
|
||||||
|
JPanel p2 = new JPanel();
|
||||||
|
splitPane.setRightComponent(p2);
|
||||||
|
p2.setLayout(new BorderLayout(0, 0));
|
||||||
|
|
||||||
|
JLabel lblNewLabel_1 = new JLabel("Real");
|
||||||
|
p2.add(lblNewLabel_1, BorderLayout.NORTH);
|
||||||
|
|
||||||
|
real_panel = new JPanel();
|
||||||
|
p2.add(real_panel, BorderLayout.SOUTH);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public JPanel getOverall_panel() {
|
||||||
|
return overall_panel;
|
||||||
|
}
|
||||||
|
public JPanel getReal_panel() {
|
||||||
|
return real_panel;
|
||||||
|
}
|
||||||
|
public JPanel getGen_panel() {
|
||||||
|
return gen_panel;
|
||||||
|
}
|
||||||
|
}
|
|
@ -24,12 +24,15 @@ package net.brutex.gan;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.junit.jupiter.api.Tag;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -98,7 +101,10 @@ public class MnistSimpleGAN {
|
||||||
|
|
||||||
return new MultiLayerNetwork(discConf);
|
return new MultiLayerNetwork(discConf);
|
||||||
}
|
}
|
||||||
|
@Test @Tag("long-running")
|
||||||
|
public void runTest() throws Exception {
|
||||||
|
main(null);
|
||||||
|
}
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
GAN gan = new GAN.Builder()
|
GAN gan = new GAN.Builder()
|
||||||
.generator(MnistSimpleGAN::getGenerator)
|
.generator(MnistSimpleGAN::getGenerator)
|
||||||
|
@ -108,6 +114,7 @@ public class MnistSimpleGAN {
|
||||||
.updater(UPDATER)
|
.updater(UPDATER)
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.gradientNormalizationThreshold(100)
|
.gradientNormalizationThreshold(100)
|
||||||
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
# Default logging detail level for all instances of SimpleLogger.
|
# Default logging detail level for all instances of SimpleLogger.
|
||||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
||||||
# If not specified, defaults to "info".
|
# If not specified, defaults to "info".
|
||||||
org.slf4j.simpleLogger.defaultLogLevel=trace
|
org.slf4j.simpleLogger.defaultLogLevel=debug
|
||||||
|
|
||||||
# Logging detail level for a SimpleLogger instance named "xxxxx".
|
# Logging detail level for a SimpleLogger instance named "xxxxx".
|
||||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
# Must be one of ("trace", "debug", "info", "warn", or "error").
|
||||||
|
@ -42,8 +42,8 @@ org.slf4j.simpleLogger.defaultLogLevel=trace
|
||||||
# If the format is not specified or is invalid, the default format is used.
|
# If the format is not specified or is invalid, the default format is used.
|
||||||
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
|
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
|
||||||
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
|
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
|
||||||
org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
|
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
|
||||||
|
|
||||||
# Set to true if you want to output the current thread name.
|
# Set to true if you want to output the current thread name.
|
||||||
# Defaults to true.
|
# Defaults to true.
|
||||||
org.slf4j.simpleLogger.showThreadName=true
|
#org.slf4j.simpleLogger.showThreadName=true
|
|
@ -88,7 +88,7 @@ public class CNN1DTestCases {
|
||||||
.convolutionMode(ConvolutionMode.Same))
|
.convolutionMode(ConvolutionMode.Same))
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("0", Convolution1DLayer.builder().nOut(32).activation(Activation.TANH).kernelSize(3).stride(1).build(), "in")
|
.layer("0", Convolution1D.builder().nOut(32).activation(Activation.TANH).kernelSize(3).stride(1).build(), "in")
|
||||||
.layer("1", Subsampling1DLayer.builder().kernelSize(2).stride(1).poolingType(SubsamplingLayer.PoolingType.MAX.toPoolingType()).build(), "0")
|
.layer("1", Subsampling1DLayer.builder().kernelSize(2).stride(1).poolingType(SubsamplingLayer.PoolingType.MAX.toPoolingType()).build(), "0")
|
||||||
.layer("2", Cropping1D.builder(1).build(), "1")
|
.layer("2", Cropping1D.builder(1).build(), "1")
|
||||||
.layer("3", ZeroPadding1DLayer.builder(1).build(), "2")
|
.layer("3", ZeroPadding1DLayer.builder(1).build(), "2")
|
||||||
|
|
|
@ -20,7 +20,7 @@ ext {
|
||||||
|
|
||||||
def javacv = [version:"1.5.7"]
|
def javacv = [version:"1.5.7"]
|
||||||
def opencv = [version: "4.5.5"]
|
def opencv = [version: "4.5.5"]
|
||||||
def leptonica = [version: "1.82.0"]
|
def leptonica = [version: "1.83.0"] //fix, only in javacpp 1.5.9
|
||||||
def junit = [version: "5.9.1"]
|
def junit = [version: "5.9.1"]
|
||||||
|
|
||||||
def flatbuffers = [version: "1.10.0"]
|
def flatbuffers = [version: "1.10.0"]
|
||||||
|
@ -71,7 +71,7 @@ dependencies {
|
||||||
// api "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}"
|
// api "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}"
|
||||||
|
|
||||||
|
|
||||||
api "org.projectlombok:lombok:1.18.26"
|
api "org.projectlombok:lombok:1.18.28"
|
||||||
|
|
||||||
/*Logging*/
|
/*Logging*/
|
||||||
api 'org.slf4j:slf4j-api:2.0.3'
|
api 'org.slf4j:slf4j-api:2.0.3'
|
||||||
|
@ -118,7 +118,8 @@ dependencies {
|
||||||
api "org.bytedeco:javacv:${javacv.version}"
|
api "org.bytedeco:javacv:${javacv.version}"
|
||||||
api "org.bytedeco:opencv:${opencv.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:opencv:${opencv.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:openblas:${openblas.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:openblas:${openblas.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:leptonica-platform:${leptonica.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:leptonica-platform:${leptonica.version}-1.5.9"
|
||||||
|
api "org.bytedeco:leptonica:${leptonica.version}-1.5.9"
|
||||||
api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
||||||
|
@ -129,6 +130,7 @@ dependencies {
|
||||||
api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:mkl-dnn:0.21.5-${javacpp.presetsVersion}"
|
api "org.bytedeco:mkl-dnn:0.21.5-${javacpp.presetsVersion}"
|
||||||
|
api "org.bytedeco:mkl:2022.0-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:tensorflow:${tensorflow.version}-${javacpp.presetsVersion}"
|
api "org.bytedeco:tensorflow:${tensorflow.version}-${javacpp.presetsVersion}"
|
||||||
api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
||||||
api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}"
|
||||||
|
|
|
@ -28,7 +28,8 @@ dependencies {
|
||||||
implementation "org.bytedeco:javacv"
|
implementation "org.bytedeco:javacv"
|
||||||
implementation "org.bytedeco:opencv"
|
implementation "org.bytedeco:opencv"
|
||||||
implementation group: "org.bytedeco", name: "opencv", classifier: buildTarget
|
implementation group: "org.bytedeco", name: "opencv", classifier: buildTarget
|
||||||
implementation "org.bytedeco:leptonica-platform"
|
//implementation "org.bytedeco:leptonica-platform"
|
||||||
|
implementation group: "org.bytedeco", name: "leptonica", classifier: buildTarget
|
||||||
implementation "org.bytedeco:hdf5-platform"
|
implementation "org.bytedeco:hdf5-platform"
|
||||||
|
|
||||||
implementation "commons-io:commons-io"
|
implementation "commons-io:commons-io"
|
||||||
|
|
|
@ -46,7 +46,7 @@ import java.nio.ByteOrder;
|
||||||
import org.bytedeco.leptonica.*;
|
import org.bytedeco.leptonica.*;
|
||||||
import org.bytedeco.opencv.opencv_core.*;
|
import org.bytedeco.opencv.opencv_core.*;
|
||||||
|
|
||||||
import static org.bytedeco.leptonica.global.lept.*;
|
import static org.bytedeco.leptonica.global.leptonica.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_core.*;
|
import static org.bytedeco.opencv.global.opencv_core.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
|
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||||
|
|
|
@ -52,10 +52,9 @@ import java.io.InputStream;
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
import java.util.stream.IntStream;
|
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import static org.bytedeco.leptonica.global.lept.*;
|
import static org.bytedeco.leptonica.global.leptonica.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_core.*;
|
import static org.bytedeco.opencv.global.opencv_core.*;
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
|
|
@ -2386,7 +2386,11 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
long[] stride();
|
long[] stride();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the ordering (fortran or c 'f' and 'c' respectively) of this ndarray
|
* Return the ordering (fortran or c 'f' and 'c' respectively) of this ndarray <br/><br/>
|
||||||
|
* C Is Contiguous layout. Mathematically speaking, row major.<br/>
|
||||||
|
* F Is Fortran contiguous layout. Mathematically speaking, column major.<br/>
|
||||||
|
* {@see https://en.wikipedia.org/wiki/Row-_and_column-major_order}<br/>
|
||||||
|
*
|
||||||
* @return the ordering of this ndarray
|
* @return the ordering of this ndarray
|
||||||
*/
|
*/
|
||||||
char ordering();
|
char ordering();
|
||||||
|
|
|
@ -334,6 +334,7 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet {
|
||||||
public void save(File to) {
|
public void save(File to) {
|
||||||
try (FileOutputStream fos = new FileOutputStream(to, false);
|
try (FileOutputStream fos = new FileOutputStream(to, false);
|
||||||
BufferedOutputStream bos = new BufferedOutputStream(fos)) {
|
BufferedOutputStream bos = new BufferedOutputStream(fos)) {
|
||||||
|
to.mkdirs();
|
||||||
save(bos);
|
save(bos);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
|
|
@ -5121,7 +5121,7 @@ public class Nd4j {
|
||||||
Nd4j.backend = backend;
|
Nd4j.backend = backend;
|
||||||
updateNd4jContext();
|
updateNd4jContext();
|
||||||
props = Nd4jContext.getInstance().getConf();
|
props = Nd4jContext.getInstance().getConf();
|
||||||
logger.info("Properties for Nd4jContext " + props);
|
log.debug("Properties for Nd4jContext {}", props);
|
||||||
PropertyParser pp = new PropertyParser(props);
|
PropertyParser pp = new PropertyParser(props);
|
||||||
|
|
||||||
String otherDtype = pp.toString(ND4JSystemProperties.DTYPE);
|
String otherDtype = pp.toString(ND4JSystemProperties.DTYPE);
|
||||||
|
|
|
@ -166,10 +166,10 @@ public class DataSetIteratorTest extends BaseDL4JTest {
|
||||||
int seed = 123;
|
int seed = 123;
|
||||||
int listenerFreq = 1;
|
int listenerFreq = 1;
|
||||||
|
|
||||||
LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples,
|
final LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples,
|
||||||
new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed));
|
new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed));
|
||||||
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed)
|
final var builder = NeuralNetConfiguration.builder().seed(seed)
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.layer(0, ConvolutionLayer.builder(5, 5).nIn(numChannels).nOut(6)
|
.layer(0, ConvolutionLayer.builder(5, 5).nIn(numChannels).nOut(6)
|
||||||
|
@ -181,7 +181,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.inputType(InputType.convolutionalFlat(numRows, numColumns, numChannels));
|
.inputType(InputType.convolutionalFlat(numRows, numColumns, numChannels));
|
||||||
|
|
||||||
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
|
final MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
|
||||||
model.init();
|
model.init();
|
||||||
|
|
||||||
model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
|
model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
|
||||||
|
|
|
@ -45,6 +45,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
|
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
||||||
|
@ -924,8 +925,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
};
|
};
|
||||||
|
|
||||||
for(EpochTerminationCondition e : etc ){
|
for(EpochTerminationCondition e : etc ){
|
||||||
String s = NeuralNetConfiguration.mapper().writeValueAsString(e);
|
String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(e);
|
||||||
EpochTerminationCondition c = NeuralNetConfiguration.mapper().readValue(s, EpochTerminationCondition.class);
|
EpochTerminationCondition c = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, EpochTerminationCondition.class);
|
||||||
assertEquals(e, c);
|
assertEquals(e, c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -936,8 +937,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
};
|
};
|
||||||
|
|
||||||
for(IterationTerminationCondition i : itc ){
|
for(IterationTerminationCondition i : itc ){
|
||||||
String s = NeuralNetConfiguration.mapper().writeValueAsString(i);
|
String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(i);
|
||||||
IterationTerminationCondition c = NeuralNetConfiguration.mapper().readValue(s, IterationTerminationCondition.class);
|
IterationTerminationCondition c = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, IterationTerminationCondition.class);
|
||||||
assertEquals(i, c);
|
assertEquals(i, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -309,7 +309,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict)
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict)
|
||||||
.list()
|
|
||||||
.layer(0, ConvolutionLayer.builder().kernelSize(2, 3).stride(2, 2).padding(0, 0).nOut(5)
|
.layer(0, ConvolutionLayer.builder().kernelSize(2, 3).stride(2, 2).padding(0, 0).nOut(5)
|
||||||
.build())
|
.build())
|
||||||
.layer(1, OutputLayer.builder().nOut(10).build())
|
.layer(1, OutputLayer.builder().nOut(10).build())
|
||||||
|
|
|
@ -77,7 +77,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration.builder().updater(new NoOp())
|
NeuralNetConfiguration.builder().updater(new NoOp())
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.seed(12345L)
|
.seed(12345L)
|
||||||
.dist(new NormalDistribution(0, 1)).list()
|
.weightInit(new NormalDistribution(0, 1))
|
||||||
.layer(0, DenseLayer.builder().nIn(4).nOut(3)
|
.layer(0, DenseLayer.builder().nIn(4).nOut(3)
|
||||||
.activation(Activation.IDENTITY).build())
|
.activation(Activation.IDENTITY).build())
|
||||||
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).nOut(3).build())
|
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).nOut(3).build())
|
||||||
|
@ -122,7 +122,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.updater(new NoOp()).seed(12345L)
|
.updater(new NoOp()).seed(12345L)
|
||||||
.dist(new NormalDistribution(0, 2)).list()
|
.dist(new NormalDistribution(0, 2)).list()
|
||||||
.layer(0, ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
|
.layer(0, Convolution2D.builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
|
||||||
.activation(Activation.IDENTITY).build())
|
.activation(Activation.IDENTITY).build())
|
||||||
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).build())
|
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).build())
|
||||||
.layer(2, ActivationLayer.builder().activation(Activation.TANH).build())
|
.layer(2, ActivationLayer.builder().activation(Activation.TANH).build())
|
||||||
|
|
|
@ -91,9 +91,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.list()
|
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -202,7 +201,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -211,7 +210,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.layer(Cropping1D.builder(cropping).build())
|
.layer(Cropping1D.builder(cropping).build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -317,7 +316,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -326,7 +325,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.layer(ZeroPadding1DLayer.builder(zeroPadding).build())
|
.layer(ZeroPadding1DLayer.builder(zeroPadding).build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -435,10 +434,9 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.list()
|
|
||||||
.layer(
|
.layer(
|
||||||
0,
|
0,
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -447,7 +445,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
1,
|
1,
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -461,6 +459,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
.padding(padding)
|
.padding(padding)
|
||||||
.pnorm(pnorm)
|
.pnorm(pnorm)
|
||||||
|
.name("SubsamplingLayer")
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
3,
|
3,
|
||||||
|
@ -548,7 +547,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(2)
|
.kernelSize(2)
|
||||||
.rnnDataFormat(RNNFormat.NCW)
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -562,7 +561,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.pnorm(pnorm)
|
.pnorm(pnorm)
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(2)
|
.kernelSize(2)
|
||||||
.rnnDataFormat(RNNFormat.NCW)
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -655,7 +654,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(k)
|
.kernelSize(k)
|
||||||
.dilation(d)
|
.dilation(d)
|
||||||
.hasBias(hasBias)
|
.hasBias(hasBias)
|
||||||
|
@ -664,7 +663,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.nOut(convNOut1)
|
.nOut(convNOut1)
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(k)
|
.kernelSize(k)
|
||||||
.dilation(d)
|
.dilation(d)
|
||||||
.convolutionMode(ConvolutionMode.Causal)
|
.convolutionMode(ConvolutionMode.Causal)
|
||||||
|
|
|
@ -0,0 +1,811 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.gradientcheck;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.TestUtils;
|
||||||
|
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.util.Convolution1DUtils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
import org.nd4j.linalg.learning.config.NoOp;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class CNN1DNewGradientCheckTest extends BaseDL4JTest {
|
||||||
|
private static final boolean PRINT_RESULTS = true;
|
||||||
|
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
||||||
|
private static final double DEFAULT_EPS = 1e-6;
|
||||||
|
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
|
||||||
|
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
||||||
|
|
||||||
|
static {
|
||||||
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1D() {
|
||||||
|
int minibatchSize = 4;
|
||||||
|
int[] dataChannels = {4, 10}; //the input
|
||||||
|
int[] kernels = {2,4,5,8};
|
||||||
|
int stride = 2;
|
||||||
|
int padding = 3;
|
||||||
|
int seriesLength = 300;
|
||||||
|
|
||||||
|
for (int kernel : kernels) {
|
||||||
|
for (int dChannels : dataChannels) {
|
||||||
|
int numLabels = ((seriesLength + (2 * padding) - kernel) / stride) + 1;
|
||||||
|
final NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nIn(dChannels) // channels
|
||||||
|
.nOut(3)
|
||||||
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(4)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(dChannels, seriesLength))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
INDArray input = Nd4j.rand(minibatchSize, dChannels, seriesLength);
|
||||||
|
INDArray labels = Nd4j.zeros(minibatchSize, 4, numLabels);
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
for (int j = 0; j < numLabels; j++) {
|
||||||
|
labels.putScalar(new int[] {i, i % 4, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
String msg =
|
||||||
|
"Minibatch="
|
||||||
|
+ minibatchSize
|
||||||
|
+ ", activationFn="
|
||||||
|
+ Activation.RELU
|
||||||
|
+ ", kernel = "
|
||||||
|
+ kernel;
|
||||||
|
|
||||||
|
System.out.println(msg);
|
||||||
|
for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
|
/**
|
||||||
|
List<Pair<INDArray, INDArray>> iter = new java.util.ArrayList<>(Collections.emptyList());
|
||||||
|
iter.add(new Pair<>(input, labels));
|
||||||
|
for(int x=0;x<100; x++) net.fit(input, labels);
|
||||||
|
Evaluation eval = net.evaluate(new INDArrayDataSetIterator(iter,2), Arrays.asList(new String[]{"One", "Two", "Three", "Four"}));
|
||||||
|
// net.fit(input, labels);
|
||||||
|
eval.eval(labels, net.output(input));
|
||||||
|
|
||||||
|
**/
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
net,
|
||||||
|
DEFAULT_EPS,
|
||||||
|
DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR,
|
||||||
|
PRINT_RESULTS,
|
||||||
|
RETURN_ON_FIRST_FAILURE,
|
||||||
|
input,
|
||||||
|
labels);
|
||||||
|
|
||||||
|
assertTrue(gradOK, msg);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1DWithLocallyConnected1D() {
|
||||||
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
||||||
|
int[] minibatchSizes = {2, 3};
|
||||||
|
int length = 25;
|
||||||
|
int convNIn = 18;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 4;
|
||||||
|
|
||||||
|
int[] kernels = {1,2,4};
|
||||||
|
int stride = 1;
|
||||||
|
int padding = 0;
|
||||||
|
|
||||||
|
Activation[] activations = {Activation.SIGMOID};
|
||||||
|
|
||||||
|
for (Activation afn : activations) {
|
||||||
|
for (int minibatchSize : minibatchSizes) {
|
||||||
|
for (int kernel : kernels) {
|
||||||
|
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
|
||||||
|
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
for (int j = 0; j < length; j++) {
|
||||||
|
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nIn(convNIn)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
LocallyConnected1D.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nIn(convNOut1)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.hasBias(false)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String json = conf.toJson();
|
||||||
|
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
|
||||||
|
assertEquals(conf, c2);
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
String msg =
|
||||||
|
"Minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
|
||||||
|
|
||||||
|
if (PRINT_RESULTS) {
|
||||||
|
System.out.println(msg);
|
||||||
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
// System.out.println("ILayer " + j + " # params: " +
|
||||||
|
// net.getLayer(j).numParams());
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
net,
|
||||||
|
DEFAULT_EPS,
|
||||||
|
DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR,
|
||||||
|
PRINT_RESULTS,
|
||||||
|
RETURN_ON_FIRST_FAILURE,
|
||||||
|
input,
|
||||||
|
labels);
|
||||||
|
|
||||||
|
assertTrue(gradOK, msg);
|
||||||
|
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1DWithCropping1D() {
|
||||||
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
||||||
|
int[] minibatchSizes = {1, 3};
|
||||||
|
int length = 7;
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 4;
|
||||||
|
|
||||||
|
int[] kernels = {1, 2, 4};
|
||||||
|
int stride = 1;
|
||||||
|
|
||||||
|
int padding = 0;
|
||||||
|
int cropping = 1;
|
||||||
|
int croppedLength = length - 2 * cropping;
|
||||||
|
|
||||||
|
Activation[] activations = {Activation.SIGMOID};
|
||||||
|
SubsamplingLayer.PoolingType[] poolingTypes =
|
||||||
|
new SubsamplingLayer.PoolingType[] {
|
||||||
|
SubsamplingLayer.PoolingType.MAX,
|
||||||
|
SubsamplingLayer.PoolingType.AVG,
|
||||||
|
SubsamplingLayer.PoolingType.PNORM
|
||||||
|
};
|
||||||
|
|
||||||
|
for (Activation afn : activations) {
|
||||||
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
|
for (int minibatchSize : minibatchSizes) {
|
||||||
|
for (int kernel : kernels) {
|
||||||
|
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
|
||||||
|
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, croppedLength);
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
for (int j = 0; j < croppedLength; j++) {
|
||||||
|
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(Cropping1D.builder(cropping).build())
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String json = conf.toJson();
|
||||||
|
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
|
||||||
|
assertEquals(conf, c2);
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
String msg =
|
||||||
|
"PoolingType="
|
||||||
|
+ poolingType
|
||||||
|
+ ", minibatch="
|
||||||
|
+ minibatchSize
|
||||||
|
+ ", activationFn="
|
||||||
|
+ afn
|
||||||
|
+ ", kernel = "
|
||||||
|
+ kernel;
|
||||||
|
|
||||||
|
if (PRINT_RESULTS) {
|
||||||
|
System.out.println(msg);
|
||||||
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
// System.out.println("ILayer " + j + " # params: " +
|
||||||
|
// net.getLayer(j).numParams());
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
net,
|
||||||
|
DEFAULT_EPS,
|
||||||
|
DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR,
|
||||||
|
PRINT_RESULTS,
|
||||||
|
RETURN_ON_FIRST_FAILURE,
|
||||||
|
input,
|
||||||
|
labels);
|
||||||
|
|
||||||
|
assertTrue(gradOK, msg);
|
||||||
|
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1DWithZeroPadding1D() {
|
||||||
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
||||||
|
int[] minibatchSizes = {1, 3};
|
||||||
|
int length = 7;
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 4;
|
||||||
|
|
||||||
|
int[] kernels = {1, 2, 4};
|
||||||
|
int stride = 1;
|
||||||
|
int pnorm = 2;
|
||||||
|
|
||||||
|
int padding = 0;
|
||||||
|
int zeroPadding = 2;
|
||||||
|
int paddedLength = length + 2 * zeroPadding;
|
||||||
|
|
||||||
|
Activation[] activations = {Activation.SIGMOID};
|
||||||
|
SubsamplingLayer.PoolingType[] poolingTypes =
|
||||||
|
new SubsamplingLayer.PoolingType[] {
|
||||||
|
SubsamplingLayer.PoolingType.MAX,
|
||||||
|
SubsamplingLayer.PoolingType.AVG,
|
||||||
|
SubsamplingLayer.PoolingType.PNORM
|
||||||
|
};
|
||||||
|
|
||||||
|
for (Activation afn : activations) {
|
||||||
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
|
for (int minibatchSize : minibatchSizes) {
|
||||||
|
for (int kernel : kernels) {
|
||||||
|
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
|
||||||
|
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, paddedLength);
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
for (int j = 0; j < paddedLength; j++) {
|
||||||
|
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(2, kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(ZeroPadding1DLayer.builder(zeroPadding).build())
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(ZeroPadding1DLayer.builder(0).build())
|
||||||
|
.layer(
|
||||||
|
Subsampling1DLayer.builder(poolingType)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.pnorm(pnorm)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String json = conf.toJson();
|
||||||
|
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
|
||||||
|
assertEquals(conf, c2);
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
String msg =
|
||||||
|
"PoolingType="
|
||||||
|
+ poolingType
|
||||||
|
+ ", minibatch="
|
||||||
|
+ minibatchSize
|
||||||
|
+ ", activationFn="
|
||||||
|
+ afn
|
||||||
|
+ ", kernel = "
|
||||||
|
+ kernel;
|
||||||
|
|
||||||
|
if (PRINT_RESULTS) {
|
||||||
|
System.out.println(msg);
|
||||||
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
// System.out.println("ILayer " + j + " # params: " +
|
||||||
|
// net.getLayer(j).numParams());
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
net,
|
||||||
|
DEFAULT_EPS,
|
||||||
|
DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR,
|
||||||
|
PRINT_RESULTS,
|
||||||
|
RETURN_ON_FIRST_FAILURE,
|
||||||
|
input,
|
||||||
|
labels);
|
||||||
|
|
||||||
|
assertTrue(gradOK, msg);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1DWithSubsampling1D() {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
int[] minibatchSizes = {1, 3};
|
||||||
|
int length = 7;
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 4;
|
||||||
|
|
||||||
|
int[] kernels = {1, 2, 4};
|
||||||
|
int stride = 1;
|
||||||
|
int padding = 0;
|
||||||
|
int pnorm = 2;
|
||||||
|
|
||||||
|
Activation[] activations = {Activation.SIGMOID, Activation.TANH};
|
||||||
|
SubsamplingLayer.PoolingType[] poolingTypes =
|
||||||
|
new SubsamplingLayer.PoolingType[] {
|
||||||
|
SubsamplingLayer.PoolingType.MAX,
|
||||||
|
SubsamplingLayer.PoolingType.AVG,
|
||||||
|
SubsamplingLayer.PoolingType.PNORM
|
||||||
|
};
|
||||||
|
|
||||||
|
for (Activation afn : activations) {
|
||||||
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
|
for (int minibatchSize : minibatchSizes) {
|
||||||
|
for (int kernel : kernels) {
|
||||||
|
INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
|
||||||
|
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
for (int j = 0; j < length; j++) {
|
||||||
|
labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.layer(
|
||||||
|
0,
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
1,
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.activation(afn)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
2,
|
||||||
|
Subsampling1DLayer.builder(poolingType)
|
||||||
|
.kernelSize(kernel)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.pnorm(pnorm)
|
||||||
|
.name("SubsamplingLayer")
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
3,
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String json = conf.toJson();
|
||||||
|
NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
|
||||||
|
assertEquals(conf, c2);
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
String msg =
|
||||||
|
"PoolingType="
|
||||||
|
+ poolingType
|
||||||
|
+ ", minibatch="
|
||||||
|
+ minibatchSize
|
||||||
|
+ ", activationFn="
|
||||||
|
+ afn
|
||||||
|
+ ", kernel = "
|
||||||
|
+ kernel;
|
||||||
|
|
||||||
|
if (PRINT_RESULTS) {
|
||||||
|
System.out.println(msg);
|
||||||
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
// System.out.println("ILayer " + j + " # params: " +
|
||||||
|
// net.getLayer(j).numParams());
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
net,
|
||||||
|
DEFAULT_EPS,
|
||||||
|
DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR,
|
||||||
|
PRINT_RESULTS,
|
||||||
|
RETURN_ON_FIRST_FAILURE,
|
||||||
|
input,
|
||||||
|
labels);
|
||||||
|
|
||||||
|
assertTrue(gradOK, msg);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1dWithMasking() {
|
||||||
|
int length = 12;
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 3;
|
||||||
|
|
||||||
|
int pnorm = 2;
|
||||||
|
|
||||||
|
SubsamplingLayer.PoolingType[] poolingTypes =
|
||||||
|
new SubsamplingLayer.PoolingType[] {
|
||||||
|
SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG
|
||||||
|
};
|
||||||
|
|
||||||
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
|
for (ConvolutionMode cm :
|
||||||
|
new ConvolutionMode[] {ConvolutionMode.Same, ConvolutionMode.Truncate}) {
|
||||||
|
for (int stride : new int[] {1, 2}) {
|
||||||
|
String s = cm + ", stride=" + stride + ", pooling=" + poolingType;
|
||||||
|
log.info("Starting test: " + s);
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.convolutionMode(cm)
|
||||||
|
.seed(12345)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.kernelSize(2)
|
||||||
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
|
.stride(stride)
|
||||||
|
.nIn(convNIn)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
Subsampling1DLayer.builder(poolingType)
|
||||||
|
.kernelSize(2)
|
||||||
|
.stride(stride)
|
||||||
|
.pnorm(pnorm)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.kernelSize(2)
|
||||||
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
|
.stride(stride)
|
||||||
|
.nIn(convNOut1)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
||||||
|
.layer(
|
||||||
|
OutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
INDArray f = Nd4j.rand(2, convNIn, length);
|
||||||
|
INDArray fm = Nd4j.create(2, length);
|
||||||
|
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
|
||||||
|
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, 6)).assign(1);
|
||||||
|
|
||||||
|
INDArray label = TestUtils.randomOneHot(2, finalNOut);
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
|
||||||
|
|
||||||
|
assertTrue(gradOK, s);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
|
||||||
|
// TODO also check that masked step values don't impact forward pass, score or gradients
|
||||||
|
|
||||||
|
DataSet ds = new DataSet(f, label, fm, null);
|
||||||
|
double scoreBefore = net.score(ds);
|
||||||
|
net.setInput(f);
|
||||||
|
net.setLabels(label);
|
||||||
|
net.setLayerMaskArrays(fm, null);
|
||||||
|
net.computeGradientAndScore();
|
||||||
|
INDArray gradBefore = net.getFlattenedGradients().dup();
|
||||||
|
f.putScalar(1, 0, 10, 10.0);
|
||||||
|
f.putScalar(1, 1, 11, 20.0);
|
||||||
|
double scoreAfter = net.score(ds);
|
||||||
|
net.setInput(f);
|
||||||
|
net.setLabels(label);
|
||||||
|
net.setLayerMaskArrays(fm, null);
|
||||||
|
net.computeGradientAndScore();
|
||||||
|
INDArray gradAfter = net.getFlattenedGradients().dup();
|
||||||
|
|
||||||
|
assertEquals(scoreBefore, scoreAfter, 1e-6);
|
||||||
|
assertEquals(gradBefore, gradAfter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1Causal() throws Exception {
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 3;
|
||||||
|
|
||||||
|
int[] lengths = {11, 12, 13, 9, 10, 11};
|
||||||
|
int[] kernels = {2, 3, 2, 4, 2, 3};
|
||||||
|
int[] dilations = {1, 1, 2, 1, 2, 1};
|
||||||
|
int[] strides = {1, 2, 1, 2, 1, 1};
|
||||||
|
boolean[] masks = {false, true, false, true, false, true};
|
||||||
|
boolean[] hasB = {true, false, true, false, true, true};
|
||||||
|
for (int i = 0; i < lengths.length; i++) {
|
||||||
|
int length = lengths[i];
|
||||||
|
int k = kernels[i];
|
||||||
|
int d = dilations[i];
|
||||||
|
int st = strides[i];
|
||||||
|
boolean mask = masks[i];
|
||||||
|
boolean hasBias = hasB[i];
|
||||||
|
// TODO has bias
|
||||||
|
String s = "k=" + k + ", s=" + st + " d=" + d + ", seqLen=" + length;
|
||||||
|
log.info("Starting test: " + s);
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
NeuralNetConfiguration conf =
|
||||||
|
NeuralNetConfiguration.builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.weightInit(new NormalDistribution(0, 1))
|
||||||
|
.seed(12345)
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.kernelSize(k)
|
||||||
|
.dilation(d)
|
||||||
|
.hasBias(hasBias)
|
||||||
|
.convolutionMode(ConvolutionMode.Causal)
|
||||||
|
.stride(st)
|
||||||
|
.nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
Convolution1DNew.builder()
|
||||||
|
.kernelSize(k)
|
||||||
|
.dilation(d)
|
||||||
|
.convolutionMode(ConvolutionMode.Causal)
|
||||||
|
.stride(st)
|
||||||
|
.nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(
|
||||||
|
RnnOutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nOut(finalNOut)
|
||||||
|
.build())
|
||||||
|
.inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length);
|
||||||
|
INDArray fm = null;
|
||||||
|
if (mask) {
|
||||||
|
fm = Nd4j.create(2, length);
|
||||||
|
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
|
||||||
|
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length - 2)).assign(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d);
|
||||||
|
long outSize2 =
|
||||||
|
Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d);
|
||||||
|
|
||||||
|
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int) outSize2);
|
||||||
|
|
||||||
|
String msg =
|
||||||
|
"Minibatch="
|
||||||
|
+ 1
|
||||||
|
+ ", activationFn="
|
||||||
|
+ Activation.RELU
|
||||||
|
+ ", kernel = "
|
||||||
|
+ k;
|
||||||
|
|
||||||
|
System.out.println(msg);
|
||||||
|
for (int j = 0; j < net.getnLayers(); j++)
|
||||||
|
System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
|
|
||||||
|
boolean gradOK =
|
||||||
|
GradientCheckUtil.checkGradients(
|
||||||
|
new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
|
||||||
|
|
||||||
|
assertTrue(gradOK, s);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -112,9 +112,8 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
|
.updater(new NoOp())
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.list()
|
|
||||||
.layer(0, Convolution3D.builder().activation(afn).kernelSize(kernel)
|
.layer(0, Convolution3D.builder().activation(afn).kernelSize(kernel)
|
||||||
.stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false)
|
.stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false)
|
||||||
.convolutionMode(mode).dataFormat(df)
|
.convolutionMode(mode).dataFormat(df)
|
||||||
|
@ -400,7 +399,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
|
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
|
||||||
.layer(0, Convolution3D.builder().activation(afn).kernelSize(1, 1, 1)
|
.layer(0, Convolution3D.builder().activation(afn).kernelSize(1, 1, 1)
|
||||||
.nIn(convNIn).nOut(convNOut).hasBias(false)
|
.nIn(convNIn).nOut(convNOut).hasBias(false)
|
||||||
.convolutionMode(mode).dataFormat(df)
|
.convolutionMode(mode).dataFormat(df)
|
||||||
|
|
|
@ -108,8 +108,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.seed(12345L)
|
.seed(12345L)
|
||||||
.list()
|
|
||||||
.layer(0, ConvolutionLayer.builder(1, 1).nOut(6).activation(afn).build())
|
.layer(0, Convolution2D.builder().kernelSize(1).stride(1).nOut(6).activation(afn).build())
|
||||||
.layer(1, OutputLayer.builder(lf).activation(outputActivation).nOut(3).build())
|
.layer(1, OutputLayer.builder(lf).activation(outputActivation).nOut(3).build())
|
||||||
.inputType(InputType.convolutionalFlat(1, 4, 1));
|
.inputType(InputType.convolutionalFlat(1, 4, 1));
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -336,7 +337,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
// to ensure that we carry the parameters through
|
// to ensure that we carry the parameters through
|
||||||
// the serializer.
|
// the serializer.
|
||||||
try{
|
try{
|
||||||
ObjectMapper m = NeuralNetConfiguration.mapper();
|
ObjectMapper m = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
String s = m.writeValueAsString(lossFunctions[i]);
|
String s = m.writeValueAsString(lossFunctions[i]);
|
||||||
ILossFunction lf2 = m.readValue(s, lossFunctions[i].getClass());
|
ILossFunction lf2 = m.readValue(s, lossFunctions[i].getClass());
|
||||||
lossFunctions[i] = lf2;
|
lossFunctions[i] = lf2;
|
||||||
|
|
|
@ -180,7 +180,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
Pooling2D.class, //Alias for SubsamplingLayer
|
Pooling2D.class, //Alias for SubsamplingLayer
|
||||||
Convolution2D.class, //Alias for ConvolutionLayer
|
Convolution2D.class, //Alias for ConvolutionLayer
|
||||||
Pooling1D.class, //Alias for Subsampling1D
|
Pooling1D.class, //Alias for Subsampling1D
|
||||||
Convolution1D.class, //Alias for Convolution1DLayer
|
Convolution1D.class, //Alias for Convolution1D
|
||||||
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
||||||
));
|
));
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.Timeout;
|
import org.junit.jupiter.api.Timeout;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -1026,7 +1026,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
} catch (DL4JInvalidInputException e) {
|
} catch (DL4JInvalidInputException e) {
|
||||||
// e.printStackTrace();
|
// e.printStackTrace();
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration"), msg);
|
assertTrue(msg.contains(Convolution2DUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration"), msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
@ -921,7 +921,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration.builder()
|
NeuralNetConfiguration.builder()
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.kernelSize(2)
|
.kernelSize(2)
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
|
@ -975,7 +975,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConv1dCausalAllowed() {
|
public void testConv1dCausalAllowed() {
|
||||||
Convolution1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
Convolution1D.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||||
Subsampling1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
Subsampling1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -346,7 +346,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(2, it.getHeight());
|
assertEquals(2, it.getHeight());
|
||||||
assertEquals(2, it.getWidth());
|
assertEquals(2, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
int[] outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
int[] outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
||||||
assertEquals(2, outSize[0]);
|
assertEquals(2, outSize[0]);
|
||||||
assertEquals(2, outSize[1]);
|
assertEquals(2, outSize[1]);
|
||||||
|
|
||||||
|
@ -357,7 +357,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(2, it.getHeight());
|
assertEquals(2, it.getHeight());
|
||||||
assertEquals(2, it.getWidth());
|
assertEquals(2, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
||||||
assertEquals(2, outSize[0]);
|
assertEquals(2, outSize[0]);
|
||||||
assertEquals(2, outSize[1]);
|
assertEquals(2, outSize[1]);
|
||||||
|
|
||||||
|
@ -367,7 +367,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(3, it.getHeight());
|
assertEquals(3, it.getHeight());
|
||||||
assertEquals(3, it.getWidth());
|
assertEquals(3, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
||||||
assertEquals(3, outSize[0]);
|
assertEquals(3, outSize[0]);
|
||||||
assertEquals(3, outSize[1]);
|
assertEquals(3, outSize[1]);
|
||||||
|
|
||||||
|
@ -397,7 +397,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
System.out.println(e.getMessage());
|
System.out.println(e.getMessage());
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
||||||
fail("Exception expected");
|
fail("Exception expected");
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println(e.getMessage());
|
System.out.println(e.getMessage());
|
||||||
|
@ -409,7 +409,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(1, it.getHeight());
|
assertEquals(1, it.getHeight());
|
||||||
assertEquals(1, it.getWidth());
|
assertEquals(1, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
||||||
assertEquals(1, outSize[0]);
|
assertEquals(1, outSize[0]);
|
||||||
assertEquals(1, outSize[1]);
|
assertEquals(1, outSize[1]);
|
||||||
|
|
||||||
|
@ -419,7 +419,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(2, it.getHeight());
|
assertEquals(2, it.getHeight());
|
||||||
assertEquals(2, it.getWidth());
|
assertEquals(2, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
||||||
assertEquals(2, outSize[0]);
|
assertEquals(2, outSize[0]);
|
||||||
assertEquals(2, outSize[1]);
|
assertEquals(2, outSize[1]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -732,7 +732,7 @@ public class BatchNormalizationTest extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(rnn ? LSTM.builder().nOut(3).build() :
|
.layer(rnn ? LSTM.builder().nOut(3).build() :
|
||||||
Convolution1DLayer.builder().kernelSize(3).stride(1).nOut(3).build())
|
Convolution1D.builder().kernelSize(3).stride(1).nOut(3).build())
|
||||||
.layer(BatchNormalization.builder().build())
|
.layer(BatchNormalization.builder().build())
|
||||||
.layer(RnnOutputLayer.builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
.layer(RnnOutputLayer.builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||||
.inputType(InputType.recurrent(3))
|
.inputType(InputType.recurrent(3))
|
||||||
|
|
|
@ -52,7 +52,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest {
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs(inputName)
|
.addInputs(inputName)
|
||||||
.setOutputs(output)
|
.setOutputs(output)
|
||||||
.layer(conv, Convolution1DLayer.builder(7)
|
.layer(conv, Convolution1D.builder(7)
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.nOut(input.size(1))
|
.nOut(input.size(1))
|
||||||
.weightInit(new WeightInitIdentity())
|
.weightInit(new WeightInitIdentity())
|
||||||
|
|
|
@ -23,6 +23,7 @@ package org.deeplearning4j.regressiontest;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.distribution.*;
|
import org.deeplearning4j.nn.conf.distribution.*;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ public class TestDistributionDeserializer extends BaseDL4JTest {
|
||||||
new Distribution[] {new NormalDistribution(3, 0.5), new UniformDistribution(-2, 1),
|
new Distribution[] {new NormalDistribution(3, 0.5), new UniformDistribution(-2, 1),
|
||||||
new GaussianDistribution(2, 1.0), new BinomialDistribution(10, 0.3)};
|
new GaussianDistribution(2, 1.0), new BinomialDistribution(10, 0.3)};
|
||||||
|
|
||||||
ObjectMapper om = NeuralNetConfiguration.mapper();
|
ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
|
|
||||||
for (Distribution d : distributions) {
|
for (Distribution d : distributions) {
|
||||||
String json = om.writeValueAsString(d);
|
String json = om.writeValueAsString(d);
|
||||||
|
@ -50,7 +51,7 @@ public class TestDistributionDeserializer extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDistributionDeserializerLegacyFormat() throws Exception {
|
public void testDistributionDeserializerLegacyFormat() throws Exception {
|
||||||
ObjectMapper om = NeuralNetConfiguration.mapper();
|
ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
|
|
||||||
String normalJson = "{\n" + " \"normal\" : {\n" + " \"mean\" : 0.1,\n"
|
String normalJson = "{\n" + " \"normal\" : {\n" + " \"mean\" : 0.1,\n"
|
||||||
+ " \"std\" : 1.2\n" + " }\n" + " }";
|
+ " \"std\" : 1.2\n" + " }\n" + " }";
|
||||||
|
|
|
@ -38,7 +38,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
||||||
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
|
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
|
||||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.nd4j.jita.allocator.Allocator;
|
import org.nd4j.jita.allocator.Allocator;
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
import org.nd4j.jita.conf.CudaEnvironment;
|
import org.nd4j.jita.conf.CudaEnvironment;
|
||||||
|
@ -681,9 +681,9 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
|
|
||||||
int[] outSize;
|
int[] outSize;
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
|
outSize = Convolution2DUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
|
||||||
padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
padding = Convolution2DUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
||||||
int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
int[] padBottomRight = Convolution2DUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
||||||
if(!Arrays.equals(padding, padBottomRight)){
|
if(!Arrays.equals(padding, padBottomRight)){
|
||||||
/*
|
/*
|
||||||
CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
|
CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
|
||||||
|
@ -731,7 +731,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
// CuDNN handle
|
// CuDNN handle
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
|
outSize = Convolution2DUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
|
||||||
}
|
}
|
||||||
|
|
||||||
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
|
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
|
||||||
|
|
|
@ -42,7 +42,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.nd4j.common.primitives.Counter;
|
import org.nd4j.common.primitives.Counter;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
@ -442,8 +442,8 @@ public class KerasModel {
|
||||||
KerasInput kerasInput = (KerasInput) layer;
|
KerasInput kerasInput = (KerasInput) layer;
|
||||||
LayerConfiguration layer1 = layersOrdered.get(kerasLayerIdx + 1).layer;
|
LayerConfiguration layer1 = layersOrdered.get(kerasLayerIdx + 1).layer;
|
||||||
//no dim order, try to pull it from the next layer if there is one
|
//no dim order, try to pull it from the next layer if there is one
|
||||||
if(ConvolutionUtils.layerHasConvolutionLayout(layer1)) {
|
if(Convolution2DUtils.layerHasConvolutionLayout(layer1)) {
|
||||||
CNN2DFormat formatForLayer = ConvolutionUtils.getFormatForLayer(layer1);
|
CNN2DFormat formatForLayer = Convolution2DUtils.getFormatForLayer(layer1);
|
||||||
if(formatForLayer == CNN2DFormat.NCHW) {
|
if(formatForLayer == CNN2DFormat.NCHW) {
|
||||||
dimOrder = KerasLayer.DimOrder.THEANO;
|
dimOrder = KerasLayer.DimOrder.THEANO;
|
||||||
} else if(formatForLayer == CNN2DFormat.NHWC) {
|
} else if(formatForLayer == CNN2DFormat.NHWC) {
|
||||||
|
|
|
@ -52,28 +52,44 @@ public class KerasSequentialModel extends KerasModel {
|
||||||
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
|
* @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
|
||||||
*/
|
*/
|
||||||
public KerasSequentialModel(KerasModelBuilder modelBuilder)
|
public KerasSequentialModel(KerasModelBuilder modelBuilder)
|
||||||
throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
|
throws UnsupportedKerasConfigurationException,
|
||||||
this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(),
|
IOException,
|
||||||
modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(),
|
InvalidKerasConfigurationException {
|
||||||
modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape());
|
this(
|
||||||
|
modelBuilder.getModelJson(),
|
||||||
|
modelBuilder.getModelYaml(),
|
||||||
|
modelBuilder.getWeightsArchive(),
|
||||||
|
modelBuilder.getWeightsRoot(),
|
||||||
|
modelBuilder.getTrainingJson(),
|
||||||
|
modelBuilder.getTrainingArchive(),
|
||||||
|
modelBuilder.isEnforceTrainingConfig(),
|
||||||
|
modelBuilder.getInputShape());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* (Not recommended) Constructor for Sequential model from model configuration
|
* (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML),
|
||||||
* (JSON or YAML), training configuration (JSON), weights, and "training mode"
|
* training configuration (JSON), weights, and "training mode" boolean indicator. When built in
|
||||||
* boolean indicator. When built in training mode, certain unsupported configurations
|
* training mode, certain unsupported configurations (e.g., unknown regularizers) will throw
|
||||||
* (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these
|
* Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be
|
||||||
* will generate warnings but will be otherwise ignored.
|
* otherwise ignored.
|
||||||
*
|
*
|
||||||
* @param modelJson model configuration JSON string
|
* @param modelJson model configuration JSON string
|
||||||
* @param modelYaml model configuration YAML string
|
* @param modelYaml model configuration YAML string
|
||||||
* @param trainingJson training configuration JSON string
|
* @param trainingJson training configuration JSON string
|
||||||
* @throws IOException I/O exception
|
* @throws IOException I/O exception
|
||||||
*/
|
*/
|
||||||
public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot,
|
public KerasSequentialModel(
|
||||||
String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig,
|
String modelJson,
|
||||||
|
String modelYaml,
|
||||||
|
Hdf5Archive weightsArchive,
|
||||||
|
String weightsRoot,
|
||||||
|
String trainingJson,
|
||||||
|
Hdf5Archive trainingArchive,
|
||||||
|
boolean enforceTrainingConfig,
|
||||||
int[] inputShape)
|
int[] inputShape)
|
||||||
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
throws IOException,
|
||||||
|
InvalidKerasConfigurationException,
|
||||||
|
UnsupportedKerasConfigurationException {
|
||||||
|
|
||||||
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
|
Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
|
||||||
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
|
this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
|
||||||
|
@ -83,19 +99,29 @@ public class KerasSequentialModel extends KerasModel {
|
||||||
/* Determine model configuration type. */
|
/* Determine model configuration type. */
|
||||||
if (!modelConfig.containsKey(config.getFieldClassName()))
|
if (!modelConfig.containsKey(config.getFieldClassName()))
|
||||||
throw new InvalidKerasConfigurationException(
|
throw new InvalidKerasConfigurationException(
|
||||||
"Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
|
"Could not determine Keras model class (no "
|
||||||
|
+ config.getFieldClassName()
|
||||||
|
+ " field found)");
|
||||||
this.className = (String) modelConfig.get(config.getFieldClassName());
|
this.className = (String) modelConfig.get(config.getFieldClassName());
|
||||||
if (!this.className.equals(config.getFieldClassNameSequential()))
|
if (!this.className.equals(config.getFieldClassNameSequential()))
|
||||||
throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential()
|
throw new InvalidKerasConfigurationException(
|
||||||
+ " (found " + this.className + ")");
|
"Model class name must be "
|
||||||
|
+ config.getFieldClassNameSequential()
|
||||||
|
+ " (found "
|
||||||
|
+ this.className
|
||||||
|
+ ")");
|
||||||
|
|
||||||
/* Process layer configurations. */
|
/* Process layer configurations. */
|
||||||
if (!modelConfig.containsKey(config.getModelFieldConfig()))
|
if (!modelConfig.containsKey(config.getModelFieldConfig()))
|
||||||
throw new InvalidKerasConfigurationException(
|
throw new InvalidKerasConfigurationException(
|
||||||
"Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)");
|
"Could not find layer configurations (no "
|
||||||
|
+ config.getModelFieldConfig()
|
||||||
|
+ " field found)");
|
||||||
|
|
||||||
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency
|
// Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations.
|
||||||
// "config" is now an object containing a "name" and "layers", the latter contain the same data as before.
|
// For consistency
|
||||||
|
// "config" is now an object containing a "name" and "layers", the latter contain the same data
|
||||||
|
// as before.
|
||||||
// This change only affects Sequential models.
|
// This change only affects Sequential models.
|
||||||
List<Object> layerList;
|
List<Object> layerList;
|
||||||
try {
|
try {
|
||||||
|
@ -105,8 +131,7 @@ public class KerasSequentialModel extends KerasModel {
|
||||||
layerList = (List<Object>) layerMap.get("layers");
|
layerList = (List<Object>) layerMap.get("layers");
|
||||||
}
|
}
|
||||||
|
|
||||||
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair =
|
Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = prepareLayers(layerList);
|
||||||
prepareLayers(layerList);
|
|
||||||
this.layers = layerPair.getFirst();
|
this.layers = layerPair.getFirst();
|
||||||
this.layersOrdered = layerPair.getSecond();
|
this.layersOrdered = layerPair.getSecond();
|
||||||
|
|
||||||
|
@ -116,15 +141,18 @@ public class KerasSequentialModel extends KerasModel {
|
||||||
} else {
|
} else {
|
||||||
/* Add placeholder input layer and update lists of input and output layers. */
|
/* Add placeholder input layer and update lists of input and output layers. */
|
||||||
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
|
int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
|
||||||
Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!");
|
Preconditions.checkState(
|
||||||
|
ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!");
|
||||||
inputLayer = new KerasInput("input1", firstLayerInputShape);
|
inputLayer = new KerasInput("input1", firstLayerInputShape);
|
||||||
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
|
inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
|
||||||
this.layers.put(inputLayer.getName(), inputLayer);
|
this.layers.put(inputLayer.getName(), inputLayer);
|
||||||
this.layersOrdered.add(0, inputLayer);
|
this.layersOrdered.add(0, inputLayer);
|
||||||
}
|
}
|
||||||
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
|
this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
|
||||||
this.outputLayerNames = new ArrayList<>(
|
this.outputLayerNames =
|
||||||
Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
|
new ArrayList<>(
|
||||||
|
Collections.singletonList(
|
||||||
|
this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
|
||||||
|
|
||||||
/* Update each layer's inbound layer list to include (only) previous layer. */
|
/* Update each layer's inbound layer list to include (only) previous layer. */
|
||||||
KerasLayer prevLayer = null;
|
KerasLayer prevLayer = null;
|
||||||
|
@ -136,12 +164,13 @@ public class KerasSequentialModel extends KerasModel {
|
||||||
|
|
||||||
/* Import training configuration. */
|
/* Import training configuration. */
|
||||||
if (enforceTrainingConfig) {
|
if (enforceTrainingConfig) {
|
||||||
if (trainingJson != null)
|
if (trainingJson != null) importTrainingConfiguration(trainingJson);
|
||||||
importTrainingConfiguration(trainingJson);
|
else
|
||||||
else log.warn("If enforceTrainingConfig is true, a training " +
|
log.warn(
|
||||||
"configuration object has to be provided. Usually the only practical way to do this is to store" +
|
"If enforceTrainingConfig is true, a training "
|
||||||
" your keras model with `model.save('model_path.h5'. If you store model config and weights" +
|
+ "configuration object has to be provided. Usually the only practical way to do this is to store"
|
||||||
" separately no training configuration is attached.");
|
+ " your keras model with `model.save('model_path.h5'. If you store model config and weights"
|
||||||
|
+ " separately no training configuration is attached.");
|
||||||
}
|
}
|
||||||
|
|
||||||
this.outputTypes = inferOutputTypes(inputShape);
|
this.outputTypes = inferOutputTypes(inputShape);
|
||||||
|
@ -150,9 +179,7 @@ public class KerasSequentialModel extends KerasModel {
|
||||||
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
|
importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** Default constructor */
|
||||||
* Default constructor
|
|
||||||
*/
|
|
||||||
public KerasSequentialModel() {
|
public KerasSequentialModel() {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
|
@ -174,14 +201,14 @@ public class KerasSequentialModel extends KerasModel {
|
||||||
throw new InvalidKerasConfigurationException(
|
throw new InvalidKerasConfigurationException(
|
||||||
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
|
"MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
|
||||||
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder();
|
NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder =
|
||||||
|
NeuralNetConfiguration.builder();
|
||||||
|
|
||||||
if (optimizer != null) {
|
if (optimizer != null) {
|
||||||
modelBuilder.updater(optimizer);
|
modelBuilder.updater(optimizer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// don't forcibly override for keras import
|
||||||
//don't forcibly override for keras import
|
|
||||||
modelBuilder.overrideNinUponBuild(false);
|
modelBuilder.overrideNinUponBuild(false);
|
||||||
/* Add layers one at a time. */
|
/* Add layers one at a time. */
|
||||||
KerasLayer prevLayer = null;
|
KerasLayer prevLayer = null;
|
||||||
|
@ -192,7 +219,10 @@ public class KerasSequentialModel extends KerasModel {
|
||||||
if (nbInbound != 1)
|
if (nbInbound != 1)
|
||||||
throw new InvalidKerasConfigurationException(
|
throw new InvalidKerasConfigurationException(
|
||||||
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
|
"Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
|
||||||
+ nbInbound + " for layer " + layer.getName() + ")");
|
+ nbInbound
|
||||||
|
+ " for layer "
|
||||||
|
+ layer.getName()
|
||||||
|
+ ")");
|
||||||
if (prevLayer != null) {
|
if (prevLayer != null) {
|
||||||
InputType[] inputTypes = new InputType[1];
|
InputType[] inputTypes = new InputType[1];
|
||||||
InputPreProcessor preprocessor;
|
InputPreProcessor preprocessor;
|
||||||
|
@ -200,42 +230,44 @@ public class KerasSequentialModel extends KerasModel {
|
||||||
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
|
inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
|
||||||
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
|
preprocessor = prevLayer.getInputPreprocessor(inputTypes);
|
||||||
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
||||||
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
|
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
|
||||||
} else {
|
} else {
|
||||||
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
|
inputTypes[0] = this.outputTypes.get(prevLayer.getName());
|
||||||
preprocessor = layer.getInputPreprocessor(inputTypes);
|
preprocessor = layer.getInputPreprocessor(inputTypes);
|
||||||
if(preprocessor != null) {
|
if (preprocessor != null) {
|
||||||
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
InputType outputType = preprocessor.getOutputType(inputTypes[0]);
|
||||||
layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
|
layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
|
||||||
|
} else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild());
|
||||||
}
|
}
|
||||||
else
|
if (preprocessor != null) {
|
||||||
layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild());
|
|
||||||
|
|
||||||
|
Map<Integer, InputPreProcessor> map = new HashMap<>();
|
||||||
|
map.put(layerIndex, preprocessor);
|
||||||
|
modelBuilder.inputPreProcessors(map);
|
||||||
}
|
}
|
||||||
if (preprocessor != null)
|
|
||||||
modelBuilder.inputPreProcessor(layerIndex, preprocessor);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
modelBuilder.layer(layerIndex++, layer.getLayer());
|
modelBuilder.layer(layerIndex++, layer.getLayer());
|
||||||
} else if (layer.getVertex() != null)
|
} else if (layer.getVertex() != null)
|
||||||
throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name "
|
throw new InvalidKerasConfigurationException(
|
||||||
+ layer.getClassName() + ", layer name " + layer.getName() + ")");
|
"Cannot add vertex to NeuralNetConfiguration (class name "
|
||||||
|
+ layer.getClassName()
|
||||||
|
+ ", layer name "
|
||||||
|
+ layer.getName()
|
||||||
|
+ ")");
|
||||||
prevLayer = layer;
|
prevLayer = layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
|
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
|
||||||
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
|
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
|
||||||
modelBuilder.backpropType(BackpropType.TruncatedBPTT)
|
modelBuilder
|
||||||
|
.backpropType(BackpropType.TruncatedBPTT)
|
||||||
.tbpttFwdLength(truncatedBPTT)
|
.tbpttFwdLength(truncatedBPTT)
|
||||||
.tbpttBackLength(truncatedBPTT);
|
.tbpttBackLength(truncatedBPTT);
|
||||||
else
|
else modelBuilder.backpropType(BackpropType.Standard);
|
||||||
modelBuilder.backpropType(BackpropType.Standard);
|
|
||||||
|
|
||||||
NeuralNetConfiguration build = modelBuilder.build();
|
NeuralNetConfiguration build = modelBuilder.build();
|
||||||
|
|
||||||
|
|
||||||
return build;
|
return build;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -84,29 +84,29 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
|
||||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||||
|
|
||||||
ConvolutionLayer.ConvolutionLayerBuilder builder = Convolution1DLayer.builder().name(this.name)
|
var builder = Convolution1D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(init)
|
.weightInit(init)
|
||||||
.dilation(getDilationRate(layerConfig, 1, conf, true)[0])
|
.dilation(getDilationRate(layerConfig, 1, conf, true)[0])
|
||||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion))
|
||||||
.hasBias(hasBias)
|
.hasBias(hasBias)
|
||||||
.rnnDataFormat(dimOrder == DimOrder.TENSORFLOW ? RNNFormat.NWC : RNNFormat.NCW)
|
.rnnDataFormat(dimOrder == DimOrder.TENSORFLOW ? RNNFormat.NWC : RNNFormat.NCW)
|
||||||
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]);
|
.stride(getStrideFromConfig(layerConfig, 1, conf));
|
||||||
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
|
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
|
||||||
if (hasBias)
|
if (hasBias)
|
||||||
builder.biasInit(0.0);
|
builder.biasInit(0.0);
|
||||||
if (padding != null)
|
if (padding != null)
|
||||||
builder.padding(padding[0]);
|
builder.padding(padding);
|
||||||
if (biasConstraint != null)
|
if (biasConstraint != null)
|
||||||
builder.constrainBias(biasConstraint);
|
builder.constrainBias(biasConstraint);
|
||||||
if (weightConstraint != null)
|
if (weightConstraint != null)
|
||||||
builder.constrainWeights(weightConstraint);
|
builder.constrainWeights(weightConstraint);
|
||||||
this.layer = builder.build();
|
this.layer = builder.build();
|
||||||
Convolution1DLayer convolution1DLayer = (Convolution1DLayer) layer;
|
Convolution1D convolution1D = (Convolution1D) layer;
|
||||||
convolution1DLayer.setDefaultValueOverriden(true);
|
convolution1D.setDefaultValueOverriden(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -114,8 +114,8 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
|
||||||
*
|
*
|
||||||
* @return ConvolutionLayer
|
* @return ConvolutionLayer
|
||||||
*/
|
*/
|
||||||
public Convolution1DLayer getAtrousConvolution1D() {
|
public Convolution1D getAtrousConvolution1D() {
|
||||||
return (Convolution1DLayer) this.layer;
|
return (Convolution1D) this.layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -24,6 +24,7 @@ import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Convolution2D;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -85,7 +86,7 @@ public class KerasAtrousConvolution2D extends KerasConvolution {
|
||||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||||
|
|
||||||
val builder = ConvolutionLayer.builder().name(this.name)
|
val builder = Convolution2D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(init)
|
.weightInit(init)
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -93,7 +93,7 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
|
|
||||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||||
Convolution1DLayer.Convolution1DLayerBuilder builder = Convolution1DLayer.builder().name(this.name)
|
var builder = Convolution1D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(init)
|
.weightInit(init)
|
||||||
|
@ -125,9 +125,9 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
|
|
||||||
this.layer = builder.build();
|
this.layer = builder.build();
|
||||||
//set this in order to infer the dimensional format
|
//set this in order to infer the dimensional format
|
||||||
Convolution1DLayer convolution1DLayer = (Convolution1DLayer) this.layer;
|
Convolution1D convolution1D = (Convolution1D) this.layer;
|
||||||
convolution1DLayer.setDataFormat(dimOrder == DimOrder.TENSORFLOW ? CNN2DFormat.NHWC : CNN2DFormat.NCHW);
|
convolution1D.setDataFormat(dimOrder == DimOrder.TENSORFLOW ? CNN2DFormat.NHWC : CNN2DFormat.NCHW);
|
||||||
convolution1DLayer.setDefaultValueOverriden(true);
|
convolution1D.setDefaultValueOverriden(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -135,8 +135,8 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
*
|
*
|
||||||
* @return ConvolutionLayer
|
* @return ConvolutionLayer
|
||||||
*/
|
*/
|
||||||
public Convolution1DLayer getConvolution1DLayer() {
|
public Convolution1D getConvolution1DLayer() {
|
||||||
return (Convolution1DLayer) this.layer;
|
return (Convolution1D) this.layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Convolution2D;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -95,7 +96,7 @@ public class KerasConvolution2D extends KerasConvolution {
|
||||||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||||
|
|
||||||
final var builder = ConvolutionLayer.builder().name(this.name)
|
final var builder = Convolution2D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(init)
|
.weightInit(init)
|
||||||
|
|
|
@ -23,6 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||||
|
@ -41,8 +42,8 @@ public class JsonTest extends BaseDL4JTest {
|
||||||
|
|
||||||
};
|
};
|
||||||
for(InputPreProcessor p : pp ){
|
for(InputPreProcessor p : pp ){
|
||||||
String s = NeuralNetConfiguration.mapper().writeValueAsString(p);
|
String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(p);
|
||||||
InputPreProcessor p2 = NeuralNetConfiguration.mapper().readValue(s, InputPreProcessor.class);
|
InputPreProcessor p2 = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, InputPreProcessor.class);
|
||||||
assertEquals(p, p2);
|
assertEquals(p, p2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,11 +29,8 @@ import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
||||||
|
@ -656,7 +653,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
|
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
|
||||||
true, true, false, null, null);
|
true, true, false, null, null);
|
||||||
Layer l = net.getLayer(0);
|
Layer l = net.getLayer(0);
|
||||||
Convolution1DLayer c1d = (Convolution1DLayer) l.getTrainingConfig();
|
Convolution1D c1d = (Convolution1D) l.getTrainingConfig();
|
||||||
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
|
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||||
|
@ -97,7 +97,7 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest {
|
||||||
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
||||||
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
||||||
|
|
||||||
Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
|
Convolution1D layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getName());
|
assertEquals(LAYER_NAME, layer.getName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInit());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
|
|
|
@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||||
|
@ -119,7 +119,7 @@ public class KerasConvolution1DTest extends BaseDL4JTest {
|
||||||
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
||||||
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
||||||
|
|
||||||
Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
|
Convolution1D layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getName());
|
assertEquals(LAYER_NAME, layer.getName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInit());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
|
|
|
@ -22,8 +22,6 @@
|
||||||
package net.brutex.ai.dnn.api;
|
package net.brutex.ai.dnn.api;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
|
|
||||||
public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
|
public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
|
||||||
|
|
||||||
|
|
|
@ -31,9 +31,11 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
public class NN {
|
public class NN {
|
||||||
|
|
||||||
|
|
||||||
public static NeuralNetConfigurationBuilder<?, ?> net() {
|
public static NeuralNetConfigurationBuilder<?, ?> nn() {
|
||||||
return NeuralNetConfiguration.builder();
|
return NeuralNetConfiguration.builder();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static DenseLayer.DenseLayerBuilder<?,?> dense() { return DenseLayer.builder(); }
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,6 @@ package net.brutex.ai.dnn.networks;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
@ -33,7 +32,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Artificial Neural Network An artificial neural network (1) takes some input data, and (2)
|
* Artificial Neural Network An artificial neural network (1) takes some input data, and (2)
|
||||||
* transforms this input data by calculating a weighted sum over the inputs and (3) applies a
|
* transforms this input data by calculating a weighted sum over the inputs and (3) applies a
|
||||||
|
|
|
@ -20,6 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping;
|
package org.deeplearning4j.earlystopping;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
@ -30,11 +34,6 @@ import org.deeplearning4j.earlystopping.termination.IterationTerminationConditio
|
||||||
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
||||||
import org.nd4j.common.function.Supplier;
|
import org.nd4j.common.function.Supplier;
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class EarlyStoppingConfiguration<T extends IModel> implements Serializable {
|
public class EarlyStoppingConfiguration<T extends IModel> implements Serializable {
|
||||||
|
|
|
@ -20,16 +20,15 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping;
|
package org.deeplearning4j.earlystopping;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
|
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
|
||||||
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
|
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
|
||||||
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
|
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = InMemoryModelSaver.class, name = "InMemoryModelSaver"),
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = InMemoryModelSaver.class, name = "InMemoryModelSaver"),
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping;
|
package org.deeplearning4j.earlystopping;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import lombok.Data;
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class EarlyStoppingResult<T extends IModel> implements Serializable {
|
public class EarlyStoppingResult<T extends IModel> implements Serializable {
|
||||||
|
|
|
@ -20,10 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.saver;
|
package org.deeplearning4j.earlystopping.saver;
|
||||||
|
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
||||||
|
|
||||||
public class InMemoryModelSaver<T extends IModel> implements EarlyStoppingModelSaver<T> {
|
public class InMemoryModelSaver<T extends IModel> implements EarlyStoppingModelSaver<T> {
|
||||||
|
|
||||||
|
|
|
@ -20,15 +20,14 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.saver;
|
package org.deeplearning4j.earlystopping.saver;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.Charset;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.charset.Charset;
|
|
||||||
|
|
||||||
public class LocalFileGraphSaver implements EarlyStoppingModelSaver<ComputationGraph> {
|
public class LocalFileGraphSaver implements EarlyStoppingModelSaver<ComputationGraph> {
|
||||||
|
|
||||||
private static final String BEST_GRAPH_BIN = "bestGraph.bin";
|
private static final String BEST_GRAPH_BIN = "bestGraph.bin";
|
||||||
|
|
|
@ -20,15 +20,14 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.saver;
|
package org.deeplearning4j.earlystopping.saver;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.Charset;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.charset.Charset;
|
|
||||||
|
|
||||||
public class LocalFileModelSaver implements EarlyStoppingModelSaver<MultiLayerNetwork> {
|
public class LocalFileModelSaver implements EarlyStoppingModelSaver<MultiLayerNetwork> {
|
||||||
|
|
||||||
private static final String BEST_MODEL_BIN = "bestModel.bin";
|
private static final String BEST_MODEL_BIN = "bestModel.bin";
|
||||||
|
|
|
@ -26,11 +26,11 @@ import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
|
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
||||||
|
|
||||||
public class AutoencoderScoreCalculator extends BaseScoreCalculator<IModel> {
|
public class AutoencoderScoreCalculator extends BaseScoreCalculator<IModel> {
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc;
|
package org.deeplearning4j.earlystopping.scorecalc;
|
||||||
|
|
||||||
import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -29,7 +30,6 @@ import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
public class DataSetLossCalculator extends BaseScoreCalculator<IModel> {
|
public class DataSetLossCalculator extends BaseScoreCalculator<IModel> {
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc;
|
package org.deeplearning4j.earlystopping.scorecalc;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
@ -27,8 +29,6 @@ import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Deprecated
|
@Deprecated
|
||||||
|
|
|
@ -20,12 +20,11 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc;
|
package org.deeplearning4j.earlystopping.scorecalc;
|
||||||
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
|
|
@ -26,11 +26,11 @@ import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
||||||
|
|
||||||
public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<IModel> {
|
public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<IModel> {
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc.base;
|
package org.deeplearning4j.earlystopping.scorecalc.base;
|
||||||
|
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||||
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
package org.deeplearning4j.earlystopping.scorecalc.base;
|
package org.deeplearning4j.earlystopping.scorecalc.base;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class BestScoreEpochTerminationCondition implements EpochTerminationCondition {
|
public class BestScoreEpochTerminationCondition implements EpochTerminationCondition {
|
||||||
|
|
|
@ -22,9 +22,7 @@ package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
|
|
@ -22,7 +22,6 @@ package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class MaxScoreIterationTerminationCondition implements IterationTerminationCondition {
|
public class MaxScoreIterationTerminationCondition implements IterationTerminationCondition {
|
||||||
|
|
|
@ -20,10 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
/**Terminate training based on max time.
|
/**Terminate training based on max time.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,6 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.trainer;
|
package org.deeplearning4j.earlystopping.trainer;
|
||||||
|
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
|
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
|
||||||
|
@ -40,13 +46,6 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.FileNotFoundException;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public abstract class BaseEarlyStoppingTrainer<T extends IModel> implements IEarlyStoppingTrainer<T> {
|
public abstract class BaseEarlyStoppingTrainer<T extends IModel> implements IEarlyStoppingTrainer<T> {
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class);
|
private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class);
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.trainer;
|
package org.deeplearning4j.earlystopping.trainer;
|
||||||
|
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
|
||||||
import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
||||||
|
|
|
@ -20,6 +20,13 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonAutoDetect;
|
||||||
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.MapperFeature;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.SerializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.module.SimpleModule;
|
||||||
|
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.common.primitives.AtomicBoolean;
|
import org.nd4j.common.primitives.AtomicBoolean;
|
||||||
|
@ -28,14 +35,6 @@ import org.nd4j.common.primitives.serde.JsonDeserializerAtomicBoolean;
|
||||||
import org.nd4j.common.primitives.serde.JsonDeserializerAtomicDouble;
|
import org.nd4j.common.primitives.serde.JsonDeserializerAtomicDouble;
|
||||||
import org.nd4j.common.primitives.serde.JsonSerializerAtomicBoolean;
|
import org.nd4j.common.primitives.serde.JsonSerializerAtomicBoolean;
|
||||||
import org.nd4j.common.primitives.serde.JsonSerializerAtomicDouble;
|
import org.nd4j.common.primitives.serde.JsonSerializerAtomicDouble;
|
||||||
import com.fasterxml.jackson.annotation.JsonAutoDetect;
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
||||||
import com.fasterxml.jackson.databind.DeserializationFeature;
|
|
||||||
import com.fasterxml.jackson.databind.MapperFeature;
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import com.fasterxml.jackson.databind.SerializationFeature;
|
|
||||||
import com.fasterxml.jackson.databind.module.SimpleModule;
|
|
||||||
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
|
|
@ -20,15 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
import com.google.common.collect.HashMultiset;
|
|
||||||
import com.google.common.collect.Multiset;
|
|
||||||
import lombok.Getter;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public class ConfusionMatrix<T extends Comparable<? super T>> extends org.nd4j.evaluation.classification.ConfusionMatrix<T> {
|
public class ConfusionMatrix<T extends Comparable<? super T>> extends org.nd4j.evaluation.classification.ConfusionMatrix<T> {
|
||||||
|
|
|
@ -20,14 +20,11 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import lombok.NonNull;
|
|
||||||
import org.nd4j.evaluation.EvaluationAveraging;
|
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Deprecated
|
@Deprecated
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Getter
|
@Getter
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.evaluation.curves.BaseHistogram;
|
import org.nd4j.evaluation.curves.BaseHistogram;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,13 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
import lombok.NonNull;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.NonNull;
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public class ReliabilityDiagram extends org.nd4j.evaluation.curves.ReliabilityDiagram {
|
public class ReliabilityDiagram extends org.nd4j.evaluation.curves.ReliabilityDiagram {
|
||||||
|
|
|
@ -20,10 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.meta;
|
package org.deeplearning4j.eval.meta;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.adapters;
|
package org.deeplearning4j.nn.adapters;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
@ -32,8 +33,6 @@ import org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
|
|
||||||
|
|
|
@ -20,14 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
public interface Classifier extends IModel {
|
public interface Classifier extends IModel {
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,13 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface ITraininableLayerConfiguration {
|
public interface ITraininableLayerConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
|
||||||
import java.util.Map;
|
import java.io.Serializable;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.nn.conf.CacheMode;
|
import org.deeplearning4j.nn.conf.CacheMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -29,10 +29,8 @@ import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.LayerHelper;
|
import org.deeplearning4j.nn.layers.LayerHelper;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A layer is the highest-level building block in deep learning. A layer is a container that usually
|
* A layer is the highest-level building block in deep learning. A layer is a container that usually
|
||||||
|
|
|
@ -20,13 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Param initializer for a layer
|
* Param initializer for a layer
|
||||||
*
|
*
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update the model
|
* Update the model
|
||||||
|
|
|
@ -22,8 +22,8 @@ package org.deeplearning4j.nn.api.layers;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.Classifier;
|
import org.deeplearning4j.nn.api.Classifier;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public interface IOutputLayer extends Layer, Classifier {
|
public interface IOutputLayer extends Layer, Classifier {
|
||||||
|
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api.layers;
|
package org.deeplearning4j.nn.api.layers;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
public interface LayerConstraint extends Cloneable, Serializable {
|
public interface LayerConstraint extends Cloneable, Serializable {
|
||||||
|
|
|
@ -20,13 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api.layers;
|
package org.deeplearning4j.nn.api.layers;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
import java.util.Map;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public interface RecurrentLayer extends Layer {
|
public interface RecurrentLayer extends Layer {
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.*;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
|
@ -34,6 +40,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
@ -42,16 +49,9 @@ import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"})
|
@EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"})
|
||||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
@ -110,7 +110,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
* @return YAML representation of configuration
|
* @return YAML representation of configuration
|
||||||
*/
|
*/
|
||||||
public String toYaml() {
|
public String toYaml() {
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
|
||||||
synchronized (mapper) {
|
synchronized (mapper) {
|
||||||
try {
|
try {
|
||||||
return mapper.writeValueAsString(this);
|
return mapper.writeValueAsString(this);
|
||||||
|
@ -127,7 +127,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
* @return {@link ComputationGraphConfiguration}
|
* @return {@link ComputationGraphConfiguration}
|
||||||
*/
|
*/
|
||||||
public static ComputationGraphConfiguration fromYaml(String json) {
|
public static ComputationGraphConfiguration fromYaml(String json) {
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
|
||||||
try {
|
try {
|
||||||
return mapper.readValue(json, ComputationGraphConfiguration.class);
|
return mapper.readValue(json, ComputationGraphConfiguration.class);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
|
@ -140,7 +140,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
*/
|
*/
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
//As per NeuralNetConfiguration.toJson()
|
//As per NeuralNetConfiguration.toJson()
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
ObjectMapper mapper =CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
synchronized (mapper) {
|
synchronized (mapper) {
|
||||||
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
|
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
|
||||||
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
|
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
|
||||||
|
@ -160,7 +160,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
*/
|
*/
|
||||||
public static ComputationGraphConfiguration fromJson(String json) {
|
public static ComputationGraphConfiguration fromJson(String json) {
|
||||||
//As per NeuralNetConfiguration.fromJson()
|
//As per NeuralNetConfiguration.fromJson()
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
ComputationGraphConfiguration conf;
|
ComputationGraphConfiguration conf;
|
||||||
try {
|
try {
|
||||||
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
|
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
|
||||||
|
|
|
@ -19,10 +19,10 @@
|
||||||
*/
|
*/
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
|
|
||||||
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
|
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
|
||||||
|
|
||||||
@JsonSerialize(using = DataFormatSerializer.class)
|
@JsonSerialize(using = DataFormatSerializer.class)
|
||||||
@JsonDeserialize(using = DataFormatDeserializer.class)
|
@JsonDeserialize(using = DataFormatDeserializer.class)
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue