2021-02-01 14:31:20 +09:00
|
|
|
/*
|
|
|
|
* ******************************************************************************
|
|
|
|
* *
|
|
|
|
* *
|
|
|
|
* * This program and the accompanying materials are made available under the
|
|
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
* *
|
2021-02-01 17:47:29 +09:00
|
|
|
* * See the NOTICE file distributed with this work for additional
|
|
|
|
* * information regarding copyright ownership.
|
2021-02-01 14:31:20 +09:00
|
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* * License for the specific language governing permissions and limitations
|
|
|
|
* * under the License.
|
|
|
|
* *
|
|
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
|
|
* *****************************************************************************
|
|
|
|
*/
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
package org.deeplearning4j.nn.conf;
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
|
|
import static org.junit.jupiter.api.Assertions.assertNotSame;
|
|
|
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
|
|
import static org.junit.jupiter.api.Assertions.fail;
|
|
|
|
|
|
|
|
import java.io.BufferedInputStream;
|
|
|
|
import java.io.BufferedOutputStream;
|
|
|
|
import java.io.File;
|
|
|
|
import java.io.FileInputStream;
|
|
|
|
import java.io.FileOutputStream;
|
|
|
|
import java.util.Arrays;
|
|
|
|
import java.util.Properties;
|
2020-04-23 01:36:49 +03:00
|
|
|
import lombok.extern.slf4j.Slf4j;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
|
|
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
|
|
|
import org.deeplearning4j.nn.api.Layer;
|
|
|
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
|
|
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
|
|
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
2023-03-23 17:39:00 +01:00
|
|
|
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
|
|
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
|
|
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
|
|
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
|
|
|
|
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
|
|
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
|
|
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
|
|
|
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
|
|
|
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
|
|
|
import org.deeplearning4j.nn.conf.weightnoise.DropConnect;
|
|
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
|
|
import org.deeplearning4j.nn.weights.WeightInit;
|
|
|
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
2021-03-15 13:02:01 +09:00
|
|
|
import org.junit.jupiter.api.Test;
|
|
|
|
import org.junit.jupiter.api.io.TempDir;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.linalg.activations.Activation;
|
|
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
|
|
import org.nd4j.linalg.learning.config.Adam;
|
|
|
|
import org.nd4j.linalg.learning.config.NoOp;
|
|
|
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2020-04-23 01:36:49 +03:00
|
|
|
@Slf4j
|
2022-09-20 15:40:53 +02:00
|
|
|
public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
|
2019-06-06 15:21:15 +03:00
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
@TempDir
|
|
|
|
public File testDir;
|
|
|
|
|
|
|
|
private static NeuralNetConfiguration getConf() {
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345L)
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(0, DenseLayer.builder().nIn(2).nOut(2)
|
2023-03-23 17:39:00 +01:00
|
|
|
.dist(new NormalDistribution(0, 1)).build())
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(1, OutputLayer.builder().nIn(2).nOut(1)
|
2023-03-23 17:39:00 +01:00
|
|
|
.activation(Activation.TANH)
|
|
|
|
.dist(new NormalDistribution(0, 1)).lossFunction(LossFunctions.LossFunction.MSE)
|
|
|
|
.build())
|
|
|
|
.build();
|
|
|
|
return conf;
|
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testJson() throws Exception {
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(0, DenseLayer.builder().dist(new NormalDistribution(1, 1e-1)).build())
|
2023-03-23 17:39:00 +01:00
|
|
|
.inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build();
|
|
|
|
|
|
|
|
String json = conf.toJson();
|
|
|
|
NeuralNetConfiguration from = NeuralNetConfiguration.fromJson(json);
|
|
|
|
assertEquals(conf.getConf(0), from.getConf(0));
|
|
|
|
|
|
|
|
Properties props = new Properties();
|
|
|
|
props.put("json", json);
|
|
|
|
String key = props.getProperty("json");
|
|
|
|
assertEquals(json, key);
|
|
|
|
File f = new File(testDir, "props");
|
|
|
|
f.deleteOnExit();
|
|
|
|
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f));
|
|
|
|
props.store(bos, "");
|
|
|
|
bos.flush();
|
|
|
|
bos.close();
|
|
|
|
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));
|
|
|
|
Properties props2 = new Properties();
|
|
|
|
props2.load(bis);
|
|
|
|
bis.close();
|
|
|
|
assertEquals(props2.getProperty("json"), props.getProperty("json"));
|
|
|
|
String json2 = props2.getProperty("json");
|
|
|
|
NeuralNetConfiguration conf3 = NeuralNetConfiguration.fromJson(json2);
|
|
|
|
assertEquals(conf.getConf(0), conf3.getConf(0));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testConvnetJson() {
|
|
|
|
final int numRows = 76;
|
|
|
|
final int numColumns = 76;
|
|
|
|
int nChannels = 3;
|
|
|
|
int outputNum = 6;
|
|
|
|
int seed = 123;
|
|
|
|
|
|
|
|
//setup the network
|
|
|
|
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed)
|
|
|
|
.l1(1e-1).l2(2e-4).weightNoise(new DropConnect(0.5)).miniBatch(true)
|
|
|
|
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
|
|
|
|
.layer(0,
|
2023-04-24 18:09:11 +02:00
|
|
|
ConvolutionLayer.builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER)
|
2023-03-23 17:39:00 +01:00
|
|
|
.activation(Activation.RELU).build())
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(1, SubsamplingLayer.builder(SubsamplingLayer.PoolingType.MAX, new int[]{2, 2})
|
2023-03-23 17:39:00 +01:00
|
|
|
.build())
|
|
|
|
.layer(2,
|
2023-04-24 18:09:11 +02:00
|
|
|
ConvolutionLayer.builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER)
|
2023-03-23 17:39:00 +01:00
|
|
|
.activation(Activation.RELU).build())
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(3, SubsamplingLayer.builder(SubsamplingLayer.PoolingType.MAX, new int[]{2, 2})
|
2023-03-23 17:39:00 +01:00
|
|
|
.build())
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(4, DenseLayer.builder().nOut(100).activation(Activation.RELU).build())
|
|
|
|
.layer(5, OutputLayer.builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
2023-03-23 17:39:00 +01:00
|
|
|
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
|
|
|
|
.build())
|
|
|
|
|
|
|
|
.inputType(InputType.convolutional(numRows, numColumns, nChannels));
|
|
|
|
|
|
|
|
NeuralNetConfiguration conf = builder.build();
|
|
|
|
String json = conf.toJson();
|
|
|
|
NeuralNetConfiguration conf2 = NeuralNetConfiguration.fromJson(json);
|
|
|
|
assertEquals(conf, conf2);
|
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testUpsamplingConvnetJson() {
|
|
|
|
final int numRows = 76;
|
|
|
|
final int numColumns = 76;
|
|
|
|
int nChannels = 3;
|
|
|
|
int outputNum = 6;
|
|
|
|
int seed = 123;
|
|
|
|
|
|
|
|
//setup the network
|
|
|
|
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed)
|
|
|
|
.l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true)
|
|
|
|
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(ConvolutionLayer.builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER)
|
2023-03-23 17:39:00 +01:00
|
|
|
.activation(Activation.RELU).build())
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(Upsampling2D.builder().size(2).build())
|
2023-03-23 17:39:00 +01:00
|
|
|
.layer(2,
|
2023-04-24 18:09:11 +02:00
|
|
|
ConvolutionLayer.builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER)
|
2023-03-23 17:39:00 +01:00
|
|
|
.activation(Activation.RELU).build())
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(Upsampling2D.builder().size(2).build())
|
|
|
|
.layer(4, DenseLayer.builder().nOut(100).activation(Activation.RELU).build())
|
|
|
|
.layer(5, OutputLayer.builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
2023-03-23 17:39:00 +01:00
|
|
|
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
|
|
|
|
.build())
|
|
|
|
|
|
|
|
.inputType(InputType.convolutional(numRows, numColumns, nChannels));
|
|
|
|
|
|
|
|
NeuralNetConfiguration conf = builder.build();
|
|
|
|
String json = conf.toJson();
|
|
|
|
NeuralNetConfiguration conf2 = NeuralNetConfiguration.fromJson(json);
|
|
|
|
assertEquals(conf, conf2);
|
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testGlobalPoolingJson() {
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new NoOp())
|
|
|
|
.dist(new NormalDistribution(0, 1.0)).seed(12345L)
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(0, ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nOut(5).build())
|
|
|
|
.layer(1, GlobalPoolingLayer.builder().poolingType(PoolingType.PNORM).pnorm(3).build())
|
|
|
|
.layer(2, OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
2023-03-23 17:39:00 +01:00
|
|
|
.activation(Activation.SOFTMAX).nOut(3).build())
|
|
|
|
.inputType(InputType.convolutional(32, 32, 1)).build();
|
|
|
|
|
|
|
|
String str = conf.toJson();
|
|
|
|
NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(str);
|
|
|
|
|
|
|
|
assertEquals(conf, fromJson);
|
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testYaml() throws Exception {
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(0, DenseLayer.builder().dist(new NormalDistribution(1, 1e-1)).build())
|
2023-03-23 17:39:00 +01:00
|
|
|
.inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build();
|
|
|
|
String json = conf.toYaml();
|
|
|
|
NeuralNetConfiguration from = NeuralNetConfiguration.fromYaml(json);
|
|
|
|
assertEquals(conf.getConf(0), from.getConf(0));
|
|
|
|
|
|
|
|
Properties props = new Properties();
|
|
|
|
props.put("json", json);
|
|
|
|
String key = props.getProperty("json");
|
|
|
|
assertEquals(json, key);
|
|
|
|
File f = new File(testDir, "props");
|
|
|
|
f.deleteOnExit();
|
|
|
|
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f));
|
|
|
|
props.store(bos, "");
|
|
|
|
bos.flush();
|
|
|
|
bos.close();
|
|
|
|
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));
|
|
|
|
Properties props2 = new Properties();
|
|
|
|
props2.load(bis);
|
|
|
|
bis.close();
|
|
|
|
assertEquals(props2.getProperty("json"), props.getProperty("json"));
|
|
|
|
String yaml = props2.getProperty("json");
|
|
|
|
NeuralNetConfiguration conf3 = NeuralNetConfiguration.fromYaml(yaml);
|
|
|
|
assertEquals(conf.getConf(0), conf3.getConf(0));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testClone() {
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(0, DenseLayer.builder().build())
|
|
|
|
.layer(1, OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).build())
|
2023-03-23 17:39:00 +01:00
|
|
|
.inputPreProcessor(1, new CnnToFeedForwardPreProcessor()).build();
|
|
|
|
|
|
|
|
NeuralNetConfiguration conf2 = conf.clone();
|
|
|
|
|
|
|
|
assertEquals(conf, conf2);
|
|
|
|
assertNotSame(conf, conf2);
|
|
|
|
assertNotSame(conf.getNetConfigurations(), conf2.getNetConfigurations());
|
|
|
|
for (int i = 0; i < conf.getNetConfigurations().size(); i++) {
|
|
|
|
assertNotSame(conf.getConf(i), conf2.getConf(i));
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
2023-03-23 17:39:00 +01:00
|
|
|
assertNotSame(conf.getInputPreProcessors(), conf2.getInputPreProcessors());
|
|
|
|
for (Integer layer : conf.getInputPreProcessors().keySet()) {
|
|
|
|
assertNotSame(conf.getInputPreProcess(layer), conf2.getInputPreProcess(layer));
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
2023-03-23 17:39:00 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testRandomWeightInit() {
|
|
|
|
MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
|
|
|
|
model1.init();
|
|
|
|
|
|
|
|
Nd4j.getRandom().setSeed(12345L);
|
|
|
|
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
|
|
|
|
model2.init();
|
|
|
|
|
|
|
|
float[] p1 = model1.getModelParams().data().asFloat();
|
|
|
|
float[] p2 = model2.getModelParams().data().asFloat();
|
|
|
|
System.out.println(Arrays.toString(p1));
|
|
|
|
System.out.println(Arrays.toString(p2));
|
|
|
|
|
|
|
|
org.junit.jupiter.api.Assertions.assertArrayEquals(p1, p2, 0.0f);
|
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testTrainingListener() {
|
|
|
|
MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
|
|
|
|
model1.init();
|
|
|
|
model1.addTrainingListeners(new ScoreIterationListener(1));
|
|
|
|
|
|
|
|
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
|
|
|
|
model2.addTrainingListeners(new ScoreIterationListener(1));
|
|
|
|
model2.init();
|
|
|
|
|
|
|
|
Layer[] l1 = model1.getLayers();
|
|
|
|
for (int i = 0; i < l1.length; i++) {
|
|
|
|
assertTrue(l1[i].getTrainingListeners() != null && l1[i].getTrainingListeners().size() == 1);
|
|
|
|
}
|
|
|
|
|
|
|
|
Layer[] l2 = model2.getLayers();
|
|
|
|
for (int i = 0; i < l2.length; i++) {
|
|
|
|
assertTrue(l2[i].getTrainingListeners() != null && l2[i].getTrainingListeners().size() == 1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testInvalidConfig() {
|
|
|
|
|
|
|
|
try {
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345)
|
|
|
|
.build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
|
|
net.init();
|
|
|
|
fail("No exception thrown for invalid configuration");
|
|
|
|
} catch (IllegalStateException e) {
|
|
|
|
//OK
|
|
|
|
log.error("", e);
|
|
|
|
} catch (Throwable e) {
|
|
|
|
log.error("", e);
|
|
|
|
fail("Unexpected exception thrown for invalid config");
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
try {
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345)
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(1, DenseLayer.builder().nIn(3).nOut(4).build())
|
|
|
|
.layer(2, OutputLayer.builder().nIn(4).nOut(5).build())
|
2023-03-23 17:39:00 +01:00
|
|
|
.build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
|
|
net.init();
|
|
|
|
fail("No exception thrown for invalid configuration");
|
|
|
|
} catch (IllegalStateException e) {
|
|
|
|
//OK
|
|
|
|
log.info(e.toString());
|
|
|
|
} catch (Throwable e) {
|
|
|
|
log.error("", e);
|
|
|
|
fail("Unexpected exception thrown for invalid config");
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
try {
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345)
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(0, DenseLayer.builder().nIn(3).nOut(4).build())
|
|
|
|
.layer(2, OutputLayer.builder().nIn(4).nOut(5).build())
|
2023-03-23 17:39:00 +01:00
|
|
|
.build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
|
|
net.init();
|
|
|
|
fail("No exception thrown for invalid configuration");
|
|
|
|
} catch (IllegalStateException e) {
|
|
|
|
//OK
|
|
|
|
log.info(e.toString());
|
|
|
|
} catch (Throwable e) {
|
|
|
|
log.error("", e);
|
|
|
|
fail("Unexpected exception thrown for invalid config");
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
2023-03-23 17:39:00 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testListOverloads() {
|
|
|
|
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345)
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(0, DenseLayer.builder().nIn(3).nOut(4).build())
|
|
|
|
.layer(1, OutputLayer.builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build())
|
2023-03-23 17:39:00 +01:00
|
|
|
.build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
|
|
net.init();
|
|
|
|
|
|
|
|
DenseLayer dl = (DenseLayer) conf.getConf(0).getLayer();
|
|
|
|
assertEquals(3, dl.getNIn());
|
|
|
|
assertEquals(4, dl.getNOut());
|
|
|
|
OutputLayer ol = (OutputLayer) conf.getConf(1).getLayer();
|
|
|
|
assertEquals(4, ol.getNIn());
|
|
|
|
assertEquals(5, ol.getNOut());
|
|
|
|
|
|
|
|
NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345)
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(0, DenseLayer.builder().nIn(3).nOut(4).build())
|
|
|
|
.layer(1, OutputLayer.builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build())
|
2023-03-23 17:39:00 +01:00
|
|
|
.build();
|
|
|
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
|
|
|
net2.init();
|
|
|
|
|
|
|
|
NeuralNetConfiguration conf3 = NeuralNetConfiguration.builder().seed(12345)
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(DenseLayer.builder().nIn(3).nOut(4).build())
|
2023-03-23 17:39:00 +01:00
|
|
|
.layer(
|
2023-04-24 18:09:11 +02:00
|
|
|
OutputLayer.builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build())
|
2023-03-23 17:39:00 +01:00
|
|
|
.build();
|
|
|
|
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
|
|
|
|
net3.init();
|
|
|
|
|
|
|
|
assertEquals(conf, conf2);
|
|
|
|
assertEquals(conf, conf3);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testBiasLr() {
|
|
|
|
//setup the network
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345)
|
|
|
|
.updater(new Adam(1e-2))
|
|
|
|
.biasUpdater(new Adam(0.5))
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(0, ConvolutionLayer.builder(5, 5).nOut(5).weightInit(WeightInit.XAVIER)
|
2023-03-23 17:39:00 +01:00
|
|
|
.activation(Activation.RELU).build())
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(1, DenseLayer.builder().nOut(100).activation(Activation.RELU).build())
|
|
|
|
.layer(2, DenseLayer.builder().nOut(100).activation(Activation.RELU).build())
|
|
|
|
.layer(3, OutputLayer.builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10)
|
2023-03-23 17:39:00 +01:00
|
|
|
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build())
|
|
|
|
.inputType(InputType.convolutional(28, 28, 1)).build();
|
|
|
|
|
2023-04-24 18:09:11 +02:00
|
|
|
conf.init();
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
BaseLayerConfiguration l0 = (BaseLayerConfiguration) conf.getConf(0).getLayer();
|
|
|
|
BaseLayerConfiguration l1 = (BaseLayerConfiguration) conf.getConf(1).getLayer();
|
|
|
|
BaseLayerConfiguration l2 = (BaseLayerConfiguration) conf.getConf(2).getLayer();
|
|
|
|
BaseLayerConfiguration l3 = (BaseLayerConfiguration) conf.getConf(3).getLayer();
|
|
|
|
|
|
|
|
assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6);
|
|
|
|
assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6);
|
|
|
|
|
|
|
|
assertEquals(0.5, ((Adam) l1.getUpdaterByParam("b")).getLearningRate(), 1e-6);
|
|
|
|
assertEquals(1e-2, ((Adam) l1.getUpdaterByParam("W")).getLearningRate(), 1e-6);
|
|
|
|
|
|
|
|
assertEquals(0.5, ((Adam) l2.getUpdaterByParam("b")).getLearningRate(), 1e-6);
|
|
|
|
assertEquals(1e-2, ((Adam) l2.getUpdaterByParam("W")).getLearningRate(), 1e-6);
|
|
|
|
|
|
|
|
assertEquals(0.5, ((Adam) l3.getUpdaterByParam("b")).getLearningRate(), 1e-6);
|
|
|
|
assertEquals(1e-2, ((Adam) l3.getUpdaterByParam("W")).getLearningRate(), 1e-6);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@Test
|
|
|
|
public void testInvalidOutputLayer() {
|
2019-06-06 15:21:15 +03:00
|
|
|
/*
|
|
|
|
Test case (invalid configs)
|
|
|
|
1. nOut=1 + softmax
|
|
|
|
2. mcxent + tanh
|
|
|
|
3. xent + softmax
|
|
|
|
4. xent + relu
|
|
|
|
5. mcxent + sigmoid
|
|
|
|
*/
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
LossFunctions.LossFunction[] lf = new LossFunctions.LossFunction[]{
|
|
|
|
LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.MCXENT,
|
|
|
|
LossFunctions.LossFunction.XENT,
|
|
|
|
LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.MCXENT};
|
|
|
|
int[] nOut = new int[]{1, 3, 3, 3, 3};
|
|
|
|
Activation[] activations = new Activation[]{Activation.SOFTMAX, Activation.TANH,
|
|
|
|
Activation.SOFTMAX, Activation.RELU, Activation.SIGMOID};
|
|
|
|
for (int i = 0; i < lf.length; i++) {
|
|
|
|
for (boolean lossLayer : new boolean[]{false, true}) {
|
|
|
|
for (boolean validate : new boolean[]{true, false}) {
|
|
|
|
String s =
|
|
|
|
"nOut=" + nOut[i] + ",lossFn=" + lf[i] + ",lossLayer=" + lossLayer + ",validate="
|
|
|
|
+ validate;
|
|
|
|
if (nOut[i] == 1 && lossLayer) {
|
|
|
|
continue; //nOuts are not availabel in loss layer, can't expect it to detect this case
|
|
|
|
}
|
|
|
|
try {
|
|
|
|
NeuralNetConfiguration.builder()
|
|
|
|
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(DenseLayer.builder().nIn(10).nOut(10).build())
|
|
|
|
.layer(!lossLayer ? OutputLayer.builder().nIn(10).nOut(nOut[i])
|
2023-03-23 17:39:00 +01:00
|
|
|
.activation(activations[i]).lossFunction(lf[i]).build()
|
2023-04-24 18:09:11 +02:00
|
|
|
: LossLayer.builder().activation(activations[i]).lossFunction(lf[i].getILossFunction())
|
2023-03-23 17:39:00 +01:00
|
|
|
.build())
|
|
|
|
.validateOutputLayerConfig(validate)
|
|
|
|
.build();
|
|
|
|
if (validate) {
|
|
|
|
fail("Expected exception: " + s);
|
|
|
|
}
|
|
|
|
} catch (DL4JInvalidConfigException e) {
|
|
|
|
if (validate) {
|
|
|
|
assertTrue(e.getMessage().toLowerCase().contains("invalid output"), s);
|
|
|
|
} else {
|
|
|
|
fail("Validation should not be enabled");
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
2023-03-23 17:39:00 +01:00
|
|
|
}
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
2023-03-23 17:39:00 +01:00
|
|
|
}
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
2023-03-23 17:39:00 +01:00
|
|
|
}
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|