2019-06-06 15:21:15 +03:00

399 lines
16 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.misc;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.weightnoise.DropConnect;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.ExponentialSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import static org.junit.Assert.assertEquals;
public class TestLrChanges extends BaseDL4JTest {
@Test
public void testChangeLrMLN(){
//First: Set LR for a *single* layer and compare vs. equivalent net config
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new Adam(0.1)).build())
.layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new RmsProp(0.01)).build())
.layer(new OutputLayer.Builder().nIn(10).nOut(10).updater(new NoOp()).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for( int i=0; i<10; i++ ){
net.fit(Nd4j.rand(10,10), Nd4j.rand(10,10));
}
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new Adam(0.5)).build()) //0.5 LR
.layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new RmsProp(0.01)).build())
.layer(new OutputLayer.Builder().nIn(10).nOut(10).updater(new NoOp()).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf2.setIterationCount(conf.getIterationCount());
net2.setParams(net.params().dup());
assertEquals(0.1, net.getLearningRate(0).doubleValue(), 0.0);
net.setLearningRate(0, 0.5); //Set LR for layer 0 to 0.5
assertEquals(0.5, net.getLearningRate(0).doubleValue(), 0.0);
assertEquals(conf, conf2);
assertEquals(conf.toJson(), conf2.toJson());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
//Perform some parameter updates - check things are actually in sync...
for( int i=0; i<3; i++ ){
INDArray in = Nd4j.rand(10, 10);
INDArray l = Nd4j.rand(10, 10);
net.fit(in, l);
net2.fit(in, l);
}
assertEquals(net.params(), net2.params());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
INDArray in1 = Nd4j.rand(10, 10);
INDArray l1 = Nd4j.rand(10, 10);
net.setInput(in1);
net.setLabels(l1);
net.computeGradientAndScore();
net2.setInput(in1);
net2.setLabels(l1);
net2.computeGradientAndScore();
assertEquals(net.score(), net2.score(), 1e-8);
//Now: Set *all* LRs to say 0.3...
MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new Adam(0.3)).build()) //0.5 LR
.layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new RmsProp(0.3)).build())
.layer(new OutputLayer.Builder().nIn(10).nOut(10).updater(new NoOp()).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
net3.init();
net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf3.setIterationCount(conf.getIterationCount());
net3.setParams(net.params().dup());
net.setLearningRate(0.3);
//Perform some parameter updates - check things are actually in sync...
for( int i=0; i<3; i++ ){
INDArray in = Nd4j.rand(10, 10);
INDArray l = Nd4j.rand(10, 10);
net.fit(in, l);
net3.fit(in, l);
}
assertEquals(net.params(), net3.params());
assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
}
@Test
public void testChangeLSGD() {
//Simple test for no updater nets
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.updater(new Sgd(0.1))
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(new OutputLayer.Builder().nIn(10).nOut(10).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setLearningRate(1.0);
net.setLearningRate(1, 0.5);
assertEquals(1.0, net.getLearningRate(0), 0.0);
assertEquals(0.5, net.getLearningRate(1), 0.0);
ComputationGraph cg = net.toComputationGraph();
cg.setLearningRate(2.0);
cg.setLearningRate("1", 2.5);
assertEquals(2.0, cg.getLearningRate("0"), 0.0);
assertEquals(2.5, cg.getLearningRate("1"), 0.0);
}
@Test
public void testChangeLrMLNSchedule(){
//First: Set LR for a *single* layer and compare vs. equivalent net config
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.updater(new Adam(0.1))
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for( int i=0; i<10; i++ ){
net.fit(Nd4j.rand(10,10), Nd4j.rand(10,10));
}
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.updater(new Adam(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )))
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build())
.build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf2.setIterationCount(conf.getIterationCount());
net2.setParams(net.params().dup());
net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5
assertEquals(conf, conf2);
assertEquals(conf.toJson(), conf2.toJson());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
//Perform some parameter updates - check things are actually in sync...
for( int i=0; i<3; i++ ){
INDArray in = Nd4j.rand(10, 10);
INDArray l = Nd4j.rand(10, 10);
net.fit(in, l);
net2.fit(in, l);
}
assertEquals(net.params(), net2.params());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
}
@Test
public void testChangeLrCompGraph(){
//First: Set LR for a *single* layer and compare vs. equivalent net config
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).updater(new Adam(0.1)).build(), "in")
.addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).updater(new RmsProp(0.01)).build(), "0")
.addLayer("2", new OutputLayer.Builder().nIn(10).nOut(10).updater(new NoOp()).lossFunction(LossFunctions.LossFunction.MSE).build(), "1")
.setOutputs("2")
.build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
for( int i=0; i<10; i++ ){
net.fit(new DataSet(Nd4j.rand(10,10), Nd4j.rand(10,10)));
}
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).updater(new Adam(0.5)).build(), "in") //0.5 LR
.addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).updater(new RmsProp(0.01)).build(), "0")
.addLayer("2", new OutputLayer.Builder().nIn(10).nOut(10).updater(new NoOp()).lossFunction(LossFunctions.LossFunction.MSE).build(), "1")
.setOutputs("2")
.build();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf2.setIterationCount(conf.getIterationCount());
net2.setParams(net.params().dup());
assertEquals(0.1, net.getLearningRate("0").doubleValue(), 0.0);
net.setLearningRate("0", 0.5); //Set LR for layer 0 to 0.5
assertEquals(0.5, net.getLearningRate("0").doubleValue(), 0.0);
assertEquals(conf, conf2);
assertEquals(conf.toJson(), conf2.toJson());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
//Perform some parameter updates - check things are actually in sync...
for( int i=0; i<3; i++ ){
INDArray in = Nd4j.rand(10, 10);
INDArray l = Nd4j.rand(10, 10);
net.fit(new DataSet(in, l));
net2.fit(new DataSet(in, l));
}
assertEquals(net.params(), net2.params());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
INDArray in1 = Nd4j.rand(10, 10);
INDArray l1 = Nd4j.rand(10, 10);
net.setInputs(in1);
net.setLabels(l1);
net.computeGradientAndScore();
net2.setInputs(in1);
net2.setLabels(l1);
net2.computeGradientAndScore();
assertEquals(net.score(), net2.score(), 1e-8);
//Now: Set *all* LRs to say 0.3...
MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new Adam(0.3)).build()) //0.5 LR
.layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new RmsProp(0.3)).build())
.layer(new OutputLayer.Builder().nIn(10).nOut(10).updater(new NoOp()).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
net3.init();
net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf3.setIterationCount(conf.getIterationCount());
net3.setParams(net.params().dup());
net.setLearningRate(0.3);
//Perform some parameter updates - check things are actually in sync...
for( int i=0; i<3; i++ ){
INDArray in = Nd4j.rand(10, 10);
INDArray l = Nd4j.rand(10, 10);
net.fit(new DataSet(in, l));
net3.fit(new DataSet(in, l));
}
assertEquals(net.params(), net3.params());
assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
}
@Test
public void testChangeLrCompGraphSchedule(){
//First: Set LR for a *single* layer and compare vs. equivalent net config
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.updater(new Adam(0.1))
.graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in")
.addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).build(), "0")
.addLayer("2", new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1")
.setOutputs("2")
.build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
for( int i=0; i<10; i++ ){
net.fit(new DataSet(Nd4j.rand(10,10), Nd4j.rand(10,10)));
}
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.seed(12345)
.updater(new Adam(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )))
.graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in")
.addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).build(), "0")
.layer("2", new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1")
.setOutputs("2")
.build();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf2.setIterationCount(conf.getIterationCount());
net2.setParams(net.params().dup());
net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5
assertEquals(conf, conf2);
assertEquals(conf.toJson(), conf2.toJson());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
//Perform some parameter updates - check things are actually in sync...
for( int i=0; i<3; i++ ){
INDArray in = Nd4j.rand(10, 10);
INDArray l = Nd4j.rand(10, 10);
net.fit(new DataSet(in, l));
net2.fit(new DataSet(in, l));
}
assertEquals(net.params(), net2.params());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
}
}