257 lines
11 KiB
Java
Raw Normal View History

2021-02-01 14:31:20 +09:00
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
2021-02-01 17:47:29 +09:00
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
2021-02-01 14:31:20 +09:00
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
2019-06-06 15:21:15 +03:00
package org.deeplearning4j.optimize.solver;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.optimize.solvers.BackTrackLineSearch;
import org.deeplearning4j.optimize.stepfunctions.DefaultStepFunction;
import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction;
2021-03-15 13:02:01 +09:00
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
2019-06-06 15:21:15 +03:00
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
2019-06-06 15:21:15 +03:00
import java.util.Collections;
2021-03-15 13:02:01 +09:00
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
2019-06-06 15:21:15 +03:00
/**
* @author Adam Gibson
*/
public class BackTrackLineSearchTest extends BaseDL4JTest {
2019-06-06 15:21:15 +03:00
private DataSetIterator irisIter;
private DataSet irisData;
2021-03-15 13:02:01 +09:00
@BeforeEach
public void before() {
2019-06-06 15:21:15 +03:00
if (irisIter == null) {
irisIter = new IrisDataSetIterator(5, 5);
}
if (irisData == null) {
irisData = irisIter.next();
irisData.normalizeZeroMeanZeroUnitVariance();
}
}
2019-06-06 15:21:15 +03:00
@Test
public void testSingleMinLineSearch() throws Exception {
OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100,
LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD);
int nParams = (int)layer.numParams();
2019-06-06 15:21:15 +03:00
layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams));
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
layer.setLabels(irisData.getLabels());
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
2019-06-06 15:21:15 +03:00
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, layer.getOptimizer());
double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
2019-06-06 15:21:15 +03:00
assertEquals(1.0, step, 1e-3);
}
@Test
public void testSingleMaxLineSearch() throws Exception {
2019-06-06 15:21:15 +03:00
double score1, score2;
OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100,
LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD);
int nParams = (int)layer.numParams();
2019-06-06 15:21:15 +03:00
layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams));
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
layer.setLabels(irisData.getLabels());
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
score1 = layer.getScore();
BackTrackLineSearch lineSearch =
new BackTrackLineSearch(layer, new NegativeDefaultStepFunction(), layer.getOptimizer());
double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
2019-06-06 15:21:15 +03:00
assertEquals(1.0, step, 1e-3);
}
2019-06-06 15:21:15 +03:00
@Test
public void testMultMinLineSearch() throws Exception {
2019-06-06 15:21:15 +03:00
double score1, score2;
OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100,
LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD);
int nParams = (int)layer.numParams();
2019-06-06 15:21:15 +03:00
layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams));
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
layer.setLabels(irisData.getLabels());
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
score1 = layer.getScore();
2019-06-06 15:21:15 +03:00
INDArray origGradient = layer.gradient().gradient().dup();
2019-06-06 15:21:15 +03:00
NegativeDefaultStepFunction sf = new NegativeDefaultStepFunction();
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
INDArray currParams = layer.getModelParams();
2019-06-06 15:21:15 +03:00
sf.step(currParams, origGradient, step);
layer.setParamsTable(currParams);
2019-06-06 15:21:15 +03:00
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
score2 = layer.getScore();
assertTrue(score1 > score2, "score1=" + score1 + ", score2=" + score2);
2019-06-06 15:21:15 +03:00
}
@Test
public void testMultMaxLineSearch() throws Exception {
2019-06-06 15:21:15 +03:00
double score1, score2;
2019-06-06 15:21:15 +03:00
irisData.normalizeZeroMeanZeroUnitVariance();
OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100, LossFunctions.LossFunction.MCXENT);
int nParams = (int)layer.numParams();
2019-06-06 15:21:15 +03:00
layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams));
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
layer.setLabels(irisData.getLabels());
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
score1 = layer.getScore();
2019-06-06 15:21:15 +03:00
INDArray origGradient = layer.gradient().gradient().dup();
2019-06-06 15:21:15 +03:00
DefaultStepFunction sf = new DefaultStepFunction();
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
double step = lineSearch.optimize(layer.getModelParams().dup(), layer.gradient().gradient().dup(),
layer.gradient().gradient().dup(), LayerWorkspaceMgr.noWorkspacesImmutable());
INDArray currParams = layer.getModelParams();
2019-06-06 15:21:15 +03:00
sf.step(currParams, origGradient, step);
layer.setParamsTable(currParams);
2019-06-06 15:21:15 +03:00
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
score2 = layer.getScore();
assertTrue(score1 < score2, "score1 = " + score1 + ", score2 = " + score2);
2019-06-06 15:21:15 +03:00
}
private static OutputLayer getIrisLogisticLayerConfig(Activation activationFunction, int maxIterations,
LossFunctions.LossFunction lossFunction) {
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder().seed(12345L).miniBatch(true)
.maxNumLineSearchIterations(maxIterations)
.layer(org.deeplearning4j.nn.conf.layers.OutputLayer.builder().lossFunction(lossFunction)
.nIn(4).nOut(3).activation(activationFunction)
.weightInit(WeightInit.XAVIER).build())
.build();
val numParams = conf.getFirstLayer().initializer().numParams(conf);
2019-06-06 15:21:15 +03:00
INDArray params = Nd4j.create(1, numParams);
return (OutputLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
2019-06-06 15:21:15 +03:00
}
///////////////////////////////////////////////////////////////////////////
2019-06-06 15:21:15 +03:00
@Test
public void testBackTrackLineGradientDescent() {
2019-06-06 15:21:15 +03:00
OptimizationAlgorithm optimizer = OptimizationAlgorithm.LINE_GRADIENT_DESCENT;
2019-06-06 15:21:15 +03:00
DataSetIterator irisIter = new IrisDataSetIterator(1, 1);
DataSet data = irisIter.next();
2019-06-06 15:21:15 +03:00
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer));
network.init();
Various fixes (#143) * #8568 ArrayUtil optimization Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6171 Keras ReLU and ELU support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Keras softmax layer import Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8549 Webjars dependency management Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for TF import names ':0' suffix issue / NPE Signed-off-by: AlexDBlack <blacka101@gmail.com> * BiasAdd: fix default data format for TF import Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update zoo test ignores Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8509 SameDiff Listener API - provide frame + iteration Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8520 ND4J Environment Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d fixes + gradient check Signed-off-by: AlexDBlack <blacka101@gmail.com> * Conv3d fixes + deconv3d DType test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue with deconv3d gradinet check weight init Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8579 Fix BaseCudaDataBuffer constructor fix for UINT16 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataType.isNumerical() returns false for BOOL type Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8504 Reduce Spark log spam for tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up DL4J gradient check test spam Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Gradient check spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for FlatBuffers mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff log spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tests should extend BaseNd4jTest Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove debug line in c++ op Signed-off-by: AlexDBlack <blacka101@gmail.com> * ND4J test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Dl4J and datavec test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for bad conv3d test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Additional test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Embedding layers: don't inherit global default activation function Signed-off-by: AlexDBlack <blacka101@gmail.com> * Trigger CI Signed-off-by: AlexDBlack <blacka101@gmail.com> * Consolidate all BaseDL4JTest classes to single class used everywhere; make timeout configurable per class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fixes and timeout increases Signed-off-by: AlexDBlack <blacka101@gmail.com> * Timeouts and PReLU fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Restore libnd4j build threads arg for CUDA build Signed-off-by: AlexDBlack <blacka101@gmail.com> * Increase timeouts on a few tests to avoid spurious failures on some CI machines Signed-off-by: AlexDBlack <blacka101@gmail.com> * More timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweak timeout for one more test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Final tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * One more ignore Signed-off-by: AlexDBlack <blacka101@gmail.com>
2020-01-04 13:45:07 +11:00
TrainingListener listener = new ScoreIterationListener(10);
network.addTrainingListeners(Collections.singletonList(listener));
2019-06-06 15:21:15 +03:00
double oldScore = network.score(data);
for( int i=0; i<100; i++ ) {
2019-06-06 15:21:15 +03:00
network.fit(data.getFeatures(), data.getLabels());
}
double score = network.getScore();
2019-06-06 15:21:15 +03:00
assertTrue(score < oldScore);
}
@Test
public void testBackTrackLineCG() {
2019-06-06 15:21:15 +03:00
OptimizationAlgorithm optimizer = OptimizationAlgorithm.CONJUGATE_GRADIENT;
2019-06-06 15:21:15 +03:00
DataSet data = irisIter.next();
data.normalizeZeroMeanZeroUnitVariance();
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
network.init();
Various fixes (#143) * #8568 ArrayUtil optimization Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6171 Keras ReLU and ELU support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Keras softmax layer import Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8549 Webjars dependency management Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for TF import names ':0' suffix issue / NPE Signed-off-by: AlexDBlack <blacka101@gmail.com> * BiasAdd: fix default data format for TF import Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update zoo test ignores Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8509 SameDiff Listener API - provide frame + iteration Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8520 ND4J Environment Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d fixes + gradient check Signed-off-by: AlexDBlack <blacka101@gmail.com> * Conv3d fixes + deconv3d DType test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue with deconv3d gradinet check weight init Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8579 Fix BaseCudaDataBuffer constructor fix for UINT16 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataType.isNumerical() returns false for BOOL type Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8504 Reduce Spark log spam for tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up DL4J gradient check test spam Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Gradient check spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for FlatBuffers mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff log spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tests should extend BaseNd4jTest Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove debug line in c++ op Signed-off-by: AlexDBlack <blacka101@gmail.com> * ND4J test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Dl4J and datavec test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for bad conv3d test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Additional test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Embedding layers: don't inherit global default activation function Signed-off-by: AlexDBlack <blacka101@gmail.com> * Trigger CI Signed-off-by: AlexDBlack <blacka101@gmail.com> * Consolidate all BaseDL4JTest classes to single class used everywhere; make timeout configurable per class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fixes and timeout increases Signed-off-by: AlexDBlack <blacka101@gmail.com> * Timeouts and PReLU fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Restore libnd4j build threads arg for CUDA build Signed-off-by: AlexDBlack <blacka101@gmail.com> * Increase timeouts on a few tests to avoid spurious failures on some CI machines Signed-off-by: AlexDBlack <blacka101@gmail.com> * More timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweak timeout for one more test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Final tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * One more ignore Signed-off-by: AlexDBlack <blacka101@gmail.com>
2020-01-04 13:45:07 +11:00
TrainingListener listener = new ScoreIterationListener(10);
network.addTrainingListeners(Collections.singletonList(listener));
2019-06-06 15:21:15 +03:00
double firstScore = network.score(data);
for( int i=0; i<5; i++ ) {
2019-06-06 15:21:15 +03:00
network.fit(data.getFeatures(), data.getLabels());
}
double score = network.getScore();
2019-06-06 15:21:15 +03:00
assertTrue(score < firstScore);
2019-06-06 15:21:15 +03:00
}
@Test
public void testBackTrackLineLBFGS() {
2019-06-06 15:21:15 +03:00
OptimizationAlgorithm optimizer = OptimizationAlgorithm.LBFGS;
DataSet data = irisIter.next();
data.normalizeZeroMeanZeroUnitVariance();
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
network.init();
Various fixes (#143) * #8568 ArrayUtil optimization Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6171 Keras ReLU and ELU support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Keras softmax layer import Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8549 Webjars dependency management Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for TF import names ':0' suffix issue / NPE Signed-off-by: AlexDBlack <blacka101@gmail.com> * BiasAdd: fix default data format for TF import Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update zoo test ignores Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8509 SameDiff Listener API - provide frame + iteration Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8520 ND4J Environment Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d fixes + gradient check Signed-off-by: AlexDBlack <blacka101@gmail.com> * Conv3d fixes + deconv3d DType test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue with deconv3d gradinet check weight init Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8579 Fix BaseCudaDataBuffer constructor fix for UINT16 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataType.isNumerical() returns false for BOOL type Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8504 Reduce Spark log spam for tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up DL4J gradient check test spam Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Gradient check spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for FlatBuffers mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff log spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tests should extend BaseNd4jTest Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove debug line in c++ op Signed-off-by: AlexDBlack <blacka101@gmail.com> * ND4J test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Dl4J and datavec test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for bad conv3d test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Additional test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Embedding layers: don't inherit global default activation function Signed-off-by: AlexDBlack <blacka101@gmail.com> * Trigger CI Signed-off-by: AlexDBlack <blacka101@gmail.com> * Consolidate all BaseDL4JTest classes to single class used everywhere; make timeout configurable per class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fixes and timeout increases Signed-off-by: AlexDBlack <blacka101@gmail.com> * Timeouts and PReLU fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Restore libnd4j build threads arg for CUDA build Signed-off-by: AlexDBlack <blacka101@gmail.com> * Increase timeouts on a few tests to avoid spurious failures on some CI machines Signed-off-by: AlexDBlack <blacka101@gmail.com> * More timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweak timeout for one more test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Final tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * One more ignore Signed-off-by: AlexDBlack <blacka101@gmail.com>
2020-01-04 13:45:07 +11:00
TrainingListener listener = new ScoreIterationListener(10);
network.addTrainingListeners(Collections.singletonList(listener));
2019-06-06 15:21:15 +03:00
double oldScore = network.score(data);
for( int i=0; i<5; i++ ) {
2019-06-06 15:21:15 +03:00
network.fit(data.getFeatures(), data.getLabels());
}
double score = network.getScore();
2019-06-06 15:21:15 +03:00
assertTrue(score < oldScore);
2019-06-06 15:21:15 +03:00
}
private static NeuralNetConfiguration getIrisMultiLayerConfig(Activation activationFunction, OptimizationAlgorithm optimizer) {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().optimizationAlgo(optimizer)
.updater(new Adam(0.01)).seed(12345L).list()
.layer(0, DenseLayer.builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER)
.activation(activationFunction).build())
.layer(1, org.deeplearning4j.nn.conf.layers.OutputLayer.builder().lossFunction(
LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.build())
.build();
2019-06-06 15:21:15 +03:00
return conf;
}
2019-06-06 15:21:15 +03:00
}