2021-02-01 14:31:20 +09:00
|
|
|
/*
|
|
|
|
|
* ******************************************************************************
|
|
|
|
|
* *
|
|
|
|
|
* *
|
|
|
|
|
* * This program and the accompanying materials are made available under the
|
|
|
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
|
* *
|
2021-02-01 17:47:29 +09:00
|
|
|
* * See the NOTICE file distributed with this work for additional
|
|
|
|
|
* * information regarding copyright ownership.
|
2021-02-01 14:31:20 +09:00
|
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
|
* * License for the specific language governing permissions and limitations
|
|
|
|
|
* * under the License.
|
|
|
|
|
* *
|
|
|
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
* *****************************************************************************
|
|
|
|
|
*/
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
package org.deeplearning4j.optimize.solver;
|
|
|
|
|
|
|
|
|
|
import lombok.val;
|
|
|
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
|
|
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
|
|
|
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
|
|
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
|
|
|
import org.deeplearning4j.nn.layers.OutputLayer;
|
|
|
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
|
|
|
import org.deeplearning4j.nn.weights.WeightInit;
|
|
|
|
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
|
|
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
|
|
|
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
|
|
|
|
import org.deeplearning4j.optimize.solvers.BackTrackLineSearch;
|
|
|
|
|
import org.deeplearning4j.optimize.stepfunctions.DefaultStepFunction;
|
|
|
|
|
import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction;
|
2021-03-15 13:02:01 +09:00
|
|
|
import org.junit.jupiter.api.BeforeEach;
|
|
|
|
|
import org.junit.jupiter.api.Test;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.linalg.activations.Activation;
|
|
|
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
|
|
|
import org.nd4j.linalg.dataset.DataSet;
|
|
|
|
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
|
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
|
|
|
import org.nd4j.linalg.learning.config.Adam;
|
|
|
|
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
import java.util.Collections;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2021-03-15 13:02:01 +09:00
|
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
|
|
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @author Adam Gibson
|
|
|
|
|
*/
|
2022-09-20 15:40:53 +02:00
|
|
|
public class BackTrackLineSearchTest extends BaseDL4JTest {
|
2019-06-06 15:21:15 +03:00
|
|
|
private DataSetIterator irisIter;
|
|
|
|
|
private DataSet irisData;
|
|
|
|
|
|
2021-03-15 13:02:01 +09:00
|
|
|
@BeforeEach
|
2022-09-20 15:40:53 +02:00
|
|
|
public void before() {
|
2019-06-06 15:21:15 +03:00
|
|
|
if (irisIter == null) {
|
|
|
|
|
irisIter = new IrisDataSetIterator(5, 5);
|
|
|
|
|
}
|
|
|
|
|
if (irisData == null) {
|
|
|
|
|
irisData = irisIter.next();
|
|
|
|
|
irisData.normalizeZeroMeanZeroUnitVariance();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testSingleMinLineSearch() throws Exception {
|
|
|
|
|
OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100,
|
|
|
|
|
LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD);
|
|
|
|
|
int nParams = (int)layer.numParams();
|
2019-06-06 15:21:15 +03:00
|
|
|
layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams));
|
|
|
|
|
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
|
|
|
|
layer.setLabels(irisData.getLabels());
|
|
|
|
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, layer.getOptimizer());
|
2023-03-23 17:39:00 +01:00
|
|
|
double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
assertEquals(1.0, step, 1e-3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testSingleMaxLineSearch() throws Exception {
|
2019-06-06 15:21:15 +03:00
|
|
|
double score1, score2;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100,
|
|
|
|
|
LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD);
|
|
|
|
|
int nParams = (int)layer.numParams();
|
2019-06-06 15:21:15 +03:00
|
|
|
layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams));
|
|
|
|
|
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
|
|
|
|
layer.setLabels(irisData.getLabels());
|
|
|
|
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
2023-03-23 17:39:00 +01:00
|
|
|
score1 = layer.getScore();
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
BackTrackLineSearch lineSearch =
|
|
|
|
|
new BackTrackLineSearch(layer, new NegativeDefaultStepFunction(), layer.getOptimizer());
|
2023-03-23 17:39:00 +01:00
|
|
|
double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
assertEquals(1.0, step, 1e-3);
|
|
|
|
|
}
|
|
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testMultMinLineSearch() throws Exception {
|
2019-06-06 15:21:15 +03:00
|
|
|
double score1, score2;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100,
|
|
|
|
|
LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD);
|
|
|
|
|
int nParams = (int)layer.numParams();
|
2019-06-06 15:21:15 +03:00
|
|
|
layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams));
|
|
|
|
|
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
|
|
|
|
layer.setLabels(irisData.getLabels());
|
|
|
|
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
2023-03-23 17:39:00 +01:00
|
|
|
score1 = layer.getScore();
|
2019-06-06 15:21:15 +03:00
|
|
|
INDArray origGradient = layer.gradient().gradient().dup();
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
NegativeDefaultStepFunction sf = new NegativeDefaultStepFunction();
|
|
|
|
|
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
|
2023-03-23 17:39:00 +01:00
|
|
|
double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
|
|
|
|
INDArray currParams = layer.getModelParams();
|
2019-06-06 15:21:15 +03:00
|
|
|
sf.step(currParams, origGradient, step);
|
2023-03-23 17:39:00 +01:00
|
|
|
layer.setParamsTable(currParams);
|
2019-06-06 15:21:15 +03:00
|
|
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
score2 = layer.getScore();
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
assertTrue(score1 > score2, "score1=" + score1 + ", score2=" + score2);
|
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testMultMaxLineSearch() throws Exception {
|
2019-06-06 15:21:15 +03:00
|
|
|
double score1, score2;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
irisData.normalizeZeroMeanZeroUnitVariance();
|
|
|
|
|
OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100, LossFunctions.LossFunction.MCXENT);
|
2022-09-20 15:40:53 +02:00
|
|
|
int nParams = (int)layer.numParams();
|
2019-06-06 15:21:15 +03:00
|
|
|
layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams));
|
|
|
|
|
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
|
|
|
|
layer.setLabels(irisData.getLabels());
|
|
|
|
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
2023-03-23 17:39:00 +01:00
|
|
|
score1 = layer.getScore();
|
2019-06-06 15:21:15 +03:00
|
|
|
INDArray origGradient = layer.gradient().gradient().dup();
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
DefaultStepFunction sf = new DefaultStepFunction();
|
|
|
|
|
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
|
2023-03-23 17:39:00 +01:00
|
|
|
double step = lineSearch.optimize(layer.getModelParams().dup(), layer.gradient().gradient().dup(),
|
2022-09-20 15:40:53 +02:00
|
|
|
layer.gradient().gradient().dup(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
INDArray currParams = layer.getModelParams();
|
2019-06-06 15:21:15 +03:00
|
|
|
sf.step(currParams, origGradient, step);
|
2023-03-23 17:39:00 +01:00
|
|
|
layer.setParamsTable(currParams);
|
2019-06-06 15:21:15 +03:00
|
|
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
2023-03-23 17:39:00 +01:00
|
|
|
score2 = layer.getScore();
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
assertTrue(score1 < score2, "score1 = " + score1 + ", score2 = " + score2);
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
private static OutputLayer getIrisLogisticLayerConfig(Activation activationFunction, int maxIterations,
|
|
|
|
|
LossFunctions.LossFunction lossFunction) {
|
|
|
|
|
NeuralNetConfiguration conf =
|
2023-03-23 17:39:00 +01:00
|
|
|
NeuralNetConfiguration.builder().seed(12345L).miniBatch(true)
|
2022-09-20 15:40:53 +02:00
|
|
|
.maxNumLineSearchIterations(maxIterations)
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(org.deeplearning4j.nn.conf.layers.OutputLayer.builder().lossFunction(lossFunction)
|
2022-09-20 15:40:53 +02:00
|
|
|
.nIn(4).nOut(3).activation(activationFunction)
|
|
|
|
|
.weightInit(WeightInit.XAVIER).build())
|
|
|
|
|
.build();
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
val numParams = conf.getFirstLayer().initializer().numParams(conf);
|
2019-06-06 15:21:15 +03:00
|
|
|
INDArray params = Nd4j.create(1, numParams);
|
2023-03-23 17:39:00 +01:00
|
|
|
return (OutputLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
///////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testBackTrackLineGradientDescent() {
|
2019-06-06 15:21:15 +03:00
|
|
|
OptimizationAlgorithm optimizer = OptimizationAlgorithm.LINE_GRADIENT_DESCENT;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
DataSetIterator irisIter = new IrisDataSetIterator(1, 1);
|
|
|
|
|
DataSet data = irisIter.next();
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer));
|
|
|
|
|
network.init();
|
2020-01-04 13:45:07 +11:00
|
|
|
TrainingListener listener = new ScoreIterationListener(10);
|
2023-03-23 17:39:00 +01:00
|
|
|
network.addTrainingListeners(Collections.singletonList(listener));
|
2019-06-06 15:21:15 +03:00
|
|
|
double oldScore = network.score(data);
|
2022-09-20 15:40:53 +02:00
|
|
|
for( int i=0; i<100; i++ ) {
|
2019-06-06 15:21:15 +03:00
|
|
|
network.fit(data.getFeatures(), data.getLabels());
|
|
|
|
|
}
|
2023-03-23 17:39:00 +01:00
|
|
|
double score = network.getScore();
|
2019-06-06 15:21:15 +03:00
|
|
|
assertTrue(score < oldScore);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testBackTrackLineCG() {
|
2019-06-06 15:21:15 +03:00
|
|
|
OptimizationAlgorithm optimizer = OptimizationAlgorithm.CONJUGATE_GRADIENT;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
DataSet data = irisIter.next();
|
|
|
|
|
data.normalizeZeroMeanZeroUnitVariance();
|
|
|
|
|
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
|
|
|
|
|
network.init();
|
2020-01-04 13:45:07 +11:00
|
|
|
TrainingListener listener = new ScoreIterationListener(10);
|
2023-03-23 17:39:00 +01:00
|
|
|
network.addTrainingListeners(Collections.singletonList(listener));
|
2019-06-06 15:21:15 +03:00
|
|
|
double firstScore = network.score(data);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
for( int i=0; i<5; i++ ) {
|
2019-06-06 15:21:15 +03:00
|
|
|
network.fit(data.getFeatures(), data.getLabels());
|
|
|
|
|
}
|
2023-03-23 17:39:00 +01:00
|
|
|
double score = network.getScore();
|
2019-06-06 15:21:15 +03:00
|
|
|
assertTrue(score < firstScore);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testBackTrackLineLBFGS() {
|
2019-06-06 15:21:15 +03:00
|
|
|
OptimizationAlgorithm optimizer = OptimizationAlgorithm.LBFGS;
|
|
|
|
|
DataSet data = irisIter.next();
|
|
|
|
|
data.normalizeZeroMeanZeroUnitVariance();
|
|
|
|
|
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
|
|
|
|
|
network.init();
|
2020-01-04 13:45:07 +11:00
|
|
|
TrainingListener listener = new ScoreIterationListener(10);
|
2023-03-23 17:39:00 +01:00
|
|
|
network.addTrainingListeners(Collections.singletonList(listener));
|
2019-06-06 15:21:15 +03:00
|
|
|
double oldScore = network.score(data);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
for( int i=0; i<5; i++ ) {
|
2019-06-06 15:21:15 +03:00
|
|
|
network.fit(data.getFeatures(), data.getLabels());
|
|
|
|
|
}
|
2023-03-23 17:39:00 +01:00
|
|
|
double score = network.getScore();
|
2019-06-06 15:21:15 +03:00
|
|
|
assertTrue(score < oldScore);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
private static NeuralNetConfiguration getIrisMultiLayerConfig(Activation activationFunction, OptimizationAlgorithm optimizer) {
|
|
|
|
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().optimizationAlgo(optimizer)
|
2022-09-20 15:40:53 +02:00
|
|
|
.updater(new Adam(0.01)).seed(12345L).list()
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(0, DenseLayer.builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER)
|
2022-09-20 15:40:53 +02:00
|
|
|
.activation(activationFunction).build())
|
2023-04-24 18:09:11 +02:00
|
|
|
.layer(1, org.deeplearning4j.nn.conf.layers.OutputLayer.builder().lossFunction(
|
2022-09-20 15:40:53 +02:00
|
|
|
LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3)
|
|
|
|
|
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
|
|
|
|
|
.build())
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
return conf;
|
|
|
|
|
}
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|