Compare commits
	
		
			20 Commits
		
	
	
		
			master
			...
			enhance-bu
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| dd151aec3f | |||
| 1b3338f809 | |||
| 3d949c5348 | |||
| 6930116c18 | |||
| e27fb8422f | |||
| d0342fc939 | |||
| b34b96d929 | |||
| 8f51471a31 | |||
| dc5de40620 | |||
| e834407b6e | |||
| 4dc5a116b6 | |||
| 997143b9dd | |||
| 0bed17c97f | |||
| 8d73a7a410 | |||
| c758cf918f | |||
| 2c8c6d9624 | |||
| 0ba049885f | |||
| 345f55a003 | |||
| 1c39dbee52 | |||
| ea504bff41 | 
@ -1,4 +1,4 @@
 | 
				
			|||||||
FROM nvidia/cuda:11.4.0-cudnn8-devel-ubuntu20.04
 | 
					FROM nvidia/cuda:11.4.3-cudnn8-devel-ubuntu20.04
 | 
				
			||||||
 | 
					
 | 
				
			||||||
RUN apt-get update &&  \
 | 
					RUN apt-get update &&  \
 | 
				
			||||||
    DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
 | 
					    DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										13
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										13
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -36,6 +36,8 @@ pom.xml.versionsBackup
 | 
				
			|||||||
pom.xml.next
 | 
					pom.xml.next
 | 
				
			||||||
release.properties
 | 
					release.properties
 | 
				
			||||||
*dependency-reduced-pom.xml
 | 
					*dependency-reduced-pom.xml
 | 
				
			||||||
 | 
					**/build/*
 | 
				
			||||||
 | 
					.gradle/*
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Specific for Nd4j
 | 
					# Specific for Nd4j
 | 
				
			||||||
*.md5
 | 
					*.md5
 | 
				
			||||||
@ -83,3 +85,14 @@ bruai4j-native-common/cmake*
 | 
				
			|||||||
/bruai4j-native/bruai4j-native-common/blasbuild/
 | 
					/bruai4j-native/bruai4j-native-common/blasbuild/
 | 
				
			||||||
/bruai4j-native/bruai4j-native-common/build/
 | 
					/bruai4j-native/bruai4j-native-common/build/
 | 
				
			||||||
/cavis-native/cavis-native-lib/blasbuild/
 | 
					/cavis-native/cavis-native-lib/blasbuild/
 | 
				
			||||||
 | 
					/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/classes/org.deeplearning4j.gradientcheck.AttentionLayerTest.html
 | 
				
			||||||
 | 
					/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/base-style.css
 | 
				
			||||||
 | 
					/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/style.css
 | 
				
			||||||
 | 
					/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/js/report.js
 | 
				
			||||||
 | 
					/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/packages/org.deeplearning4j.gradientcheck.html
 | 
				
			||||||
 | 
					/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/index.html
 | 
				
			||||||
 | 
					/cavis-dnn/cavis-dnn-core/build/resources/main/iris.dat
 | 
				
			||||||
 | 
					/cavis-dnn/cavis-dnn-core/build/resources/test/junit-platform.properties
 | 
				
			||||||
 | 
					/cavis-dnn/cavis-dnn-core/build/resources/test/logback-test.xml
 | 
				
			||||||
 | 
					/cavis-dnn/cavis-dnn-core/build/test-results/cudaTest/TEST-org.deeplearning4j.gradientcheck.AttentionLayerTest.xml
 | 
				
			||||||
 | 
					/cavis-dnn/cavis-dnn-core/build/tmp/jar/MANIFEST.MF
 | 
				
			||||||
 | 
				
			|||||||
@ -35,7 +35,7 @@ pipeline {
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
        stage('build-linux-cpu') {
 | 
					        stage('build-linux-cpu') {
 | 
				
			||||||
            environment {
 | 
					            environment {
 | 
				
			||||||
                        MAVEN = credentials('Internal Archiva')
 | 
					                        MAVEN = credentials('Internal_Archiva')
 | 
				
			||||||
                        OSSRH = credentials('OSSRH')
 | 
					                        OSSRH = credentials('OSSRH')
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -21,13 +21,15 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
pipeline {
 | 
					pipeline {
 | 
				
			||||||
    agent {
 | 
					    agent {
 | 
				
			||||||
        dockerfile {
 | 
					      /*  dockerfile {
 | 
				
			||||||
            filename 'Dockerfile'
 | 
					            filename 'Dockerfile'
 | 
				
			||||||
            dir '.docker'
 | 
					            dir '.docker'
 | 
				
			||||||
            label 'linux && cuda'
 | 
					            label 'linux && cuda'
 | 
				
			||||||
            //additionalBuildArgs  '--build-arg version=1.0.2'
 | 
					            //additionalBuildArgs  '--build-arg version=1.0.2'
 | 
				
			||||||
            //args '--gpus all' --needed for test only, you can build without GPU
 | 
					            //args '--gpus all' --needed for test only, you can build without GPU
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					        */
 | 
				
			||||||
 | 
					       label 'linux && cuda'
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    stages {
 | 
					    stages {
 | 
				
			||||||
@ -43,13 +45,13 @@ pipeline {
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
        stage('build-linux-cuda') {
 | 
					        stage('build-linux-cuda') {
 | 
				
			||||||
            environment {
 | 
					            environment {
 | 
				
			||||||
                        MAVEN = credentials('Internal Archiva')
 | 
					                        MAVEN = credentials('Internal_Archiva')
 | 
				
			||||||
                        OSSRH = credentials('OSSRH')
 | 
					                        OSSRH = credentials('OSSRH')
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            steps {
 | 
					            steps {
 | 
				
			||||||
                withGradle {
 | 
					                withGradle {
 | 
				
			||||||
                    sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \
 | 
					                    sh 'sh ./gradlew build --stacktrace  -PCAVIS_CHIP=cuda \
 | 
				
			||||||
                                -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
 | 
					                                -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
 | 
				
			||||||
                                -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
 | 
					                                -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
				
			|||||||
@ -47,7 +47,7 @@ pipeline {
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
                stage('build-linux-cuda') {
 | 
					                stage('build-linux-cuda') {
 | 
				
			||||||
                    environment {
 | 
					                    environment {
 | 
				
			||||||
                        MAVEN = credentials('Internal Archiva')
 | 
					                        MAVEN = credentials('Internal_Archiva')
 | 
				
			||||||
                        OSSRH = credentials('OSSRH')
 | 
					                        OSSRH = credentials('OSSRH')
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -41,7 +41,7 @@ pipeline {
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
        stage('build-linux-cpu') {
 | 
					        stage('build-linux-cpu') {
 | 
				
			||||||
            environment {
 | 
					            environment {
 | 
				
			||||||
                        MAVEN = credentials('Internal Archiva')
 | 
					                        MAVEN = credentials('Internal_Archiva')
 | 
				
			||||||
                        OSSRH = credentials('OSSRH')
 | 
					                        OSSRH = credentials('OSSRH')
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -33,7 +33,7 @@ pipeline {
 | 
				
			|||||||
    stages {
 | 
					    stages {
 | 
				
			||||||
        stage('publish-linux-cpu') {
 | 
					        stage('publish-linux-cpu') {
 | 
				
			||||||
            environment {
 | 
					            environment {
 | 
				
			||||||
                MAVEN = credentials('Internal Archiva')
 | 
					                MAVEN = credentials('Internal_Archiva')
 | 
				
			||||||
                OSSRH = credentials('OSSRH')
 | 
					                OSSRH = credentials('OSSRH')
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -43,7 +43,7 @@ pipeline {
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
        stage('build-linux-cuda') {
 | 
					        stage('build-linux-cuda') {
 | 
				
			||||||
            environment {
 | 
					            environment {
 | 
				
			||||||
                        MAVEN = credentials('Internal Archiva')
 | 
					                        MAVEN = credentials('Internal_Archiva')
 | 
				
			||||||
                        OSSRH = credentials('OSSRH')
 | 
					                        OSSRH = credentials('OSSRH')
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -56,5 +56,20 @@ pipeline {
 | 
				
			|||||||
                //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
 | 
					                //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					        stage('test-linux-cuda') {
 | 
				
			||||||
 | 
					            environment {
 | 
				
			||||||
 | 
					                MAVEN = credentials('Internal_Archiva')
 | 
				
			||||||
 | 
					                OSSRH = credentials('OSSRH')
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            steps {
 | 
				
			||||||
 | 
					                withGradle {
 | 
				
			||||||
 | 
					                    sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -Pskip-native=true -PCAVIS_CHIP=cuda \
 | 
				
			||||||
 | 
					                                -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
 | 
				
			||||||
 | 
					                                -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -41,7 +41,7 @@ pipeline {
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
        stage('build-linux-cpu') {
 | 
					        stage('build-linux-cpu') {
 | 
				
			||||||
            environment {
 | 
					            environment {
 | 
				
			||||||
                        MAVEN = credentials('Internal Archiva')
 | 
					                        MAVEN = credentials('Internal_Archiva')
 | 
				
			||||||
                        OSSRH = credentials('OSSRH')
 | 
					                        OSSRH = credentials('OSSRH')
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,167 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *    ******************************************************************************
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 *    * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 *    * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    *  See the NOTICE file distributed with this work for additional
 | 
				
			||||||
 | 
					 *    *  information regarding copyright ownership.
 | 
				
			||||||
 | 
					 *    * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 *    * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 *    * under the License.
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 *    *****************************************************************************
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package net.brutex.ai.nd4j.tests;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import java.util.ArrayList;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					import lombok.extern.slf4j.Slf4j;
 | 
				
			||||||
 | 
					import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
 | 
				
			||||||
 | 
					import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
 | 
				
			||||||
 | 
					import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
 | 
					import org.nd4j.common.primitives.Pair;
 | 
				
			||||||
 | 
					import org.nd4j.evaluation.classification.Evaluation;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.buffer.DataType;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.factory.Nd4j;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.lossfunctions.LossFunctions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Slf4j
 | 
				
			||||||
 | 
					public class ExploreParamsTest {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void testParam() {
 | 
				
			||||||
 | 
					    NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					        NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					            .seed(12345)
 | 
				
			||||||
 | 
					                .dataType(DataType.DOUBLE)
 | 
				
			||||||
 | 
					            .layer(
 | 
				
			||||||
 | 
					                DenseLayer.builder().nIn(4).nOut(30).name("1. Dense").activation(Activation.TANH))
 | 
				
			||||||
 | 
					           .layer(DenseLayer.builder().nIn(30).nOut(10).name("2. Dense"))
 | 
				
			||||||
 | 
					            //  .layer(FrozenLayer.builder(DenseLayer.builder().nOut(6).build()).build())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            .layer(
 | 
				
			||||||
 | 
					                OutputLayer.builder()
 | 
				
			||||||
 | 
					                    .nOut(3)
 | 
				
			||||||
 | 
					                    .lossFunction(LossFunctions.LossFunction.MSE)
 | 
				
			||||||
 | 
					                    .activation(Activation.SOFTMAX))
 | 
				
			||||||
 | 
					            .build();
 | 
				
			||||||
 | 
					    MultiLayerNetwork nn = new MultiLayerNetwork(conf);
 | 
				
			||||||
 | 
					    nn.init();
 | 
				
			||||||
 | 
					    log.info(nn.summary());
 | 
				
			||||||
 | 
					    // INDArray input = Nd4j.rand(10,4);
 | 
				
			||||||
 | 
					    INDArray labels = Nd4j.zeros(9, 3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    INDArray input =
 | 
				
			||||||
 | 
					        Nd4j.create(
 | 
				
			||||||
 | 
					            new double[][] {
 | 
				
			||||||
 | 
					              {5.15, 3.5, 1.4, 0.21},     // setosa
 | 
				
			||||||
 | 
					              {4.9, 3.2, 1.4, 0.2},       // setosa
 | 
				
			||||||
 | 
					              {4.7, 3.2, 1.23, 0.2},      // setosa
 | 
				
			||||||
 | 
					              {7, 3.25, 4.7, 1.41},       // versicolor
 | 
				
			||||||
 | 
					              {6.4, 3.2, 4.54, 1.5},      // versicolor
 | 
				
			||||||
 | 
					              {6.9, 3.1, 4.92, 1.5},      // versicolor
 | 
				
			||||||
 | 
					              {7.7, 3, 6.1, 2.3},         // virginica
 | 
				
			||||||
 | 
					              {6.3, 3.4, 5.6, 2.45},      // virginica
 | 
				
			||||||
 | 
					              {6.4, 3.12, 5.5, 1.8}       // virginica
 | 
				
			||||||
 | 
					            });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    labels.putScalar(0, 1);
 | 
				
			||||||
 | 
					    labels.putScalar(3, 1);
 | 
				
			||||||
 | 
					    labels.putScalar(6, 1);
 | 
				
			||||||
 | 
					    labels.putScalar(10, 1);
 | 
				
			||||||
 | 
					    labels.putScalar(13, 1);
 | 
				
			||||||
 | 
					    labels.putScalar(16, 1);
 | 
				
			||||||
 | 
					    labels.putScalar(20, 1);
 | 
				
			||||||
 | 
					    labels.putScalar(23, 1);
 | 
				
			||||||
 | 
					    labels.putScalar(26, 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    IrisDataSetIterator iter = new IrisDataSetIterator();
 | 
				
			||||||
 | 
					    //Iterable<Pair<INDArray, INDArray>> it = List.of(new Pair<INDArray, INDArray>(input, labels));
 | 
				
			||||||
 | 
					    List l = new ArrayList<>();
 | 
				
			||||||
 | 
					    for (int i=0; i< input.rows(); i++) {
 | 
				
			||||||
 | 
					      l.add(new Pair(input.getRow(i), labels.getRow(i)));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    Iterable<Pair<INDArray, INDArray>> it = l;
 | 
				
			||||||
 | 
					    INDArrayDataSetIterator diter = new INDArrayDataSetIterator(it, 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int i = 0; i < 100; i++) {
 | 
				
			||||||
 | 
					      // nn.fit(input, labels);
 | 
				
			||||||
 | 
					      // nn.fit( input, labels);
 | 
				
			||||||
 | 
					      nn.fit(diter);
 | 
				
			||||||
 | 
					      // nn.feedForward(input);
 | 
				
			||||||
 | 
					      if(i%20==0) log.info("Score: {}", nn.getScore());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Evaluation eval = nn.evaluate(iter, List.of("setosa", "vericolor", "virginica"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    log.info("\n{}", eval.stats());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void testParam2() throws IOException {
 | 
				
			||||||
 | 
					    NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					            NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                    .seed(12345)
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                            DenseLayer.builder().nIn(784).nOut(20).name("1. Dense"))
 | 
				
			||||||
 | 
					                    .layer(DenseLayer.builder().nIn(20).nOut(10).name("2. Dense"))
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                            OutputLayer.builder()
 | 
				
			||||||
 | 
					                                    .nOut(10)
 | 
				
			||||||
 | 
					                                    .lossFunction(LossFunctions.LossFunction.MSE)
 | 
				
			||||||
 | 
					                                    .activation(Activation.SOFTMAX))
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					    MultiLayerNetwork nn = new MultiLayerNetwork(conf);
 | 
				
			||||||
 | 
					    nn.init();
 | 
				
			||||||
 | 
					    log.info(nn.summary());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    NeuralNetConfiguration conf2 =
 | 
				
			||||||
 | 
					            NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                    .seed(12345)
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                            DenseLayer.builder().nIn(784).nOut(20).name("1. Dense").dropOut(0.7))
 | 
				
			||||||
 | 
					                    .layer(DenseLayer.builder().nIn(20).nOut(10).name("2. Dense"))
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                            OutputLayer.builder()
 | 
				
			||||||
 | 
					                                    .nOut(10)
 | 
				
			||||||
 | 
					                                    .lossFunction(LossFunctions.LossFunction.MSE)
 | 
				
			||||||
 | 
					                                    .activation(Activation.SOFTMAX))
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					    MultiLayerNetwork nn2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
 | 
					    nn2.init();
 | 
				
			||||||
 | 
					    log.info(nn2.summary());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    MnistDataSetIterator iter = new MnistDataSetIterator(10, 500);
 | 
				
			||||||
 | 
					    MnistDataSetIterator iter2 = new MnistDataSetIterator(10, 50);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int i = 0; i < 200; i++) {
 | 
				
			||||||
 | 
					      nn.fit(iter);
 | 
				
			||||||
 | 
					      nn2.fit(iter);
 | 
				
			||||||
 | 
					      if(i%20==0) log.info("Score: {} vs. {}", nn.getScore(), nn2.getScore());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Evaluation eval = nn.evaluate(iter2);
 | 
				
			||||||
 | 
					    Evaluation eval2 = nn2.evaluate(iter2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    log.info("\n{} \n{}", eval.stats(), eval2.stats());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -1,110 +1,48 @@
 | 
				
			|||||||
/*
 | 
					 | 
				
			||||||
 *
 | 
					 | 
				
			||||||
 *    ******************************************************************************
 | 
					 | 
				
			||||||
 *    *
 | 
					 | 
				
			||||||
 *    * This program and the accompanying materials are made available under the
 | 
					 | 
				
			||||||
 *    * terms of the Apache License, Version 2.0 which is available at
 | 
					 | 
				
			||||||
 *    * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					 | 
				
			||||||
 *    *
 | 
					 | 
				
			||||||
 *    *  See the NOTICE file distributed with this work for additional
 | 
					 | 
				
			||||||
 *    *  information regarding copyright ownership.
 | 
					 | 
				
			||||||
 *    * Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
					 | 
				
			||||||
 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
					 | 
				
			||||||
 *    * License for the specific language governing permissions and limitations
 | 
					 | 
				
			||||||
 *    * under the License.
 | 
					 | 
				
			||||||
 *    *
 | 
					 | 
				
			||||||
 *    * SPDX-License-Identifier: Apache-2.0
 | 
					 | 
				
			||||||
 *    *****************************************************************************
 | 
					 | 
				
			||||||
 *
 | 
					 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
package net.brutex.gan;
 | 
					package net.brutex.gan;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.awt.BorderLayout;
 | 
					import static net.brutex.ai.dnn.api.NN.dense;
 | 
				
			||||||
import java.awt.Dimension;
 | 
					
 | 
				
			||||||
import java.awt.GridLayout;
 | 
					import java.awt.*;
 | 
				
			||||||
import java.awt.Image;
 | 
					 | 
				
			||||||
import java.awt.image.BufferedImage;
 | 
					import java.awt.image.BufferedImage;
 | 
				
			||||||
import java.io.File;
 | 
					import java.io.File;
 | 
				
			||||||
import java.util.Arrays;
 | 
					import java.util.Arrays;
 | 
				
			||||||
import java.util.Random;
 | 
					import javax.swing.*;
 | 
				
			||||||
import javax.swing.ImageIcon;
 | 
					 | 
				
			||||||
import javax.swing.JFrame;
 | 
					 | 
				
			||||||
import javax.swing.JLabel;
 | 
					 | 
				
			||||||
import javax.swing.JPanel;
 | 
					 | 
				
			||||||
import javax.swing.WindowConstants;
 | 
					 | 
				
			||||||
import lombok.extern.slf4j.Slf4j;
 | 
					 | 
				
			||||||
import org.apache.commons.lang3.ArrayUtils;
 | 
					import org.apache.commons.lang3.ArrayUtils;
 | 
				
			||||||
import org.datavec.api.split.FileSplit;
 | 
					import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
 | 
				
			||||||
import org.datavec.image.loader.NativeImageLoader;
 | 
					 | 
				
			||||||
import org.datavec.image.recordreader.ImageRecordReader;
 | 
					 | 
				
			||||||
import org.datavec.image.transform.ColorConversionTransform;
 | 
					 | 
				
			||||||
import org.datavec.image.transform.ImageTransform;
 | 
					 | 
				
			||||||
import org.datavec.image.transform.PipelineImageTransform;
 | 
					 | 
				
			||||||
import org.datavec.image.transform.ResizeImageTransform;
 | 
					 | 
				
			||||||
import org.datavec.image.transform.ShowImageTransform;
 | 
					 | 
				
			||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
					import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
					import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
					import org.deeplearning4j.nn.weights.WeightInit;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
 | 
					 | 
				
			||||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
 | 
					import org.deeplearning4j.optimize.listeners.PerformanceListener;
 | 
				
			||||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
 | 
					 | 
				
			||||||
import org.junit.jupiter.api.Test;
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
import org.nd4j.linalg.activations.Activation;
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
					import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
import org.nd4j.linalg.dataset.DataSet;
 | 
					import org.nd4j.linalg.dataset.DataSet;
 | 
				
			||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.factory.Nd4j;
 | 
					import org.nd4j.linalg.factory.Nd4j;
 | 
				
			||||||
import org.nd4j.linalg.learning.config.Adam;
 | 
					import org.nd4j.linalg.learning.config.Adam;
 | 
				
			||||||
import org.nd4j.linalg.learning.config.IUpdater;
 | 
					import org.nd4j.linalg.learning.config.IUpdater;
 | 
				
			||||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
 | 
					import org.nd4j.linalg.lossfunctions.LossFunctions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Slf4j
 | 
					 | 
				
			||||||
public class App {
 | 
					public class App {
 | 
				
			||||||
  private static final double LEARNING_RATE = 0.000002;
 | 
					    private static final double LEARNING_RATE = 0.002;
 | 
				
			||||||
    private static final double GRADIENT_THRESHOLD = 100.0;
 | 
					    private static final double GRADIENT_THRESHOLD = 100.0;
 | 
				
			||||||
 | 
					 | 
				
			||||||
  private static final int X_DIM = 20 ;
 | 
					 | 
				
			||||||
  private static final int Y_DIM = 20;
 | 
					 | 
				
			||||||
  private static final int CHANNELS = 1;
 | 
					 | 
				
			||||||
  private static final int batchSize = 10;
 | 
					 | 
				
			||||||
  private static final int INPUT = 128;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  private static final int OUTPUT_PER_PANEL = 4;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS;
 | 
					 | 
				
			||||||
    private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
 | 
					    private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
 | 
				
			||||||
 | 
					    private static final int BATCHSIZE = 128;
 | 
				
			||||||
    private static JFrame frame;
 | 
					    private static JFrame frame;
 | 
				
			||||||
  private static  JFrame frame2;
 | 
					 | 
				
			||||||
    private static JPanel panel;
 | 
					    private static JPanel panel;
 | 
				
			||||||
  private static JPanel panel2;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static LayerConfiguration[] genLayers() {
 | 
					    private static LayerConfiguration[] genLayers() {
 | 
				
			||||||
        return new LayerConfiguration[] {
 | 
					        return new LayerConfiguration[] {
 | 
				
			||||||
        DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
 | 
					                dense().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(),
 | 
				
			||||||
        ActivationLayer.builder(Activation.LEAKYRELU).build(),
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
 | 
					 | 
				
			||||||
                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
 | 
					                dense().nIn(256).nOut(512).build(),
 | 
				
			||||||
                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					                dense().nIn(512).nOut(1024).build(),
 | 
				
			||||||
        DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build()
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					                dense().nIn(1024).nOut(784).activation(Activation.TANH).build()
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -119,65 +57,51 @@ public class App {
 | 
				
			|||||||
                .updater(UPDATER)
 | 
					                .updater(UPDATER)
 | 
				
			||||||
                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
					                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
					                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
				
			||||||
        //.weightInit(WeightInit.XAVIER)
 | 
					 | 
				
			||||||
                .weightInit(WeightInit.XAVIER)
 | 
					                .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
                .activation(Activation.IDENTITY)
 | 
					                .activation(Activation.IDENTITY)
 | 
				
			||||||
                .layersFromArray(genLayers())
 | 
					                .layersFromArray(genLayers())
 | 
				
			||||||
        .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
 | 
					                .name("generator")
 | 
				
			||||||
       // .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
 | 
					 | 
				
			||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
    ((NeuralNetConfiguration) conf).init();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return conf;
 | 
					        return conf;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static LayerConfiguration[] disLayers() {
 | 
					    private static LayerConfiguration[] disLayers() {
 | 
				
			||||||
        return new LayerConfiguration[]{
 | 
					        return new LayerConfiguration[]{
 | 
				
			||||||
        DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network
 | 
					                dense().nIn(784).nOut(1024).build(),
 | 
				
			||||||
                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
                DropoutLayer.builder(1 - 0.5).build(),
 | 
					                DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
        DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
 | 
					                dense().nIn(1024).nOut(512).build(),
 | 
				
			||||||
                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
                DropoutLayer.builder(1 - 0.5).build(),
 | 
					                DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
        DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
 | 
					                dense().nIn(512).nOut(256).build(),
 | 
				
			||||||
                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					                ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
                DropoutLayer.builder(1 - 0.5).build(),
 | 
					                DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
        DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
 | 
					                OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
 | 
				
			||||||
        ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
					 | 
				
			||||||
        DropoutLayer.builder(1 - 0.5).build(),
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        OutputLayer.builder().name("dis-output").lossFunction(LossFunction.XENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
 | 
					 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static NeuralNetConfiguration discriminator() {
 | 
					    private static NeuralNetConfiguration discriminator() {
 | 
				
			||||||
 | 
					        NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
				
			||||||
    NeuralNetConfiguration conf =
 | 
					 | 
				
			||||||
        NeuralNetConfiguration.builder()
 | 
					 | 
				
			||||||
                .seed(42)
 | 
					                .seed(42)
 | 
				
			||||||
                .updater(UPDATER)
 | 
					                .updater(UPDATER)
 | 
				
			||||||
                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
					                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
					                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
				
			||||||
                .weightInit(WeightInit.XAVIER)
 | 
					                .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
            //.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
					 | 
				
			||||||
                .weightNoise(null)
 | 
					 | 
				
			||||||
            // .weightInitFn(new WeightInitXavier())
 | 
					 | 
				
			||||||
            // .activationFn(new ActivationIdentity())
 | 
					 | 
				
			||||||
                .activation(Activation.IDENTITY)
 | 
					                .activation(Activation.IDENTITY)
 | 
				
			||||||
                .layersFromArray(disLayers())
 | 
					                .layersFromArray(disLayers())
 | 
				
			||||||
            .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
 | 
					                .name("discriminator")
 | 
				
			||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
    ((NeuralNetConfiguration) conf).init();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return conf;
 | 
					        return conf;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static NeuralNetConfiguration gan() {
 | 
					    private static NeuralNetConfiguration gan() {
 | 
				
			||||||
        LayerConfiguration[] genLayers = genLayers();
 | 
					        LayerConfiguration[] genLayers = genLayers();
 | 
				
			||||||
    LayerConfiguration[] disLayers = Arrays.stream(disLayers())
 | 
					        LayerConfiguration[] disLayers = discriminator().getFlattenedLayerConfigurations().stream()
 | 
				
			||||||
                .map((layer) -> {
 | 
					                .map((layer) -> {
 | 
				
			||||||
                    if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
 | 
					                    if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
 | 
				
			||||||
          return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
 | 
					                        return FrozenLayerWithBackprop.builder(layer).build();
 | 
				
			||||||
                    } else {
 | 
					                    } else {
 | 
				
			||||||
                        return layer;
 | 
					                        return layer;
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
@ -186,107 +110,57 @@ public class App {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
					        NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
				
			||||||
                .seed(42)
 | 
					                .seed(42)
 | 
				
			||||||
        .updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() )
 | 
					                .updater(UPDATER)
 | 
				
			||||||
        .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
 | 
					                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
        .gradientNormalizationThreshold( 100 )
 | 
					                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
 | 
				
			||||||
        //.weightInitFn( new WeightInitXavier() ) //this is internal
 | 
					                .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
            .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
					                .activation(Activation.IDENTITY)
 | 
				
			||||||
        .weightInit( WeightInit.XAVIER)
 | 
					                .layersFromArray(layers)
 | 
				
			||||||
        //.activationFn( new ActivationIdentity()) //this is internal
 | 
					                .name("GAN")
 | 
				
			||||||
        .activation( Activation.IDENTITY )
 | 
					 | 
				
			||||||
        .layersFromArray(  layers  )
 | 
					 | 
				
			||||||
        .inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
 | 
					 | 
				
			||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
((NeuralNetConfiguration) conf).init();
 | 
					
 | 
				
			||||||
        return conf;
 | 
					        return conf;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void runTest() throws Exception {
 | 
					    public void runTest() throws Exception {
 | 
				
			||||||
    main();
 | 
					        App.main(null);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    public static void main(String... args) throws Exception {
 | 
					    public static void main(String... args) throws Exception {
 | 
				
			||||||
 | 
					        Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    log.info("\u001B[32m  Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m   ");
 | 
					        MnistDataSetIterator trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
 | 
				
			||||||
    Nd4j.getMemoryManager().setAutoGcWindow(500);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
//    MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
 | 
					 | 
				
			||||||
  //  FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
 | 
					 | 
				
			||||||
    FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageTransform transform2 = new ShowImageTransform("Tester", 30);
 | 
					 | 
				
			||||||
    ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageTransform tr = new PipelineImageTransform.Builder()
 | 
					 | 
				
			||||||
        .addImageTransform(transform) //convert to GREY SCALE
 | 
					 | 
				
			||||||
        .addImageTransform(transform3)
 | 
					 | 
				
			||||||
        //.addImageTransform(transform2)
 | 
					 | 
				
			||||||
        .build();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS);
 | 
					 | 
				
			||||||
    imageRecordReader.initialize(fileSplit, tr);
 | 
					 | 
				
			||||||
    DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize );
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork gen = new MultiLayerNetwork(generator());
 | 
					        MultiLayerNetwork gen = new MultiLayerNetwork(generator());
 | 
				
			||||||
        MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
 | 
					        MultiLayerNetwork dis = new MultiLayerNetwork(discriminator());
 | 
				
			||||||
        MultiLayerNetwork gan = new MultiLayerNetwork(gan());
 | 
					        MultiLayerNetwork gan = new MultiLayerNetwork(gan());
 | 
				
			||||||
    gen.init(); log.debug("Generator network: {}", gen);
 | 
					        gen.init();
 | 
				
			||||||
    dis.init(); log.debug("Discriminator network: {}", dis);
 | 
					        dis.init();
 | 
				
			||||||
    gan.init(); log.debug("Complete GAN network: {}", gan);
 | 
					        gan.init();
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        copyParams(gen, dis, gan);
 | 
					        copyParams(gen, dis, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    gen.addTrainingListeners(new PerformanceListener(15, true));
 | 
					        gen.addTrainingListeners(new PerformanceListener(10, true));
 | 
				
			||||||
    //dis.addTrainingListeners(new PerformanceListener(10, true));
 | 
					        dis.addTrainingListeners(new PerformanceListener(10, true));
 | 
				
			||||||
    //gan.addTrainingListeners(new PerformanceListener(10, true));
 | 
					        gan.addTrainingListeners(new PerformanceListener(10, true));
 | 
				
			||||||
    //gan.addTrainingListeners(new ScoreToChartListener("gan"));
 | 
					 | 
				
			||||||
    //dis.setListeners(new ScoreToChartListener("dis"));
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    System.out.println(gan.toString());
 | 
					        trainData.reset();
 | 
				
			||||||
    gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    //gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1)));
 | 
					 | 
				
			||||||
    //trainData.reset();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int j = 0;
 | 
					        int j = 0;
 | 
				
			||||||
    for (int i = 0; i < 201; i++) { //epoch
 | 
					        for (int i = 0; i < 50; i++) {
 | 
				
			||||||
            while (trainData.hasNext()) {
 | 
					            while (trainData.hasNext()) {
 | 
				
			||||||
                j++;
 | 
					                j++;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSet next = trainData.next();
 | 
					 | 
				
			||||||
                // generate data
 | 
					                // generate data
 | 
				
			||||||
        INDArray real = next.getFeatures();//.div(255f);
 | 
					                INDArray real = trainData.next().getFeatures().muli(2).subi(1);
 | 
				
			||||||
 | 
					                int batchSize = (int) real.shape()[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //start next round if there are not enough images left to have a full batchsize dataset
 | 
					                INDArray fakeIn = Nd4j.rand(batchSize, 100);
 | 
				
			||||||
        if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) {
 | 
					 | 
				
			||||||
          log.warn("Your total number of input images is not a multiple of {}, "
 | 
					 | 
				
			||||||
              + "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE);
 | 
					 | 
				
			||||||
        break;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if(i%20 == 0) {
 | 
					 | 
				
			||||||
         // frame2 = visualize(new INDArray[]{real}, batchSize,
 | 
					 | 
				
			||||||
         //     frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
       real.divi(255f);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
//        int batchSize = (int) real.shape()[0];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
					 | 
				
			||||||
                INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
					                INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
				
			||||||
        fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //log.info("real has {} items.", real.length());
 | 
					 | 
				
			||||||
                DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
 | 
					                DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
 | 
				
			||||||
                DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
 | 
					                DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
                DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
 | 
					                DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                dis.fit(data);
 | 
					                dis.fit(data);
 | 
				
			||||||
@ -295,32 +169,26 @@ public class App {
 | 
				
			|||||||
                // Update the discriminator in the GAN network
 | 
					                // Update the discriminator in the GAN network
 | 
				
			||||||
                updateGan(gen, dis, gan);
 | 
					                updateGan(gen, dis, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1)));
 | 
					                gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));
 | 
				
			||||||
        gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)));
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if (j % 10 == 1) {
 | 
					                if (j % 10 == 1) {
 | 
				
			||||||
          System.out.println("Iteration " + j + " Visualizing...");
 | 
					                    System.out.println("Epoch " + i +" Iteration " + j + " Visualizing...");
 | 
				
			||||||
          INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize];
 | 
					                    INDArray[] samples = new INDArray[9];
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          for (int k = 0; k < samples.length; k++) {
 | 
					 | 
				
			||||||
            //INDArray input = fakeSet2.get(k).getFeatures();
 | 
					 | 
				
			||||||
                    DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
 | 
					                    DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));
 | 
				
			||||||
            INDArray input = fakeSet2.get(k).getFeatures();
 | 
					 | 
				
			||||||
            input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    for (int k = 0; k < 9; k++) {
 | 
				
			||||||
 | 
					                        INDArray input = fakeSet2.get(k).getFeatures();
 | 
				
			||||||
                        //samples[k] = gen.output(input, false);
 | 
					                        //samples[k] = gen.output(input, false);
 | 
				
			||||||
                        samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
 | 
					                        samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);
 | 
				
			||||||
            samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM);
 | 
					 | 
				
			||||||
            //samples[k] =
 | 
					 | 
				
			||||||
            samples[k].addi(1f).divi(2f).muli(255f);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
          frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
 | 
					                    visualize(samples);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            trainData.reset();
 | 
					            trainData.reset();
 | 
				
			||||||
 | 
					            // Copy the GANs generator to gen.
 | 
				
			||||||
 | 
					            //updateGen(gen, gan);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Copy the GANs generator to gen.
 | 
					        // Copy the GANs generator to gen.
 | 
				
			||||||
@ -333,10 +201,8 @@ public class App {
 | 
				
			|||||||
        int genLayerCount = gen.getLayers().length;
 | 
					        int genLayerCount = gen.getLayers().length;
 | 
				
			||||||
        for (int i = 0; i < gan.getLayers().length; i++) {
 | 
					        for (int i = 0; i < gan.getLayers().length; i++) {
 | 
				
			||||||
            if (i < genLayerCount) {
 | 
					            if (i < genLayerCount) {
 | 
				
			||||||
        if(gan.getLayer(i).getParams() != null)
 | 
					 | 
				
			||||||
                gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
					                gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
        if(gan.getLayer(i).getParams() != null)
 | 
					 | 
				
			||||||
                dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
 | 
					                dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -355,57 +221,41 @@ public class App {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
 | 
					    private static void visualize(INDArray[] samples) {
 | 
				
			||||||
    if (isOrig) {
 | 
					        if (frame == null) {
 | 
				
			||||||
      frame.setTitle("Viz Original");
 | 
					            frame = new JFrame();
 | 
				
			||||||
    } else {
 | 
					            frame.setTitle("Viz");
 | 
				
			||||||
      frame.setTitle("Generated");
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
 | 
					            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
 | 
				
			||||||
            frame.setLayout(new BorderLayout());
 | 
					            frame.setLayout(new BorderLayout());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    JPanel panelx = new JPanel();
 | 
					            panel = new JPanel();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    panelx.setLayout(new GridLayout(4, 4, 8, 8));
 | 
					            panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
 | 
				
			||||||
    for (INDArray sample : samples) {
 | 
					            frame.add(panel, BorderLayout.CENTER);
 | 
				
			||||||
      for(int i = 0; i<batchElements; i++) {
 | 
					 | 
				
			||||||
        panelx.add(getImage(sample, i, isOrig));
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    frame.add(panelx, BorderLayout.CENTER);
 | 
					 | 
				
			||||||
            frame.setVisible(true);
 | 
					            frame.setVisible(true);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        panel.removeAll();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for (INDArray sample : samples) {
 | 
				
			||||||
 | 
					            panel.add(getImage(sample));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        frame.revalidate();
 | 
					        frame.revalidate();
 | 
				
			||||||
    frame.setMinimumSize(new Dimension(300, 20));
 | 
					 | 
				
			||||||
        frame.pack();
 | 
					        frame.pack();
 | 
				
			||||||
    return frame;
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
 | 
					    private static JLabel getImage(INDArray tensor) {
 | 
				
			||||||
    final BufferedImage bi = new BufferedImage(X_DIM, Y_DIM, BufferedImage.TYPE_BYTE_GRAY);
 | 
					        BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
 | 
				
			||||||
    final int imageSize = X_DIM * Y_DIM;
 | 
					        for (int i = 0; i < 784; i++) {
 | 
				
			||||||
    final int offset = batchElement * imageSize;
 | 
					            int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
 | 
				
			||||||
    int pxl = offset * CHANNELS; //where to start in the INDArray
 | 
					            bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
 | 
				
			||||||
 | 
					 | 
				
			||||||
    //Image in NCHW - channels first format
 | 
					 | 
				
			||||||
    for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
 | 
					 | 
				
			||||||
      for (int y = 0; y < Y_DIM; y++) { // step through the columns x
 | 
					 | 
				
			||||||
        for (int x = 0; x < X_DIM; x++) { //step through the rows y
 | 
					 | 
				
			||||||
          if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, tensor.getFloat(pxl));
 | 
					 | 
				
			||||||
          bi.getRaster().setSample(x, y, c, tensor.getFloat(pxl));
 | 
					 | 
				
			||||||
          pxl++; //next item in INDArray
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ImageIcon orig = new ImageIcon(bi);
 | 
					        ImageIcon orig = new ImageIcon(bi);
 | 
				
			||||||
 | 
					        Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);
 | 
				
			||||||
    Image imageScaled = orig.getImage().getScaledInstance((4 * X_DIM), (4 * Y_DIM), Image.SCALE_DEFAULT);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ImageIcon scaled = new ImageIcon(imageScaled);
 | 
					        ImageIcon scaled = new ImageIcon(imageScaled);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return new JLabel(scaled);
 | 
					        return new JLabel(scaled);
 | 
				
			||||||
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
							
								
								
									
										343
									
								
								brutex-extended-tests/src/test/java/net/brutex/gan/App2.java
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										343
									
								
								brutex-extended-tests/src/test/java/net/brutex/gan/App2.java
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,343 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *    ******************************************************************************
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 *    * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 *    * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    *  See the NOTICE file distributed with this work for additional
 | 
				
			||||||
 | 
					 *    *  information regarding copyright ownership.
 | 
				
			||||||
 | 
					 *    * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 *    * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 *    * under the License.
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 *    *****************************************************************************
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package net.brutex.gan;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.awt.*;
 | 
				
			||||||
 | 
					import java.awt.image.BufferedImage;
 | 
				
			||||||
 | 
					import java.io.File;
 | 
				
			||||||
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import java.util.*;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					import javax.imageio.ImageIO;
 | 
				
			||||||
 | 
					import javax.swing.*;
 | 
				
			||||||
 | 
					import lombok.extern.slf4j.Slf4j;
 | 
				
			||||||
 | 
					import org.datavec.api.split.FileSplit;
 | 
				
			||||||
 | 
					import org.datavec.image.loader.NativeImageLoader;
 | 
				
			||||||
 | 
					import org.datavec.image.recordreader.ImageRecordReader;
 | 
				
			||||||
 | 
					import org.datavec.image.transform.*;
 | 
				
			||||||
 | 
					import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
 | 
				
			||||||
 | 
					import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
 | 
					import org.deeplearning4j.optimize.listeners.PerformanceListener;
 | 
				
			||||||
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.dataset.DataSet;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.factory.Nd4j;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Slf4j
 | 
				
			||||||
 | 
					public class App2 {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    final int INPUT = CHANNELS*DIMENSIONS*DIMENSIONS;
 | 
				
			||||||
 | 
					    static final float COLORSPACE = 255f;
 | 
				
			||||||
 | 
					    static final int DIMENSIONS = 28;
 | 
				
			||||||
 | 
					    static final int CHANNELS = 1;
 | 
				
			||||||
 | 
					    final int ARRAY_SIZE_PER_SAMPLE = DIMENSIONS*DIMENSIONS*CHANNELS;
 | 
				
			||||||
 | 
					    final int OUTPUT_PER_PANEL = 10;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    final boolean BIAS = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    static final int BATCHSIZE=128;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private JFrame frame2, frame;
 | 
				
			||||||
 | 
					    static final String OUTPUT_DIR = "d:/out/";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    final static INDArray label_real = Nd4j.ones(BATCHSIZE, 1);
 | 
				
			||||||
 | 
					    final static INDArray label_fake = Nd4j.zeros(BATCHSIZE, 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    void runTest() throws IOException {
 | 
				
			||||||
 | 
					        Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        MnistDataSetIterator mnistIter = new MnistDataSetIterator(20, 200);
 | 
				
			||||||
 | 
					        FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans2"), NativeImageLoader.getALLOWED_FORMATS());
 | 
				
			||||||
 | 
					        ImageTransform transform = new ColorConversionTransform(new Random(42), 7 );
 | 
				
			||||||
 | 
					        ImageTransform transform2 = new ShowImageTransform("Tester", 30);
 | 
				
			||||||
 | 
					        ImageTransform transform3 = new ResizeImageTransform(DIMENSIONS, DIMENSIONS);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ImageTransform tr = new PipelineImageTransform.Builder()
 | 
				
			||||||
 | 
					               .addImageTransform(transform) //convert to GREY SCALE
 | 
				
			||||||
 | 
					                .addImageTransform(transform3)
 | 
				
			||||||
 | 
					                //.addImageTransform(transform2)
 | 
				
			||||||
 | 
					                .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ImageRecordReader imageRecordReader = new ImageRecordReader(DIMENSIONS, DIMENSIONS, CHANNELS);
 | 
				
			||||||
 | 
					        imageRecordReader.initialize(fileSplit, tr);
 | 
				
			||||||
 | 
					        DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, BATCHSIZE );
 | 
				
			||||||
 | 
					        trainData = new MnistDataSetIterator(BATCHSIZE, true, 42);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        MultiLayerNetwork dis = new MultiLayerNetwork(App2Config.discriminator());
 | 
				
			||||||
 | 
					        MultiLayerNetwork gen = new MultiLayerNetwork(App2Config.generator());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        LayerConfiguration[] disLayers = App2Config.discriminator().getFlattenedLayerConfigurations().stream()
 | 
				
			||||||
 | 
					                .map((layer) -> {
 | 
				
			||||||
 | 
					                    if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
 | 
				
			||||||
 | 
					                        return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
 | 
				
			||||||
 | 
					                    } else {
 | 
				
			||||||
 | 
					                        return layer;
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }).toArray(LayerConfiguration[]::new);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    NeuralNetConfiguration netConfiguration =
 | 
				
			||||||
 | 
					        NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					            .name("GAN")
 | 
				
			||||||
 | 
					            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
 | 
					                .gradientNormalizationThreshold(100)
 | 
				
			||||||
 | 
					                .updater(App2Config.UPDATER)
 | 
				
			||||||
 | 
					            .innerConfigurations(new ArrayList<>(List.of(App2Config.generator())))
 | 
				
			||||||
 | 
					            .layersFromList(new ArrayList<>(Arrays.asList(disLayers)))
 | 
				
			||||||
 | 
					            // .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
				
			||||||
 | 
					            // .inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					            //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					           // .inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
				
			||||||
 | 
					                //.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					                //.dataType(DataType.FLOAT)
 | 
				
			||||||
 | 
					            .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        MultiLayerNetwork gan = new MultiLayerNetwork(netConfiguration );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        dis.init(); log.debug("Discriminator network: {}", dis);
 | 
				
			||||||
 | 
					        gen.init(); log.debug("Generator network: {}", gen);
 | 
				
			||||||
 | 
					        gan.init(); log.debug("GAN network: {}", gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        log.info("Generator Summary:\n{}", gen.summary());
 | 
				
			||||||
 | 
					        log.info("GAN Summary:\n{}", gan.summary());
 | 
				
			||||||
 | 
					        dis.addTrainingListeners(new PerformanceListener(10, true, "DIS"));
 | 
				
			||||||
 | 
					        gen.addTrainingListeners(new PerformanceListener(10, true, "GEN"));
 | 
				
			||||||
 | 
					        gan.addTrainingListeners(new PerformanceListener(10, true, "GAN"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        int j = 0;
 | 
				
			||||||
 | 
					        for (int i = 0; i < 51; i++) { //epoch
 | 
				
			||||||
 | 
					            while (trainData.hasNext()) {
 | 
				
			||||||
 | 
					                j++;
 | 
				
			||||||
 | 
					                DataSet next = trainData.next();
 | 
				
			||||||
 | 
					                // generate data
 | 
				
			||||||
 | 
					                INDArray real = next.getFeatures(); //.muli(2).subi(1);;//.div(255f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                //start next round if there are not enough images left to have a full batchsize dataset
 | 
				
			||||||
 | 
					                if(real.length() < ARRAY_SIZE_PER_SAMPLE*BATCHSIZE) {
 | 
				
			||||||
 | 
					                    log.warn("Your total number of input images is not a multiple of {}, "
 | 
				
			||||||
 | 
					                            + "thus skipping {} images to make it fit", BATCHSIZE, real.length()/ARRAY_SIZE_PER_SAMPLE);
 | 
				
			||||||
 | 
					                    break;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                //if(i%20 == 0) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					               // frame2 = visualize(new INDArray[]{real}, BATCHSIZE,
 | 
				
			||||||
 | 
					               //         frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images
 | 
				
			||||||
 | 
					                //}
 | 
				
			||||||
 | 
					                //real.divi(255f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//        int batchSize = (int) real.shape()[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                //INDArray fakeIn = Nd4j.rand(BATCHSIZE, CHANNELS, DIMENSIONS, DIMENSIONS);
 | 
				
			||||||
 | 
					                //INDArray fakeIn = Nd4j.rand(new int[]{batchSize, X_DIM*Y_DIM}); //hack for MNIST only, use above otherwise
 | 
				
			||||||
 | 
					                INDArray fakeIn = Nd4j.rand(BATCHSIZE, App2Config.INPUT);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);
 | 
				
			||||||
 | 
					                // when generator has TANH as activation - value range is -1 to 1
 | 
				
			||||||
 | 
					                // when generator has SIGMOID, then range is 0 to 1
 | 
				
			||||||
 | 
					                fake.addi(1f).divi(2f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                DataSet realSet = new DataSet(real, label_real);
 | 
				
			||||||
 | 
					                DataSet fakeSet = new DataSet(fake, label_fake);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                dis.fit(data);
 | 
				
			||||||
 | 
					                dis.fit(data);
 | 
				
			||||||
 | 
					                // Update the discriminator in the GAN network
 | 
				
			||||||
 | 
					                updateGan(gen, dis, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                gan.fit(new DataSet(Nd4j.rand(BATCHSIZE, App2Config.INPUT), label_fake));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                //Visualize and reporting
 | 
				
			||||||
 | 
					                if (j % 10 == 1) {
 | 
				
			||||||
 | 
					                    System.out.println("Epoch " + i + " Iteration " + j + " Visualizing...");
 | 
				
			||||||
 | 
					                    INDArray[] samples = BATCHSIZE > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[BATCHSIZE];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    for (int k = 0; k < samples.length; k++) {
 | 
				
			||||||
 | 
					                        DataSet fakeSet2 = new DataSet(fakeIn, label_fake);
 | 
				
			||||||
 | 
					                        INDArray input = fakeSet2.get(k).getFeatures();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        //input = input.reshape(1,CHANNELS, DIMENSIONS, DIMENSIONS); //batch size will be 1 here for images
 | 
				
			||||||
 | 
					                        input = input.reshape(1, App2Config.INPUT);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        //samples[k] = gen.output(input, false);
 | 
				
			||||||
 | 
					                        samples[k] = gen.activateSelectedLayers(0, gen.getLayers().length - 1, input);
 | 
				
			||||||
 | 
					                        samples[k] = samples[k].reshape(1, CHANNELS, DIMENSIONS, DIMENSIONS);
 | 
				
			||||||
 | 
					                        //samples[k] =
 | 
				
			||||||
 | 
					                        //samples[k].muli(255f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                    frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (trainData.resetSupported()) {
 | 
				
			||||||
 | 
					                trainData.reset();
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					                log.error("Trainingdata {} does not support reset.", trainData.toString());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            // Copy the GANs generator to gen.
 | 
				
			||||||
 | 
					            updateGen(gen, gan);
 | 
				
			||||||
 | 
					            log.info("Updated GAN's generator from gen.");
 | 
				
			||||||
 | 
					            gen.save(new File("mnist-mlp-generator.dlj"));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) {
 | 
				
			||||||
 | 
					        if (isOrig) {
 | 
				
			||||||
 | 
					            frame.setTitle("Viz Original");
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            frame.setTitle("Generated");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
 | 
				
			||||||
 | 
					        frame.setLayout(new BorderLayout());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        JPanel panelx = new JPanel();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        panelx.setLayout(new GridLayout(4, 4, 8, 8));
 | 
				
			||||||
 | 
					        for (INDArray sample : samples) {
 | 
				
			||||||
 | 
					            for(int i = 0; i<batchElements; i++) {
 | 
				
			||||||
 | 
					                panelx.add(getImage(sample, i, isOrig));
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        frame.add(panelx, BorderLayout.CENTER);
 | 
				
			||||||
 | 
					        frame.setVisible(true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        frame.revalidate();
 | 
				
			||||||
 | 
					        frame.setMinimumSize(new Dimension(300, 20));
 | 
				
			||||||
 | 
					        frame.pack();
 | 
				
			||||||
 | 
					        return frame;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static JLabel getImage(INDArray tensor, int batchElement, boolean isOrig) {
 | 
				
			||||||
 | 
					        final BufferedImage bi;
 | 
				
			||||||
 | 
					        if(CHANNELS >1) {
 | 
				
			||||||
 | 
					            bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_INT_RGB); //need to change here based on channels
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            bi = new BufferedImage(DIMENSIONS, DIMENSIONS, BufferedImage.TYPE_BYTE_GRAY); //need to change here based on channels
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        final int imageSize = DIMENSIONS * DIMENSIONS;
 | 
				
			||||||
 | 
					        final int offset = batchElement * imageSize;
 | 
				
			||||||
 | 
					        int pxl = offset * CHANNELS; //where to start in the INDArray
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        //Image in NCHW - channels first format
 | 
				
			||||||
 | 
					        for (int c = 0; c < CHANNELS; c++) { //step through the num channels for each pixel
 | 
				
			||||||
 | 
					            for (int y = 0; y < DIMENSIONS; y++) { // step through the columns x
 | 
				
			||||||
 | 
					                for (int x = 0; x < DIMENSIONS; x++) { //step through the rows y
 | 
				
			||||||
 | 
					                    float f_pxl = tensor.getFloat(pxl) * COLORSPACE;
 | 
				
			||||||
 | 
					                    if(isOrig) log.trace("'{}.' Image (x,y,c): ({}, {}, {}) with INDArray with index {} and value '{}'", batchElement, x, y, c, pxl, f_pxl);
 | 
				
			||||||
 | 
					                    bi.getRaster().setSample(x, y, c, f_pxl);
 | 
				
			||||||
 | 
					                    pxl++; //next item in INDArray
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        ImageIcon orig = new ImageIcon(bi);
 | 
				
			||||||
 | 
					        Image imageScaled = orig.getImage().getScaledInstance((4 * DIMENSIONS), (4 * DIMENSIONS), Image.SCALE_DEFAULT);
 | 
				
			||||||
 | 
					        ImageIcon scaled = new ImageIcon(imageScaled);
 | 
				
			||||||
 | 
					        if(! isOrig)  saveImage(imageScaled, batchElement, isOrig);
 | 
				
			||||||
 | 
					        return new JLabel(scaled);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static void saveImage(Image image, int batchElement, boolean isOrig) {
 | 
				
			||||||
 | 
					        String outputDirectory = OUTPUT_DIR; // Set the output directory where the images will be saved
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            // Save the images to disk
 | 
				
			||||||
 | 
					            saveImage(image, outputDirectory, UUID.randomUUID().toString()+".png");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            log.debug("Images saved successfully.");
 | 
				
			||||||
 | 
					        } catch (IOException e) {
 | 
				
			||||||
 | 
					            log.error("Error saving the images: {}", e.getMessage());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    private static void saveImage(Image image, String outputDirectory, String fileName) throws IOException {
 | 
				
			||||||
 | 
					        File directory = new File(outputDirectory);
 | 
				
			||||||
 | 
					        if (!directory.exists()) {
 | 
				
			||||||
 | 
					            directory.mkdir();
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        File outputFile = new File(directory, fileName);
 | 
				
			||||||
 | 
					        ImageIO.write(imageToBufferedImage(image), "png", outputFile);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public static BufferedImage imageToBufferedImage(Image image) {
 | 
				
			||||||
 | 
					        if (image instanceof BufferedImage) {
 | 
				
			||||||
 | 
					            return (BufferedImage) image;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Create a buffered image with the same dimensions and transparency as the original image
 | 
				
			||||||
 | 
					        BufferedImage bufferedImage;
 | 
				
			||||||
 | 
					    if (CHANNELS > 1) {
 | 
				
			||||||
 | 
					      bufferedImage =
 | 
				
			||||||
 | 
					          new BufferedImage(
 | 
				
			||||||
 | 
					              image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_INT_ARGB);
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					        bufferedImage =
 | 
				
			||||||
 | 
					                new BufferedImage(
 | 
				
			||||||
 | 
					                        image.getWidth(null), image.getHeight(null), BufferedImage.TYPE_BYTE_GRAY);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Draw the original image onto the buffered image
 | 
				
			||||||
 | 
					        Graphics2D g2d = bufferedImage.createGraphics();
 | 
				
			||||||
 | 
					        g2d.drawImage(image, 0, 0, null);
 | 
				
			||||||
 | 
					        g2d.dispose();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return bufferedImage;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
 | 
				
			||||||
 | 
					        for (int i = 0; i < gen.getLayers().length; i++) {
 | 
				
			||||||
 | 
					            gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
				
			||||||
 | 
					        int genLayerCount = gen.getLayers().length;
 | 
				
			||||||
 | 
					        for (int i = genLayerCount; i < gan.getLayers().length; i++) {
 | 
				
			||||||
 | 
					            gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,176 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *    ******************************************************************************
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 *    * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 *    * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    *  See the NOTICE file distributed with this work for additional
 | 
				
			||||||
 | 
					 *    *  information regarding copyright ownership.
 | 
				
			||||||
 | 
					 *    * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 *    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 *    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 *    * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 *    * under the License.
 | 
				
			||||||
 | 
					 *    *
 | 
				
			||||||
 | 
					 *    * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 *    *****************************************************************************
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package net.brutex.gan;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import static net.brutex.ai.dnn.api.NN.*;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.weights.WeightInit;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.learning.config.Adam;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.learning.config.IUpdater;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.lossfunctions.LossFunctions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					public class App2Config {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  public static final int INPUT = 100;
 | 
				
			||||||
 | 
					  public static final int X_DIM = 28;
 | 
				
			||||||
 | 
					  public static final int y_DIM = 28;
 | 
				
			||||||
 | 
					  public static final int CHANNELS = 1;
 | 
				
			||||||
 | 
					  public static final IUpdater UPDATER = Adam.builder().learningRate(0.0002).beta1(0.5).build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static LayerConfiguration[] genLayerConfig() {
 | 
				
			||||||
 | 
					    return new LayerConfiguration[] {
 | 
				
			||||||
 | 
					            /*
 | 
				
			||||||
 | 
					      DenseLayer.builder().name("L-0").nIn(INPUT).nOut(INPUT + (INPUT / 2)).activation(Activation.RELU).build(),
 | 
				
			||||||
 | 
					      ActivationLayer.builder().activation(Activation.RELU).build(), /*
 | 
				
			||||||
 | 
					                Deconvolution2D.builder().name("L-Deconv-01").nIn(CHANNELS).nOut(CHANNELS)
 | 
				
			||||||
 | 
					                        .kernelSize(2,2)
 | 
				
			||||||
 | 
					                        .stride(1,1)
 | 
				
			||||||
 | 
					                        .padding(0,0)
 | 
				
			||||||
 | 
					                        .convolutionMode(ConvolutionMode.Truncate)
 | 
				
			||||||
 | 
					                        .activation(Activation.RELU)
 | 
				
			||||||
 | 
					                        .hasBias(BIAS).build(),
 | 
				
			||||||
 | 
					                //BatchNormalization.builder().nOut(CHANNELS).build(),
 | 
				
			||||||
 | 
					                Deconvolution2D.builder().name("L-Deconv-02").nIn(CHANNELS).nOut(CHANNELS)
 | 
				
			||||||
 | 
					                        .kernelSize(2,2)
 | 
				
			||||||
 | 
					                        .stride(2,2)
 | 
				
			||||||
 | 
					                        .padding(0,0)
 | 
				
			||||||
 | 
					                        .convolutionMode(ConvolutionMode.Truncate)
 | 
				
			||||||
 | 
					                        .activation(Activation.RELU)
 | 
				
			||||||
 | 
					                        .hasBias(BIAS).build(),
 | 
				
			||||||
 | 
					               //BatchNormalization.builder().name("L-batch").nOut(CHANNELS).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      DenseLayer.builder().name("L-x").nIn(INPUT + (INPUT / 2)).nOut(2 * INPUT).build(),
 | 
				
			||||||
 | 
					      ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
 | 
				
			||||||
 | 
					      DenseLayer.builder().name("L-x").nIn(2 * INPUT).nOut(3 * INPUT).build(),
 | 
				
			||||||
 | 
					      ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
 | 
				
			||||||
 | 
					      DenseLayer.builder().name("L-x").nIn(3 * INPUT).nOut(2 * INPUT).build(),
 | 
				
			||||||
 | 
					      ActivationLayer.builder().activation(Activation.RELU).dropOut(0.2).build(),
 | 
				
			||||||
 | 
					      // DropoutLayer.builder(0.001).build(),
 | 
				
			||||||
 | 
					      DenseLayer.builder().nIn(2 * INPUT).nOut(INPUT).activation(Activation.TANH).build()    */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            dense().nIn(INPUT).nOut(256).weightInit(WeightInit.NORMAL).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            dense().nIn(256).nOut(512).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            dense().nIn(512).nOut(1024).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            dense().nIn(1024).nOut(784).activation(Activation.TANH).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static LayerConfiguration[] disLayerConfig() {
 | 
				
			||||||
 | 
					    return new LayerConfiguration[] {/*
 | 
				
			||||||
 | 
					                Convolution2D.builder().nIn(CHANNELS).kernelSize(2,2).padding(1,1).stride(1,1).nOut(CHANNELS)
 | 
				
			||||||
 | 
					                        .build(),
 | 
				
			||||||
 | 
					                Convolution2D.builder().nIn(CHANNELS).kernelSize(3,3).padding(1,1).stride(2,2).nOut(CHANNELS)
 | 
				
			||||||
 | 
					                        .build(),
 | 
				
			||||||
 | 
					                ActivationLayer.builder().activation(Activation.LEAKYRELU).build(),
 | 
				
			||||||
 | 
					                BatchNormalization.builder().build(),
 | 
				
			||||||
 | 
					                OutputLayer.builder().nOut(1).lossFunction(LossFunctions.LossFunction.MCXENT)
 | 
				
			||||||
 | 
					                        .activation(Activation.SIGMOID)
 | 
				
			||||||
 | 
					                        .build()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            dense().name("L-dense").nIn(INPUT).nOut(INPUT).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder().activation(Activation.RELU).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(0.5).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            DenseLayer.builder().nIn(INPUT).nOut(INPUT/2).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder().activation(Activation.RELU).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(0.5).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            DenseLayer.builder().nIn(INPUT/2).nOut(INPUT/4).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder().activation(Activation.RELU).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(0.5).build(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OutputLayer.builder().nIn(INPUT/4).nOut(1).lossFunction(LossFunctions.LossFunction.XENT)
 | 
				
			||||||
 | 
					                    .activation(Activation.SIGMOID)
 | 
				
			||||||
 | 
					                    .build() */
 | 
				
			||||||
 | 
					            dense().nIn(784).nOut(1024).hasBias(true).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
 | 
					            dense().nIn(1024).nOut(512).hasBias(true).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
 | 
					            dense().nIn(512).nOut(256).hasBias(true).build(),
 | 
				
			||||||
 | 
					            ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
 | 
				
			||||||
 | 
					            DropoutLayer.builder(1 - 0.5).build(),
 | 
				
			||||||
 | 
					            OutputLayer.builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build()
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static NeuralNetConfiguration generator() {
 | 
				
			||||||
 | 
					    NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					            NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                    .name("generator")
 | 
				
			||||||
 | 
					                    .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
 | 
					                    .gradientNormalizationThreshold(100)
 | 
				
			||||||
 | 
					                    .seed(42)
 | 
				
			||||||
 | 
					                    .updater(UPDATER)
 | 
				
			||||||
 | 
					                    .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
 | 
					                    //.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
				
			||||||
 | 
					                    .weightNoise(null)
 | 
				
			||||||
 | 
					                    // .weightInitFn(new WeightInitXavier())
 | 
				
			||||||
 | 
					                    // .activationFn(new ActivationIdentity())
 | 
				
			||||||
 | 
					                    .activation(Activation.IDENTITY)
 | 
				
			||||||
 | 
					                    .layersFromArray(App2Config.genLayerConfig())
 | 
				
			||||||
 | 
					                    // .inputType(InputType.convolutional(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
				
			||||||
 | 
					                    //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					                    //.inputPreProcessor(2, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
				
			||||||
 | 
					                    //.inputPreProcessor(4, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					    conf.init();
 | 
				
			||||||
 | 
					    return conf;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static NeuralNetConfiguration discriminator() {
 | 
				
			||||||
 | 
					    NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					            NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                    .name("discriminator")
 | 
				
			||||||
 | 
					                    .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
 | 
					                    .gradientNormalizationThreshold(100)
 | 
				
			||||||
 | 
					                    .seed(42)
 | 
				
			||||||
 | 
					                    .updater(UPDATER)
 | 
				
			||||||
 | 
					                    .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
 | 
					                    // .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
				
			||||||
 | 
					                    .weightNoise(null)
 | 
				
			||||||
 | 
					                    // .weightInitFn(new WeightInitXavier())
 | 
				
			||||||
 | 
					                    // .activationFn(new ActivationIdentity())
 | 
				
			||||||
 | 
					                    .activation(Activation.IDENTITY)
 | 
				
			||||||
 | 
					                    .layersFromArray(disLayerConfig())
 | 
				
			||||||
 | 
					                    //.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(DIMENSIONS, DIMENSIONS, CHANNELS))
 | 
				
			||||||
 | 
					                    //.inputPreProcessor(0, new CnnToFeedForwardPreProcessor())
 | 
				
			||||||
 | 
					                    //.dataType(DataType.FLOAT)
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					    conf.init();
 | 
				
			||||||
 | 
					    return conf;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -24,12 +24,14 @@ package net.brutex.gan;
 | 
				
			|||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
 | 
					import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
					import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.ActivationLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.DropoutLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
					import org.deeplearning4j.nn.weights.WeightInit;
 | 
				
			||||||
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
import org.nd4j.linalg.activations.Activation;
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
					import org.nd4j.linalg.activations.impl.ActivationLReLU;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
@ -98,7 +100,10 @@ public class MnistSimpleGAN {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return new MultiLayerNetwork(discConf);
 | 
					    return new MultiLayerNetwork(discConf);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void runTest() throws Exception {
 | 
				
			||||||
 | 
					    main(null);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  public static void main(String[] args) throws Exception {
 | 
					  public static void main(String[] args) throws Exception {
 | 
				
			||||||
    GAN gan = new GAN.Builder()
 | 
					    GAN gan = new GAN.Builder()
 | 
				
			||||||
        .generator(MnistSimpleGAN::getGenerator)
 | 
					        .generator(MnistSimpleGAN::getGenerator)
 | 
				
			||||||
@ -108,6 +113,7 @@ public class MnistSimpleGAN {
 | 
				
			|||||||
        .updater(UPDATER)
 | 
					        .updater(UPDATER)
 | 
				
			||||||
        .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
					        .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
        .gradientNormalizationThreshold(100)
 | 
					        .gradientNormalizationThreshold(100)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        .build();
 | 
					        .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
 | 
					    Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
 | 
				
			||||||
 | 
				
			|||||||
@ -88,7 +88,7 @@ public class CNN1DTestCases {
 | 
				
			|||||||
                        .convolutionMode(ConvolutionMode.Same))
 | 
					                        .convolutionMode(ConvolutionMode.Same))
 | 
				
			||||||
                        .graphBuilder()
 | 
					                        .graphBuilder()
 | 
				
			||||||
                        .addInputs("in")
 | 
					                        .addInputs("in")
 | 
				
			||||||
                        .layer("0", Convolution1DLayer.builder().nOut(32).activation(Activation.TANH).kernelSize(3).stride(1).build(), "in")
 | 
					                        .layer("0", Convolution1D.builder().nOut(32).activation(Activation.TANH).kernelSize(3).stride(1).build(), "in")
 | 
				
			||||||
                        .layer("1", Subsampling1DLayer.builder().kernelSize(2).stride(1).poolingType(SubsamplingLayer.PoolingType.MAX.toPoolingType()).build(), "0")
 | 
					                        .layer("1", Subsampling1DLayer.builder().kernelSize(2).stride(1).poolingType(SubsamplingLayer.PoolingType.MAX.toPoolingType()).build(), "0")
 | 
				
			||||||
                        .layer("2", Cropping1D.builder(1).build(), "1")
 | 
					                        .layer("2", Cropping1D.builder(1).build(), "1")
 | 
				
			||||||
                        .layer("3", ZeroPadding1DLayer.builder(1).build(), "2")
 | 
					                        .layer("3", ZeroPadding1DLayer.builder(1).build(), "2")
 | 
				
			||||||
 | 
				
			|||||||
@ -25,7 +25,7 @@
 | 
				
			|||||||
# Default logging detail level for all instances of SimpleLogger.
 | 
					# Default logging detail level for all instances of SimpleLogger.
 | 
				
			||||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
					# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
				
			||||||
# If not specified, defaults to "info".
 | 
					# If not specified, defaults to "info".
 | 
				
			||||||
org.slf4j.simpleLogger.defaultLogLevel=trace
 | 
					org.slf4j.simpleLogger.defaultLogLevel=debug
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Logging detail level for a SimpleLogger instance named "xxxxx".
 | 
					# Logging detail level for a SimpleLogger instance named "xxxxx".
 | 
				
			||||||
# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
					# Must be one of ("trace", "debug", "info", "warn", or "error").
 | 
				
			||||||
@ -42,8 +42,8 @@ org.slf4j.simpleLogger.defaultLogLevel=trace
 | 
				
			|||||||
# If the format is not specified or is invalid, the default format is used.
 | 
					# If the format is not specified or is invalid, the default format is used.
 | 
				
			||||||
# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
 | 
					# The default format is yyyy-MM-dd HH:mm:ss:SSS Z.
 | 
				
			||||||
#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
 | 
					#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z
 | 
				
			||||||
org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
 | 
					#org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Set to true if you want to output the current thread name.
 | 
					# Set to true if you want to output the current thread name.
 | 
				
			||||||
# Defaults to true.
 | 
					# Defaults to true.
 | 
				
			||||||
org.slf4j.simpleLogger.showThreadName=true
 | 
					#org.slf4j.simpleLogger.showThreadName=true
 | 
				
			||||||
@ -71,7 +71,7 @@ dependencies {
 | 
				
			|||||||
       // api "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}"
 | 
					       // api "com.fasterxml.jackson.module:jackson-module-scala_${scalaVersion}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        api "org.projectlombok:lombok:1.18.26"
 | 
					        api "org.projectlombok:lombok:1.18.28"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        /*Logging*/
 | 
					        /*Logging*/
 | 
				
			||||||
        api 'org.slf4j:slf4j-api:2.0.3'
 | 
					        api 'org.slf4j:slf4j-api:2.0.3'
 | 
				
			||||||
 | 
				
			|||||||
@ -2386,7 +2386,11 @@ public interface INDArray extends Serializable, AutoCloseable {
 | 
				
			|||||||
    long[] stride();
 | 
					    long[] stride();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
     * Return the ordering (fortran or c  'f' and 'c' respectively) of this ndarray
 | 
					   * Return the ordering (fortran or c  'f' and 'c' respectively) of this ndarray <br/><br/>
 | 
				
			||||||
 | 
					   * C Is Contiguous layout. Mathematically speaking, row major.<br/>
 | 
				
			||||||
 | 
					   * F Is Fortran contiguous layout. Mathematically speaking, column major.<br/>
 | 
				
			||||||
 | 
					   * {@see https://en.wikipedia.org/wiki/Row-_and_column-major_order}<br/>
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
   * @return the ordering of this ndarray
 | 
					   * @return the ordering of this ndarray
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  char ordering();
 | 
					  char ordering();
 | 
				
			||||||
 | 
				
			|||||||
@ -334,6 +334,7 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet {
 | 
				
			|||||||
    public void save(File to) {
 | 
					    public void save(File to) {
 | 
				
			||||||
        try (FileOutputStream fos = new FileOutputStream(to, false);
 | 
					        try (FileOutputStream fos = new FileOutputStream(to, false);
 | 
				
			||||||
                        BufferedOutputStream bos = new BufferedOutputStream(fos)) {
 | 
					                        BufferedOutputStream bos = new BufferedOutputStream(fos)) {
 | 
				
			||||||
 | 
					            to.mkdirs();
 | 
				
			||||||
            save(bos);
 | 
					            save(bos);
 | 
				
			||||||
        } catch (IOException e) {
 | 
					        } catch (IOException e) {
 | 
				
			||||||
            throw new RuntimeException(e);
 | 
					            throw new RuntimeException(e);
 | 
				
			||||||
 | 
				
			|||||||
@ -5121,7 +5121,7 @@ public class Nd4j {
 | 
				
			|||||||
            Nd4j.backend = backend;
 | 
					            Nd4j.backend = backend;
 | 
				
			||||||
            updateNd4jContext();
 | 
					            updateNd4jContext();
 | 
				
			||||||
            props = Nd4jContext.getInstance().getConf();
 | 
					            props = Nd4jContext.getInstance().getConf();
 | 
				
			||||||
            logger.info("Properties for Nd4jContext " + props);
 | 
					            log.debug("Properties for Nd4jContext {}", props);
 | 
				
			||||||
            PropertyParser pp = new PropertyParser(props);
 | 
					            PropertyParser pp = new PropertyParser(props);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            String otherDtype = pp.toString(ND4JSystemProperties.DTYPE);
 | 
					            String otherDtype = pp.toString(ND4JSystemProperties.DTYPE);
 | 
				
			||||||
 | 
				
			|||||||
@ -166,10 +166,10 @@ public class DataSetIteratorTest extends BaseDL4JTest {
 | 
				
			|||||||
        int seed = 123;
 | 
					        int seed = 123;
 | 
				
			||||||
        int listenerFreq = 1;
 | 
					        int listenerFreq = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples,
 | 
					        final LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples,
 | 
				
			||||||
                        new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed));
 | 
					                        new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed)
 | 
					        final var builder = NeuralNetConfiguration.builder().seed(seed)
 | 
				
			||||||
                        .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
					                        .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
 | 
					                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
 | 
				
			||||||
                        .layer(0, ConvolutionLayer.builder(5, 5).nIn(numChannels).nOut(6)
 | 
					                        .layer(0, ConvolutionLayer.builder(5, 5).nIn(numChannels).nOut(6)
 | 
				
			||||||
@ -181,7 +181,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
 | 
				
			|||||||
                                        .build())
 | 
					                                        .build())
 | 
				
			||||||
                        .inputType(InputType.convolutionalFlat(numRows, numColumns, numChannels));
 | 
					                        .inputType(InputType.convolutionalFlat(numRows, numColumns, numChannels));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
 | 
					        final MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
 | 
				
			||||||
        model.init();
 | 
					        model.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
 | 
					        model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
 | 
				
			||||||
 | 
				
			|||||||
@ -45,6 +45,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.layers.*;
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
 | 
					import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
 | 
					import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.serde.CavisMapper;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
					import org.deeplearning4j.nn.weights.WeightInit;
 | 
				
			||||||
import org.deeplearning4j.optimize.api.BaseTrainingListener;
 | 
					import org.deeplearning4j.optimize.api.BaseTrainingListener;
 | 
				
			||||||
@ -924,8 +925,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for(EpochTerminationCondition e : etc ){
 | 
					        for(EpochTerminationCondition e : etc ){
 | 
				
			||||||
            String s = NeuralNetConfiguration.mapper().writeValueAsString(e);
 | 
					            String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(e);
 | 
				
			||||||
            EpochTerminationCondition c = NeuralNetConfiguration.mapper().readValue(s, EpochTerminationCondition.class);
 | 
					            EpochTerminationCondition c = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, EpochTerminationCondition.class);
 | 
				
			||||||
            assertEquals(e, c);
 | 
					            assertEquals(e, c);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -936,8 +937,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for(IterationTerminationCondition i : itc ){
 | 
					        for(IterationTerminationCondition i : itc ){
 | 
				
			||||||
            String s = NeuralNetConfiguration.mapper().writeValueAsString(i);
 | 
					            String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(i);
 | 
				
			||||||
            IterationTerminationCondition c = NeuralNetConfiguration.mapper().readValue(s, IterationTerminationCondition.class);
 | 
					            IterationTerminationCondition c = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, IterationTerminationCondition.class);
 | 
				
			||||||
            assertEquals(i, c);
 | 
					            assertEquals(i, c);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -309,7 +309,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        try {
 | 
					        try {
 | 
				
			||||||
            NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict)
 | 
					            NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict)
 | 
				
			||||||
                            .list()
 | 
					
 | 
				
			||||||
                            .layer(0, ConvolutionLayer.builder().kernelSize(2, 3).stride(2, 2).padding(0, 0).nOut(5)
 | 
					                            .layer(0, ConvolutionLayer.builder().kernelSize(2, 3).stride(2, 2).padding(0, 0).nOut(5)
 | 
				
			||||||
                                            .build())
 | 
					                                            .build())
 | 
				
			||||||
                            .layer(1, OutputLayer.builder().nOut(10).build())
 | 
					                            .layer(1, OutputLayer.builder().nOut(10).build())
 | 
				
			||||||
 | 
				
			|||||||
@ -77,7 +77,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                NeuralNetConfiguration.builder().updater(new NoOp())
 | 
					                NeuralNetConfiguration.builder().updater(new NoOp())
 | 
				
			||||||
                        .dataType(DataType.DOUBLE)
 | 
					                        .dataType(DataType.DOUBLE)
 | 
				
			||||||
                        .seed(12345L)
 | 
					                        .seed(12345L)
 | 
				
			||||||
                        .dist(new NormalDistribution(0, 1)).list()
 | 
					                        .weightInit(new NormalDistribution(0, 1))
 | 
				
			||||||
                        .layer(0, DenseLayer.builder().nIn(4).nOut(3)
 | 
					                        .layer(0, DenseLayer.builder().nIn(4).nOut(3)
 | 
				
			||||||
                                .activation(Activation.IDENTITY).build())
 | 
					                                .activation(Activation.IDENTITY).build())
 | 
				
			||||||
                        .layer(1,BatchNormalization.builder().useLogStd(useLogStd).nOut(3).build())
 | 
					                        .layer(1,BatchNormalization.builder().useLogStd(useLogStd).nOut(3).build())
 | 
				
			||||||
@ -122,7 +122,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                    .dataType(DataType.DOUBLE)
 | 
					                    .dataType(DataType.DOUBLE)
 | 
				
			||||||
                    .updater(new NoOp()).seed(12345L)
 | 
					                    .updater(new NoOp()).seed(12345L)
 | 
				
			||||||
                    .dist(new NormalDistribution(0, 2)).list()
 | 
					                    .dist(new NormalDistribution(0, 2)).list()
 | 
				
			||||||
                    .layer(0, ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
 | 
					                    .layer(0, Convolution2D.builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
 | 
				
			||||||
                            .activation(Activation.IDENTITY).build())
 | 
					                            .activation(Activation.IDENTITY).build())
 | 
				
			||||||
                    .layer(1,BatchNormalization.builder().useLogStd(useLogStd).build())
 | 
					                    .layer(1,BatchNormalization.builder().useLogStd(useLogStd).build())
 | 
				
			||||||
                    .layer(2, ActivationLayer.builder().activation(Activation.TANH).build())
 | 
					                    .layer(2, ActivationLayer.builder().activation(Activation.TANH).build())
 | 
				
			||||||
 | 
				
			|||||||
@ -91,9 +91,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                  .updater(new NoOp())
 | 
					                  .updater(new NoOp())
 | 
				
			||||||
                  .dist(new NormalDistribution(0, 1))
 | 
					                  .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
                  .convolutionMode(ConvolutionMode.Same)
 | 
					                  .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
                  .list()
 | 
					 | 
				
			||||||
                  .layer(
 | 
					                  .layer(
 | 
				
			||||||
                      Convolution1DLayer.builder()
 | 
					                      Convolution1D.builder()
 | 
				
			||||||
                          .activation(afn)
 | 
					                          .activation(afn)
 | 
				
			||||||
                          .kernelSize(kernel)
 | 
					                          .kernelSize(kernel)
 | 
				
			||||||
                          .stride(stride)
 | 
					                          .stride(stride)
 | 
				
			||||||
@ -202,7 +201,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                    .dist(new NormalDistribution(0, 1))
 | 
					                    .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
                    .convolutionMode(ConvolutionMode.Same)
 | 
					                    .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
                    .layer(
 | 
					                    .layer(
 | 
				
			||||||
                        Convolution1DLayer.builder()
 | 
					                        Convolution1D.builder()
 | 
				
			||||||
                            .activation(afn)
 | 
					                            .activation(afn)
 | 
				
			||||||
                            .kernelSize(kernel)
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
                            .stride(stride)
 | 
					                            .stride(stride)
 | 
				
			||||||
@ -211,7 +210,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                            .build())
 | 
					                            .build())
 | 
				
			||||||
                    .layer(Cropping1D.builder(cropping).build())
 | 
					                    .layer(Cropping1D.builder(cropping).build())
 | 
				
			||||||
                    .layer(
 | 
					                    .layer(
 | 
				
			||||||
                        Convolution1DLayer.builder()
 | 
					                        Convolution1D.builder()
 | 
				
			||||||
                            .activation(afn)
 | 
					                            .activation(afn)
 | 
				
			||||||
                            .kernelSize(kernel)
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
                            .stride(stride)
 | 
					                            .stride(stride)
 | 
				
			||||||
@ -317,7 +316,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                    .dist(new NormalDistribution(0, 1))
 | 
					                    .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
                    .convolutionMode(ConvolutionMode.Same)
 | 
					                    .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
                    .layer(
 | 
					                    .layer(
 | 
				
			||||||
                        Convolution1DLayer.builder()
 | 
					                        Convolution1D.builder()
 | 
				
			||||||
                            .activation(afn)
 | 
					                            .activation(afn)
 | 
				
			||||||
                            .kernelSize(kernel)
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
                            .stride(stride)
 | 
					                            .stride(stride)
 | 
				
			||||||
@ -326,7 +325,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                            .build())
 | 
					                            .build())
 | 
				
			||||||
                    .layer(ZeroPadding1DLayer.builder(zeroPadding).build())
 | 
					                    .layer(ZeroPadding1DLayer.builder(zeroPadding).build())
 | 
				
			||||||
                    .layer(
 | 
					                    .layer(
 | 
				
			||||||
                        Convolution1DLayer.builder()
 | 
					                        Convolution1D.builder()
 | 
				
			||||||
                            .activation(afn)
 | 
					                            .activation(afn)
 | 
				
			||||||
                            .kernelSize(kernel)
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
                            .stride(stride)
 | 
					                            .stride(stride)
 | 
				
			||||||
@ -435,10 +434,9 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                    .updater(new NoOp())
 | 
					                    .updater(new NoOp())
 | 
				
			||||||
                    .dist(new NormalDistribution(0, 1))
 | 
					                    .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
                    .convolutionMode(ConvolutionMode.Same)
 | 
					                    .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
                    .list()
 | 
					 | 
				
			||||||
                    .layer(
 | 
					                    .layer(
 | 
				
			||||||
                        0,
 | 
					                        0,
 | 
				
			||||||
                        Convolution1DLayer.builder()
 | 
					                        Convolution1D.builder()
 | 
				
			||||||
                            .activation(afn)
 | 
					                            .activation(afn)
 | 
				
			||||||
                            .kernelSize(kernel)
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
                            .stride(stride)
 | 
					                            .stride(stride)
 | 
				
			||||||
@ -447,7 +445,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                            .build())
 | 
					                            .build())
 | 
				
			||||||
                    .layer(
 | 
					                    .layer(
 | 
				
			||||||
                        1,
 | 
					                        1,
 | 
				
			||||||
                        Convolution1DLayer.builder()
 | 
					                        Convolution1D.builder()
 | 
				
			||||||
                            .activation(afn)
 | 
					                            .activation(afn)
 | 
				
			||||||
                            .kernelSize(kernel)
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
                            .stride(stride)
 | 
					                            .stride(stride)
 | 
				
			||||||
@ -461,6 +459,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                            .stride(stride)
 | 
					                            .stride(stride)
 | 
				
			||||||
                            .padding(padding)
 | 
					                            .padding(padding)
 | 
				
			||||||
                            .pnorm(pnorm)
 | 
					                            .pnorm(pnorm)
 | 
				
			||||||
 | 
					                                .name("SubsamplingLayer")
 | 
				
			||||||
                            .build())
 | 
					                            .build())
 | 
				
			||||||
                    .layer(
 | 
					                    .layer(
 | 
				
			||||||
                        3,
 | 
					                        3,
 | 
				
			||||||
@ -548,7 +547,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                  .seed(12345)
 | 
					                  .seed(12345)
 | 
				
			||||||
                  .list()
 | 
					                  .list()
 | 
				
			||||||
                  .layer(
 | 
					                  .layer(
 | 
				
			||||||
                      Convolution1DLayer.builder()
 | 
					                      Convolution1D.builder()
 | 
				
			||||||
                          .kernelSize(2)
 | 
					                          .kernelSize(2)
 | 
				
			||||||
                          .rnnDataFormat(RNNFormat.NCW)
 | 
					                          .rnnDataFormat(RNNFormat.NCW)
 | 
				
			||||||
                          .stride(stride)
 | 
					                          .stride(stride)
 | 
				
			||||||
@ -562,7 +561,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                          .pnorm(pnorm)
 | 
					                          .pnorm(pnorm)
 | 
				
			||||||
                          .build())
 | 
					                          .build())
 | 
				
			||||||
                  .layer(
 | 
					                  .layer(
 | 
				
			||||||
                      Convolution1DLayer.builder()
 | 
					                      Convolution1D.builder()
 | 
				
			||||||
                          .kernelSize(2)
 | 
					                          .kernelSize(2)
 | 
				
			||||||
                          .rnnDataFormat(RNNFormat.NCW)
 | 
					                          .rnnDataFormat(RNNFormat.NCW)
 | 
				
			||||||
                          .stride(stride)
 | 
					                          .stride(stride)
 | 
				
			||||||
@ -655,7 +654,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
              .seed(12345)
 | 
					              .seed(12345)
 | 
				
			||||||
              .list()
 | 
					              .list()
 | 
				
			||||||
              .layer(
 | 
					              .layer(
 | 
				
			||||||
                  Convolution1DLayer.builder()
 | 
					                  Convolution1D.builder()
 | 
				
			||||||
                      .kernelSize(k)
 | 
					                      .kernelSize(k)
 | 
				
			||||||
                      .dilation(d)
 | 
					                      .dilation(d)
 | 
				
			||||||
                      .hasBias(hasBias)
 | 
					                      .hasBias(hasBias)
 | 
				
			||||||
@ -664,7 +663,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                      .nOut(convNOut1)
 | 
					                      .nOut(convNOut1)
 | 
				
			||||||
                      .build())
 | 
					                      .build())
 | 
				
			||||||
              .layer(
 | 
					              .layer(
 | 
				
			||||||
                  Convolution1DLayer.builder()
 | 
					                  Convolution1D.builder()
 | 
				
			||||||
                      .kernelSize(k)
 | 
					                      .kernelSize(k)
 | 
				
			||||||
                      .dilation(d)
 | 
					                      .dilation(d)
 | 
				
			||||||
                      .convolutionMode(ConvolutionMode.Causal)
 | 
					                      .convolutionMode(ConvolutionMode.Causal)
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,811 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					 *  ******************************************************************************
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 *  * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 *  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  *  See the NOTICE file distributed with this work for additional
 | 
				
			||||||
 | 
					 *  *  information regarding copyright ownership.
 | 
				
			||||||
 | 
					 *  * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 *  * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 *  * under the License.
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 *  *****************************************************************************
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package org.deeplearning4j.gradientcheck;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import static org.junit.jupiter.api.Assertions.assertEquals;
 | 
				
			||||||
 | 
					import static org.junit.jupiter.api.Assertions.assertTrue;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import lombok.extern.slf4j.Slf4j;
 | 
				
			||||||
 | 
					import org.deeplearning4j.BaseDL4JTest;
 | 
				
			||||||
 | 
					import org.deeplearning4j.TestUtils;
 | 
				
			||||||
 | 
					import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.RNNFormat;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
 | 
					import org.deeplearning4j.util.Convolution1DUtils;
 | 
				
			||||||
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
 | 
					import org.nd4j.common.primitives.Pair;
 | 
				
			||||||
 | 
					import org.nd4j.evaluation.classification.Evaluation;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.buffer.DataType;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.dataset.DataSet;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.factory.Nd4j;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.indexing.NDArrayIndex;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.learning.config.NoOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.lossfunctions.LossFunctions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.Arrays;
 | 
				
			||||||
 | 
					import java.util.Collections;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Slf4j
 | 
				
			||||||
 | 
					public class CNN1DNewGradientCheckTest extends BaseDL4JTest {
 | 
				
			||||||
 | 
					  private static final boolean PRINT_RESULTS = true;
 | 
				
			||||||
 | 
					  private static final boolean RETURN_ON_FIRST_FAILURE = false;
 | 
				
			||||||
 | 
					  private static final double DEFAULT_EPS = 1e-6;
 | 
				
			||||||
 | 
					  private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
 | 
				
			||||||
 | 
					  private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static {
 | 
				
			||||||
 | 
					    Nd4j.setDataType(DataType.DOUBLE);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void testCnn1D() {
 | 
				
			||||||
 | 
					    int minibatchSize = 4;
 | 
				
			||||||
 | 
					    int[] dataChannels = {4, 10}; //the input
 | 
				
			||||||
 | 
					    int[] kernels = {2,4,5,8};
 | 
				
			||||||
 | 
					    int stride = 2;
 | 
				
			||||||
 | 
					    int padding = 3;
 | 
				
			||||||
 | 
					    int seriesLength = 300;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int kernel : kernels) {
 | 
				
			||||||
 | 
					      for (int dChannels : dataChannels) {
 | 
				
			||||||
 | 
					        int numLabels = ((seriesLength + (2 * padding) - kernel) / stride) + 1;
 | 
				
			||||||
 | 
					        final NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					            NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                .dataType(DataType.DOUBLE)
 | 
				
			||||||
 | 
					                .updater(new NoOp())
 | 
				
			||||||
 | 
					                .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
 | 
					                .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
 | 
					                .layer(
 | 
				
			||||||
 | 
					                    Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                        .activation(Activation.RELU)
 | 
				
			||||||
 | 
					                        .kernelSize(kernel)
 | 
				
			||||||
 | 
					                        .stride(stride)
 | 
				
			||||||
 | 
					                        .padding(padding)
 | 
				
			||||||
 | 
					                        .nIn(dChannels) // channels
 | 
				
			||||||
 | 
					                        .nOut(3)
 | 
				
			||||||
 | 
					                        .rnnDataFormat(RNNFormat.NCW)
 | 
				
			||||||
 | 
					                        .build())
 | 
				
			||||||
 | 
					                .layer(
 | 
				
			||||||
 | 
					                    RnnOutputLayer.builder()
 | 
				
			||||||
 | 
					                        .lossFunction(LossFunctions.LossFunction.MCXENT)
 | 
				
			||||||
 | 
					                        .activation(Activation.SOFTMAX)
 | 
				
			||||||
 | 
					                        .nOut(4)
 | 
				
			||||||
 | 
					                        .build())
 | 
				
			||||||
 | 
					                .inputType(InputType.recurrent(dChannels, seriesLength))
 | 
				
			||||||
 | 
					                .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        INDArray input = Nd4j.rand(minibatchSize, dChannels, seriesLength);
 | 
				
			||||||
 | 
					        INDArray labels = Nd4j.zeros(minibatchSize, 4, numLabels);
 | 
				
			||||||
 | 
					        for (int i = 0; i < minibatchSize; i++) {
 | 
				
			||||||
 | 
					          for (int j = 0; j < numLabels; j++) {
 | 
				
			||||||
 | 
					            labels.putScalar(new int[] {i, i % 4, j}, 1.0);
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        final MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
 | 
					        net.init();
 | 
				
			||||||
 | 
					        String msg =
 | 
				
			||||||
 | 
					            "Minibatch="
 | 
				
			||||||
 | 
					                + minibatchSize
 | 
				
			||||||
 | 
					                + ", activationFn="
 | 
				
			||||||
 | 
					                + Activation.RELU
 | 
				
			||||||
 | 
					                + ", kernel = "
 | 
				
			||||||
 | 
					                + kernel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        System.out.println(msg);
 | 
				
			||||||
 | 
					        for (int j = 0; j < net.getnLayers(); j++)
 | 
				
			||||||
 | 
					          System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					        List<Pair<INDArray, INDArray>> iter = new java.util.ArrayList<>(Collections.emptyList());
 | 
				
			||||||
 | 
					        iter.add(new Pair<>(input, labels));
 | 
				
			||||||
 | 
					        for(int x=0;x<100; x++) net.fit(input, labels);
 | 
				
			||||||
 | 
					        Evaluation eval = net.evaluate(new INDArrayDataSetIterator(iter,2), Arrays.asList(new String[]{"One", "Two", "Three", "Four"}));
 | 
				
			||||||
 | 
					        // net.fit(input, labels);
 | 
				
			||||||
 | 
					        eval.eval(labels, net.output(input));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 **/
 | 
				
			||||||
 | 
					        boolean gradOK =
 | 
				
			||||||
 | 
					            GradientCheckUtil.checkGradients(
 | 
				
			||||||
 | 
					                net,
 | 
				
			||||||
 | 
					                DEFAULT_EPS,
 | 
				
			||||||
 | 
					                DEFAULT_MAX_REL_ERROR,
 | 
				
			||||||
 | 
					                DEFAULT_MIN_ABS_ERROR,
 | 
				
			||||||
 | 
					                PRINT_RESULTS,
 | 
				
			||||||
 | 
					                RETURN_ON_FIRST_FAILURE,
 | 
				
			||||||
 | 
					                input,
 | 
				
			||||||
 | 
					                labels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assertTrue(gradOK, msg);
 | 
				
			||||||
 | 
					        TestUtils.testModelSerialization(net);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void testCnn1DWithLocallyConnected1D() {
 | 
				
			||||||
 | 
					    Nd4j.getRandom().setSeed(1337);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int[] minibatchSizes = {2, 3};
 | 
				
			||||||
 | 
					    int length = 25;
 | 
				
			||||||
 | 
					    int convNIn = 18;
 | 
				
			||||||
 | 
					    int convNOut1 = 3;
 | 
				
			||||||
 | 
					    int convNOut2 = 4;
 | 
				
			||||||
 | 
					    int finalNOut = 4;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int[] kernels = {1,2,4};
 | 
				
			||||||
 | 
					    int stride = 1;
 | 
				
			||||||
 | 
					    int padding = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Activation[] activations = {Activation.SIGMOID};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (Activation afn : activations) {
 | 
				
			||||||
 | 
					      for (int minibatchSize : minibatchSizes) {
 | 
				
			||||||
 | 
					        for (int kernel : kernels) {
 | 
				
			||||||
 | 
					          INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
 | 
				
			||||||
 | 
					          INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
 | 
				
			||||||
 | 
					          for (int i = 0; i < minibatchSize; i++) {
 | 
				
			||||||
 | 
					            for (int j = 0; j < length; j++) {
 | 
				
			||||||
 | 
					              labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					              NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                  .dataType(DataType.DOUBLE)
 | 
				
			||||||
 | 
					                  .updater(new NoOp())
 | 
				
			||||||
 | 
					                  .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
 | 
					                  .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
 | 
					                  .layer(
 | 
				
			||||||
 | 
					                      Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                          .activation(afn)
 | 
				
			||||||
 | 
					                          .kernelSize(kernel)
 | 
				
			||||||
 | 
					                          .stride(stride)
 | 
				
			||||||
 | 
					                          .padding(padding)
 | 
				
			||||||
 | 
					                          .nIn(convNIn)
 | 
				
			||||||
 | 
					                          .nOut(convNOut1)
 | 
				
			||||||
 | 
					                          .rnnDataFormat(RNNFormat.NCW)
 | 
				
			||||||
 | 
					                          .build())
 | 
				
			||||||
 | 
					                  .layer(
 | 
				
			||||||
 | 
					                      LocallyConnected1D.builder()
 | 
				
			||||||
 | 
					                          .activation(afn)
 | 
				
			||||||
 | 
					                          .kernelSize(kernel)
 | 
				
			||||||
 | 
					                          .stride(stride)
 | 
				
			||||||
 | 
					                          .padding(padding)
 | 
				
			||||||
 | 
					                          .nIn(convNOut1)
 | 
				
			||||||
 | 
					                          .nOut(convNOut2)
 | 
				
			||||||
 | 
					                          .hasBias(false)
 | 
				
			||||||
 | 
					                          .build())
 | 
				
			||||||
 | 
					                  .layer(
 | 
				
			||||||
 | 
					                      RnnOutputLayer.builder()
 | 
				
			||||||
 | 
					                          .lossFunction(LossFunctions.LossFunction.MCXENT)
 | 
				
			||||||
 | 
					                          .activation(Activation.SOFTMAX)
 | 
				
			||||||
 | 
					                          .nOut(finalNOut)
 | 
				
			||||||
 | 
					                          .build())
 | 
				
			||||||
 | 
					                  .inputType(InputType.recurrent(convNIn, length))
 | 
				
			||||||
 | 
					                  .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          String json = conf.toJson();
 | 
				
			||||||
 | 
					          NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
 | 
				
			||||||
 | 
					          assertEquals(conf, c2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
 | 
					          net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          String msg =
 | 
				
			||||||
 | 
					              "Minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          if (PRINT_RESULTS) {
 | 
				
			||||||
 | 
					            System.out.println(msg);
 | 
				
			||||||
 | 
					            //                        for (int j = 0; j < net.getnLayers(); j++)
 | 
				
			||||||
 | 
					            //                            System.out.println("ILayer " + j + " # params: " +
 | 
				
			||||||
 | 
					            // net.getLayer(j).numParams());
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          boolean gradOK =
 | 
				
			||||||
 | 
					              GradientCheckUtil.checkGradients(
 | 
				
			||||||
 | 
					                  net,
 | 
				
			||||||
 | 
					                  DEFAULT_EPS,
 | 
				
			||||||
 | 
					                  DEFAULT_MAX_REL_ERROR,
 | 
				
			||||||
 | 
					                  DEFAULT_MIN_ABS_ERROR,
 | 
				
			||||||
 | 
					                  PRINT_RESULTS,
 | 
				
			||||||
 | 
					                  RETURN_ON_FIRST_FAILURE,
 | 
				
			||||||
 | 
					                  input,
 | 
				
			||||||
 | 
					                  labels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          assertTrue(gradOK, msg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          TestUtils.testModelSerialization(net);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void testCnn1DWithCropping1D() {
 | 
				
			||||||
 | 
					    Nd4j.getRandom().setSeed(1337);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int[] minibatchSizes = {1, 3};
 | 
				
			||||||
 | 
					    int length = 7;
 | 
				
			||||||
 | 
					    int convNIn = 2;
 | 
				
			||||||
 | 
					    int convNOut1 = 3;
 | 
				
			||||||
 | 
					    int convNOut2 = 4;
 | 
				
			||||||
 | 
					    int finalNOut = 4;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int[] kernels = {1, 2, 4};
 | 
				
			||||||
 | 
					    int stride = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int padding = 0;
 | 
				
			||||||
 | 
					    int cropping = 1;
 | 
				
			||||||
 | 
					    int croppedLength = length - 2 * cropping;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Activation[] activations = {Activation.SIGMOID};
 | 
				
			||||||
 | 
					    SubsamplingLayer.PoolingType[] poolingTypes =
 | 
				
			||||||
 | 
					        new SubsamplingLayer.PoolingType[] {
 | 
				
			||||||
 | 
					          SubsamplingLayer.PoolingType.MAX,
 | 
				
			||||||
 | 
					          SubsamplingLayer.PoolingType.AVG,
 | 
				
			||||||
 | 
					          SubsamplingLayer.PoolingType.PNORM
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (Activation afn : activations) {
 | 
				
			||||||
 | 
					      for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
 | 
				
			||||||
 | 
					        for (int minibatchSize : minibatchSizes) {
 | 
				
			||||||
 | 
					          for (int kernel : kernels) {
 | 
				
			||||||
 | 
					            INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
 | 
				
			||||||
 | 
					            INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, croppedLength);
 | 
				
			||||||
 | 
					            for (int i = 0; i < minibatchSize; i++) {
 | 
				
			||||||
 | 
					              for (int j = 0; j < croppedLength; j++) {
 | 
				
			||||||
 | 
					                labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					                NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                    .dataType(DataType.DOUBLE)
 | 
				
			||||||
 | 
					                    .updater(new NoOp())
 | 
				
			||||||
 | 
					                    .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
 | 
					                    .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                        Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                            .activation(afn)
 | 
				
			||||||
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
 | 
					                            .stride(stride)
 | 
				
			||||||
 | 
					                            .padding(padding)
 | 
				
			||||||
 | 
					                            .nOut(convNOut1)
 | 
				
			||||||
 | 
					                            .build())
 | 
				
			||||||
 | 
					                    .layer(Cropping1D.builder(cropping).build())
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                        Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                            .activation(afn)
 | 
				
			||||||
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
 | 
					                            .stride(stride)
 | 
				
			||||||
 | 
					                            .padding(padding)
 | 
				
			||||||
 | 
					                            .nOut(convNOut2)
 | 
				
			||||||
 | 
					                            .build())
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                        RnnOutputLayer.builder()
 | 
				
			||||||
 | 
					                            .lossFunction(LossFunctions.LossFunction.MCXENT)
 | 
				
			||||||
 | 
					                            .activation(Activation.SOFTMAX)
 | 
				
			||||||
 | 
					                            .nOut(finalNOut)
 | 
				
			||||||
 | 
					                            .build())
 | 
				
			||||||
 | 
					                    .inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            String json = conf.toJson();
 | 
				
			||||||
 | 
					            NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
 | 
				
			||||||
 | 
					            assertEquals(conf, c2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
 | 
					            net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            String msg =
 | 
				
			||||||
 | 
					                "PoolingType="
 | 
				
			||||||
 | 
					                    + poolingType
 | 
				
			||||||
 | 
					                    + ", minibatch="
 | 
				
			||||||
 | 
					                    + minibatchSize
 | 
				
			||||||
 | 
					                    + ", activationFn="
 | 
				
			||||||
 | 
					                    + afn
 | 
				
			||||||
 | 
					                    + ", kernel = "
 | 
				
			||||||
 | 
					                    + kernel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (PRINT_RESULTS) {
 | 
				
			||||||
 | 
					              System.out.println(msg);
 | 
				
			||||||
 | 
					              //                            for (int j = 0; j < net.getnLayers(); j++)
 | 
				
			||||||
 | 
					              //                                System.out.println("ILayer " + j + " # params: " +
 | 
				
			||||||
 | 
					              // net.getLayer(j).numParams());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            boolean gradOK =
 | 
				
			||||||
 | 
					                GradientCheckUtil.checkGradients(
 | 
				
			||||||
 | 
					                    net,
 | 
				
			||||||
 | 
					                    DEFAULT_EPS,
 | 
				
			||||||
 | 
					                    DEFAULT_MAX_REL_ERROR,
 | 
				
			||||||
 | 
					                    DEFAULT_MIN_ABS_ERROR,
 | 
				
			||||||
 | 
					                    PRINT_RESULTS,
 | 
				
			||||||
 | 
					                    RETURN_ON_FIRST_FAILURE,
 | 
				
			||||||
 | 
					                    input,
 | 
				
			||||||
 | 
					                    labels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            assertTrue(gradOK, msg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            TestUtils.testModelSerialization(net);
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void testCnn1DWithZeroPadding1D() {
 | 
				
			||||||
 | 
					    Nd4j.getRandom().setSeed(1337);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int[] minibatchSizes = {1, 3};
 | 
				
			||||||
 | 
					    int length = 7;
 | 
				
			||||||
 | 
					    int convNIn = 2;
 | 
				
			||||||
 | 
					    int convNOut1 = 3;
 | 
				
			||||||
 | 
					    int convNOut2 = 4;
 | 
				
			||||||
 | 
					    int finalNOut = 4;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int[] kernels = {1, 2, 4};
 | 
				
			||||||
 | 
					    int stride = 1;
 | 
				
			||||||
 | 
					    int pnorm = 2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int padding = 0;
 | 
				
			||||||
 | 
					    int zeroPadding = 2;
 | 
				
			||||||
 | 
					    int paddedLength = length + 2 * zeroPadding;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Activation[] activations = {Activation.SIGMOID};
 | 
				
			||||||
 | 
					    SubsamplingLayer.PoolingType[] poolingTypes =
 | 
				
			||||||
 | 
					        new SubsamplingLayer.PoolingType[] {
 | 
				
			||||||
 | 
					          SubsamplingLayer.PoolingType.MAX,
 | 
				
			||||||
 | 
					          SubsamplingLayer.PoolingType.AVG,
 | 
				
			||||||
 | 
					          SubsamplingLayer.PoolingType.PNORM
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (Activation afn : activations) {
 | 
				
			||||||
 | 
					      for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
 | 
				
			||||||
 | 
					        for (int minibatchSize : minibatchSizes) {
 | 
				
			||||||
 | 
					          for (int kernel : kernels) {
 | 
				
			||||||
 | 
					            INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
 | 
				
			||||||
 | 
					            INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, paddedLength);
 | 
				
			||||||
 | 
					            for (int i = 0; i < minibatchSize; i++) {
 | 
				
			||||||
 | 
					              for (int j = 0; j < paddedLength; j++) {
 | 
				
			||||||
 | 
					                labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					                NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                    .dataType(DataType.DOUBLE)
 | 
				
			||||||
 | 
					                    .updater(new NoOp())
 | 
				
			||||||
 | 
					                    .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
 | 
					                    .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                        Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                            .activation(afn)
 | 
				
			||||||
 | 
					                            .kernelSize(2, kernel)
 | 
				
			||||||
 | 
					                            .stride(stride)
 | 
				
			||||||
 | 
					                            .padding(padding)
 | 
				
			||||||
 | 
					                            .nOut(convNOut1)
 | 
				
			||||||
 | 
					                            .build())
 | 
				
			||||||
 | 
					                    .layer(ZeroPadding1DLayer.builder(zeroPadding).build())
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                        Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                            .activation(afn)
 | 
				
			||||||
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
 | 
					                            .stride(stride)
 | 
				
			||||||
 | 
					                            .padding(padding)
 | 
				
			||||||
 | 
					                            .nOut(convNOut2)
 | 
				
			||||||
 | 
					                            .build())
 | 
				
			||||||
 | 
					                    .layer(ZeroPadding1DLayer.builder(0).build())
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                        Subsampling1DLayer.builder(poolingType)
 | 
				
			||||||
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
 | 
					                            .stride(stride)
 | 
				
			||||||
 | 
					                            .padding(padding)
 | 
				
			||||||
 | 
					                            .pnorm(pnorm)
 | 
				
			||||||
 | 
					                            .build())
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                        RnnOutputLayer.builder()
 | 
				
			||||||
 | 
					                            .lossFunction(LossFunctions.LossFunction.MCXENT)
 | 
				
			||||||
 | 
					                            .activation(Activation.SOFTMAX)
 | 
				
			||||||
 | 
					                            .nOut(finalNOut)
 | 
				
			||||||
 | 
					                            .build())
 | 
				
			||||||
 | 
					                    .inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            String json = conf.toJson();
 | 
				
			||||||
 | 
					            NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
 | 
				
			||||||
 | 
					            assertEquals(conf, c2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
 | 
					            net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            String msg =
 | 
				
			||||||
 | 
					                "PoolingType="
 | 
				
			||||||
 | 
					                    + poolingType
 | 
				
			||||||
 | 
					                    + ", minibatch="
 | 
				
			||||||
 | 
					                    + minibatchSize
 | 
				
			||||||
 | 
					                    + ", activationFn="
 | 
				
			||||||
 | 
					                    + afn
 | 
				
			||||||
 | 
					                    + ", kernel = "
 | 
				
			||||||
 | 
					                    + kernel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (PRINT_RESULTS) {
 | 
				
			||||||
 | 
					              System.out.println(msg);
 | 
				
			||||||
 | 
					              //                            for (int j = 0; j < net.getnLayers(); j++)
 | 
				
			||||||
 | 
					              //                                System.out.println("ILayer " + j + " # params: " +
 | 
				
			||||||
 | 
					              // net.getLayer(j).numParams());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            boolean gradOK =
 | 
				
			||||||
 | 
					                GradientCheckUtil.checkGradients(
 | 
				
			||||||
 | 
					                    net,
 | 
				
			||||||
 | 
					                    DEFAULT_EPS,
 | 
				
			||||||
 | 
					                    DEFAULT_MAX_REL_ERROR,
 | 
				
			||||||
 | 
					                    DEFAULT_MIN_ABS_ERROR,
 | 
				
			||||||
 | 
					                    PRINT_RESULTS,
 | 
				
			||||||
 | 
					                    RETURN_ON_FIRST_FAILURE,
 | 
				
			||||||
 | 
					                    input,
 | 
				
			||||||
 | 
					                    labels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            assertTrue(gradOK, msg);
 | 
				
			||||||
 | 
					            TestUtils.testModelSerialization(net);
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void testCnn1DWithSubsampling1D() {
 | 
				
			||||||
 | 
					    Nd4j.getRandom().setSeed(12345);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int[] minibatchSizes = {1, 3};
 | 
				
			||||||
 | 
					    int length = 7;
 | 
				
			||||||
 | 
					    int convNIn = 2;
 | 
				
			||||||
 | 
					    int convNOut1 = 3;
 | 
				
			||||||
 | 
					    int convNOut2 = 4;
 | 
				
			||||||
 | 
					    int finalNOut = 4;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int[] kernels = {1, 2, 4};
 | 
				
			||||||
 | 
					    int stride = 1;
 | 
				
			||||||
 | 
					    int padding = 0;
 | 
				
			||||||
 | 
					    int pnorm = 2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Activation[] activations = {Activation.SIGMOID, Activation.TANH};
 | 
				
			||||||
 | 
					    SubsamplingLayer.PoolingType[] poolingTypes =
 | 
				
			||||||
 | 
					        new SubsamplingLayer.PoolingType[] {
 | 
				
			||||||
 | 
					          SubsamplingLayer.PoolingType.MAX,
 | 
				
			||||||
 | 
					          SubsamplingLayer.PoolingType.AVG,
 | 
				
			||||||
 | 
					          SubsamplingLayer.PoolingType.PNORM
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (Activation afn : activations) {
 | 
				
			||||||
 | 
					      for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
 | 
				
			||||||
 | 
					        for (int minibatchSize : minibatchSizes) {
 | 
				
			||||||
 | 
					          for (int kernel : kernels) {
 | 
				
			||||||
 | 
					            INDArray input = Nd4j.rand(minibatchSize, convNIn, length);
 | 
				
			||||||
 | 
					            INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
 | 
				
			||||||
 | 
					            for (int i = 0; i < minibatchSize; i++) {
 | 
				
			||||||
 | 
					              for (int j = 0; j < length; j++) {
 | 
				
			||||||
 | 
					                labels.putScalar(new int[] {i, i % finalNOut, j}, 1.0);
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					                NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                    .dataType(DataType.DOUBLE)
 | 
				
			||||||
 | 
					                    .updater(new NoOp())
 | 
				
			||||||
 | 
					                    .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
 | 
					                    .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                        0,
 | 
				
			||||||
 | 
					                        Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                            .activation(afn)
 | 
				
			||||||
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
 | 
					                            .stride(stride)
 | 
				
			||||||
 | 
					                            .padding(padding)
 | 
				
			||||||
 | 
					                            .nOut(convNOut1)
 | 
				
			||||||
 | 
					                            .build())
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                        1,
 | 
				
			||||||
 | 
					                        Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                            .activation(afn)
 | 
				
			||||||
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
 | 
					                            .stride(stride)
 | 
				
			||||||
 | 
					                            .padding(padding)
 | 
				
			||||||
 | 
					                            .nOut(convNOut2)
 | 
				
			||||||
 | 
					                            .build())
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                        2,
 | 
				
			||||||
 | 
					                        Subsampling1DLayer.builder(poolingType)
 | 
				
			||||||
 | 
					                            .kernelSize(kernel)
 | 
				
			||||||
 | 
					                            .stride(stride)
 | 
				
			||||||
 | 
					                            .padding(padding)
 | 
				
			||||||
 | 
					                            .pnorm(pnorm)
 | 
				
			||||||
 | 
					                            .name("SubsamplingLayer")
 | 
				
			||||||
 | 
					                            .build())
 | 
				
			||||||
 | 
					                    .layer(
 | 
				
			||||||
 | 
					                        3,
 | 
				
			||||||
 | 
					                        RnnOutputLayer.builder()
 | 
				
			||||||
 | 
					                            .lossFunction(LossFunctions.LossFunction.MCXENT)
 | 
				
			||||||
 | 
					                            .activation(Activation.SOFTMAX)
 | 
				
			||||||
 | 
					                            .nOut(finalNOut)
 | 
				
			||||||
 | 
					                            .build())
 | 
				
			||||||
 | 
					                    .inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            String json = conf.toJson();
 | 
				
			||||||
 | 
					            NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json);
 | 
				
			||||||
 | 
					            assertEquals(conf, c2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
 | 
					            net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            String msg =
 | 
				
			||||||
 | 
					                "PoolingType="
 | 
				
			||||||
 | 
					                    + poolingType
 | 
				
			||||||
 | 
					                    + ", minibatch="
 | 
				
			||||||
 | 
					                    + minibatchSize
 | 
				
			||||||
 | 
					                    + ", activationFn="
 | 
				
			||||||
 | 
					                    + afn
 | 
				
			||||||
 | 
					                    + ", kernel = "
 | 
				
			||||||
 | 
					                    + kernel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (PRINT_RESULTS) {
 | 
				
			||||||
 | 
					              System.out.println(msg);
 | 
				
			||||||
 | 
					              //                            for (int j = 0; j < net.getnLayers(); j++)
 | 
				
			||||||
 | 
					              //                                System.out.println("ILayer " + j + " # params: " +
 | 
				
			||||||
 | 
					              // net.getLayer(j).numParams());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            boolean gradOK =
 | 
				
			||||||
 | 
					                GradientCheckUtil.checkGradients(
 | 
				
			||||||
 | 
					                    net,
 | 
				
			||||||
 | 
					                    DEFAULT_EPS,
 | 
				
			||||||
 | 
					                    DEFAULT_MAX_REL_ERROR,
 | 
				
			||||||
 | 
					                    DEFAULT_MIN_ABS_ERROR,
 | 
				
			||||||
 | 
					                    PRINT_RESULTS,
 | 
				
			||||||
 | 
					                    RETURN_ON_FIRST_FAILURE,
 | 
				
			||||||
 | 
					                    input,
 | 
				
			||||||
 | 
					                    labels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            assertTrue(gradOK, msg);
 | 
				
			||||||
 | 
					            TestUtils.testModelSerialization(net);
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void testCnn1dWithMasking() {
 | 
				
			||||||
 | 
					    int length = 12;
 | 
				
			||||||
 | 
					    int convNIn = 2;
 | 
				
			||||||
 | 
					    int convNOut1 = 3;
 | 
				
			||||||
 | 
					    int convNOut2 = 4;
 | 
				
			||||||
 | 
					    int finalNOut = 3;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int pnorm = 2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    SubsamplingLayer.PoolingType[] poolingTypes =
 | 
				
			||||||
 | 
					        new SubsamplingLayer.PoolingType[] {
 | 
				
			||||||
 | 
					          SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
 | 
				
			||||||
 | 
					      for (ConvolutionMode cm :
 | 
				
			||||||
 | 
					          new ConvolutionMode[] {ConvolutionMode.Same, ConvolutionMode.Truncate}) {
 | 
				
			||||||
 | 
					        for (int stride : new int[] {1, 2}) {
 | 
				
			||||||
 | 
					          String s = cm + ", stride=" + stride + ", pooling=" + poolingType;
 | 
				
			||||||
 | 
					          log.info("Starting test: " + s);
 | 
				
			||||||
 | 
					          Nd4j.getRandom().setSeed(12345);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					              NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					                  .dataType(DataType.DOUBLE)
 | 
				
			||||||
 | 
					                  .updater(new NoOp())
 | 
				
			||||||
 | 
					                  .activation(Activation.TANH)
 | 
				
			||||||
 | 
					                  .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
 | 
					                  .convolutionMode(cm)
 | 
				
			||||||
 | 
					                  .seed(12345)
 | 
				
			||||||
 | 
					                  .layer(
 | 
				
			||||||
 | 
					                      Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                          .kernelSize(2)
 | 
				
			||||||
 | 
					                          .rnnDataFormat(RNNFormat.NCW)
 | 
				
			||||||
 | 
					                          .stride(stride)
 | 
				
			||||||
 | 
					                          .nIn(convNIn)
 | 
				
			||||||
 | 
					                          .nOut(convNOut1)
 | 
				
			||||||
 | 
					                          .build())
 | 
				
			||||||
 | 
					                  .layer(
 | 
				
			||||||
 | 
					                      Subsampling1DLayer.builder(poolingType)
 | 
				
			||||||
 | 
					                          .kernelSize(2)
 | 
				
			||||||
 | 
					                          .stride(stride)
 | 
				
			||||||
 | 
					                          .pnorm(pnorm)
 | 
				
			||||||
 | 
					                          .build())
 | 
				
			||||||
 | 
					                  .layer(
 | 
				
			||||||
 | 
					                      Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                          .kernelSize(2)
 | 
				
			||||||
 | 
					                          .rnnDataFormat(RNNFormat.NCW)
 | 
				
			||||||
 | 
					                          .stride(stride)
 | 
				
			||||||
 | 
					                          .nIn(convNOut1)
 | 
				
			||||||
 | 
					                          .nOut(convNOut2)
 | 
				
			||||||
 | 
					                          .build())
 | 
				
			||||||
 | 
					                  .layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
 | 
				
			||||||
 | 
					                  .layer(
 | 
				
			||||||
 | 
					                      OutputLayer.builder()
 | 
				
			||||||
 | 
					                          .lossFunction(LossFunctions.LossFunction.MCXENT)
 | 
				
			||||||
 | 
					                          .activation(Activation.SOFTMAX)
 | 
				
			||||||
 | 
					                          .nOut(finalNOut)
 | 
				
			||||||
 | 
					                          .build())
 | 
				
			||||||
 | 
					                  .inputType(InputType.recurrent(convNIn, length))
 | 
				
			||||||
 | 
					                  .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
 | 
					          net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          INDArray f = Nd4j.rand(2, convNIn, length);
 | 
				
			||||||
 | 
					          INDArray fm = Nd4j.create(2, length);
 | 
				
			||||||
 | 
					          fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
 | 
				
			||||||
 | 
					          fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, 6)).assign(1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          INDArray label = TestUtils.randomOneHot(2, finalNOut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          boolean gradOK =
 | 
				
			||||||
 | 
					              GradientCheckUtil.checkGradients(
 | 
				
			||||||
 | 
					                  new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          assertTrue(gradOK, s);
 | 
				
			||||||
 | 
					          TestUtils.testModelSerialization(net);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          // TODO also check that masked step values don't impact forward pass, score or gradients
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          DataSet ds = new DataSet(f, label, fm, null);
 | 
				
			||||||
 | 
					          double scoreBefore = net.score(ds);
 | 
				
			||||||
 | 
					          net.setInput(f);
 | 
				
			||||||
 | 
					          net.setLabels(label);
 | 
				
			||||||
 | 
					          net.setLayerMaskArrays(fm, null);
 | 
				
			||||||
 | 
					          net.computeGradientAndScore();
 | 
				
			||||||
 | 
					          INDArray gradBefore = net.getFlattenedGradients().dup();
 | 
				
			||||||
 | 
					          f.putScalar(1, 0, 10, 10.0);
 | 
				
			||||||
 | 
					          f.putScalar(1, 1, 11, 20.0);
 | 
				
			||||||
 | 
					          double scoreAfter = net.score(ds);
 | 
				
			||||||
 | 
					          net.setInput(f);
 | 
				
			||||||
 | 
					          net.setLabels(label);
 | 
				
			||||||
 | 
					          net.setLayerMaskArrays(fm, null);
 | 
				
			||||||
 | 
					          net.computeGradientAndScore();
 | 
				
			||||||
 | 
					          INDArray gradAfter = net.getFlattenedGradients().dup();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          assertEquals(scoreBefore, scoreAfter, 1e-6);
 | 
				
			||||||
 | 
					          assertEquals(gradBefore, gradAfter);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void testCnn1Causal() throws Exception {
 | 
				
			||||||
 | 
					    int convNIn = 2;
 | 
				
			||||||
 | 
					    int convNOut1 = 3;
 | 
				
			||||||
 | 
					    int convNOut2 = 4;
 | 
				
			||||||
 | 
					    int finalNOut = 3;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int[] lengths = {11, 12, 13, 9, 10, 11};
 | 
				
			||||||
 | 
					    int[] kernels = {2, 3, 2, 4, 2, 3};
 | 
				
			||||||
 | 
					    int[] dilations = {1, 1, 2, 1, 2, 1};
 | 
				
			||||||
 | 
					    int[] strides = {1, 2, 1, 2, 1, 1};
 | 
				
			||||||
 | 
					    boolean[] masks = {false, true, false, true, false, true};
 | 
				
			||||||
 | 
					    boolean[] hasB = {true, false, true, false, true, true};
 | 
				
			||||||
 | 
					    for (int i = 0; i < lengths.length; i++) {
 | 
				
			||||||
 | 
					      int length = lengths[i];
 | 
				
			||||||
 | 
					      int k = kernels[i];
 | 
				
			||||||
 | 
					      int d = dilations[i];
 | 
				
			||||||
 | 
					      int st = strides[i];
 | 
				
			||||||
 | 
					      boolean mask = masks[i];
 | 
				
			||||||
 | 
					      boolean hasBias = hasB[i];
 | 
				
			||||||
 | 
					      // TODO has bias
 | 
				
			||||||
 | 
					      String s = "k=" + k + ", s=" + st + " d=" + d + ", seqLen=" + length;
 | 
				
			||||||
 | 
					      log.info("Starting test: " + s);
 | 
				
			||||||
 | 
					      Nd4j.getRandom().setSeed(12345);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      NeuralNetConfiguration conf =
 | 
				
			||||||
 | 
					          NeuralNetConfiguration.builder()
 | 
				
			||||||
 | 
					              .dataType(DataType.DOUBLE)
 | 
				
			||||||
 | 
					              .updater(new NoOp())
 | 
				
			||||||
 | 
					              .activation(Activation.TANH)
 | 
				
			||||||
 | 
					              .weightInit(new NormalDistribution(0, 1))
 | 
				
			||||||
 | 
					              .seed(12345)
 | 
				
			||||||
 | 
					              .layer(
 | 
				
			||||||
 | 
					                  Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                      .kernelSize(k)
 | 
				
			||||||
 | 
					                      .dilation(d)
 | 
				
			||||||
 | 
					                      .hasBias(hasBias)
 | 
				
			||||||
 | 
					                      .convolutionMode(ConvolutionMode.Causal)
 | 
				
			||||||
 | 
					                      .stride(st)
 | 
				
			||||||
 | 
					                      .nOut(convNOut1)
 | 
				
			||||||
 | 
					                      .build())
 | 
				
			||||||
 | 
					              .layer(
 | 
				
			||||||
 | 
					                  Convolution1DNew.builder()
 | 
				
			||||||
 | 
					                      .kernelSize(k)
 | 
				
			||||||
 | 
					                      .dilation(d)
 | 
				
			||||||
 | 
					                      .convolutionMode(ConvolutionMode.Causal)
 | 
				
			||||||
 | 
					                      .stride(st)
 | 
				
			||||||
 | 
					                      .nOut(convNOut2)
 | 
				
			||||||
 | 
					                      .build())
 | 
				
			||||||
 | 
					              .layer(
 | 
				
			||||||
 | 
					                  RnnOutputLayer.builder()
 | 
				
			||||||
 | 
					                      .lossFunction(LossFunctions.LossFunction.MCXENT)
 | 
				
			||||||
 | 
					                      .activation(Activation.SOFTMAX)
 | 
				
			||||||
 | 
					                      .nOut(finalNOut)
 | 
				
			||||||
 | 
					                      .build())
 | 
				
			||||||
 | 
					              .inputType(InputType.recurrent(convNIn, length, RNNFormat.NCW))
 | 
				
			||||||
 | 
					              .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
 | 
					      net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length);
 | 
				
			||||||
 | 
					      INDArray fm = null;
 | 
				
			||||||
 | 
					      if (mask) {
 | 
				
			||||||
 | 
					        fm = Nd4j.create(2, length);
 | 
				
			||||||
 | 
					        fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
 | 
				
			||||||
 | 
					        fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length - 2)).assign(1);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d);
 | 
				
			||||||
 | 
					      long outSize2 =
 | 
				
			||||||
 | 
					          Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int) outSize2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      String msg =
 | 
				
			||||||
 | 
					              "Minibatch="
 | 
				
			||||||
 | 
					                      + 1
 | 
				
			||||||
 | 
					                      + ", activationFn="
 | 
				
			||||||
 | 
					                      + Activation.RELU
 | 
				
			||||||
 | 
					                      + ", kernel = "
 | 
				
			||||||
 | 
					                      + k;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      System.out.println(msg);
 | 
				
			||||||
 | 
					      for (int j = 0; j < net.getnLayers(); j++)
 | 
				
			||||||
 | 
					        System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      boolean gradOK =
 | 
				
			||||||
 | 
					          GradientCheckUtil.checkGradients(
 | 
				
			||||||
 | 
					              new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      assertTrue(gradOK, s);
 | 
				
			||||||
 | 
					      TestUtils.testModelSerialization(net);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -112,9 +112,8 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                                            NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
					                                            NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
 | 
				
			||||||
                                                    .dataType(DataType.DOUBLE)
 | 
					                                                    .dataType(DataType.DOUBLE)
 | 
				
			||||||
                                                    .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
 | 
					                                                    .updater(new NoOp())
 | 
				
			||||||
                                                    .dist(new NormalDistribution(0, 1))
 | 
					                                                    .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
                                                    .list()
 | 
					 | 
				
			||||||
                                                    .layer(0, Convolution3D.builder().activation(afn).kernelSize(kernel)
 | 
					                                                    .layer(0, Convolution3D.builder().activation(afn).kernelSize(kernel)
 | 
				
			||||||
                                                            .stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false)
 | 
					                                                            .stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false)
 | 
				
			||||||
                                                            .convolutionMode(mode).dataFormat(df)
 | 
					                                                            .convolutionMode(mode).dataFormat(df)
 | 
				
			||||||
@ -400,7 +399,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                                .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
 | 
					                                .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
 | 
				
			||||||
                                .dist(new NormalDistribution(0, 1))
 | 
					                                .dist(new NormalDistribution(0, 1))
 | 
				
			||||||
                                .seed(12345)
 | 
					                                .seed(12345)
 | 
				
			||||||
                                .list()
 | 
					 | 
				
			||||||
                                .layer(0, Convolution3D.builder().activation(afn).kernelSize(1, 1, 1)
 | 
					                                .layer(0, Convolution3D.builder().activation(afn).kernelSize(1, 1, 1)
 | 
				
			||||||
                                        .nIn(convNIn).nOut(convNOut).hasBias(false)
 | 
					                                        .nIn(convNIn).nOut(convNOut).hasBias(false)
 | 
				
			||||||
                                        .convolutionMode(mode).dataFormat(df)
 | 
					                                        .convolutionMode(mode).dataFormat(df)
 | 
				
			||||||
 | 
				
			|||||||
@ -108,8 +108,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                  .updater(new NoOp())
 | 
					                  .updater(new NoOp())
 | 
				
			||||||
                  .weightInit(WeightInit.XAVIER)
 | 
					                  .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
                  .seed(12345L)
 | 
					                  .seed(12345L)
 | 
				
			||||||
                  .list()
 | 
					
 | 
				
			||||||
                  .layer(0, ConvolutionLayer.builder(1, 1).nOut(6).activation(afn).build())
 | 
					                  .layer(0, Convolution2D.builder().kernelSize(1).stride(1).nOut(6).activation(afn).build())
 | 
				
			||||||
                  .layer(1, OutputLayer.builder(lf).activation(outputActivation).nOut(3).build())
 | 
					                  .layer(1, OutputLayer.builder(lf).activation(outputActivation).nOut(3).build())
 | 
				
			||||||
                  .inputType(InputType.convolutionalFlat(1, 4, 1));
 | 
					                  .inputType(InputType.convolutionalFlat(1, 4, 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -32,6 +32,7 @@ import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.LossLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.serde.CavisMapper;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.junit.jupiter.api.Test;
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
import org.nd4j.linalg.activations.Activation;
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
@ -336,7 +337,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
 | 
				
			|||||||
                // to ensure that we carry the parameters through
 | 
					                // to ensure that we carry the parameters through
 | 
				
			||||||
                // the serializer.
 | 
					                // the serializer.
 | 
				
			||||||
                try{
 | 
					                try{
 | 
				
			||||||
                    ObjectMapper m = NeuralNetConfiguration.mapper();
 | 
					                    ObjectMapper m = CavisMapper.getMapper(CavisMapper.Type.JSON);
 | 
				
			||||||
                    String s = m.writeValueAsString(lossFunctions[i]);
 | 
					                    String s = m.writeValueAsString(lossFunctions[i]);
 | 
				
			||||||
                    ILossFunction lf2 = m.readValue(s, lossFunctions[i].getClass());
 | 
					                    ILossFunction lf2 = m.readValue(s, lossFunctions[i].getClass());
 | 
				
			||||||
                    lossFunctions[i] = lf2;
 | 
					                    lossFunctions[i] = lf2;
 | 
				
			||||||
 | 
				
			|||||||
@ -180,7 +180,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
            Pooling2D.class,        //Alias for SubsamplingLayer
 | 
					            Pooling2D.class,        //Alias for SubsamplingLayer
 | 
				
			||||||
            Convolution2D.class,    //Alias for ConvolutionLayer
 | 
					            Convolution2D.class,    //Alias for ConvolutionLayer
 | 
				
			||||||
            Pooling1D.class,        //Alias for Subsampling1D
 | 
					            Pooling1D.class,        //Alias for Subsampling1D
 | 
				
			||||||
            Convolution1D.class,    //Alias for  Convolution1DLayer
 | 
					            Convolution1D.class,    //Alias for  Convolution1D
 | 
				
			||||||
            TensorFlowCnnToFeedForwardPreProcessor.class    //Deprecated
 | 
					            TensorFlowCnnToFeedForwardPreProcessor.class    //Deprecated
 | 
				
			||||||
    ));
 | 
					    ));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -37,7 +37,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
 | 
				
			|||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.deeplearning4j.nn.workspace.ArrayType;
 | 
					import org.deeplearning4j.nn.workspace.ArrayType;
 | 
				
			||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
					import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
				
			||||||
import org.deeplearning4j.util.ConvolutionUtils;
 | 
					import org.deeplearning4j.util.Convolution2DUtils;
 | 
				
			||||||
import org.junit.jupiter.api.Test;
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
import org.junit.jupiter.api.Timeout;
 | 
					import org.junit.jupiter.api.Timeout;
 | 
				
			||||||
import org.nd4j.linalg.activations.Activation;
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
@ -1026,7 +1026,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
 | 
				
			|||||||
                } catch (DL4JInvalidInputException e) {
 | 
					                } catch (DL4JInvalidInputException e) {
 | 
				
			||||||
//                    e.printStackTrace();
 | 
					//                    e.printStackTrace();
 | 
				
			||||||
                    String msg = e.getMessage();
 | 
					                    String msg = e.getMessage();
 | 
				
			||||||
                    assertTrue(msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration"), msg);
 | 
					                    assertTrue(msg.contains(Convolution2DUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration"), msg);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
				
			|||||||
@ -36,7 +36,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.RNNFormat;
 | 
					import org.deeplearning4j.nn.conf.RNNFormat;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.*;
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.Convolution1D;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
					import org.deeplearning4j.nn.weights.WeightInit;
 | 
				
			||||||
@ -921,7 +921,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        NeuralNetConfiguration.builder()
 | 
					        NeuralNetConfiguration.builder()
 | 
				
			||||||
            .convolutionMode(ConvolutionMode.Same)
 | 
					            .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
            .layer(
 | 
					            .layer(
 | 
				
			||||||
                Convolution1DLayer.builder()
 | 
					                Convolution1D.builder()
 | 
				
			||||||
                    .nOut(3)
 | 
					                    .nOut(3)
 | 
				
			||||||
                    .kernelSize(2)
 | 
					                    .kernelSize(2)
 | 
				
			||||||
                    .activation(Activation.TANH)
 | 
					                    .activation(Activation.TANH)
 | 
				
			||||||
@ -975,7 +975,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  @Test
 | 
					  @Test
 | 
				
			||||||
  public void testConv1dCausalAllowed() {
 | 
					  public void testConv1dCausalAllowed() {
 | 
				
			||||||
    Convolution1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
 | 
					    Convolution1D.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
 | 
				
			||||||
    Subsampling1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
 | 
					    Subsampling1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			|||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
					import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
					import org.deeplearning4j.nn.weights.WeightInit;
 | 
				
			||||||
import org.deeplearning4j.util.ConvolutionUtils;
 | 
					import org.deeplearning4j.util.Convolution2DUtils;
 | 
				
			||||||
import org.junit.jupiter.api.Test;
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
import org.nd4j.linalg.activations.Activation;
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
@ -346,7 +346,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(2, it.getHeight());
 | 
					        assertEquals(2, it.getHeight());
 | 
				
			||||||
        assertEquals(2, it.getWidth());
 | 
					        assertEquals(2, it.getWidth());
 | 
				
			||||||
        assertEquals(dOut, it.getChannels());
 | 
					        assertEquals(dOut, it.getChannels());
 | 
				
			||||||
        int[] outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
 | 
					        int[] outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
 | 
				
			||||||
        assertEquals(2, outSize[0]);
 | 
					        assertEquals(2, outSize[0]);
 | 
				
			||||||
        assertEquals(2, outSize[1]);
 | 
					        assertEquals(2, outSize[1]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -357,7 +357,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(2, it.getHeight());
 | 
					        assertEquals(2, it.getHeight());
 | 
				
			||||||
        assertEquals(2, it.getWidth());
 | 
					        assertEquals(2, it.getWidth());
 | 
				
			||||||
        assertEquals(dOut, it.getChannels());
 | 
					        assertEquals(dOut, it.getChannels());
 | 
				
			||||||
        outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
 | 
					        outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
 | 
				
			||||||
        assertEquals(2, outSize[0]);
 | 
					        assertEquals(2, outSize[0]);
 | 
				
			||||||
        assertEquals(2, outSize[1]);
 | 
					        assertEquals(2, outSize[1]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -367,7 +367,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(3, it.getHeight());
 | 
					        assertEquals(3, it.getHeight());
 | 
				
			||||||
        assertEquals(3, it.getWidth());
 | 
					        assertEquals(3, it.getWidth());
 | 
				
			||||||
        assertEquals(dOut, it.getChannels());
 | 
					        assertEquals(dOut, it.getChannels());
 | 
				
			||||||
        outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
 | 
					        outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
 | 
				
			||||||
        assertEquals(3, outSize[0]);
 | 
					        assertEquals(3, outSize[0]);
 | 
				
			||||||
        assertEquals(3, outSize[1]);
 | 
					        assertEquals(3, outSize[1]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -397,7 +397,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
 | 
				
			|||||||
            System.out.println(e.getMessage());
 | 
					            System.out.println(e.getMessage());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        try {
 | 
					        try {
 | 
				
			||||||
            outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
 | 
					            outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
 | 
				
			||||||
            fail("Exception expected");
 | 
					            fail("Exception expected");
 | 
				
			||||||
        } catch (DL4JException e) {
 | 
					        } catch (DL4JException e) {
 | 
				
			||||||
            System.out.println(e.getMessage());
 | 
					            System.out.println(e.getMessage());
 | 
				
			||||||
@ -409,7 +409,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(1, it.getHeight());
 | 
					        assertEquals(1, it.getHeight());
 | 
				
			||||||
        assertEquals(1, it.getWidth());
 | 
					        assertEquals(1, it.getWidth());
 | 
				
			||||||
        assertEquals(dOut, it.getChannels());
 | 
					        assertEquals(dOut, it.getChannels());
 | 
				
			||||||
        outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
 | 
					        outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
 | 
				
			||||||
        assertEquals(1, outSize[0]);
 | 
					        assertEquals(1, outSize[0]);
 | 
				
			||||||
        assertEquals(1, outSize[1]);
 | 
					        assertEquals(1, outSize[1]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -419,7 +419,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(2, it.getHeight());
 | 
					        assertEquals(2, it.getHeight());
 | 
				
			||||||
        assertEquals(2, it.getWidth());
 | 
					        assertEquals(2, it.getWidth());
 | 
				
			||||||
        assertEquals(dOut, it.getChannels());
 | 
					        assertEquals(dOut, it.getChannels());
 | 
				
			||||||
        outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
 | 
					        outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
 | 
				
			||||||
        assertEquals(2, outSize[0]);
 | 
					        assertEquals(2, outSize[0]);
 | 
				
			||||||
        assertEquals(2, outSize[1]);
 | 
					        assertEquals(2, outSize[1]);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -732,7 +732,7 @@ public class BatchNormalizationTest extends BaseDL4JTest {
 | 
				
			|||||||
                    .weightInit(WeightInit.XAVIER)
 | 
					                    .weightInit(WeightInit.XAVIER)
 | 
				
			||||||
                    .convolutionMode(ConvolutionMode.Same)
 | 
					                    .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
                    .layer(rnn ? LSTM.builder().nOut(3).build() :
 | 
					                    .layer(rnn ? LSTM.builder().nOut(3).build() :
 | 
				
			||||||
                            Convolution1DLayer.builder().kernelSize(3).stride(1).nOut(3).build())
 | 
					                            Convolution1D.builder().kernelSize(3).stride(1).nOut(3).build())
 | 
				
			||||||
                    .layer(BatchNormalization.builder().build())
 | 
					                    .layer(BatchNormalization.builder().build())
 | 
				
			||||||
                    .layer(RnnOutputLayer.builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
 | 
					                    .layer(RnnOutputLayer.builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
 | 
				
			||||||
                    .inputType(InputType.recurrent(3))
 | 
					                    .inputType(InputType.recurrent(3))
 | 
				
			||||||
 | 
				
			|||||||
@ -52,7 +52,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest {
 | 
				
			|||||||
                .graphBuilder()
 | 
					                .graphBuilder()
 | 
				
			||||||
                .addInputs(inputName)
 | 
					                .addInputs(inputName)
 | 
				
			||||||
                .setOutputs(output)
 | 
					                .setOutputs(output)
 | 
				
			||||||
                .layer(conv, Convolution1DLayer.builder(7)
 | 
					                .layer(conv, Convolution1D.builder(7)
 | 
				
			||||||
                        .convolutionMode(ConvolutionMode.Same)
 | 
					                        .convolutionMode(ConvolutionMode.Same)
 | 
				
			||||||
                        .nOut(input.size(1))
 | 
					                        .nOut(input.size(1))
 | 
				
			||||||
                        .weightInit(new WeightInitIdentity())
 | 
					                        .weightInit(new WeightInitIdentity())
 | 
				
			||||||
 | 
				
			|||||||
@ -23,6 +23,7 @@ package org.deeplearning4j.regressiontest;
 | 
				
			|||||||
import org.deeplearning4j.BaseDL4JTest;
 | 
					import org.deeplearning4j.BaseDL4JTest;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.*;
 | 
					import org.deeplearning4j.nn.conf.distribution.*;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.serde.CavisMapper;
 | 
				
			||||||
import org.junit.jupiter.api.Test;
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
import com.fasterxml.jackson.databind.ObjectMapper;
 | 
					import com.fasterxml.jackson.databind.ObjectMapper;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -38,7 +39,7 @@ public class TestDistributionDeserializer extends BaseDL4JTest {
 | 
				
			|||||||
                        new Distribution[] {new NormalDistribution(3, 0.5), new UniformDistribution(-2, 1),
 | 
					                        new Distribution[] {new NormalDistribution(3, 0.5), new UniformDistribution(-2, 1),
 | 
				
			||||||
                                        new GaussianDistribution(2, 1.0), new BinomialDistribution(10, 0.3)};
 | 
					                                        new GaussianDistribution(2, 1.0), new BinomialDistribution(10, 0.3)};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ObjectMapper om = NeuralNetConfiguration.mapper();
 | 
					        ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (Distribution d : distributions) {
 | 
					        for (Distribution d : distributions) {
 | 
				
			||||||
            String json = om.writeValueAsString(d);
 | 
					            String json = om.writeValueAsString(d);
 | 
				
			||||||
@ -50,7 +51,7 @@ public class TestDistributionDeserializer extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testDistributionDeserializerLegacyFormat() throws Exception {
 | 
					    public void testDistributionDeserializerLegacyFormat() throws Exception {
 | 
				
			||||||
        ObjectMapper om = NeuralNetConfiguration.mapper();
 | 
					        ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String normalJson = "{\n" + "          \"normal\" : {\n" + "            \"mean\" : 0.1,\n"
 | 
					        String normalJson = "{\n" + "          \"normal\" : {\n" + "            \"mean\" : 0.1,\n"
 | 
				
			||||||
                        + "            \"std\" : 1.2\n" + "          }\n" + "        }";
 | 
					                        + "            \"std\" : 1.2\n" + "          }\n" + "        }";
 | 
				
			||||||
 | 
				
			|||||||
@ -38,7 +38,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
 | 
				
			|||||||
import org.deeplearning4j.cuda.BaseCudnnHelper;
 | 
					import org.deeplearning4j.cuda.BaseCudnnHelper;
 | 
				
			||||||
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
 | 
					import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
 | 
				
			||||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
 | 
					import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
 | 
				
			||||||
import org.deeplearning4j.util.ConvolutionUtils;
 | 
					import org.deeplearning4j.util.Convolution2DUtils;
 | 
				
			||||||
import org.nd4j.jita.allocator.Allocator;
 | 
					import org.nd4j.jita.allocator.Allocator;
 | 
				
			||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
 | 
					import org.nd4j.jita.allocator.impl.AtomicAllocator;
 | 
				
			||||||
import org.nd4j.jita.conf.CudaEnvironment;
 | 
					import org.nd4j.jita.conf.CudaEnvironment;
 | 
				
			||||||
@ -681,9 +681,9 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        int[] outSize;
 | 
					        int[] outSize;
 | 
				
			||||||
        if (convolutionMode == ConvolutionMode.Same) {
 | 
					        if (convolutionMode == ConvolutionMode.Same) {
 | 
				
			||||||
            outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
 | 
					            outSize = Convolution2DUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
 | 
				
			||||||
            padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
 | 
					            padding = Convolution2DUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
 | 
				
			||||||
            int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
 | 
					            int[] padBottomRight = Convolution2DUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
 | 
				
			||||||
            if(!Arrays.equals(padding, padBottomRight)){
 | 
					            if(!Arrays.equals(padding, padBottomRight)){
 | 
				
			||||||
                /*
 | 
					                /*
 | 
				
			||||||
                CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
 | 
					                CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
 | 
				
			||||||
@ -731,7 +731,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
 | 
				
			|||||||
                // CuDNN handle
 | 
					                // CuDNN handle
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
 | 
					            outSize = Convolution2DUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
 | 
					        return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
 | 
				
			||||||
 | 
				
			|||||||
@ -42,7 +42,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
 | 
				
			|||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
 | 
					import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
 | 
					import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils;
 | 
					import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils;
 | 
				
			||||||
import org.deeplearning4j.util.ConvolutionUtils;
 | 
					import org.deeplearning4j.util.Convolution2DUtils;
 | 
				
			||||||
import org.nd4j.common.primitives.Counter;
 | 
					import org.nd4j.common.primitives.Counter;
 | 
				
			||||||
import org.nd4j.common.primitives.Pair;
 | 
					import org.nd4j.common.primitives.Pair;
 | 
				
			||||||
import org.nd4j.linalg.learning.config.IUpdater;
 | 
					import org.nd4j.linalg.learning.config.IUpdater;
 | 
				
			||||||
@ -442,8 +442,8 @@ public class KerasModel {
 | 
				
			|||||||
                    KerasInput kerasInput = (KerasInput) layer;
 | 
					                    KerasInput kerasInput = (KerasInput) layer;
 | 
				
			||||||
                    LayerConfiguration layer1 = layersOrdered.get(kerasLayerIdx + 1).layer;
 | 
					                    LayerConfiguration layer1 = layersOrdered.get(kerasLayerIdx + 1).layer;
 | 
				
			||||||
                    //no dim order, try to pull it from the next layer if there is one
 | 
					                    //no dim order, try to pull it from the next layer if there is one
 | 
				
			||||||
                    if(ConvolutionUtils.layerHasConvolutionLayout(layer1)) {
 | 
					                    if(Convolution2DUtils.layerHasConvolutionLayout(layer1)) {
 | 
				
			||||||
                        CNN2DFormat formatForLayer = ConvolutionUtils.getFormatForLayer(layer1);
 | 
					                        CNN2DFormat formatForLayer = Convolution2DUtils.getFormatForLayer(layer1);
 | 
				
			||||||
                        if(formatForLayer == CNN2DFormat.NCHW) {
 | 
					                        if(formatForLayer == CNN2DFormat.NCHW) {
 | 
				
			||||||
                            dimOrder = KerasLayer.DimOrder.THEANO;
 | 
					                            dimOrder = KerasLayer.DimOrder.THEANO;
 | 
				
			||||||
                        }  else if(formatForLayer == CNN2DFormat.NHWC) {
 | 
					                        }  else if(formatForLayer == CNN2DFormat.NHWC) {
 | 
				
			||||||
 | 
				
			|||||||
@ -52,28 +52,44 @@ public class KerasSequentialModel extends KerasModel {
 | 
				
			|||||||
   * @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
 | 
					   * @throws UnsupportedKerasConfigurationException Unsupported Keras configuration
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public KerasSequentialModel(KerasModelBuilder modelBuilder)
 | 
					  public KerasSequentialModel(KerasModelBuilder modelBuilder)
 | 
				
			||||||
            throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
 | 
					      throws UnsupportedKerasConfigurationException,
 | 
				
			||||||
        this(modelBuilder.getModelJson(), modelBuilder.getModelYaml(), modelBuilder.getWeightsArchive(),
 | 
					          IOException,
 | 
				
			||||||
                modelBuilder.getWeightsRoot(), modelBuilder.getTrainingJson(), modelBuilder.getTrainingArchive(),
 | 
					          InvalidKerasConfigurationException {
 | 
				
			||||||
                modelBuilder.isEnforceTrainingConfig(), modelBuilder.getInputShape());
 | 
					    this(
 | 
				
			||||||
 | 
					        modelBuilder.getModelJson(),
 | 
				
			||||||
 | 
					        modelBuilder.getModelYaml(),
 | 
				
			||||||
 | 
					        modelBuilder.getWeightsArchive(),
 | 
				
			||||||
 | 
					        modelBuilder.getWeightsRoot(),
 | 
				
			||||||
 | 
					        modelBuilder.getTrainingJson(),
 | 
				
			||||||
 | 
					        modelBuilder.getTrainingArchive(),
 | 
				
			||||||
 | 
					        modelBuilder.isEnforceTrainingConfig(),
 | 
				
			||||||
 | 
					        modelBuilder.getInputShape());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
     * (Not recommended) Constructor for Sequential model from model configuration
 | 
					   * (Not recommended) Constructor for Sequential model from model configuration (JSON or YAML),
 | 
				
			||||||
     * (JSON or YAML), training configuration (JSON), weights, and "training mode"
 | 
					   * training configuration (JSON), weights, and "training mode" boolean indicator. When built in
 | 
				
			||||||
     * boolean indicator. When built in training mode, certain unsupported configurations
 | 
					   * training mode, certain unsupported configurations (e.g., unknown regularizers) will throw
 | 
				
			||||||
     * (e.g., unknown regularizers) will throw Exceptions. When enforceTrainingConfig=false, these
 | 
					   * Exceptions. When enforceTrainingConfig=false, these will generate warnings but will be
 | 
				
			||||||
     * will generate warnings but will be otherwise ignored.
 | 
					   * otherwise ignored.
 | 
				
			||||||
   *
 | 
					   *
 | 
				
			||||||
   * @param modelJson model configuration JSON string
 | 
					   * @param modelJson model configuration JSON string
 | 
				
			||||||
   * @param modelYaml model configuration YAML string
 | 
					   * @param modelYaml model configuration YAML string
 | 
				
			||||||
   * @param trainingJson training configuration JSON string
 | 
					   * @param trainingJson training configuration JSON string
 | 
				
			||||||
   * @throws IOException I/O exception
 | 
					   * @throws IOException I/O exception
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
    public KerasSequentialModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot,
 | 
					  public KerasSequentialModel(
 | 
				
			||||||
                                String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig,
 | 
					      String modelJson,
 | 
				
			||||||
 | 
					      String modelYaml,
 | 
				
			||||||
 | 
					      Hdf5Archive weightsArchive,
 | 
				
			||||||
 | 
					      String weightsRoot,
 | 
				
			||||||
 | 
					      String trainingJson,
 | 
				
			||||||
 | 
					      Hdf5Archive trainingArchive,
 | 
				
			||||||
 | 
					      boolean enforceTrainingConfig,
 | 
				
			||||||
      int[] inputShape)
 | 
					      int[] inputShape)
 | 
				
			||||||
            throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
 | 
					      throws IOException,
 | 
				
			||||||
 | 
					          InvalidKerasConfigurationException,
 | 
				
			||||||
 | 
					          UnsupportedKerasConfigurationException {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
 | 
					    Map<String, Object> modelConfig = KerasModelUtils.parseModelConfig(modelJson, modelYaml);
 | 
				
			||||||
    this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
 | 
					    this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(modelConfig, config);
 | 
				
			||||||
@ -83,19 +99,29 @@ public class KerasSequentialModel extends KerasModel {
 | 
				
			|||||||
    /* Determine model configuration type. */
 | 
					    /* Determine model configuration type. */
 | 
				
			||||||
    if (!modelConfig.containsKey(config.getFieldClassName()))
 | 
					    if (!modelConfig.containsKey(config.getFieldClassName()))
 | 
				
			||||||
      throw new InvalidKerasConfigurationException(
 | 
					      throw new InvalidKerasConfigurationException(
 | 
				
			||||||
                    "Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
 | 
					          "Could not determine Keras model class (no "
 | 
				
			||||||
 | 
					              + config.getFieldClassName()
 | 
				
			||||||
 | 
					              + " field found)");
 | 
				
			||||||
    this.className = (String) modelConfig.get(config.getFieldClassName());
 | 
					    this.className = (String) modelConfig.get(config.getFieldClassName());
 | 
				
			||||||
    if (!this.className.equals(config.getFieldClassNameSequential()))
 | 
					    if (!this.className.equals(config.getFieldClassNameSequential()))
 | 
				
			||||||
            throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential()
 | 
					      throw new InvalidKerasConfigurationException(
 | 
				
			||||||
                    + " (found " + this.className + ")");
 | 
					          "Model class name must be "
 | 
				
			||||||
 | 
					              + config.getFieldClassNameSequential()
 | 
				
			||||||
 | 
					              + " (found "
 | 
				
			||||||
 | 
					              + this.className
 | 
				
			||||||
 | 
					              + ")");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /* Process layer configurations. */
 | 
					    /* Process layer configurations. */
 | 
				
			||||||
    if (!modelConfig.containsKey(config.getModelFieldConfig()))
 | 
					    if (!modelConfig.containsKey(config.getModelFieldConfig()))
 | 
				
			||||||
      throw new InvalidKerasConfigurationException(
 | 
					      throw new InvalidKerasConfigurationException(
 | 
				
			||||||
                    "Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)");
 | 
					          "Could not find layer configurations (no "
 | 
				
			||||||
 | 
					              + config.getModelFieldConfig()
 | 
				
			||||||
 | 
					              + " field found)");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations. For consistency
 | 
					    // Prior to Keras 2.2.3 the "config" of a Sequential model was a list of layer configurations.
 | 
				
			||||||
        // "config" is now an object containing a "name" and "layers", the latter contain the same data as before.
 | 
					    // For consistency
 | 
				
			||||||
 | 
					    // "config" is now an object containing a "name" and "layers", the latter contain the same data
 | 
				
			||||||
 | 
					    // as before.
 | 
				
			||||||
    // This change only affects Sequential models.
 | 
					    // This change only affects Sequential models.
 | 
				
			||||||
    List<Object> layerList;
 | 
					    List<Object> layerList;
 | 
				
			||||||
    try {
 | 
					    try {
 | 
				
			||||||
@ -105,8 +131,7 @@ public class KerasSequentialModel extends KerasModel {
 | 
				
			|||||||
      layerList = (List<Object>) layerMap.get("layers");
 | 
					      layerList = (List<Object>) layerMap.get("layers");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair =
 | 
					    Pair<Map<String, KerasLayer>, List<KerasLayer>> layerPair = prepareLayers(layerList);
 | 
				
			||||||
                prepareLayers(layerList);
 | 
					 | 
				
			||||||
    this.layers = layerPair.getFirst();
 | 
					    this.layers = layerPair.getFirst();
 | 
				
			||||||
    this.layersOrdered = layerPair.getSecond();
 | 
					    this.layersOrdered = layerPair.getSecond();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -116,15 +141,18 @@ public class KerasSequentialModel extends KerasModel {
 | 
				
			|||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      /* Add placeholder input layer and update lists of input and output layers. */
 | 
					      /* Add placeholder input layer and update lists of input and output layers. */
 | 
				
			||||||
      int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
 | 
					      int[] firstLayerInputShape = this.layersOrdered.get(0).getInputShape();
 | 
				
			||||||
            Preconditions.checkState(ArrayUtil.prod(firstLayerInputShape) > 0,"Input shape must not be zero!");
 | 
					      Preconditions.checkState(
 | 
				
			||||||
 | 
					          ArrayUtil.prod(firstLayerInputShape) > 0, "Input shape must not be zero!");
 | 
				
			||||||
      inputLayer = new KerasInput("input1", firstLayerInputShape);
 | 
					      inputLayer = new KerasInput("input1", firstLayerInputShape);
 | 
				
			||||||
      inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
 | 
					      inputLayer.setDimOrder(this.layersOrdered.get(0).getDimOrder());
 | 
				
			||||||
      this.layers.put(inputLayer.getName(), inputLayer);
 | 
					      this.layers.put(inputLayer.getName(), inputLayer);
 | 
				
			||||||
      this.layersOrdered.add(0, inputLayer);
 | 
					      this.layersOrdered.add(0, inputLayer);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
 | 
					    this.inputLayerNames = new ArrayList<>(Collections.singletonList(inputLayer.getName()));
 | 
				
			||||||
        this.outputLayerNames = new ArrayList<>(
 | 
					    this.outputLayerNames =
 | 
				
			||||||
                Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
 | 
					        new ArrayList<>(
 | 
				
			||||||
 | 
					            Collections.singletonList(
 | 
				
			||||||
 | 
					                this.layersOrdered.get(this.layersOrdered.size() - 1).getName()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /* Update each layer's inbound layer list to include (only) previous layer. */
 | 
					    /* Update each layer's inbound layer list to include (only) previous layer. */
 | 
				
			||||||
    KerasLayer prevLayer = null;
 | 
					    KerasLayer prevLayer = null;
 | 
				
			||||||
@ -136,12 +164,13 @@ public class KerasSequentialModel extends KerasModel {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    /* Import training configuration. */
 | 
					    /* Import training configuration. */
 | 
				
			||||||
    if (enforceTrainingConfig) {
 | 
					    if (enforceTrainingConfig) {
 | 
				
			||||||
            if (trainingJson != null)
 | 
					      if (trainingJson != null) importTrainingConfiguration(trainingJson);
 | 
				
			||||||
                importTrainingConfiguration(trainingJson);
 | 
					      else
 | 
				
			||||||
            else log.warn("If enforceTrainingConfig is true, a training " +
 | 
					        log.warn(
 | 
				
			||||||
                    "configuration object has to be provided. Usually the only practical way to do this is to store" +
 | 
					            "If enforceTrainingConfig is true, a training "
 | 
				
			||||||
                    " your keras model with `model.save('model_path.h5'. If you store model config and weights" +
 | 
					                + "configuration object has to be provided. Usually the only practical way to do this is to store"
 | 
				
			||||||
                    " separately no training configuration is attached.");
 | 
					                + " your keras model with `model.save('model_path.h5'. If you store model config and weights"
 | 
				
			||||||
 | 
					                + " separately no training configuration is attached.");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    this.outputTypes = inferOutputTypes(inputShape);
 | 
					    this.outputTypes = inferOutputTypes(inputShape);
 | 
				
			||||||
@ -150,9 +179,7 @@ public class KerasSequentialModel extends KerasModel {
 | 
				
			|||||||
      importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
 | 
					      importWeights(weightsArchive, weightsRoot, layers, kerasMajorVersion, kerasBackend);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					  /** Default constructor */
 | 
				
			||||||
     * Default constructor
 | 
					 | 
				
			||||||
     */
 | 
					 | 
				
			||||||
  public KerasSequentialModel() {
 | 
					  public KerasSequentialModel() {
 | 
				
			||||||
    super();
 | 
					    super();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@ -174,14 +201,14 @@ public class KerasSequentialModel extends KerasModel {
 | 
				
			|||||||
      throw new InvalidKerasConfigurationException(
 | 
					      throw new InvalidKerasConfigurationException(
 | 
				
			||||||
          "MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
 | 
					          "MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder();
 | 
					    NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder =
 | 
				
			||||||
 | 
					        NeuralNetConfiguration.builder();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (optimizer != null) {
 | 
					    if (optimizer != null) {
 | 
				
			||||||
      modelBuilder.updater(optimizer);
 | 
					      modelBuilder.updater(optimizer);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // don't forcibly override for keras import
 | 
				
			||||||
        //don't forcibly override for keras import
 | 
					 | 
				
			||||||
    modelBuilder.overrideNinUponBuild(false);
 | 
					    modelBuilder.overrideNinUponBuild(false);
 | 
				
			||||||
    /* Add layers one at a time. */
 | 
					    /* Add layers one at a time. */
 | 
				
			||||||
    KerasLayer prevLayer = null;
 | 
					    KerasLayer prevLayer = null;
 | 
				
			||||||
@ -192,7 +219,10 @@ public class KerasSequentialModel extends KerasModel {
 | 
				
			|||||||
        if (nbInbound != 1)
 | 
					        if (nbInbound != 1)
 | 
				
			||||||
          throw new InvalidKerasConfigurationException(
 | 
					          throw new InvalidKerasConfigurationException(
 | 
				
			||||||
              "Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
 | 
					              "Layers in NeuralNetConfiguration must have exactly one inbound layer (found "
 | 
				
			||||||
                                    + nbInbound + " for layer " + layer.getName() + ")");
 | 
					                  + nbInbound
 | 
				
			||||||
 | 
					                  + " for layer "
 | 
				
			||||||
 | 
					                  + layer.getName()
 | 
				
			||||||
 | 
					                  + ")");
 | 
				
			||||||
        if (prevLayer != null) {
 | 
					        if (prevLayer != null) {
 | 
				
			||||||
          InputType[] inputTypes = new InputType[1];
 | 
					          InputType[] inputTypes = new InputType[1];
 | 
				
			||||||
          InputPreProcessor preprocessor;
 | 
					          InputPreProcessor preprocessor;
 | 
				
			||||||
@ -200,42 +230,44 @@ public class KerasSequentialModel extends KerasModel {
 | 
				
			|||||||
            inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
 | 
					            inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0));
 | 
				
			||||||
            preprocessor = prevLayer.getInputPreprocessor(inputTypes);
 | 
					            preprocessor = prevLayer.getInputPreprocessor(inputTypes);
 | 
				
			||||||
            InputType outputType = preprocessor.getOutputType(inputTypes[0]);
 | 
					            InputType outputType = preprocessor.getOutputType(inputTypes[0]);
 | 
				
			||||||
                        layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
 | 
					            layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
 | 
				
			||||||
          } else {
 | 
					          } else {
 | 
				
			||||||
            inputTypes[0] = this.outputTypes.get(prevLayer.getName());
 | 
					            inputTypes[0] = this.outputTypes.get(prevLayer.getName());
 | 
				
			||||||
            preprocessor = layer.getInputPreprocessor(inputTypes);
 | 
					            preprocessor = layer.getInputPreprocessor(inputTypes);
 | 
				
			||||||
                        if(preprocessor != null) {
 | 
					            if (preprocessor != null) {
 | 
				
			||||||
              InputType outputType = preprocessor.getOutputType(inputTypes[0]);
 | 
					              InputType outputType = preprocessor.getOutputType(inputTypes[0]);
 | 
				
			||||||
                            layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild());
 | 
					              layer.getLayer().setNIn(outputType, modelBuilder.isOverrideNinUponBuild());
 | 
				
			||||||
 | 
					            } else layer.getLayer().setNIn(inputTypes[0], modelBuilder.isOverrideNinUponBuild());
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
                        else
 | 
					          if (preprocessor != null) {
 | 
				
			||||||
                            layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild());
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            Map<Integer, InputPreProcessor> map = new HashMap<>();
 | 
				
			||||||
 | 
					            map.put(layerIndex, preprocessor);
 | 
				
			||||||
 | 
					            modelBuilder.inputPreProcessors(map);
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
                    if (preprocessor != null)
 | 
					 | 
				
			||||||
                        modelBuilder.inputPreProcessor(layerIndex, preprocessor);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        modelBuilder.layer(layerIndex++, layer.getLayer());
 | 
					        modelBuilder.layer(layerIndex++, layer.getLayer());
 | 
				
			||||||
      } else if (layer.getVertex() != null)
 | 
					      } else if (layer.getVertex() != null)
 | 
				
			||||||
                throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name "
 | 
					        throw new InvalidKerasConfigurationException(
 | 
				
			||||||
                        + layer.getClassName() + ", layer name " + layer.getName() + ")");
 | 
					            "Cannot add vertex to NeuralNetConfiguration (class name "
 | 
				
			||||||
 | 
					                + layer.getClassName()
 | 
				
			||||||
 | 
					                + ", layer name "
 | 
				
			||||||
 | 
					                + layer.getName()
 | 
				
			||||||
 | 
					                + ")");
 | 
				
			||||||
      prevLayer = layer;
 | 
					      prevLayer = layer;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /* Whether to use standard backprop (or BPTT) or truncated BPTT. */
 | 
					    /* Whether to use standard backprop (or BPTT) or truncated BPTT. */
 | 
				
			||||||
    if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
 | 
					    if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
 | 
				
			||||||
            modelBuilder.backpropType(BackpropType.TruncatedBPTT)
 | 
					      modelBuilder
 | 
				
			||||||
 | 
					          .backpropType(BackpropType.TruncatedBPTT)
 | 
				
			||||||
          .tbpttFwdLength(truncatedBPTT)
 | 
					          .tbpttFwdLength(truncatedBPTT)
 | 
				
			||||||
          .tbpttBackLength(truncatedBPTT);
 | 
					          .tbpttBackLength(truncatedBPTT);
 | 
				
			||||||
        else
 | 
					    else modelBuilder.backpropType(BackpropType.Standard);
 | 
				
			||||||
            modelBuilder.backpropType(BackpropType.Standard);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    NeuralNetConfiguration build = modelBuilder.build();
 | 
					    NeuralNetConfiguration build = modelBuilder.build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    return build;
 | 
					    return build;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -23,7 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
 | 
				
			|||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
 | 
					import org.deeplearning4j.nn.api.layers.LayerConstraint;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.RNNFormat;
 | 
					import org.deeplearning4j.nn.conf.RNNFormat;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.Convolution1D;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
 | 
					import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
 | 
					import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
 | 
				
			||||||
@ -84,29 +84,29 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
 | 
				
			|||||||
        IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
 | 
					        IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
 | 
				
			||||||
                enforceTrainingConfig, conf, kerasMajorVersion);
 | 
					                enforceTrainingConfig, conf, kerasMajorVersion);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ConvolutionLayer.ConvolutionLayerBuilder builder = Convolution1DLayer.builder().name(this.name)
 | 
					        var builder = Convolution1D.builder().name(this.name)
 | 
				
			||||||
                .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
 | 
					                .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
 | 
				
			||||||
                .activation(getIActivationFromConfig(layerConfig, conf))
 | 
					                .activation(getIActivationFromConfig(layerConfig, conf))
 | 
				
			||||||
                .weightInit(init)
 | 
					                .weightInit(init)
 | 
				
			||||||
                .dilation(getDilationRate(layerConfig, 1, conf, true)[0])
 | 
					                .dilation(getDilationRate(layerConfig, 1, conf, true)[0])
 | 
				
			||||||
                .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
 | 
					                .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
 | 
				
			||||||
                .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
 | 
					                .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
 | 
				
			||||||
                .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
 | 
					                .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion))
 | 
				
			||||||
                .hasBias(hasBias)
 | 
					                .hasBias(hasBias)
 | 
				
			||||||
                .rnnDataFormat(dimOrder == DimOrder.TENSORFLOW ? RNNFormat.NWC : RNNFormat.NCW)
 | 
					                .rnnDataFormat(dimOrder == DimOrder.TENSORFLOW ? RNNFormat.NWC : RNNFormat.NCW)
 | 
				
			||||||
                .stride(getStrideFromConfig(layerConfig, 1, conf)[0]);
 | 
					                .stride(getStrideFromConfig(layerConfig, 1, conf));
 | 
				
			||||||
        int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
 | 
					        int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
 | 
				
			||||||
        if (hasBias)
 | 
					        if (hasBias)
 | 
				
			||||||
            builder.biasInit(0.0);
 | 
					            builder.biasInit(0.0);
 | 
				
			||||||
        if (padding != null)
 | 
					        if (padding != null)
 | 
				
			||||||
            builder.padding(padding[0]);
 | 
					            builder.padding(padding);
 | 
				
			||||||
        if (biasConstraint != null)
 | 
					        if (biasConstraint != null)
 | 
				
			||||||
            builder.constrainBias(biasConstraint);
 | 
					            builder.constrainBias(biasConstraint);
 | 
				
			||||||
        if (weightConstraint != null)
 | 
					        if (weightConstraint != null)
 | 
				
			||||||
            builder.constrainWeights(weightConstraint);
 | 
					            builder.constrainWeights(weightConstraint);
 | 
				
			||||||
        this.layer = builder.build();
 | 
					        this.layer = builder.build();
 | 
				
			||||||
        Convolution1DLayer convolution1DLayer = (Convolution1DLayer) layer;
 | 
					        Convolution1D convolution1D = (Convolution1D) layer;
 | 
				
			||||||
        convolution1DLayer.setDefaultValueOverriden(true);
 | 
					        convolution1D.setDefaultValueOverriden(true);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -114,8 +114,8 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
 | 
				
			|||||||
     *
 | 
					     *
 | 
				
			||||||
     * @return ConvolutionLayer
 | 
					     * @return ConvolutionLayer
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public Convolution1DLayer getAtrousConvolution1D() {
 | 
					    public Convolution1D getAtrousConvolution1D() {
 | 
				
			||||||
        return (Convolution1DLayer) this.layer;
 | 
					        return (Convolution1D) this.layer;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
 | 
				
			|||||||
@ -24,6 +24,7 @@ import lombok.val;
 | 
				
			|||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
 | 
					import org.deeplearning4j.nn.api.layers.LayerConstraint;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
 | 
					import org.deeplearning4j.nn.conf.CNN2DFormat;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.Convolution2D;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
 | 
					import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
 | 
					import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
 | 
				
			||||||
@ -85,7 +86,7 @@ public class KerasAtrousConvolution2D extends KerasConvolution {
 | 
				
			|||||||
        IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
 | 
					        IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
 | 
				
			||||||
                enforceTrainingConfig, conf, kerasMajorVersion);
 | 
					                enforceTrainingConfig, conf, kerasMajorVersion);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        val builder = ConvolutionLayer.builder().name(this.name)
 | 
					        val builder = Convolution2D.builder().name(this.name)
 | 
				
			||||||
                .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
 | 
					                .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
 | 
				
			||||||
                .activation(getIActivationFromConfig(layerConfig, conf))
 | 
					                .activation(getIActivationFromConfig(layerConfig, conf))
 | 
				
			||||||
                .weightInit(init)
 | 
					                .weightInit(init)
 | 
				
			||||||
 | 
				
			|||||||
@ -28,7 +28,7 @@ import org.deeplearning4j.nn.conf.CNN2DFormat;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
					import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.RNNFormat;
 | 
					import org.deeplearning4j.nn.conf.RNNFormat;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.Convolution1D;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
 | 
					import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
 | 
					import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
 | 
					import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
 | 
				
			||||||
@ -93,7 +93,7 @@ public class KerasConvolution1D extends KerasConvolution {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
 | 
					        IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
 | 
				
			||||||
                enforceTrainingConfig, conf, kerasMajorVersion);
 | 
					                enforceTrainingConfig, conf, kerasMajorVersion);
 | 
				
			||||||
        Convolution1DLayer.Convolution1DLayerBuilder builder = Convolution1DLayer.builder().name(this.name)
 | 
					        var builder = Convolution1D.builder().name(this.name)
 | 
				
			||||||
                .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
 | 
					                .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
 | 
				
			||||||
                .activation(getIActivationFromConfig(layerConfig, conf))
 | 
					                .activation(getIActivationFromConfig(layerConfig, conf))
 | 
				
			||||||
                .weightInit(init)
 | 
					                .weightInit(init)
 | 
				
			||||||
@ -125,9 +125,9 @@ public class KerasConvolution1D extends KerasConvolution {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        this.layer = builder.build();
 | 
					        this.layer = builder.build();
 | 
				
			||||||
        //set this in order to infer the dimensional format
 | 
					        //set this in order to infer the dimensional format
 | 
				
			||||||
        Convolution1DLayer convolution1DLayer = (Convolution1DLayer) this.layer;
 | 
					        Convolution1D convolution1D = (Convolution1D) this.layer;
 | 
				
			||||||
        convolution1DLayer.setDataFormat(dimOrder == DimOrder.TENSORFLOW ? CNN2DFormat.NHWC : CNN2DFormat.NCHW);
 | 
					        convolution1D.setDataFormat(dimOrder == DimOrder.TENSORFLOW ? CNN2DFormat.NHWC : CNN2DFormat.NCHW);
 | 
				
			||||||
        convolution1DLayer.setDefaultValueOverriden(true);
 | 
					        convolution1D.setDefaultValueOverriden(true);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -135,8 +135,8 @@ public class KerasConvolution1D extends KerasConvolution {
 | 
				
			|||||||
     *
 | 
					     *
 | 
				
			||||||
     * @return  ConvolutionLayer
 | 
					     * @return  ConvolutionLayer
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public Convolution1DLayer getConvolution1DLayer() {
 | 
					    public Convolution1D getConvolution1DLayer() {
 | 
				
			||||||
        return (Convolution1DLayer) this.layer;
 | 
					        return (Convolution1D) this.layer;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -28,6 +28,7 @@ import org.deeplearning4j.nn.api.layers.LayerConstraint;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
 | 
					import org.deeplearning4j.nn.conf.CNN2DFormat;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
					import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.Convolution2D;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
 | 
					import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
 | 
					import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
 | 
				
			||||||
@ -95,7 +96,7 @@ public class KerasConvolution2D extends KerasConvolution {
 | 
				
			|||||||
        LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
 | 
					        LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
 | 
				
			||||||
                layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
 | 
					                layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        final var builder = ConvolutionLayer.builder().name(this.name)
 | 
					        final var builder = Convolution2D.builder().name(this.name)
 | 
				
			||||||
                .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
 | 
					                .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
 | 
				
			||||||
                .activation(getIActivationFromConfig(layerConfig, conf))
 | 
					                .activation(getIActivationFromConfig(layerConfig, conf))
 | 
				
			||||||
                .weightInit(init)
 | 
					                .weightInit(init)
 | 
				
			||||||
 | 
				
			|||||||
@ -23,6 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
					import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.BaseDL4JTest;
 | 
					import org.deeplearning4j.BaseDL4JTest;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.serde.CavisMapper;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
 | 
					import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
 | 
					import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
 | 
					import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
 | 
				
			||||||
@ -41,8 +42,8 @@ public class JsonTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
        for(InputPreProcessor p : pp ){
 | 
					        for(InputPreProcessor p : pp ){
 | 
				
			||||||
            String s = NeuralNetConfiguration.mapper().writeValueAsString(p);
 | 
					            String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(p);
 | 
				
			||||||
            InputPreProcessor p2 = NeuralNetConfiguration.mapper().readValue(s, InputPreProcessor.class);
 | 
					            InputPreProcessor p2 = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, InputPreProcessor.class);
 | 
				
			||||||
            assertEquals(p, p2);
 | 
					            assertEquals(p, p2);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -29,11 +29,8 @@ import org.deeplearning4j.gradientcheck.GradientCheckUtil;
 | 
				
			|||||||
import org.deeplearning4j.nn.api.Layer;
 | 
					import org.deeplearning4j.nn.api.Layer;
 | 
				
			||||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
 | 
					import org.deeplearning4j.nn.api.layers.IOutputLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
					import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.Convolution1D;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
					import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
				
			||||||
import org.deeplearning4j.BaseDL4JTest;
 | 
					import org.deeplearning4j.BaseDL4JTest;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
 | 
					import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
 | 
				
			||||||
@ -656,7 +653,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
 | 
				
			|||||||
            MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
 | 
					            MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
 | 
				
			||||||
                    true, true, false, null, null);
 | 
					                    true, true, false, null, null);
 | 
				
			||||||
            Layer l = net.getLayer(0);
 | 
					            Layer l = net.getLayer(0);
 | 
				
			||||||
            Convolution1DLayer c1d = (Convolution1DLayer) l.getTrainingConfig();
 | 
					            Convolution1D c1d = (Convolution1D) l.getTrainingConfig();
 | 
				
			||||||
            assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
 | 
					            assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
					import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
 | 
					import org.deeplearning4j.nn.conf.dropout.Dropout;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.Convolution1D;
 | 
				
			||||||
import org.deeplearning4j.BaseDL4JTest;
 | 
					import org.deeplearning4j.BaseDL4JTest;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
 | 
					import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
 | 
					import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
 | 
				
			||||||
@ -97,7 +97,7 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest {
 | 
				
			|||||||
        config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
 | 
					        config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
 | 
				
			||||||
        layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
 | 
					        layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
 | 
					        Convolution1D layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
 | 
				
			||||||
        assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
 | 
					        assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
 | 
				
			||||||
        assertEquals(LAYER_NAME, layer.getName());
 | 
					        assertEquals(LAYER_NAME, layer.getName());
 | 
				
			||||||
        assertEquals(INIT_DL4J, layer.getWeightInit());
 | 
					        assertEquals(INIT_DL4J, layer.getWeightInit());
 | 
				
			||||||
 | 
				
			|||||||
@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
					import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
 | 
					import org.deeplearning4j.nn.conf.dropout.Dropout;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.Convolution1D;
 | 
				
			||||||
import org.deeplearning4j.BaseDL4JTest;
 | 
					import org.deeplearning4j.BaseDL4JTest;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
 | 
					import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
 | 
					import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
 | 
				
			||||||
@ -119,7 +119,7 @@ public class KerasConvolution1DTest extends BaseDL4JTest {
 | 
				
			|||||||
        config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
 | 
					        config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
 | 
				
			||||||
        layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
 | 
					        layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
 | 
					        Convolution1D layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
 | 
				
			||||||
        assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
 | 
					        assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
 | 
				
			||||||
        assertEquals(LAYER_NAME, layer.getName());
 | 
					        assertEquals(LAYER_NAME, layer.getName());
 | 
				
			||||||
        assertEquals(INIT_DL4J, layer.getWeightInit());
 | 
					        assertEquals(INIT_DL4J, layer.getWeightInit());
 | 
				
			||||||
 | 
				
			|||||||
@ -22,8 +22,6 @@
 | 
				
			|||||||
package net.brutex.ai.dnn.api;
 | 
					package net.brutex.ai.dnn.api;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.io.Serializable;
 | 
					import java.io.Serializable;
 | 
				
			||||||
import java.util.List;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
 | 
					public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -31,9 +31,11 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			|||||||
public class NN {
 | 
					public class NN {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public static NeuralNetConfigurationBuilder<?, ?> net() {
 | 
					  public static NeuralNetConfigurationBuilder<?, ?> nn() {
 | 
				
			||||||
    return NeuralNetConfiguration.builder();
 | 
					    return NeuralNetConfiguration.builder();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  public static DenseLayer.DenseLayerBuilder<?,?> dense() { return DenseLayer.builder(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -23,7 +23,6 @@ package net.brutex.ai.dnn.networks;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import java.io.Serializable;
 | 
					import java.io.Serializable;
 | 
				
			||||||
import java.util.Arrays;
 | 
					import java.util.Arrays;
 | 
				
			||||||
import java.util.HashMap;
 | 
					 | 
				
			||||||
import java.util.Map;
 | 
					import java.util.Map;
 | 
				
			||||||
import lombok.Getter;
 | 
					import lombok.Getter;
 | 
				
			||||||
import lombok.NonNull;
 | 
					import lombok.NonNull;
 | 
				
			||||||
@ -33,7 +32,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			|||||||
import org.deeplearning4j.nn.gradient.Gradient;
 | 
					import org.deeplearning4j.nn.gradient.Gradient;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Artificial Neural Network An artificial neural network (1) takes some input data, and (2)
 | 
					 * Artificial Neural Network An artificial neural network (1) takes some input data, and (2)
 | 
				
			||||||
 * transforms this input data by calculating a weighted sum over the inputs and (3) applies a
 | 
					 * transforms this input data by calculating a weighted sum over the inputs and (3) applies a
 | 
				
			||||||
 | 
				
			|||||||
@ -20,6 +20,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping;
 | 
					package org.deeplearning4j.earlystopping;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.io.Serializable;
 | 
				
			||||||
 | 
					import java.util.ArrayList;
 | 
				
			||||||
 | 
					import java.util.Collections;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
import lombok.NoArgsConstructor;
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
@ -30,11 +34,6 @@ import org.deeplearning4j.earlystopping.termination.IterationTerminationConditio
 | 
				
			|||||||
import org.deeplearning4j.exception.DL4JInvalidConfigException;
 | 
					import org.deeplearning4j.exception.DL4JInvalidConfigException;
 | 
				
			||||||
import org.nd4j.common.function.Supplier;
 | 
					import org.nd4j.common.function.Supplier;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.io.Serializable;
 | 
					 | 
				
			||||||
import java.util.ArrayList;
 | 
					 | 
				
			||||||
import java.util.Collections;
 | 
					 | 
				
			||||||
import java.util.List;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
@NoArgsConstructor
 | 
					@NoArgsConstructor
 | 
				
			||||||
public class EarlyStoppingConfiguration<T extends IModel> implements Serializable {
 | 
					public class EarlyStoppingConfiguration<T extends IModel> implements Serializable {
 | 
				
			||||||
 | 
				
			|||||||
@ -20,16 +20,15 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping;
 | 
					package org.deeplearning4j.earlystopping;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonInclude;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonSubTypes;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
				
			||||||
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import java.io.Serializable;
 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
 | 
					import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
 | 
				
			||||||
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
 | 
					import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
 | 
				
			||||||
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
 | 
					import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonInclude;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.io.IOException;
 | 
					 | 
				
			||||||
import java.io.Serializable;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
 | 
					@JsonInclude(JsonInclude.Include.NON_NULL)
 | 
				
			||||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = InMemoryModelSaver.class, name = "InMemoryModelSaver"),
 | 
					@JsonSubTypes(value = {@JsonSubTypes.Type(value = InMemoryModelSaver.class, name = "InMemoryModelSaver"),
 | 
				
			||||||
 | 
				
			|||||||
@ -20,11 +20,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping;
 | 
					package org.deeplearning4j.earlystopping;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.Data;
 | 
					 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.io.Serializable;
 | 
					import java.io.Serializable;
 | 
				
			||||||
import java.util.Map;
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
public class EarlyStoppingResult<T extends IModel> implements Serializable {
 | 
					public class EarlyStoppingResult<T extends IModel> implements Serializable {
 | 
				
			||||||
 | 
				
			|||||||
@ -20,10 +20,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.saver;
 | 
					package org.deeplearning4j.earlystopping.saver;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
 | 
					 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.io.IOException;
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
 | 
					import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public class InMemoryModelSaver<T extends IModel> implements EarlyStoppingModelSaver<T> {
 | 
					public class InMemoryModelSaver<T extends IModel> implements EarlyStoppingModelSaver<T> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -20,15 +20,14 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.saver;
 | 
					package org.deeplearning4j.earlystopping.saver;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.io.File;
 | 
				
			||||||
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import java.nio.charset.Charset;
 | 
				
			||||||
import org.apache.commons.io.FilenameUtils;
 | 
					import org.apache.commons.io.FilenameUtils;
 | 
				
			||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
 | 
					import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
 | 
				
			||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
					import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
				
			||||||
import org.deeplearning4j.util.ModelSerializer;
 | 
					import org.deeplearning4j.util.ModelSerializer;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.io.File;
 | 
					 | 
				
			||||||
import java.io.IOException;
 | 
					 | 
				
			||||||
import java.nio.charset.Charset;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
public class LocalFileGraphSaver implements EarlyStoppingModelSaver<ComputationGraph> {
 | 
					public class LocalFileGraphSaver implements EarlyStoppingModelSaver<ComputationGraph> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static final String BEST_GRAPH_BIN = "bestGraph.bin";
 | 
					    private static final String BEST_GRAPH_BIN = "bestGraph.bin";
 | 
				
			||||||
 | 
				
			|||||||
@ -20,15 +20,14 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.saver;
 | 
					package org.deeplearning4j.earlystopping.saver;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.io.File;
 | 
				
			||||||
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import java.nio.charset.Charset;
 | 
				
			||||||
import org.apache.commons.io.FilenameUtils;
 | 
					import org.apache.commons.io.FilenameUtils;
 | 
				
			||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
 | 
					import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.deeplearning4j.util.ModelSerializer;
 | 
					import org.deeplearning4j.util.ModelSerializer;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.io.File;
 | 
					 | 
				
			||||||
import java.io.IOException;
 | 
					 | 
				
			||||||
import java.nio.charset.Charset;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
public class LocalFileModelSaver implements EarlyStoppingModelSaver<MultiLayerNetwork> {
 | 
					public class LocalFileModelSaver implements EarlyStoppingModelSaver<MultiLayerNetwork> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static final String BEST_MODEL_BIN = "bestModel.bin";
 | 
					    private static final String BEST_MODEL_BIN = "bestModel.bin";
 | 
				
			||||||
 | 
				
			|||||||
@ -26,11 +26,11 @@ import org.deeplearning4j.nn.api.Layer;
 | 
				
			|||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
					import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
				
			||||||
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
 | 
					import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
				
			||||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
 | 
					import org.nd4j.evaluation.regression.RegressionEvaluation;
 | 
				
			||||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
 | 
					import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
					import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
				
			||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
public class AutoencoderScoreCalculator extends BaseScoreCalculator<IModel> {
 | 
					public class AutoencoderScoreCalculator extends BaseScoreCalculator<IModel> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -20,8 +20,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.scorecalc;
 | 
					package org.deeplearning4j.earlystopping.scorecalc;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
 | 
					import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
 | 
				
			||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
					import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
@ -29,7 +30,6 @@ import org.nd4j.linalg.dataset.DataSet;
 | 
				
			|||||||
import org.nd4j.linalg.dataset.MultiDataSet;
 | 
					import org.nd4j.linalg.dataset.MultiDataSet;
 | 
				
			||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
					import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
				
			||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
 | 
					import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
public class DataSetLossCalculator extends BaseScoreCalculator<IModel> {
 | 
					public class DataSetLossCalculator extends BaseScoreCalculator<IModel> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -20,6 +20,8 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.scorecalc;
 | 
					package org.deeplearning4j.earlystopping.scorecalc;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonIgnore;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import lombok.NoArgsConstructor;
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
import lombok.val;
 | 
					import lombok.val;
 | 
				
			||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
					import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
				
			||||||
@ -27,8 +29,6 @@ import org.nd4j.linalg.dataset.DataSet;
 | 
				
			|||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
 | 
					import org.nd4j.linalg.dataset.api.MultiDataSet;
 | 
				
			||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
					import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
				
			||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
 | 
					import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@NoArgsConstructor
 | 
					@NoArgsConstructor
 | 
				
			||||||
@Deprecated
 | 
					@Deprecated
 | 
				
			||||||
 | 
				
			|||||||
@ -20,12 +20,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.scorecalc;
 | 
					package org.deeplearning4j.earlystopping.scorecalc;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonInclude;
 | 
					import com.fasterxml.jackson.annotation.JsonInclude;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
 | 
					import com.fasterxml.jackson.annotation.JsonSubTypes;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
					import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.io.Serializable;
 | 
					import java.io.Serializable;
 | 
				
			||||||
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
 | 
					@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
 | 
				
			||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
 | 
					@JsonInclude(JsonInclude.Include.NON_NULL)
 | 
				
			||||||
 | 
				
			|||||||
@ -26,11 +26,11 @@ import org.deeplearning4j.nn.api.Layer;
 | 
				
			|||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
					import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
				
			||||||
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
 | 
					import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
				
			||||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
 | 
					import org.nd4j.evaluation.regression.RegressionEvaluation;
 | 
				
			||||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
 | 
					import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
					import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
				
			||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<IModel> {
 | 
					public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<IModel> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -20,9 +20,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.scorecalc.base;
 | 
					package org.deeplearning4j.earlystopping.scorecalc.base;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
 | 
					import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
 | 
				
			||||||
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
 | 
					import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
					import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.nd4j.evaluation.IEvaluation;
 | 
					import org.nd4j.evaluation.IEvaluation;
 | 
				
			||||||
 | 
				
			|||||||
@ -21,8 +21,8 @@
 | 
				
			|||||||
package org.deeplearning4j.earlystopping.scorecalc.base;
 | 
					package org.deeplearning4j.earlystopping.scorecalc.base;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.NonNull;
 | 
					import lombok.NonNull;
 | 
				
			||||||
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
 | 
					 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
 | 
					import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
import org.nd4j.linalg.dataset.DataSet;
 | 
					import org.nd4j.linalg.dataset.DataSet;
 | 
				
			||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
 | 
					import org.nd4j.linalg.dataset.api.MultiDataSet;
 | 
				
			||||||
 | 
				
			|||||||
@ -20,8 +20,8 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.termination;
 | 
					package org.deeplearning4j.earlystopping.termination;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.Data;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
public class BestScoreEpochTerminationCondition implements EpochTerminationCondition {
 | 
					public class BestScoreEpochTerminationCondition implements EpochTerminationCondition {
 | 
				
			||||||
 | 
				
			|||||||
@ -22,9 +22,7 @@ package org.deeplearning4j.earlystopping.termination;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonInclude;
 | 
					import com.fasterxml.jackson.annotation.JsonInclude;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
					import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.io.Serializable;
 | 
					import java.io.Serializable;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
 | 
					@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
 | 
				
			||||||
 | 
				
			|||||||
@ -22,7 +22,6 @@ package org.deeplearning4j.earlystopping.termination;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonInclude;
 | 
					import com.fasterxml.jackson.annotation.JsonInclude;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
					import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.io.Serializable;
 | 
					import java.io.Serializable;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
 | 
					@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
 | 
				
			||||||
 | 
				
			|||||||
@ -20,10 +20,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.termination;
 | 
					package org.deeplearning4j.earlystopping.termination;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.Data;
 | 
					 | 
				
			||||||
import lombok.NoArgsConstructor;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonCreator;
 | 
					import com.fasterxml.jackson.annotation.JsonCreator;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@NoArgsConstructor
 | 
					@NoArgsConstructor
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
 | 
				
			|||||||
@ -20,8 +20,8 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.termination;
 | 
					package org.deeplearning4j.earlystopping.termination;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.Data;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
public class MaxScoreIterationTerminationCondition implements IterationTerminationCondition {
 | 
					public class MaxScoreIterationTerminationCondition implements IterationTerminationCondition {
 | 
				
			||||||
 | 
				
			|||||||
@ -20,10 +20,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.termination;
 | 
					package org.deeplearning4j.earlystopping.termination;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.Data;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.util.concurrent.TimeUnit;
 | 
					import java.util.concurrent.TimeUnit;
 | 
				
			||||||
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**Terminate training based on max time.
 | 
					/**Terminate training based on max time.
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
 | 
				
			|||||||
@ -20,9 +20,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.termination;
 | 
					package org.deeplearning4j.earlystopping.termination;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
import lombok.extern.slf4j.Slf4j;
 | 
					import lombok.extern.slf4j.Slf4j;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Slf4j
 | 
					@Slf4j
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
 | 
				
			|||||||
@ -20,6 +20,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.trainer;
 | 
					package org.deeplearning4j.earlystopping.trainer;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.io.FileNotFoundException;
 | 
				
			||||||
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import java.util.Collection;
 | 
				
			||||||
 | 
					import java.util.Iterator;
 | 
				
			||||||
 | 
					import java.util.LinkedHashMap;
 | 
				
			||||||
 | 
					import java.util.Map;
 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
 | 
					import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
 | 
				
			||||||
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
 | 
					import org.deeplearning4j.earlystopping.EarlyStoppingResult;
 | 
				
			||||||
@ -40,13 +46,6 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
 | 
				
			|||||||
import org.slf4j.Logger;
 | 
					import org.slf4j.Logger;
 | 
				
			||||||
import org.slf4j.LoggerFactory;
 | 
					import org.slf4j.LoggerFactory;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.io.FileNotFoundException;
 | 
					 | 
				
			||||||
import java.io.IOException;
 | 
					 | 
				
			||||||
import java.util.Collection;
 | 
					 | 
				
			||||||
import java.util.Iterator;
 | 
					 | 
				
			||||||
import java.util.LinkedHashMap;
 | 
					 | 
				
			||||||
import java.util.Map;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
public abstract class BaseEarlyStoppingTrainer<T extends IModel> implements IEarlyStoppingTrainer<T> {
 | 
					public abstract class BaseEarlyStoppingTrainer<T extends IModel> implements IEarlyStoppingTrainer<T> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class);
 | 
					    private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class);
 | 
				
			||||||
 | 
				
			|||||||
@ -20,7 +20,6 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.earlystopping.trainer;
 | 
					package org.deeplearning4j.earlystopping.trainer;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
 | 
					 | 
				
			||||||
import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
 | 
					import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
 | 
				
			||||||
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
 | 
					import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
 | 
				
			||||||
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
 | 
					import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
 | 
				
			||||||
 | 
				
			|||||||
@ -20,6 +20,13 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.eval;
 | 
					package org.deeplearning4j.eval;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonAutoDetect;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.databind.DeserializationFeature;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.databind.MapperFeature;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.databind.ObjectMapper;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.databind.SerializationFeature;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.databind.module.SimpleModule;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
 | 
				
			||||||
import lombok.EqualsAndHashCode;
 | 
					import lombok.EqualsAndHashCode;
 | 
				
			||||||
import lombok.Getter;
 | 
					import lombok.Getter;
 | 
				
			||||||
import org.nd4j.common.primitives.AtomicBoolean;
 | 
					import org.nd4j.common.primitives.AtomicBoolean;
 | 
				
			||||||
@ -28,14 +35,6 @@ import org.nd4j.common.primitives.serde.JsonDeserializerAtomicBoolean;
 | 
				
			|||||||
import org.nd4j.common.primitives.serde.JsonDeserializerAtomicDouble;
 | 
					import org.nd4j.common.primitives.serde.JsonDeserializerAtomicDouble;
 | 
				
			||||||
import org.nd4j.common.primitives.serde.JsonSerializerAtomicBoolean;
 | 
					import org.nd4j.common.primitives.serde.JsonSerializerAtomicBoolean;
 | 
				
			||||||
import org.nd4j.common.primitives.serde.JsonSerializerAtomicDouble;
 | 
					import org.nd4j.common.primitives.serde.JsonSerializerAtomicDouble;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonAutoDetect;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.databind.DeserializationFeature;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.databind.MapperFeature;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.databind.ObjectMapper;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.databind.SerializationFeature;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.databind.module.SimpleModule;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Deprecated
 | 
					@Deprecated
 | 
				
			||||||
@EqualsAndHashCode(callSuper = false)
 | 
					@EqualsAndHashCode(callSuper = false)
 | 
				
			||||||
 | 
				
			|||||||
@ -20,15 +20,8 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.eval;
 | 
					package org.deeplearning4j.eval;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import com.google.common.collect.HashMultiset;
 | 
					 | 
				
			||||||
import com.google.common.collect.Multiset;
 | 
					 | 
				
			||||||
import lombok.Getter;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.io.Serializable;
 | 
					 | 
				
			||||||
import java.util.ArrayList;
 | 
					 | 
				
			||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
import java.util.Map;
 | 
					 | 
				
			||||||
import java.util.concurrent.ConcurrentHashMap;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Deprecated
 | 
					@Deprecated
 | 
				
			||||||
public class ConfusionMatrix<T extends Comparable<? super T>> extends org.nd4j.evaluation.classification.ConfusionMatrix<T> {
 | 
					public class ConfusionMatrix<T extends Comparable<? super T>> extends org.nd4j.evaluation.classification.ConfusionMatrix<T> {
 | 
				
			||||||
 | 
				
			|||||||
@ -20,14 +20,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.eval;
 | 
					package org.deeplearning4j.eval;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.EqualsAndHashCode;
 | 
					 | 
				
			||||||
import lombok.NonNull;
 | 
					 | 
				
			||||||
import org.nd4j.evaluation.EvaluationAveraging;
 | 
					 | 
				
			||||||
import org.nd4j.evaluation.IEvaluation;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
import java.util.Map;
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					import lombok.EqualsAndHashCode;
 | 
				
			||||||
 | 
					import lombok.NonNull;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@EqualsAndHashCode(callSuper = true)
 | 
					@EqualsAndHashCode(callSuper = true)
 | 
				
			||||||
@Deprecated
 | 
					@Deprecated
 | 
				
			||||||
 | 
				
			|||||||
@ -20,9 +20,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.eval;
 | 
					package org.deeplearning4j.eval;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import lombok.EqualsAndHashCode;
 | 
					import lombok.EqualsAndHashCode;
 | 
				
			||||||
import lombok.Getter;
 | 
					import lombok.Getter;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Deprecated
 | 
					@Deprecated
 | 
				
			||||||
@Getter
 | 
					@Getter
 | 
				
			||||||
 | 
				
			|||||||
@ -20,11 +20,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.eval;
 | 
					package org.deeplearning4j.eval;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
import lombok.EqualsAndHashCode;
 | 
					import lombok.EqualsAndHashCode;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.util.List;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@Deprecated
 | 
					@Deprecated
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
@EqualsAndHashCode(callSuper = true)
 | 
					@EqualsAndHashCode(callSuper = true)
 | 
				
			||||||
 | 
				
			|||||||
@ -20,10 +20,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.eval.curves;
 | 
					package org.deeplearning4j.eval.curves;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
import lombok.EqualsAndHashCode;
 | 
					import lombok.EqualsAndHashCode;
 | 
				
			||||||
import org.nd4j.evaluation.curves.BaseHistogram;
 | 
					import org.nd4j.evaluation.curves.BaseHistogram;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Deprecated
 | 
					@Deprecated
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
 | 
				
			|||||||
@ -20,13 +20,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.eval.curves;
 | 
					package org.deeplearning4j.eval.curves;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import com.google.common.base.Preconditions;
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import lombok.AllArgsConstructor;
 | 
					 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
import lombok.EqualsAndHashCode;
 | 
					import lombok.EqualsAndHashCode;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.util.Arrays;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Deprecated
 | 
					@Deprecated
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
 | 
				
			|||||||
@ -20,8 +20,8 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.eval.curves;
 | 
					package org.deeplearning4j.eval.curves;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.NonNull;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
 | 
					import lombok.NonNull;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Deprecated
 | 
					@Deprecated
 | 
				
			||||||
public class ReliabilityDiagram extends org.nd4j.evaluation.curves.ReliabilityDiagram {
 | 
					public class ReliabilityDiagram extends org.nd4j.evaluation.curves.ReliabilityDiagram {
 | 
				
			||||||
 | 
				
			|||||||
@ -20,10 +20,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.eval.curves;
 | 
					package org.deeplearning4j.eval.curves;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import com.google.common.base.Preconditions;
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
import lombok.EqualsAndHashCode;
 | 
					import lombok.EqualsAndHashCode;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Deprecated
 | 
					@Deprecated
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
 | 
				
			|||||||
@ -20,7 +20,6 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.eval.meta;
 | 
					package org.deeplearning4j.eval.meta;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.AllArgsConstructor;
 | 
					 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
 | 
				
			|||||||
@ -20,6 +20,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.adapters;
 | 
					package org.deeplearning4j.nn.adapters;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
import lombok.AllArgsConstructor;
 | 
					import lombok.AllArgsConstructor;
 | 
				
			||||||
import lombok.Builder;
 | 
					import lombok.Builder;
 | 
				
			||||||
import lombok.NoArgsConstructor;
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
@ -32,8 +33,6 @@ import org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer;
 | 
				
			|||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
 | 
					import org.nd4j.linalg.exception.ND4JIllegalStateException;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.util.List;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@Builder
 | 
					@Builder
 | 
				
			||||||
@AllArgsConstructor
 | 
					@AllArgsConstructor
 | 
				
			||||||
@NoArgsConstructor
 | 
					@NoArgsConstructor
 | 
				
			||||||
 | 
				
			|||||||
@ -21,7 +21,6 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.api;
 | 
					package org.deeplearning4j.nn.api;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.Getter;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
					import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -20,14 +20,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.api;
 | 
					package org.deeplearning4j.nn.api;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
import org.nd4j.linalg.dataset.api.DataSet;
 | 
					import org.nd4j.linalg.dataset.api.DataSet;
 | 
				
			||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
					import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.util.List;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
public interface Classifier extends IModel {
 | 
					public interface Classifier extends IModel {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -20,13 +20,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.api;
 | 
					package org.deeplearning4j.nn.api;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
					import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
				
			||||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
					import org.nd4j.linalg.api.buffer.DataType;
 | 
				
			||||||
import org.nd4j.linalg.learning.config.IUpdater;
 | 
					import org.nd4j.linalg.learning.config.IUpdater;
 | 
				
			||||||
import org.nd4j.linalg.learning.regularization.Regularization;
 | 
					import org.nd4j.linalg.learning.regularization.Regularization;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.util.List;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
public interface ITraininableLayerConfiguration {
 | 
					public interface ITraininableLayerConfiguration {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
 | 
				
			|||||||
@ -21,7 +21,7 @@
 | 
				
			|||||||
package org.deeplearning4j.nn.api;
 | 
					package org.deeplearning4j.nn.api;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.util.Map;
 | 
					import java.io.Serializable;
 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.CacheMode;
 | 
					import org.deeplearning4j.nn.conf.CacheMode;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
@ -29,10 +29,8 @@ import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
				
			|||||||
import org.deeplearning4j.nn.gradient.Gradient;
 | 
					import org.deeplearning4j.nn.gradient.Gradient;
 | 
				
			||||||
import org.deeplearning4j.nn.layers.LayerHelper;
 | 
					import org.deeplearning4j.nn.layers.LayerHelper;
 | 
				
			||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
					import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					 | 
				
			||||||
import org.nd4j.common.primitives.Pair;
 | 
					import org.nd4j.common.primitives.Pair;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
import java.io.Serializable;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * A layer is the highest-level building block in deep learning. A layer is a container that usually
 | 
					 * A layer is the highest-level building block in deep learning. A layer is a container that usually
 | 
				
			||||||
 | 
				
			|||||||
@ -20,13 +20,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.api;
 | 
					package org.deeplearning4j.nn.api;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					import java.util.Map;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
					import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.util.List;
 | 
					 | 
				
			||||||
import java.util.Map;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Param initializer for a layer
 | 
					 * Param initializer for a layer
 | 
				
			||||||
 *
 | 
					 *
 | 
				
			||||||
 | 
				
			|||||||
@ -20,11 +20,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.api;
 | 
					package org.deeplearning4j.nn.api;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.nn.gradient.Gradient;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.io.Serializable;
 | 
					import java.io.Serializable;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.gradient.Gradient;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Update the model
 | 
					 * Update the model
 | 
				
			||||||
 | 
				
			|||||||
@ -22,8 +22,8 @@ package org.deeplearning4j.nn.api.layers;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.nn.api.Classifier;
 | 
					import org.deeplearning4j.nn.api.Classifier;
 | 
				
			||||||
import org.deeplearning4j.nn.api.Layer;
 | 
					import org.deeplearning4j.nn.api.Layer;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
					import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public interface IOutputLayer extends Layer, Classifier {
 | 
					public interface IOutputLayer extends Layer, Classifier {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -20,11 +20,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.api.layers;
 | 
					package org.deeplearning4j.nn.api.layers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.nn.api.Layer;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
					import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.io.Serializable;
 | 
					import java.io.Serializable;
 | 
				
			||||||
import java.util.Set;
 | 
					import java.util.Set;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.api.Layer;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
 | 
					@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
 | 
				
			||||||
public interface LayerConstraint extends Cloneable, Serializable {
 | 
					public interface LayerConstraint extends Cloneable, Serializable {
 | 
				
			||||||
 | 
				
			|||||||
@ -20,13 +20,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.api.layers;
 | 
					package org.deeplearning4j.nn.api.layers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.Map;
 | 
				
			||||||
import org.deeplearning4j.nn.api.Layer;
 | 
					import org.deeplearning4j.nn.api.Layer;
 | 
				
			||||||
import org.deeplearning4j.nn.gradient.Gradient;
 | 
					import org.deeplearning4j.nn.gradient.Gradient;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					 | 
				
			||||||
import org.nd4j.common.primitives.Pair;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
					import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
				
			||||||
 | 
					import org.nd4j.common.primitives.Pair;
 | 
				
			||||||
import java.util.Map;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public interface RecurrentLayer extends Layer {
 | 
					public interface RecurrentLayer extends Layer {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -20,6 +20,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.conf;
 | 
					package org.deeplearning4j.nn.conf;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.databind.JsonNode;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.databind.ObjectMapper;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
 | 
				
			||||||
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import java.io.Serializable;
 | 
				
			||||||
 | 
					import java.util.*;
 | 
				
			||||||
import lombok.*;
 | 
					import lombok.*;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
 | 
					import org.deeplearning4j.nn.conf.distribution.Distribution;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
 | 
					import org.deeplearning4j.nn.conf.graph.GraphVertex;
 | 
				
			||||||
@ -34,6 +40,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
 | 
					import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
 | 
					import org.deeplearning4j.nn.conf.memory.MemoryReport;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
 | 
					import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.serde.CavisMapper;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
 | 
					import org.deeplearning4j.nn.conf.serde.JsonMappers;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.IWeightInit;
 | 
					import org.deeplearning4j.nn.weights.IWeightInit;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
					import org.deeplearning4j.nn.weights.WeightInit;
 | 
				
			||||||
@ -42,16 +49,9 @@ import org.nd4j.common.base.Preconditions;
 | 
				
			|||||||
import org.nd4j.linalg.activations.Activation;
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
import org.nd4j.linalg.activations.IActivation;
 | 
					import org.nd4j.linalg.activations.IActivation;
 | 
				
			||||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
					import org.nd4j.linalg.api.buffer.DataType;
 | 
				
			||||||
import com.fasterxml.jackson.databind.JsonNode;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.databind.ObjectMapper;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
 | 
					 | 
				
			||||||
import org.slf4j.Logger;
 | 
					import org.slf4j.Logger;
 | 
				
			||||||
import org.slf4j.LoggerFactory;
 | 
					import org.slf4j.LoggerFactory;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.io.IOException;
 | 
					 | 
				
			||||||
import java.io.Serializable;
 | 
					 | 
				
			||||||
import java.util.*;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
@EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"})
 | 
					@EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"})
 | 
				
			||||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
 | 
					@AllArgsConstructor(access = AccessLevel.PRIVATE)
 | 
				
			||||||
@ -110,7 +110,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
 | 
				
			|||||||
     * @return YAML representation of configuration
 | 
					     * @return YAML representation of configuration
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public String toYaml() {
 | 
					    public String toYaml() {
 | 
				
			||||||
        ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
 | 
					        ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
 | 
				
			||||||
        synchronized (mapper) {
 | 
					        synchronized (mapper) {
 | 
				
			||||||
            try {
 | 
					            try {
 | 
				
			||||||
                return mapper.writeValueAsString(this);
 | 
					                return mapper.writeValueAsString(this);
 | 
				
			||||||
@ -127,7 +127,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
 | 
				
			|||||||
     * @return {@link ComputationGraphConfiguration}
 | 
					     * @return {@link ComputationGraphConfiguration}
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public static ComputationGraphConfiguration fromYaml(String json) {
 | 
					    public static ComputationGraphConfiguration fromYaml(String json) {
 | 
				
			||||||
        ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
 | 
					        ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
 | 
				
			||||||
        try {
 | 
					        try {
 | 
				
			||||||
            return mapper.readValue(json, ComputationGraphConfiguration.class);
 | 
					            return mapper.readValue(json, ComputationGraphConfiguration.class);
 | 
				
			||||||
        } catch (IOException e) {
 | 
					        } catch (IOException e) {
 | 
				
			||||||
@ -140,7 +140,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
 | 
				
			|||||||
     */
 | 
					     */
 | 
				
			||||||
    public String toJson() {
 | 
					    public String toJson() {
 | 
				
			||||||
        //As per NeuralNetConfiguration.toJson()
 | 
					        //As per NeuralNetConfiguration.toJson()
 | 
				
			||||||
        ObjectMapper mapper = NeuralNetConfiguration.mapper();
 | 
					        ObjectMapper mapper =CavisMapper.getMapper(CavisMapper.Type.JSON);
 | 
				
			||||||
        synchronized (mapper) {
 | 
					        synchronized (mapper) {
 | 
				
			||||||
            //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
 | 
					            //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
 | 
				
			||||||
            //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
 | 
					            //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
 | 
				
			||||||
@ -160,7 +160,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
 | 
				
			|||||||
     */
 | 
					     */
 | 
				
			||||||
    public static ComputationGraphConfiguration fromJson(String json) {
 | 
					    public static ComputationGraphConfiguration fromJson(String json) {
 | 
				
			||||||
        //As per NeuralNetConfiguration.fromJson()
 | 
					        //As per NeuralNetConfiguration.fromJson()
 | 
				
			||||||
        ObjectMapper mapper = NeuralNetConfiguration.mapper();
 | 
					        ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.JSON);
 | 
				
			||||||
        ComputationGraphConfiguration conf;
 | 
					        ComputationGraphConfiguration conf;
 | 
				
			||||||
        try {
 | 
					        try {
 | 
				
			||||||
            conf = mapper.readValue(json, ComputationGraphConfiguration.class);
 | 
					            conf = mapper.readValue(json, ComputationGraphConfiguration.class);
 | 
				
			||||||
 | 
				
			|||||||
@ -19,10 +19,10 @@
 | 
				
			|||||||
 */
 | 
					 */
 | 
				
			||||||
package org.deeplearning4j.nn.conf;
 | 
					package org.deeplearning4j.nn.conf;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
 | 
					import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
 | 
				
			||||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
 | 
					import com.fasterxml.jackson.databind.annotation.JsonSerialize;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@JsonSerialize(using = DataFormatSerializer.class)
 | 
					@JsonSerialize(using = DataFormatSerializer.class)
 | 
				
			||||||
@JsonDeserialize(using = DataFormatDeserializer.class)
 | 
					@JsonDeserialize(using = DataFormatDeserializer.class)
 | 
				
			||||||
 | 
				
			|||||||
@ -21,14 +21,13 @@
 | 
				
			|||||||
package org.deeplearning4j.nn.conf;
 | 
					package org.deeplearning4j.nn.conf;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
				
			||||||
 | 
					import java.io.Serializable;
 | 
				
			||||||
import org.deeplearning4j.nn.api.MaskState;
 | 
					import org.deeplearning4j.nn.api.MaskState;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					 | 
				
			||||||
import org.nd4j.common.primitives.Pair;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
					import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
					import org.nd4j.common.primitives.Pair;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
import java.io.Serializable;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
 | 
					@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
 | 
				
			||||||
public interface InputPreProcessor extends Serializable, Cloneable {
 | 
					public interface InputPreProcessor extends Serializable, Cloneable {
 | 
				
			||||||
 | 
				
			|||||||
@ -21,10 +21,9 @@
 | 
				
			|||||||
package org.deeplearning4j.nn.conf;
 | 
					package org.deeplearning4j.nn.conf;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
 | 
					import com.fasterxml.jackson.annotation.JsonIgnore;
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
					import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
				
			||||||
import com.fasterxml.jackson.databind.JsonNode;
 | 
					import java.util.*;
 | 
				
			||||||
import com.fasterxml.jackson.databind.ObjectMapper;
 | 
					 | 
				
			||||||
import com.fasterxml.jackson.databind.node.ArrayNode;
 | 
					 | 
				
			||||||
import lombok.*;
 | 
					import lombok.*;
 | 
				
			||||||
import lombok.experimental.SuperBuilder;
 | 
					import lombok.experimental.SuperBuilder;
 | 
				
			||||||
import lombok.extern.slf4j.Slf4j;
 | 
					import lombok.extern.slf4j.Slf4j;
 | 
				
			||||||
@ -35,10 +34,8 @@ import org.deeplearning4j.nn.conf.distribution.Distribution;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
 | 
					import org.deeplearning4j.nn.conf.dropout.Dropout;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.dropout.IDropout;
 | 
					import org.deeplearning4j.nn.conf.dropout.IDropout;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
					import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
 | 
					import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
 | 
					import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.IWeightInit;
 | 
					import org.deeplearning4j.nn.weights.IWeightInit;
 | 
				
			||||||
@ -47,7 +44,6 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution;
 | 
				
			|||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
 | 
					import org.deeplearning4j.nn.weights.WeightInitXavier;
 | 
				
			||||||
import org.deeplearning4j.util.NetworkUtils;
 | 
					import org.deeplearning4j.util.NetworkUtils;
 | 
				
			||||||
import org.nd4j.common.base.Preconditions;
 | 
					import org.nd4j.common.base.Preconditions;
 | 
				
			||||||
import org.nd4j.linalg.activations.Activation;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.activations.IActivation;
 | 
					import org.nd4j.linalg.activations.IActivation;
 | 
				
			||||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
					import org.nd4j.linalg.api.buffer.DataType;
 | 
				
			||||||
import org.nd4j.linalg.learning.config.IUpdater;
 | 
					import org.nd4j.linalg.learning.config.IUpdater;
 | 
				
			||||||
@ -57,9 +53,6 @@ import org.nd4j.linalg.learning.regularization.L2Regularization;
 | 
				
			|||||||
import org.nd4j.linalg.learning.regularization.Regularization;
 | 
					import org.nd4j.linalg.learning.regularization.Regularization;
 | 
				
			||||||
import org.nd4j.linalg.learning.regularization.WeightDecay;
 | 
					import org.nd4j.linalg.learning.regularization.WeightDecay;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.io.IOException;
 | 
					 | 
				
			||||||
import java.util.*;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
 | 
					 * Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
 | 
				
			||||||
 * multiple layers. Everything starts with a NeuralNetConfiguration, which organizes those layers
 | 
					 * multiple layers. Everything starts with a NeuralNetConfiguration, which organizes those layers
 | 
				
			||||||
@ -159,7 +152,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
  @Getter @Setter @NonNull @lombok.Builder.Default
 | 
					  @Getter @Setter @NonNull @lombok.Builder.Default
 | 
				
			||||||
  protected BackpropType backpropType = BackpropType.Standard;
 | 
					  protected BackpropType backpropType = BackpropType.Standard;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @Getter @lombok.Builder.Default
 | 
					  @Getter @Setter @Singular
 | 
				
			||||||
  protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
 | 
					  protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
 | 
					   * When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated)
 | 
				
			||||||
@ -331,7 +324,6 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
   */
 | 
					   */
 | 
				
			||||||
  @Getter @Setter @lombok.Builder.Default private IUpdater biasUpdater = null;
 | 
					  @Getter @Setter @lombok.Builder.Default private IUpdater biasUpdater = null;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Weight initialization scheme to use, for initial weight values Note: values set by this method
 | 
					   * Weight initialization scheme to use, for initial weight values Note: values set by this method
 | 
				
			||||||
   * will be applied to all applicable layers in the network, unless a different value is explicitly
 | 
					   * will be applied to all applicable layers in the network, unless a different value is explicitly
 | 
				
			||||||
@ -339,6 +331,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
   * and can be overridden on a per-layer basis.
 | 
					   * and can be overridden on a per-layer basis.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  @Getter @Setter @lombok.Builder.Default private IWeightInit weightInit = new WeightInitXavier();
 | 
					  @Getter @Setter @lombok.Builder.Default private IWeightInit weightInit = new WeightInitXavier();
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Sets the convolution mode for convolutional layers, which impacts padding and output sizes. See
 | 
					   * Sets the convolution mode for convolutional layers, which impacts padding and output sizes. See
 | 
				
			||||||
   * {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE<br>
 | 
					   * {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE<br>
 | 
				
			||||||
@ -416,113 +409,6 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
  @Getter @Setter @lombok.Builder.Default private double biasInit = 0.0;
 | 
					  @Getter @Setter @lombok.Builder.Default private double biasInit = 0.0;
 | 
				
			||||||
  @Getter @Setter @lombok.Builder.Default private double gainInit = 1.0;
 | 
					  @Getter @Setter @lombok.Builder.Default private double gainInit = 1.0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied
 | 
					 | 
				
			||||||
   * from handling of {@link Activation} above.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @return True if all is well and layer iteration shall continue. False else-wise.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  private static boolean handleLegacyWeightInitFromJson(
 | 
					 | 
				
			||||||
      String json, LayerConfiguration l, ObjectMapper mapper, JsonNode confs, int layerCount) {
 | 
					 | 
				
			||||||
    if ((l instanceof BaseLayerConfiguration)
 | 
					 | 
				
			||||||
        && ((BaseLayerConfiguration) l).getWeightInit() == null) {
 | 
					 | 
				
			||||||
      try {
 | 
					 | 
				
			||||||
        JsonNode jsonNode = mapper.readTree(json);
 | 
					 | 
				
			||||||
        if (confs == null) {
 | 
					 | 
				
			||||||
          confs = jsonNode.get("confs");
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        if (confs instanceof ArrayNode) {
 | 
					 | 
				
			||||||
          ArrayNode layerConfs = (ArrayNode) confs;
 | 
					 | 
				
			||||||
          JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
 | 
					 | 
				
			||||||
          if (outputLayerNNCNode == null) {
 | 
					 | 
				
			||||||
            return false; // Should never happen...
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
          JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
 | 
					 | 
				
			||||||
            return true;
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          JsonNode layerNode = layerWrapperNode.elements().next();
 | 
					 | 
				
			||||||
          JsonNode weightInit =
 | 
					 | 
				
			||||||
              layerNode.get("weightInit"); // Should only have 1 element: "dense", "output", etc
 | 
					 | 
				
			||||||
          JsonNode distribution = layerNode.get("dist");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          Distribution dist = null;
 | 
					 | 
				
			||||||
          if (distribution != null) {
 | 
					 | 
				
			||||||
            dist = mapper.treeToValue(distribution, Distribution.class);
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          if (weightInit != null) {
 | 
					 | 
				
			||||||
            final IWeightInit wi =
 | 
					 | 
				
			||||||
                WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist);
 | 
					 | 
				
			||||||
            ((BaseLayerConfiguration) l).setWeightInit(wi);
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      } catch (IOException e) {
 | 
					 | 
				
			||||||
        log.warn(
 | 
					 | 
				
			||||||
            "ILayer with null WeightInit detected: " + l.getName() + ", could not parse JSON",
 | 
					 | 
				
			||||||
            e);
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return true;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Object mapper for serialization of configurations
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @return
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  public static ObjectMapper mapperYaml() {
 | 
					 | 
				
			||||||
    return JsonMappers.getMapperYaml();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Object mapper for serialization of configurations
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @return
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  public static ObjectMapper mapper() {
 | 
					 | 
				
			||||||
    return JsonMappers.getMapper();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  public static NeuralNetBaseBuilderConfiguration fromYaml(String input) {
 | 
					 | 
				
			||||||
    throw new RuntimeException("Needs fixing - not supported."); // TODO
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * @return JSON representation of NN configuration
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  public String toYaml() {
 | 
					 | 
				
			||||||
    ObjectMapper mapper = NeuralNetBaseBuilderConfiguration.mapperYaml();
 | 
					 | 
				
			||||||
    synchronized (mapper) {
 | 
					 | 
				
			||||||
      try {
 | 
					 | 
				
			||||||
        return mapper.writeValueAsString(this);
 | 
					 | 
				
			||||||
      } catch (com.fasterxml.jackson.core.JsonProcessingException e) {
 | 
					 | 
				
			||||||
        throw new RuntimeException(e);
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * @return JSON representation of NN configuration
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  public String toJson() {
 | 
					 | 
				
			||||||
    ObjectMapper mapper = NeuralNetBaseBuilderConfiguration.mapper();
 | 
					 | 
				
			||||||
    synchronized (mapper) {
 | 
					 | 
				
			||||||
      // JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields
 | 
					 | 
				
			||||||
      // occasionally
 | 
					 | 
				
			||||||
      // when writeValueAsString is used by multiple threads. This results in invalid JSON. See
 | 
					 | 
				
			||||||
      // issue #3243
 | 
					 | 
				
			||||||
      try {
 | 
					 | 
				
			||||||
        return mapper.writeValueAsString(this);
 | 
					 | 
				
			||||||
      } catch (com.fasterxml.jackson.core.JsonProcessingException e) {
 | 
					 | 
				
			||||||
        throw new RuntimeException(e);
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  @Override
 | 
					  @Override
 | 
				
			||||||
  public NeuralNetBaseBuilderConfiguration clone() {
 | 
					  public NeuralNetBaseBuilderConfiguration clone() {
 | 
				
			||||||
    NeuralNetBaseBuilderConfiguration clone;
 | 
					    NeuralNetBaseBuilderConfiguration clone;
 | 
				
			||||||
@ -561,14 +447,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    List<Object> innerConfigurations$value = new ArrayList<>(); // initialize with an empty list
 | 
					    List<Object> innerConfigurations$value = new ArrayList<>(); // initialize with an empty list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public B activation(Activation activation) {
 | 
					
 | 
				
			||||||
      this.activation = activation;
 | 
					 | 
				
			||||||
      return self();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    public B activation(IActivation activation) {
 | 
					 | 
				
			||||||
      this.activation = activation;
 | 
					 | 
				
			||||||
      return self();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Set constraints to be applied to all layers. Default: no constraints.<br>
 | 
					     * Set constraints to be applied to all layers. Default: no constraints.<br>
 | 
				
			||||||
     * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
 | 
					     * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
 | 
				
			||||||
@ -583,7 +462,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
    public B constrainWeights(LayerConstraint... constraints) {
 | 
					    public B constrainWeights(LayerConstraint... constraints) {
 | 
				
			||||||
      constrainWeights$value = Arrays.asList(constraints);
 | 
					      constrainWeights$value = Arrays.asList(constraints);
 | 
				
			||||||
      constrainWeights$set = true;
 | 
					      constrainWeights$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -618,7 +497,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
    public B constrainAllParameters(LayerConstraint... constraints) {
 | 
					    public B constrainAllParameters(LayerConstraint... constraints) {
 | 
				
			||||||
      allParamConstraints$value = Arrays.asList(constraints);
 | 
					      allParamConstraints$value = Arrays.asList(constraints);
 | 
				
			||||||
      allParamConstraints$set = true;
 | 
					      allParamConstraints$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -635,7 +514,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
    public B constrainBias(LayerConstraint... constraints) {
 | 
					    public B constrainBias(LayerConstraint... constraints) {
 | 
				
			||||||
      biasConstraints$value = Arrays.asList(constraints);
 | 
					      biasConstraints$value = Arrays.asList(constraints);
 | 
				
			||||||
      biasConstraints$set = true;
 | 
					      biasConstraints$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -645,11 +524,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
     * @param processor what to use to preProcess the data.
 | 
					     * @param processor what to use to preProcess the data.
 | 
				
			||||||
     * @return builder pattern
 | 
					     * @return builder pattern
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public B inputPreProcessor(Integer layer, InputPreProcessor processor) {
 | 
					   //public B inputPreProcessor(@NonNull Integer layer, @NonNull InputPreProcessor processor) {
 | 
				
			||||||
      inputPreProcessors$value.put(layer, processor);
 | 
					   //   inputPreProcessors$value.put(layer, processor);
 | 
				
			||||||
      inputPreProcessors$set = true;
 | 
					   //   inputPreProcessors$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					   //   return self();
 | 
				
			||||||
    }
 | 
					   // }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Set layer at index
 | 
					     * Set layer at index
 | 
				
			||||||
@ -658,7 +537,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
     * @param layer the layer
 | 
					     * @param layer the layer
 | 
				
			||||||
     * @return builder
 | 
					     * @return builder
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public B layer(Integer index, @NonNull LayerConfiguration layer) {
 | 
					    public B layer(@NonNull Integer index, @NonNull LayerConfiguration layer) {
 | 
				
			||||||
      innerConfigurations$value.add(index, layer);
 | 
					      innerConfigurations$value.add(index, layer);
 | 
				
			||||||
      innerConfigurations$set = true;
 | 
					      innerConfigurations$set = true;
 | 
				
			||||||
      return self();
 | 
					      return self();
 | 
				
			||||||
@ -680,10 +559,11 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
     * @param layer the layer
 | 
					     * @param layer the layer
 | 
				
			||||||
     * @return builder
 | 
					     * @return builder
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
 | 
					    @JsonIgnore
 | 
				
			||||||
    public B layer(@NonNull LayerConfiguration layer) {
 | 
					    public B layer(@NonNull LayerConfiguration layer) {
 | 
				
			||||||
      innerConfigurations$value.add(layer);
 | 
					      innerConfigurations$value.add(layer);
 | 
				
			||||||
      innerConfigurations$set = true;
 | 
					      innerConfigurations$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    public B layer(@NonNull LayerConfiguration.LayerConfigurationBuilder<?, ?> layer) {
 | 
					    public B layer(@NonNull LayerConfiguration.LayerConfigurationBuilder<?, ?> layer) {
 | 
				
			||||||
      return this.layer(layer.build());
 | 
					      return this.layer(layer.build());
 | 
				
			||||||
@ -699,7 +579,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
    public B layersFromArray(@NonNull LayerConfiguration[] arrLayers) {
 | 
					    public B layersFromArray(@NonNull LayerConfiguration[] arrLayers) {
 | 
				
			||||||
      innerConfigurations$value.addAll(List.of(arrLayers));
 | 
					      innerConfigurations$value.addAll(List.of(arrLayers));
 | 
				
			||||||
      innerConfigurations$set = true;
 | 
					      innerConfigurations$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /** Specify additional layer configurations */
 | 
					    /** Specify additional layer configurations */
 | 
				
			||||||
@ -707,7 +587,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
    public B layersFromList(@NonNull List<LayerConfiguration> listLayers) {
 | 
					    public B layersFromList(@NonNull List<LayerConfiguration> listLayers) {
 | 
				
			||||||
      innerConfigurations$value.addAll(listLayers);
 | 
					      innerConfigurations$value.addAll(listLayers);
 | 
				
			||||||
      innerConfigurations$set = true;
 | 
					      innerConfigurations$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -723,7 +603,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
        regularization$value.add(new L1Regularization(l1));
 | 
					        regularization$value.add(new L1Regularization(l1));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      regularization$set = true;
 | 
					      regularization$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -751,7 +631,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
        regularization$value.add(new L2Regularization(l2));
 | 
					        regularization$value.add(new L2Regularization(l2));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      regularization$set = true;
 | 
					      regularization$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -766,7 +646,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
        regularizationBias$value.add(new L1Regularization(l1Bias));
 | 
					        regularizationBias$value.add(new L1Regularization(l1Bias));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      regularizationBias$set = true;
 | 
					      regularizationBias$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -791,7 +671,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
            "L2 bias regularization removed: incompatible with added WeightDecay regularization");
 | 
					            "L2 bias regularization removed: incompatible with added WeightDecay regularization");
 | 
				
			||||||
        regularizationBias$value.add(new L2Regularization(l2Bias));
 | 
					        regularizationBias$value.add(new L2Regularization(l2Bias));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -833,7 +713,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
        regularization$value.add(new WeightDecay(coefficient, applyLR));
 | 
					        regularization$value.add(new WeightDecay(coefficient, applyLR));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      regularization$set = true;
 | 
					      regularization$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -870,7 +750,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
        regularizationBias$value.add(new WeightDecay(coefficient, applyLR));
 | 
					        regularizationBias$value.add(new WeightDecay(coefficient, applyLR));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      regularization$set = true;
 | 
					      regularization$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -881,7 +761,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
     */
 | 
					     */
 | 
				
			||||||
    @Deprecated
 | 
					    @Deprecated
 | 
				
			||||||
    public B list() {
 | 
					    public B list() {
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -893,23 +773,24 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
     *
 | 
					     *
 | 
				
			||||||
     * @param distribution Distribution to use for weight initialization
 | 
					     * @param distribution Distribution to use for weight initialization
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    @JsonIgnore
 | 
					    @JsonIgnore @Deprecated
 | 
				
			||||||
    public B weightInit(Distribution distribution) {
 | 
					    public B weightInit(Distribution distribution) {
 | 
				
			||||||
      this.weightInit$value = new WeightInitDistribution(distribution);
 | 
					      this.weightInit$value = new WeightInitDistribution(distribution);
 | 
				
			||||||
      this.weightInit$set = true;
 | 
					      this.weightInit$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    @JsonIgnore
 | 
					    @JsonIgnore
 | 
				
			||||||
    public B weightInit(WeightInit weightInit) {
 | 
					    public B weightInit(WeightInit weightInit) {
 | 
				
			||||||
      this.weightInit$value = weightInit.getWeightInitFunction();
 | 
					      this.weightInit$value = weightInit.getWeightInitFunction();
 | 
				
			||||||
      this.weightInit$set = true;
 | 
					      this.weightInit$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @JsonProperty("weightInit") //this is needed for Jackson < 2.4, otherwise JsonIgnore on the other setters will ignore this also
 | 
				
			||||||
    public B weightInit(IWeightInit iWeightInit) {
 | 
					    public B weightInit(IWeightInit iWeightInit) {
 | 
				
			||||||
      this.weightInit$value = iWeightInit;
 | 
					      this.weightInit$value = iWeightInit;
 | 
				
			||||||
      this.weightInit$set = true;
 | 
					      this.weightInit$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -918,12 +799,13 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
     * @param distribution
 | 
					     * @param distribution
 | 
				
			||||||
     * @return
 | 
					     * @return
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
 | 
					    @JsonIgnore
 | 
				
			||||||
    public B dist(@NonNull Distribution distribution) {
 | 
					    public B dist(@NonNull Distribution distribution) {
 | 
				
			||||||
      return (B) weightInit(distribution);
 | 
					      return weightInit(distribution);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public B dropOut(@NonNull IDropout dropout) {
 | 
					    public B dropOut(@NonNull IDropout dropout) {
 | 
				
			||||||
      return (B) idropOut(dropout);
 | 
					      return idropOut(dropout);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -933,7 +815,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
     * @return builder
 | 
					     * @return builder
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public B dropOut(double dropout) {
 | 
					    public B dropOut(double dropout) {
 | 
				
			||||||
      return (B) idropOut(new Dropout(dropout));
 | 
					      return idropOut(new Dropout(dropout));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@ -946,7 +828,8 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
 | 
				
			|||||||
    public B confs(@NonNull List<NeuralNetConfiguration> confs) {
 | 
					    public B confs(@NonNull List<NeuralNetConfiguration> confs) {
 | 
				
			||||||
      innerConfigurations$value.addAll(confs);
 | 
					      innerConfigurations$value.addAll(confs);
 | 
				
			||||||
      innerConfigurations$set = true;
 | 
					      innerConfigurations$set = true;
 | 
				
			||||||
      return (B) this;
 | 
					      return self();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -22,39 +22,27 @@ package org.deeplearning4j.nn.conf;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
 | 
					import com.fasterxml.jackson.annotation.JsonIgnore;
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 | 
					import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 | 
				
			||||||
import com.fasterxml.jackson.databind.JsonNode;
 | 
					import com.fasterxml.jackson.core.JsonProcessingException;
 | 
				
			||||||
import com.fasterxml.jackson.databind.ObjectMapper;
 | 
					import com.fasterxml.jackson.databind.*;
 | 
				
			||||||
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
 | 
					import java.util.*;
 | 
				
			||||||
import com.fasterxml.jackson.databind.node.ArrayNode;
 | 
					import java.util.concurrent.atomic.AtomicInteger;
 | 
				
			||||||
 | 
					import java.util.stream.Collectors;
 | 
				
			||||||
import lombok.*;
 | 
					import lombok.*;
 | 
				
			||||||
import lombok.experimental.SuperBuilder;
 | 
					import lombok.experimental.SuperBuilder;
 | 
				
			||||||
import lombok.extern.jackson.Jacksonized;
 | 
					import lombok.extern.jackson.Jacksonized;
 | 
				
			||||||
import lombok.extern.slf4j.Slf4j;
 | 
					import lombok.extern.slf4j.Slf4j;
 | 
				
			||||||
import net.brutex.ai.dnn.api.IModel;
 | 
					import net.brutex.ai.dnn.api.IModel;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.*;
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
 | 
					import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
 | 
					import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
 | 
					import org.deeplearning4j.nn.conf.memory.MemoryReport;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
 | 
					import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
 | 
					import org.deeplearning4j.nn.conf.serde.CavisMapper;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.IWeightInit;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
					 | 
				
			||||||
import org.deeplearning4j.util.OutputLayerUtil;
 | 
					import org.deeplearning4j.util.OutputLayerUtil;
 | 
				
			||||||
import org.nd4j.linalg.activations.Activation;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.factory.Nd4j;
 | 
					import org.nd4j.linalg.factory.Nd4j;
 | 
				
			||||||
import org.nd4j.linalg.learning.config.IUpdater;
 | 
					import org.nd4j.linalg.learning.config.IUpdater;
 | 
				
			||||||
import org.nd4j.linalg.learning.config.Sgd;
 | 
					import org.nd4j.linalg.learning.config.Sgd;
 | 
				
			||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.io.IOException;
 | 
					 | 
				
			||||||
import java.util.*;
 | 
					 | 
				
			||||||
import java.util.stream.Collectors;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
 | 
					 * Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
 | 
				
			||||||
@ -62,71 +50,50 @@ import java.util.stream.Collectors;
 | 
				
			|||||||
 * and their hyperparameters. Hyperparameters are variables that determine how a neural network
 | 
					 * and their hyperparameters. Hyperparameters are variables that determine how a neural network
 | 
				
			||||||
 * learns. They include how many times to update the weights of the model, how to initialize those
 | 
					 * learns. They include how many times to update the weights of the model, how to initialize those
 | 
				
			||||||
 * weights, which activation function to attach to the nodes, which optimization algorithm to use,
 | 
					 * weights, which activation function to attach to the nodes, which optimization algorithm to use,
 | 
				
			||||||
 * and how fast the model should learn. This is what one configuration would look like:
 | 
					 * and how fast the model should learn. This is what one configuration would look like: <br>
 | 
				
			||||||
 * <br/><br/>
 | 
					 * <br>
 | 
				
			||||||
 *
 | 
					 * NeuralNetConfiguration conf = NeuralNetConfiguration.builder()<br>
 | 
				
			||||||
 * NeuralNetConfiguration conf = NeuralNetConfiguration.builder()<br/>
 | 
					 * .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)<br>
 | 
				
			||||||
 * .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)<br/>
 | 
					 * .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)<br>
 | 
				
			||||||
 * .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)<br/>
 | 
					 * .updater(new Sgd(0.05)) //... other hyperparameters <br>
 | 
				
			||||||
 * .updater(new Sgd(0.05)) //... other hyperparameters <br/>
 | 
					 * .backprop(true)<br>
 | 
				
			||||||
 *  .backprop(true)<br/>
 | 
					 * .build();<br>
 | 
				
			||||||
 * .build();<br/><br/>
 | 
					 * <br>
 | 
				
			||||||
 *
 | 
					 * With Deeplearning4j, you add a layer by calling layer on the
 | 
				
			||||||
 * With Deeplearning4j, you add a layer
 | 
					 * NeuralNetConfiguration.NeuralNetConfigurationBuilder(), specifying its place in the order of
 | 
				
			||||||
 * by calling layer on the NeuralNetConfiguration.NeuralNetConfigurationBuilder(), specifying its place in the order of
 | 
					 | 
				
			||||||
 * layers (the zero-indexed layer below is the input layer), the number of input and output nodes,
 | 
					 * layers (the zero-indexed layer below is the input layer), the number of input and output nodes,
 | 
				
			||||||
 * nIn and nOut, as well as the type: DenseLayer.<br/><br/>
 | 
					 * nIn and nOut, as well as the type: DenseLayer.<br>
 | 
				
			||||||
 *
 | 
					 * <br>
 | 
				
			||||||
 * .layer(0, DenseLayer.builder().nIn(784).nOut(250)<br/>
 | 
					 * .layer(0, DenseLayer.builder().nIn(784).nOut(250)<br>
 | 
				
			||||||
 * .build())<br/><br/>
 | 
					 * .build())<br>
 | 
				
			||||||
 *
 | 
					 * <br>
 | 
				
			||||||
 * Once you've configured your net, you train the
 | 
					 * Once you've configured your net, you train the model with model.fit.
 | 
				
			||||||
 * model with model.fit.
 | 
					 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
@Slf4j
 | 
					@Slf4j
 | 
				
			||||||
@Jacksonized
 | 
					@JsonIgnoreProperties(value = {"net"})
 | 
				
			||||||
@JsonIgnoreProperties(value={"net"}, ignoreUnknown = true)
 | 
					@EqualsAndHashCode(callSuper = true)
 | 
				
			||||||
@EqualsAndHashCode(exclude = {"net"}, callSuper = true)
 | 
					// @JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id")
 | 
				
			||||||
//@JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
//The inner builder, that we can then extend ...
 | 
					// The inner builder, that we can then extend ...
 | 
				
			||||||
@SuperBuilder //TODO fix access
 | 
					@Jacksonized
 | 
				
			||||||
 | 
					@SuperBuilder // TODO fix access
 | 
				
			||||||
public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
					public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
  private IModel net;
 | 
					 | 
				
			||||||
  private static final int DEFAULT_TBPTT_LENGTH = 20;
 | 
					  private static final int DEFAULT_TBPTT_LENGTH = 20;
 | 
				
			||||||
  private boolean initCalled = false;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Getter @Setter @NonNull @lombok.Builder.Default @Deprecated
 | 
				
			||||||
  @Getter
 | 
					 | 
				
			||||||
  @Setter
 | 
					 | 
				
			||||||
  @NonNull
 | 
					 | 
				
			||||||
  @lombok.Builder.Default
 | 
					 | 
				
			||||||
  @Deprecated
 | 
					 | 
				
			||||||
  protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
 | 
					  protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
 | 
				
			||||||
  @Getter
 | 
					
 | 
				
			||||||
  @Setter
 | 
					  @Getter @Setter @NonNull @lombok.Builder.Default @Deprecated
 | 
				
			||||||
  @NonNull
 | 
					 | 
				
			||||||
  @lombok.Builder.Default
 | 
					 | 
				
			||||||
  @Deprecated
 | 
					 | 
				
			||||||
  protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
 | 
					  protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Getter @Setter @lombok.Builder.Default protected int iterationCount = 0;
 | 
				
			||||||
  @Getter
 | 
					  // Counter for the number of epochs completed so far. Used for per-epoch schedules
 | 
				
			||||||
  @Setter
 | 
					  @Getter @Setter @lombok.Builder.Default protected int epochCount = 0;
 | 
				
			||||||
  @lombok.Builder.Default
 | 
					  @lombok.Builder.Default protected double dampingFactor = 100;
 | 
				
			||||||
  protected int iterationCount = 0;
 | 
					  @EqualsAndHashCode.Exclude private IModel net;
 | 
				
			||||||
  //Counter for the number of epochs completed so far. Used for per-epoch schedules
 | 
					  private boolean initCalled = false;
 | 
				
			||||||
  @Getter
 | 
					 | 
				
			||||||
  @Setter
 | 
					 | 
				
			||||||
  @lombok.Builder.Default
 | 
					 | 
				
			||||||
  protected int epochCount = 0;
 | 
					 | 
				
			||||||
  @lombok.Builder.Default
 | 
					 | 
				
			||||||
  protected double dampingFactor = 100;
 | 
					 | 
				
			||||||
  // gradient keys used for ensuring order when getting and setting the gradient
 | 
					  // gradient keys used for ensuring order when getting and setting the gradient
 | 
				
			||||||
  @lombok.Builder.Default private LinkedHashSet<String> netWideVariables = new LinkedHashSet<>();
 | 
					  @lombok.Builder.Default private LinkedHashSet<String> netWideVariables = new LinkedHashSet<>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -141,22 +108,19 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
   */
 | 
					   */
 | 
				
			||||||
  @Getter @Setter @Builder.Default private IUpdater updater = new Sgd();
 | 
					  @Getter @Setter @Builder.Default private IUpdater updater = new Sgd();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage of cuDNN.
 | 
					   * Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage
 | 
				
			||||||
   * See {@link ConvolutionLayer.AlgoMode} for details.  Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory.
 | 
					   * of cuDNN. See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but
 | 
				
			||||||
   * <br>
 | 
					   * "NO_WORKSPACE" uses less memory. <br>
 | 
				
			||||||
   * Note: values set by this method will be applied to all applicable layers in the network, unless a different
 | 
					   * Note: values set by this method will be applied to all applicable layers in the network, unless
 | 
				
			||||||
   * value is explicitly set on a given layer. In other words: values set via this method are used as the default
 | 
					   * a different value is explicitly set on a given layer. In other words: values set via this
 | 
				
			||||||
   * value, and can be overridden on a per-layer basis.
 | 
					   * method are used as the default value, and can be overridden on a per-layer basis.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
   * @param cudnnAlgoMode cuDNN algo mode to use
 | 
					   * @param cudnnAlgoMode cuDNN algo mode to use
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  @Getter
 | 
					  @Getter @Setter @lombok.Builder.Default
 | 
				
			||||||
  @Setter
 | 
					 | 
				
			||||||
  @lombok.Builder.Default
 | 
					 | 
				
			||||||
  private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST;
 | 
					  private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Create a neural net configuration from json
 | 
					   * Create a neural net configuration from json
 | 
				
			||||||
   *
 | 
					   *
 | 
				
			||||||
@ -164,260 +128,23 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
   * @return {@link NeuralNetConfiguration}
 | 
					   * @return {@link NeuralNetConfiguration}
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public static NeuralNetConfiguration fromJson(String json) {
 | 
					  public static NeuralNetConfiguration fromJson(String json) {
 | 
				
			||||||
    NeuralNetConfiguration conf;
 | 
					    ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.JSON);
 | 
				
			||||||
    ObjectMapper mapper = NeuralNetConfiguration.mapper();
 | 
					 | 
				
			||||||
    try {
 | 
					    try {
 | 
				
			||||||
      conf = mapper.readValue(json, NeuralNetConfiguration.class);
 | 
					      return mapper.readValue(json, NeuralNetConfiguration.class);
 | 
				
			||||||
    } catch (InvalidTypeIdException e) {
 | 
					    } catch (JsonProcessingException e) {
 | 
				
			||||||
      if (e.getMessage().contains("@class")) {
 | 
					 | 
				
			||||||
        try {
 | 
					 | 
				
			||||||
          //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
 | 
					 | 
				
			||||||
          return JsonMappers.getLegacyMapper().readValue(json, NeuralNetConfiguration.class);
 | 
					 | 
				
			||||||
        } catch (InvalidTypeIdException e2) {
 | 
					 | 
				
			||||||
          //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.ILayer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
 | 
					 | 
				
			||||||
          //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
 | 
					 | 
				
			||||||
          String msg = e2.getMessage();
 | 
					 | 
				
			||||||
          if (msg != null && msg.contains("Could not resolve type id")) {
 | 
					 | 
				
			||||||
            throw new RuntimeException(
 | 
					 | 
				
			||||||
                "Error deserializing NeuralNetConfiguration - configuration may have a custom " +
 | 
					 | 
				
			||||||
                    "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom"
 | 
					 | 
				
			||||||
                    +
 | 
					 | 
				
			||||||
                    " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J",
 | 
					 | 
				
			||||||
                e);
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
          throw new RuntimeException(e2);
 | 
					 | 
				
			||||||
        } catch (IOException e2) {
 | 
					 | 
				
			||||||
          throw new RuntimeException(e2);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      throw new RuntimeException(e);
 | 
					 | 
				
			||||||
    } catch (IOException e) {
 | 
					 | 
				
			||||||
      //Check if this exception came from legacy deserializer...
 | 
					 | 
				
			||||||
      String msg = e.getMessage();
 | 
					 | 
				
			||||||
      if (msg != null && msg.contains("legacy")) {
 | 
					 | 
				
			||||||
        throw new RuntimeException(
 | 
					 | 
				
			||||||
            "Error deserializing NeuralNetConfiguration - configuration may have a custom " +
 | 
					 | 
				
			||||||
                "layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be "
 | 
					 | 
				
			||||||
                +
 | 
					 | 
				
			||||||
                "deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)",
 | 
					 | 
				
			||||||
            e);
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      throw new RuntimeException(e);
 | 
					      throw new RuntimeException(e);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    //To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
 | 
					 | 
				
			||||||
    // Previously: enumeration used for loss functions. Now: use classes
 | 
					 | 
				
			||||||
    // IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
 | 
					 | 
				
			||||||
    int layerCount = 0;
 | 
					 | 
				
			||||||
    JsonNode confs = null;
 | 
					 | 
				
			||||||
    for (LayerConfiguration nnc : conf.getFlattenedLayerConfigurations()) {
 | 
					 | 
				
			||||||
      LayerConfiguration l = nnc;
 | 
					 | 
				
			||||||
      if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFunction() == null) {
 | 
					 | 
				
			||||||
        //lossFn field null -> may be an old config format, with lossFunction field being for the enum
 | 
					 | 
				
			||||||
        //if so, try walking the JSON graph to extract out the appropriate enum value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        BaseOutputLayer ol = (BaseOutputLayer) l;
 | 
					 | 
				
			||||||
        try {
 | 
					 | 
				
			||||||
          JsonNode jsonNode = mapper.readTree(json);
 | 
					 | 
				
			||||||
          if (confs == null) {
 | 
					 | 
				
			||||||
            confs = jsonNode.get("confs");
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
          if (confs instanceof ArrayNode) {
 | 
					 | 
				
			||||||
            ArrayNode layerConfs = (ArrayNode) confs;
 | 
					 | 
				
			||||||
            JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
 | 
					 | 
				
			||||||
            if (outputLayerNNCNode == null) {
 | 
					 | 
				
			||||||
              throw new RuntimeException(
 | 
					 | 
				
			||||||
                  "should never happen"); //return conf; //Should never happen...
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            JsonNode lossFunctionNode = null;
 | 
					 | 
				
			||||||
            if (outputLayerNode.has("output")) {
 | 
					 | 
				
			||||||
              lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
 | 
					 | 
				
			||||||
            } else if (outputLayerNode.has("rnnoutput")) {
 | 
					 | 
				
			||||||
              lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if (lossFunctionNode != null) {
 | 
					 | 
				
			||||||
              String lossFunctionEnumStr = lossFunctionNode.asText();
 | 
					 | 
				
			||||||
              LossFunctions.LossFunction lossFunction = null;
 | 
					 | 
				
			||||||
              try {
 | 
					 | 
				
			||||||
                lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
 | 
					 | 
				
			||||||
              } catch (Exception e) {
 | 
					 | 
				
			||||||
                log.warn(
 | 
					 | 
				
			||||||
                    "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
 | 
					 | 
				
			||||||
                    e);
 | 
					 | 
				
			||||||
              }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
              if (lossFunction != null) {
 | 
					 | 
				
			||||||
                switch (lossFunction) {
 | 
					 | 
				
			||||||
                  case MSE:
 | 
					 | 
				
			||||||
                    ol.setLossFunction(new LossMSE());
 | 
					 | 
				
			||||||
                    break;
 | 
					 | 
				
			||||||
                  case XENT:
 | 
					 | 
				
			||||||
                    ol.setLossFunction(new LossBinaryXENT());
 | 
					 | 
				
			||||||
                    break;
 | 
					 | 
				
			||||||
                  case NEGATIVELOGLIKELIHOOD:
 | 
					 | 
				
			||||||
                    ol.setLossFunction(new LossNegativeLogLikelihood());
 | 
					 | 
				
			||||||
                    break;
 | 
					 | 
				
			||||||
                  case MCXENT:
 | 
					 | 
				
			||||||
                    ol.setLossFunction(new LossMCXENT());
 | 
					 | 
				
			||||||
                    break;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                  //Remaining: TODO
 | 
					 | 
				
			||||||
                  case SQUARED_LOSS:
 | 
					 | 
				
			||||||
                  case RECONSTRUCTION_CROSSENTROPY:
 | 
					 | 
				
			||||||
                  default:
 | 
					 | 
				
			||||||
                    log.warn(
 | 
					 | 
				
			||||||
                        "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}",
 | 
					 | 
				
			||||||
                        lossFunction);
 | 
					 | 
				
			||||||
                    break;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
              }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          } else {
 | 
					 | 
				
			||||||
            log.warn(
 | 
					 | 
				
			||||||
                "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})",
 | 
					 | 
				
			||||||
                (confs != null ? confs.getClass() : null));
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
        } catch (IOException e) {
 | 
					 | 
				
			||||||
          log.warn(
 | 
					 | 
				
			||||||
              "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
 | 
					 | 
				
			||||||
              e);
 | 
					 | 
				
			||||||
          break;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      //Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn")
 | 
					 | 
				
			||||||
      //Try to load the old format if necessary, and create the appropriate IActivation instance
 | 
					 | 
				
			||||||
      if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getActivationFn() == null) {
 | 
					 | 
				
			||||||
        try {
 | 
					 | 
				
			||||||
          JsonNode jsonNode = mapper.readTree(json);
 | 
					 | 
				
			||||||
          if (confs == null) {
 | 
					 | 
				
			||||||
            confs = jsonNode.get("confs");
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
          if (confs instanceof ArrayNode) {
 | 
					 | 
				
			||||||
            ArrayNode layerConfs = (ArrayNode) confs;
 | 
					 | 
				
			||||||
            JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
 | 
					 | 
				
			||||||
            if (outputLayerNNCNode == null) {
 | 
					 | 
				
			||||||
              throw new RuntimeException(
 | 
					 | 
				
			||||||
                  "Should never happen"); //return conf; //Should never happen...
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
 | 
					 | 
				
			||||||
              continue;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            JsonNode layerNode = layerWrapperNode.elements().next();
 | 
					 | 
				
			||||||
            JsonNode activationFunction = layerNode.get(
 | 
					 | 
				
			||||||
                "activationFunction"); //Should only have 1 element: "dense", "output", etc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if (activationFunction != null) {
 | 
					 | 
				
			||||||
              Activation ia = Activation.fromString(activationFunction.asText());
 | 
					 | 
				
			||||||
              ((BaseLayerConfiguration) l).setActivation(ia.getActivationFunction());
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        } catch (IOException e) {
 | 
					 | 
				
			||||||
          log.warn(
 | 
					 | 
				
			||||||
              "ILayer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
 | 
					 | 
				
			||||||
              e);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      if (!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) {
 | 
					 | 
				
			||||||
        return conf;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      layerCount++;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return conf;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied
 | 
					 | 
				
			||||||
   * from handling of {@link Activation} above.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @return True if all is well and layer iteration shall continue. False else-wise.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l,
 | 
					 | 
				
			||||||
      ObjectMapper mapper,
 | 
					 | 
				
			||||||
      JsonNode confs, int layerCount) {
 | 
					 | 
				
			||||||
    if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInit() == null) {
 | 
					 | 
				
			||||||
      try {
 | 
					 | 
				
			||||||
        JsonNode jsonNode = mapper.readTree(json);
 | 
					 | 
				
			||||||
        if (confs == null) {
 | 
					 | 
				
			||||||
          confs = jsonNode.get("confs");
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        if (confs instanceof ArrayNode) {
 | 
					 | 
				
			||||||
          ArrayNode layerConfs = (ArrayNode) confs;
 | 
					 | 
				
			||||||
          JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
 | 
					 | 
				
			||||||
          if (outputLayerNNCNode == null) {
 | 
					 | 
				
			||||||
            return false; //Should never happen...
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
          JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
 | 
					 | 
				
			||||||
            return true;
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          JsonNode layerNode = layerWrapperNode.elements().next();
 | 
					 | 
				
			||||||
          JsonNode weightInit = layerNode.get(
 | 
					 | 
				
			||||||
              "weightInit"); //Should only have 1 element: "dense", "output", etc
 | 
					 | 
				
			||||||
          JsonNode distribution = layerNode.get("dist");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          Distribution dist = null;
 | 
					 | 
				
			||||||
          if (distribution != null) {
 | 
					 | 
				
			||||||
            dist = mapper.treeToValue(distribution, Distribution.class);
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          if (weightInit != null) {
 | 
					 | 
				
			||||||
            final IWeightInit wi = WeightInit.valueOf(weightInit.asText())
 | 
					 | 
				
			||||||
                .getWeightInitFunction(dist);
 | 
					 | 
				
			||||||
            ((BaseLayerConfiguration) l).setWeightInit(wi);
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      } catch (IOException e) {
 | 
					 | 
				
			||||||
        log.warn(
 | 
					 | 
				
			||||||
            "ILayer with null WeightInit detected: " + l.getName() + ", could not parse JSON",
 | 
					 | 
				
			||||||
            e);
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return true;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Object mapper for serialization of configurations
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @return
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  public static ObjectMapper mapperYaml() {
 | 
					 | 
				
			||||||
    return JsonMappers.getMapperYaml();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Object mapper for serialization of configurations
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @return
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  public static ObjectMapper mapper() {
 | 
					 | 
				
			||||||
    return JsonMappers.getMapper();
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public static NeuralNetConfiguration fromYaml(String input) {
 | 
					  public static NeuralNetConfiguration fromYaml(String input) {
 | 
				
			||||||
    throw new RuntimeException("Needs fixing - not supported."); //TODO
 | 
					    throw new RuntimeException("Needs fixing - not supported."); // TODO
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * @return JSON representation of NN configuration
 | 
					   * @return JSON representation of NN configuration
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public String toYaml() {
 | 
					  public String toYaml() {
 | 
				
			||||||
    ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
 | 
					    ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
 | 
				
			||||||
    synchronized (mapper) {
 | 
					    synchronized (mapper) {
 | 
				
			||||||
      try {
 | 
					      try {
 | 
				
			||||||
        return mapper.writeValueAsString(this);
 | 
					        return mapper.writeValueAsString(this);
 | 
				
			||||||
@ -431,10 +158,12 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
   * @return JSON representation of NN configuration
 | 
					   * @return JSON representation of NN configuration
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public String toJson() {
 | 
					  public String toJson() {
 | 
				
			||||||
    ObjectMapper mapper = NeuralNetConfiguration.mapper();
 | 
					    ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.JSON);
 | 
				
			||||||
    synchronized (mapper) {
 | 
					    synchronized (mapper) {
 | 
				
			||||||
      //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
 | 
					      // JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields
 | 
				
			||||||
      //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
 | 
					      // occasionally
 | 
				
			||||||
 | 
					      // when writeValueAsString is used by multiple threads. This results in invalid JSON. See
 | 
				
			||||||
 | 
					      // issue #3243
 | 
				
			||||||
      try {
 | 
					      try {
 | 
				
			||||||
        return mapper.writeValueAsString(this);
 | 
					        return mapper.writeValueAsString(this);
 | 
				
			||||||
      } catch (com.fasterxml.jackson.core.JsonProcessingException e) {
 | 
					      } catch (com.fasterxml.jackson.core.JsonProcessingException e) {
 | 
				
			||||||
@ -453,7 +182,9 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
  public NeuralNetConfiguration clone() {
 | 
					  public NeuralNetConfiguration clone() {
 | 
				
			||||||
    NeuralNetConfiguration clone;
 | 
					    NeuralNetConfiguration clone;
 | 
				
			||||||
    clone = (NeuralNetConfiguration) super.clone();
 | 
					    clone = (NeuralNetConfiguration) super.clone();
 | 
				
			||||||
    if(getStepFunction() != null) {   clone.setStepFunction(getStepFunction().clone()); }
 | 
					    if (getStepFunction() != null) {
 | 
				
			||||||
 | 
					      clone.setStepFunction(getStepFunction().clone());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    clone.netWideVariables = new LinkedHashSet<>(netWideVariables);
 | 
					    clone.netWideVariables = new LinkedHashSet<>(netWideVariables);
 | 
				
			||||||
    clone.setInnerConfigurations(new ArrayList<>(innerConfigurations));
 | 
					    clone.setInnerConfigurations(new ArrayList<>(innerConfigurations));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -473,98 +204,109 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
    clone.setDataType(this.getDataType());
 | 
					    clone.setDataType(this.getDataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return clone;
 | 
					    return clone;
 | 
				
			||||||
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /** */
 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  @Override
 | 
					  @Override
 | 
				
			||||||
  public void init() {
 | 
					  public void init() {
 | 
				
			||||||
    if(initCalled) return;
 | 
					    if (initCalled) return;
 | 
				
			||||||
    initCalled=true;
 | 
					    initCalled = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /** Run init() for each layer */
 | 
				
			||||||
     * Run init() for each layer
 | 
					    for (NeuralNetConfiguration nconf : getNetConfigurations()) {
 | 
				
			||||||
     */
 | 
					 | 
				
			||||||
    for( NeuralNetConfiguration nconf : getNetConfigurations() ) {
 | 
					 | 
				
			||||||
      nconf.init();
 | 
					      nconf.init();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    //getNetConfigurations().stream().forEach( conf -> {
 | 
					    // getNetConfigurations().stream().forEach( conf -> {
 | 
				
			||||||
    // conf.init(); //do not call on self
 | 
					    // conf.init(); //do not call on self
 | 
				
			||||||
    //}); //call init on all embedded net configurations
 | 
					    // }); //call init on all embedded net configurations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //TODO do not put inside self to avoid serialization issues
 | 
					    // TODO do not put inside self to avoid serialization issues
 | 
				
			||||||
    // innerConfigurations.add(0, this); //put this configuration at first place
 | 
					    // innerConfigurations.add(0, this); //put this configuration at first place
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    getLayerConfigurations().stream()
 | 
				
			||||||
 | 
					            .forEach(
 | 
				
			||||||
 | 
					                    lconf ->
 | 
				
			||||||
 | 
					                            lconf.setNetConfiguration(
 | 
				
			||||||
 | 
					                                    this)); // set this as net config for all layers (defined in here, not stacked
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Inherit network wide configuration setting to those layer configurations
 | 
					     * Inherit network wide configuration setting to those layer configurations that do not have an
 | 
				
			||||||
     * that do not have an individual setting (nor a default)
 | 
					     * individual setting (nor a default)
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    for(LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) {
 | 
					    for (LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) {
 | 
				
			||||||
      lconf.runInheritance();
 | 
					      lconf.runInheritance();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    getLayerConfigurations().stream().forEach( lconf -> lconf.setNetConfiguration(this)); //set this as net config for all layers (defined in here, not stacked
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Validate BackpropType setting
 | 
				
			||||||
    //Validate BackpropType setting
 | 
					 | 
				
			||||||
    if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH)
 | 
					    if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH)
 | 
				
			||||||
        && backpropType != BackpropType.TruncatedBPTT) {
 | 
					        && backpropType != BackpropType.TruncatedBPTT) {
 | 
				
			||||||
      log.warn("Truncated backpropagation through time lengths have been configured with values "
 | 
					      log.warn(
 | 
				
			||||||
 | 
					          "Truncated backpropagation through time lengths have been configured with values "
 | 
				
			||||||
              + tbpttFwdLength
 | 
					              + tbpttFwdLength
 | 
				
			||||||
          + " and " + tbpttBackLength + " but backprop type is set to " + backpropType
 | 
					              + " and "
 | 
				
			||||||
          + ". TBPTT configuration" +
 | 
					              + tbpttBackLength
 | 
				
			||||||
          " settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
 | 
					              + " but backprop type is set to "
 | 
				
			||||||
 | 
					              + backpropType
 | 
				
			||||||
 | 
					              + ". TBPTT configuration"
 | 
				
			||||||
 | 
					              + " settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (backpropType == BackpropType.TruncatedBPTT && isValidateTbpttConfig()) {
 | 
					    if (backpropType == BackpropType.TruncatedBPTT && isValidateTbpttConfig()) {
 | 
				
			||||||
      //Check for invalid combination - tbptt plus LastTimeStepLayer or
 | 
					      // Check for invalid combination - tbptt plus LastTimeStepLayer or
 | 
				
			||||||
      for (int i = 0; i < getFlattenedLayerConfigurations().size(); i++) {
 | 
					      for (int i = 0; i < getFlattenedLayerConfigurations().size(); i++) {
 | 
				
			||||||
        LayerConfiguration l = getFlattenedLayerConfigurations().get(i);
 | 
					        LayerConfiguration l = getFlattenedLayerConfigurations().get(i);
 | 
				
			||||||
        if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) {
 | 
					        if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) {
 | 
				
			||||||
          throw new IllegalStateException(
 | 
					          throw new IllegalStateException(
 | 
				
			||||||
              "Invalid network configuration detected: Truncated backpropagation through time (TBPTT)"
 | 
					              "Invalid network configuration detected: Truncated backpropagation through time (TBPTT)"
 | 
				
			||||||
                  +
 | 
					                  + " cannot be used with layer "
 | 
				
			||||||
                  " cannot be used with layer " + i + " of type " + l.getClass().getName()
 | 
					                  + i
 | 
				
			||||||
                  + ": TBPTT is incompatible with this layer type (which is designed " +
 | 
					                  + " of type "
 | 
				
			||||||
                  "to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n"
 | 
					                  + l.getClass().getName()
 | 
				
			||||||
                  +
 | 
					                  + ": TBPTT is incompatible with this layer type (which is designed "
 | 
				
			||||||
                  "This check can be disabled using validateTbpttConfig(false) but this is not recommended.");
 | 
					                  + "to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n"
 | 
				
			||||||
 | 
					                  + "This check can be disabled using validateTbpttConfig(false) but this is not recommended.");
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (getInputType() == null && inputPreProcessors.get(0) == null) {
 | 
					    if (getInputType() == null && inputPreProcessors.get(0) == null) {
 | 
				
			||||||
      //User hasn't set the InputType. Sometimes we can infer it...
 | 
					      // User hasn't set the InputType. Sometimes we can infer it...
 | 
				
			||||||
      // For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to feed in
 | 
					      // For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to
 | 
				
			||||||
 | 
					      // feed in
 | 
				
			||||||
      // standard feedforward or RNN data
 | 
					      // standard feedforward or RNN data
 | 
				
			||||||
      //This isn't the most elegant implementation, but should avoid breaking backward compatibility here
 | 
					      // This isn't the most elegant implementation, but should avoid breaking backward
 | 
				
			||||||
      //Can't infer InputType for CNN layers, however (don't know image dimensions/depth)
 | 
					      // compatibility here
 | 
				
			||||||
 | 
					      // Can't infer InputType for CNN layers, however (don't know image dimensions/depth)
 | 
				
			||||||
      LayerConfiguration firstLayer = getFlattenedLayerConfigurations().get(0);
 | 
					      LayerConfiguration firstLayer = getFlattenedLayerConfigurations().get(0);
 | 
				
			||||||
      if (firstLayer instanceof BaseRecurrentLayer) {
 | 
					      if (firstLayer instanceof BaseRecurrentLayer) {
 | 
				
			||||||
        BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer;
 | 
					        BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer;
 | 
				
			||||||
        val nIn = brl.getNIn();
 | 
					        val nIn = brl.getNIn();
 | 
				
			||||||
        if (nIn > 0) {
 | 
					        if (nIn > 0) {
 | 
				
			||||||
          setInputType( InputType.recurrent(nIn, brl.getDataFormat()));
 | 
					          setInputType(InputType.recurrent(nIn, brl.getDataFormat()));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      } else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer
 | 
					      } else if (firstLayer instanceof DenseLayer
 | 
				
			||||||
 | 
					          || firstLayer instanceof EmbeddingLayer
 | 
				
			||||||
          || firstLayer instanceof OutputLayer) {
 | 
					          || firstLayer instanceof OutputLayer) {
 | 
				
			||||||
        //Can't just use "instanceof FeedForwardLayer" here. ConvolutionLayer is also a FeedForwardLayer
 | 
					        // Can't just use "instanceof FeedForwardLayer" here. ConvolutionLayer is also a
 | 
				
			||||||
 | 
					        // FeedForwardLayer
 | 
				
			||||||
        FeedForwardLayer ffl = (FeedForwardLayer) firstLayer;
 | 
					        FeedForwardLayer ffl = (FeedForwardLayer) firstLayer;
 | 
				
			||||||
        val nIn = ffl.getNIn();
 | 
					        val nIn = ffl.getNIn();
 | 
				
			||||||
        if (nIn > 0) {
 | 
					        if (nIn > 0) {
 | 
				
			||||||
          setInputType( InputType.feedForward(nIn));
 | 
					          setInputType(InputType.feedForward(nIn));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //Add preprocessors and set nIns, if InputType has been set
 | 
					    // Add preprocessors and set nIns, if InputType has been set
 | 
				
			||||||
    // Builder.inputType field can be set in 1 of 4 ways:
 | 
					    // Builder.inputType field can be set in 1 of 4 ways:
 | 
				
			||||||
    // 1. User calls setInputType directly
 | 
					    // 1. User calls setInputType directly
 | 
				
			||||||
    // 2. Via ConvolutionLayerSetup -> internally calls setInputType(InputType.convolutional(...))
 | 
					    // 2. Via ConvolutionLayerSetup -> internally calls setInputType(InputType.convolutional(...))
 | 
				
			||||||
    // 3. Via the above code: i.e., assume input is as expected  by the RNN or dense layer -> sets the inputType field
 | 
					    // 3. Via the above code: i.e., assume input is as expected  by the RNN or dense layer -> sets
 | 
				
			||||||
    if(inputPreProcessors == null) {
 | 
					    // the inputType field
 | 
				
			||||||
 | 
					    if (inputPreProcessors == null) {
 | 
				
			||||||
      inputPreProcessors = new HashMap<>();
 | 
					      inputPreProcessors = new HashMap<>();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    if (getInputType() != null) {
 | 
					    if (getInputType() != null) {
 | 
				
			||||||
@ -572,10 +314,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
      for (int i = 0; i < getFlattenedLayerConfigurations().size(); i++) {
 | 
					      for (int i = 0; i < getFlattenedLayerConfigurations().size(); i++) {
 | 
				
			||||||
        LayerConfiguration l = getFlattenedLayerConfigurations().get(i);
 | 
					        LayerConfiguration l = getFlattenedLayerConfigurations().get(i);
 | 
				
			||||||
        if (inputPreProcessors.get(i) == null) {
 | 
					        if (inputPreProcessors.get(i) == null) {
 | 
				
			||||||
          //Don't override preprocessor setting, but set preprocessor if required...
 | 
					          // Don't override preprocessor setting, but set preprocessor if required...
 | 
				
			||||||
          @NonNull
 | 
					          @NonNull
 | 
				
			||||||
          InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
 | 
					          InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType);
 | 
				
			||||||
          if (inputPreProcessor != null) {
 | 
					          if (inputPreProcessor != null) {
 | 
				
			||||||
 | 
					            inputPreProcessors = new HashMap<>(inputPreProcessors);
 | 
				
			||||||
            inputPreProcessors.put(i, inputPreProcessor);
 | 
					            inputPreProcessors.put(i, inputPreProcessor);
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -586,41 +329,47 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
        if (i > 0) {
 | 
					        if (i > 0) {
 | 
				
			||||||
          LayerConfiguration layer = getFlattenedLayerConfigurations().get(i - 1);
 | 
					          LayerConfiguration layer = getFlattenedLayerConfigurations().get(i - 1);
 | 
				
			||||||
          //convolution 1d is an edge case where it has rnn input type but the filters
 | 
					          // convolution 1d is an edge case where it has rnn input type but the filters
 | 
				
			||||||
          //should be the output
 | 
					          // should be the output
 | 
				
			||||||
          if (layer instanceof Convolution1DLayer) {
 | 
					          if (layer instanceof Convolution1D || layer instanceof Convolution1DNew) {
 | 
				
			||||||
            if (l instanceof DenseLayer && getInputType() instanceof InputType.InputTypeRecurrent) {
 | 
					            if (l instanceof DenseLayer && getInputType() instanceof InputType.InputTypeRecurrent) {
 | 
				
			||||||
              FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l;
 | 
					              FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l;
 | 
				
			||||||
              if (getInputType() instanceof InputType.InputTypeRecurrent) {
 | 
					              if (getInputType() instanceof InputType.InputTypeRecurrent) {
 | 
				
			||||||
                InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) getInputType();
 | 
					                InputType.InputTypeRecurrent recurrent =
 | 
				
			||||||
 | 
					                    (InputType.InputTypeRecurrent) getInputType();
 | 
				
			||||||
                feedForwardLayer.setNIn(recurrent.getTimeSeriesLength());
 | 
					                feedForwardLayer.setNIn(recurrent.getTimeSeriesLength());
 | 
				
			||||||
              }
 | 
					              }
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
              l.setNIn(currentInputType,
 | 
					              l.setNIn(
 | 
				
			||||||
                  isOverrideNinUponBuild()); //Don't override the nIn setting, if it's manually set by the user
 | 
					                  currentInputType,
 | 
				
			||||||
 | 
					                  isOverrideNinUponBuild()); // Don't override the nIn setting, if it's manually set
 | 
				
			||||||
 | 
					                                             // by the user
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
          } else {
 | 
					          } else {
 | 
				
			||||||
            l.setNIn(currentInputType,
 | 
					            l.setNIn(
 | 
				
			||||||
                isOverrideNinUponBuild()); //Don't override the nIn setting, if it's manually set by the user
 | 
					                currentInputType,
 | 
				
			||||||
 | 
					                isOverrideNinUponBuild()); // Don't override the nIn setting, if it's manually set
 | 
				
			||||||
 | 
					                                           // by the user
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
          l.setNIn(currentInputType,
 | 
					          l.setNIn(
 | 
				
			||||||
              isOverrideNinUponBuild()); //Don't override the nIn setting, if it's manually set by the user
 | 
					              currentInputType,
 | 
				
			||||||
 | 
					              isOverrideNinUponBuild()); // Don't override the nIn setting, if it's manually set by
 | 
				
			||||||
 | 
					                                         // the user
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        currentInputType = l.getOutputType(i, currentInputType);
 | 
					        currentInputType = l.getOutputType(i, currentInputType);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Nd4j.getRandom().setSeed(getSeed());
 | 
					    Nd4j.getRandom().setSeed(getSeed());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //Validate output layer configuration
 | 
					    // Validate output layer configuration
 | 
				
			||||||
    if (isValidateOutputLayerConfig()) {
 | 
					    if (isValidateOutputLayerConfig()) {
 | 
				
			||||||
      //Validate output layer configurations...
 | 
					      // Validate output layer configurations...
 | 
				
			||||||
      for (LayerConfiguration n : getFlattenedLayerConfigurations()) {
 | 
					      for (LayerConfiguration n : getFlattenedLayerConfigurations()) {
 | 
				
			||||||
        OutputLayerUtil.validateOutputLayer(n.getName(), n); //No-op for non output/loss layers
 | 
					        OutputLayerUtil.validateOutputLayer(n.getName(), n); // No-op for non output/loss layers
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@ -646,26 +395,28 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
        layerName = String.valueOf(i);
 | 
					        layerName = String.valueOf(i);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      //Pass input type through preprocessor, if necessary
 | 
					      // Pass input type through preprocessor, if necessary
 | 
				
			||||||
      InputPreProcessor preproc = getInputPreProcess(i);
 | 
					      InputPreProcessor preproc = getInputPreProcess(i);
 | 
				
			||||||
      //TODO memory requirements for preprocessor
 | 
					      // TODO memory requirements for preprocessor
 | 
				
			||||||
      if (preproc != null) {
 | 
					      if (preproc != null) {
 | 
				
			||||||
        inputType = preproc.getOutputType(inputType);
 | 
					        inputType = preproc.getOutputType(inputType);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      LayerMemoryReport report = getFlattenedLayerConfigurations().get(i).getMemoryReport(inputType);
 | 
					      LayerMemoryReport report =
 | 
				
			||||||
 | 
					          getFlattenedLayerConfigurations().get(i).getMemoryReport(inputType);
 | 
				
			||||||
      memoryReportMap.put(layerName, report);
 | 
					      memoryReportMap.put(layerName, report);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      inputType = getFlattenedLayerConfigurations().get(i).getOutputType(i, inputType);
 | 
					      inputType = getFlattenedLayerConfigurations().get(i).getOutputType(i, inputType);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return new NetworkMemoryReport(memoryReportMap, NeuralNetConfiguration.class,
 | 
					    return new NetworkMemoryReport(
 | 
				
			||||||
        "MultiLayerNetwork", inputType);
 | 
					        memoryReportMap, NeuralNetConfiguration.class, "MultiLayerNetwork", inputType);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * For the given input shape/type for the network, return a list of activation sizes for each
 | 
					   * For the given input shape/type for the network, return a list of activation sizes for each
 | 
				
			||||||
   * layer in the network.<br> i.e., list.get(i) is the output activation sizes for layer i
 | 
					   * layer in the network.<br>
 | 
				
			||||||
 | 
					   * i.e., list.get(i) is the output activation sizes for layer i
 | 
				
			||||||
   *
 | 
					   *
 | 
				
			||||||
   * @param inputType Input type for the network
 | 
					   * @param inputType Input type for the network
 | 
				
			||||||
   * @return A lits of activation types for the network, indexed by layer number
 | 
					   * @return A lits of activation types for the network, indexed by layer number
 | 
				
			||||||
@ -699,38 +450,47 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
  public void addNetWideVariable(String variable) {
 | 
					  public void addNetWideVariable(String variable) {
 | 
				
			||||||
    if (!netWideVariables.contains(variable)) {
 | 
					    if (!netWideVariables.contains(variable)) {
 | 
				
			||||||
      netWideVariables.add(variable);
 | 
					      netWideVariables.add(variable);
 | 
				
			||||||
      log.trace("Adding neural network wide variable '{}' to the list of variables. New length is {}.", variable, netWideVariables.size());
 | 
					      log.trace(
 | 
				
			||||||
 | 
					          "Adding neural network wide variable '{}' to the list of variables. New length is {}.",
 | 
				
			||||||
 | 
					          variable,
 | 
				
			||||||
 | 
					          netWideVariables.size());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    log.trace("Skipped adding neural network wide variable '{}' to the list of variables. It was already present. Length remains {}.", variable, netWideVariables.size());
 | 
					    log.trace(
 | 
				
			||||||
 | 
					        "Skipped adding neural network wide variable '{}' to the list of variables. It was already present. Length remains {}.",
 | 
				
			||||||
 | 
					        variable,
 | 
				
			||||||
 | 
					        netWideVariables.size());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public void clearNetWideVariable() {
 | 
					  public void clearNetWideVariable() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    netWideVariables.clear();
 | 
					    netWideVariables.clear();
 | 
				
			||||||
    log.trace("Adding neural network wide variables have been cleared. New length is {}.", netWideVariables.size());
 | 
					    log.trace(
 | 
				
			||||||
 | 
					        "Adding neural network wide variables have been cleared. New length is {}.",
 | 
				
			||||||
 | 
					        netWideVariables.size());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * From the list of layers and neural net configurations, only return the Layer Configurations that
 | 
					   * From the list of layers and neural net configurations, only return the Layer Configurations
 | 
				
			||||||
   * are defined in this neural network (it does not include embedded neural network configuration
 | 
					   * that are defined in this neural network (it does not include embedded neural network
 | 
				
			||||||
   * layers)
 | 
					   * configuration layers)
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
   * @return list with layer configurations
 | 
					   * @return list with layer configurations
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  @JsonIgnore
 | 
					  @JsonIgnore
 | 
				
			||||||
  public List<LayerConfiguration> getLayerConfigurations() {
 | 
					  public List<LayerConfiguration> getLayerConfigurations() {
 | 
				
			||||||
    return innerConfigurations.stream()
 | 
					    return innerConfigurations.stream()
 | 
				
			||||||
        .filter(obj -> (obj instanceof LayerConfiguration))
 | 
					        .filter(obj -> (obj instanceof LayerConfiguration))
 | 
				
			||||||
        .map( obj -> (LayerConfiguration)obj )
 | 
					        .map(obj -> (LayerConfiguration) obj)
 | 
				
			||||||
        .collect( Collectors.toList());
 | 
					        .collect(Collectors.toList());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * From the list of layers and neural net configurations, only return the neural net configurations
 | 
					   * From the list of layers and neural net configurations, only return the neural net
 | 
				
			||||||
 | 
					   * configurations
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
   * @return list with neural net configurations
 | 
					   * @return list with neural net configurations
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  //@Synchronized("innerConfigurationsLock")
 | 
					  // @Synchronized("innerConfigurationsLock")
 | 
				
			||||||
  @JsonIgnore
 | 
					  @JsonIgnore
 | 
				
			||||||
  public List<NeuralNetConfiguration> getNetConfigurations() {
 | 
					  public List<NeuralNetConfiguration> getNetConfigurations() {
 | 
				
			||||||
    List<NeuralNetConfiguration> list;
 | 
					    List<NeuralNetConfiguration> list;
 | 
				
			||||||
@ -751,35 +511,47 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
   * @return list of layer configurations
 | 
					   * @return list of layer configurations
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public List<LayerConfiguration> getFlattenedLayerConfigurations(NeuralNetConfiguration conf) {
 | 
					  public List<LayerConfiguration> getFlattenedLayerConfigurations(NeuralNetConfiguration conf) {
 | 
				
			||||||
    List<LayerConfiguration> ret = new ArrayList<>(); //create the final return list
 | 
					    List<LayerConfiguration> ret = new ArrayList<>(); // create the final return list
 | 
				
			||||||
    //When properly initialized, _this_ configuration is set first in the list, however we
 | 
					    // When properly initialized, _this_ configuration is set first in the list, however we
 | 
				
			||||||
    //can find cases where this is not true, thus the first configuration is another net or layer configuration
 | 
					    // can find cases where this is not true, thus the first configuration is another net or layer
 | 
				
			||||||
    //and should not be skipped. In essence, skip first configuration if that is "this".
 | 
					    // configuration
 | 
				
			||||||
    //TODO: skipping not needed anymore as we removed _this_ from innerConfigurations
 | 
					    // and should not be skipped. In essence, skip first configuration if that is "this".
 | 
				
			||||||
 | 
					    // TODO: skipping not needed anymore as we removed _this_ from innerConfigurations
 | 
				
			||||||
    int iSkip = 0;
 | 
					    int iSkip = 0;
 | 
				
			||||||
    if(conf.getInnerConfigurations().size()>0 && conf.getInnerConfigurations().get(0).equals(this)) { iSkip=1;}
 | 
					    if (conf.getInnerConfigurations().size() > 0
 | 
				
			||||||
    conf.getInnerConfigurations().stream().skip(iSkip)
 | 
					        && conf.getInnerConfigurations().get(0).equals(this)) {
 | 
				
			||||||
            .forEach(obj -> {
 | 
					      iSkip = 1;
 | 
				
			||||||
              //if Layer Config, include in list and inherit parameters from this conf
 | 
					    }
 | 
				
			||||||
              //else if neural net configuration, call self recursively to resolve layer configurations
 | 
					    conf.getInnerConfigurations().stream()
 | 
				
			||||||
 | 
					        .skip(iSkip)
 | 
				
			||||||
 | 
					        .forEach(
 | 
				
			||||||
 | 
					            obj -> {
 | 
				
			||||||
 | 
					              // if Layer Config, include in list and inherit parameters from this conf
 | 
				
			||||||
 | 
					              // else if neural net configuration, call self recursively to resolve layer
 | 
				
			||||||
 | 
					              // configurations
 | 
				
			||||||
              if (obj instanceof LayerConfiguration) {
 | 
					              if (obj instanceof LayerConfiguration) {
 | 
				
			||||||
                ((LayerConfiguration) obj).setNetConfiguration(conf);
 | 
					                ((LayerConfiguration) obj).setNetConfiguration(conf);
 | 
				
			||||||
                ret.add((LayerConfiguration) obj);
 | 
					                ret.add((LayerConfiguration) obj);
 | 
				
			||||||
              } else if (obj instanceof NeuralNetConfiguration)
 | 
					              } else if (obj instanceof NeuralNetConfiguration)
 | 
				
			||||||
                ret.addAll(getFlattenedLayerConfigurations(
 | 
					                ret.addAll(getFlattenedLayerConfigurations((NeuralNetConfiguration) obj));
 | 
				
			||||||
                        (NeuralNetConfiguration) obj));
 | 
					 | 
				
			||||||
              else {
 | 
					              else {
 | 
				
			||||||
                log.error(
 | 
					                log.error(
 | 
				
			||||||
                    "The list of layers and neural network configurations does contain an object of {}. Element will be ignored.",
 | 
					                    "The list of layers and neural network configurations does contain an object of {}. Element will be ignored.",
 | 
				
			||||||
                    obj.getClass().getSimpleName());
 | 
					                    obj.getClass().getSimpleName());
 | 
				
			||||||
              }
 | 
					              }
 | 
				
			||||||
            });
 | 
					            });
 | 
				
			||||||
 | 
					    // make sure the indexes are sequenced properly
 | 
				
			||||||
 | 
					    AtomicInteger i = new AtomicInteger();
 | 
				
			||||||
 | 
					    ret.forEach(obj -> {
 | 
				
			||||||
 | 
					      obj.setIndex(i.getAndIncrement());
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
    return ret;
 | 
					    return ret;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Sames as {@link #getFlattenedLayerConfigurations(NeuralNetConfiguration)}, but uses this configurations
 | 
					   * Sames as {@link #getFlattenedLayerConfigurations(NeuralNetConfiguration)}, but uses this
 | 
				
			||||||
   * list of configurations
 | 
					   * configurations list of configurations
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
   * @return list of layer configurations
 | 
					   * @return list of layer configurations
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  @JsonIgnore
 | 
					  @JsonIgnore
 | 
				
			||||||
@ -789,6 +561,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Add a new layer to the first position
 | 
					   * Add a new layer to the first position
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
   * @param layer configuration
 | 
					   * @param layer configuration
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public void setLayer(@NonNull LayerConfiguration layer) {
 | 
					  public void setLayer(@NonNull LayerConfiguration layer) {
 | 
				
			||||||
@ -801,26 +574,28 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Deprecated, do not use. Workaround for old tests
 | 
					   * Deprecated, do not use. Workaround for old tests and getFlattenedLayerConfigurations().get(0);
 | 
				
			||||||
   * and getFlattenedLayerConfigurations().get(0);
 | 
					   *
 | 
				
			||||||
   * @return
 | 
					   * @return
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  @Deprecated @JsonIgnore
 | 
					  @Deprecated
 | 
				
			||||||
 | 
					  @JsonIgnore
 | 
				
			||||||
  public LayerConfiguration getFirstLayer() {
 | 
					  public LayerConfiguration getFirstLayer() {
 | 
				
			||||||
    log.warn("This getFirstLayer method is an ugly workaround and will be removed.");
 | 
					    log.warn("This getFirstLayer method is an ugly workaround and will be removed.");
 | 
				
			||||||
    return getFlattenedLayerConfigurations().get(0);
 | 
					    return getFlattenedLayerConfigurations().get(0);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /*
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    protected boolean canEqual(final Object other) {
 | 
					    protected boolean canEqual(final Object other) {
 | 
				
			||||||
      return other instanceof NeuralNetConfiguration;
 | 
					      return other instanceof NeuralNetConfiguration;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					  */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  public abstract static class NeuralNetConfigurationBuilder<
 | 
				
			||||||
  public static abstract class NeuralNetConfigurationBuilder<C extends NeuralNetConfiguration,
 | 
					          C extends NeuralNetConfiguration,
 | 
				
			||||||
      B extends NeuralNetConfiguration.NeuralNetConfigurationBuilder<C, B>> extends
 | 
					          B extends NeuralNetConfiguration.NeuralNetConfigurationBuilder<C, B>>
 | 
				
			||||||
      NeuralNetBaseBuilderConfigurationBuilder<C, B> {
 | 
					      extends NeuralNetBaseBuilderConfigurationBuilder<C, B> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public ComputationGraphConfiguration.GraphBuilder graphBuilder() {
 | 
					    public ComputationGraphConfiguration.GraphBuilder graphBuilder() {
 | 
				
			||||||
      return new ComputationGraphConfiguration.GraphBuilder(this);
 | 
					      return new ComputationGraphConfiguration.GraphBuilder(this);
 | 
				
			||||||
@ -829,10 +604,9 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
 | 
				
			|||||||
    public NeuralNetConfigurationBuilder clone() {
 | 
					    public NeuralNetConfigurationBuilder clone() {
 | 
				
			||||||
      try {
 | 
					      try {
 | 
				
			||||||
        return (NeuralNetConfigurationBuilder) super.clone();
 | 
					        return (NeuralNetConfigurationBuilder) super.clone();
 | 
				
			||||||
      } catch(CloneNotSupportedException ex) {
 | 
					      } catch (CloneNotSupportedException ex) {
 | 
				
			||||||
        throw new RuntimeException(ex);
 | 
					        throw new RuntimeException(ex);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -21,7 +21,13 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.conf;
 | 
					package org.deeplearning4j.nn.conf;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * N is the batch size<br/>
 | 
				
			||||||
 | 
					 * C is the number of feature maps (that is,, number of channels)<br/>
 | 
				
			||||||
 | 
					 * H is the image height (not used for 1D conv as this is an RNN format<br/>
 | 
				
			||||||
 | 
					 * W is the image width<br/>
 | 
				
			||||||
 | 
					 * **/
 | 
				
			||||||
public enum RNNFormat implements DataFormat {
 | 
					public enum RNNFormat implements DataFormat {
 | 
				
			||||||
    NCW,
 | 
					    /** n=batch size; c=channels/ features; w=width **/ NCW,
 | 
				
			||||||
    NWC
 | 
					    /** n=batch size; w=width; c=channels/ features **/ NWC
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -20,6 +20,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.conf.constraint;
 | 
					package org.deeplearning4j.nn.conf.constraint;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.HashSet;
 | 
				
			||||||
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					import java.util.Set;
 | 
				
			||||||
import lombok.*;
 | 
					import lombok.*;
 | 
				
			||||||
import org.apache.commons.lang3.ArrayUtils;
 | 
					import org.apache.commons.lang3.ArrayUtils;
 | 
				
			||||||
import org.deeplearning4j.nn.api.Layer;
 | 
					import org.deeplearning4j.nn.api.Layer;
 | 
				
			||||||
@ -27,11 +30,6 @@ import org.deeplearning4j.nn.api.ParamInitializer;
 | 
				
			|||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
 | 
					import org.deeplearning4j.nn.api.layers.LayerConstraint;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.util.HashSet;
 | 
					 | 
				
			||||||
import java.util.Map;
 | 
					 | 
				
			||||||
import java.util.Set;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@AllArgsConstructor
 | 
					@AllArgsConstructor
 | 
				
			||||||
@EqualsAndHashCode
 | 
					@EqualsAndHashCode
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
 | 
				
			|||||||
@ -20,6 +20,8 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.nn.conf.constraint;
 | 
					package org.deeplearning4j.nn.conf.constraint;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.Collections;
 | 
				
			||||||
 | 
					import java.util.Set;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
import lombok.EqualsAndHashCode;
 | 
					import lombok.EqualsAndHashCode;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
@ -27,9 +29,6 @@ import org.nd4j.linalg.factory.Broadcast;
 | 
				
			|||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
 | 
					import org.nd4j.linalg.indexing.BooleanIndexing;
 | 
				
			||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
 | 
					import org.nd4j.linalg.indexing.conditions.Conditions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.util.Collections;
 | 
					 | 
				
			||||||
import java.util.Set;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
@EqualsAndHashCode(callSuper = true)
 | 
					@EqualsAndHashCode(callSuper = true)
 | 
				
			||||||
public class MaxNormConstraint extends BaseConstraint {
 | 
					public class MaxNormConstraint extends BaseConstraint {
 | 
				
			||||||
 | 
				
			|||||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user