2022-09-20 15:40:53 +02:00

159 lines
7.1 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.multilayernetwork;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
import org.deeplearning4j.arbiter.saver.local.FileModelSaver;
import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction;
import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator;
import org.deeplearning4j.arbiter.util.TestDataProviderMnist;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
@Timeout(20)
public class TestErrors extends BaseDL4JTest {
@TempDir
public File temp;
@Test
public void testAllInvalidConfig() throws Exception {
//Invalid config - basically check that this actually terminates
File f = temp;
MultiLayerSpace mls = new MultiLayerSpace.Builder()
.addLayer(new DenseLayerSpace.Builder().nIn(4).nOut(new FixedValue<>(0)) //INVALID: nOut of 0
.activation(Activation.TANH)
.build())
.addLayer(new OutputLayerSpace.Builder().nOut(3).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration);
runner.execute();
}
@Test
public void testAllInvalidDataConfigMismatch() throws Exception {
//Valid config - but mismatched with provided data
File f = temp;
MultiLayerSpace mls = new MultiLayerSpace.Builder()
.addLayer(new DenseLayerSpace.Builder().nIn(4).nOut(10) //INVALID: nOut of 0
.activation(Activation.TANH)
.build())
.addLayer(new OutputLayerSpace.Builder().nIn(10).nOut(3).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration);
runner.execute();
}
@Test
public void testAllInvalidConfigCG() throws Exception {
//Invalid config - basically check that this actually terminates
File f = temp;
ComputationGraphSpace mls = new ComputationGraphSpace.Builder()
.addInputs("in")
.layer("0", new DenseLayerSpace.Builder().nIn(4).nOut(new FixedValue<>(0)) //INVALID: nOut of 0
.activation(Activation.TANH)
.build(), "in")
.layer("1", new OutputLayerSpace.Builder().nOut(3).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "0")
.setOutputs("1")
.build();
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxCandidatesCondition(5))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration);
runner.execute();
}
@Test
public void testAllInvalidDataConfigMismatchCG() throws Exception {
//Valid config - but mismatched with provided data
File f = temp;
ComputationGraphSpace mls = new ComputationGraphSpace.Builder()
.addInputs("in")
.layer("0", new DenseLayerSpace.Builder().nIn(4).nOut(10)
.activation(Activation.TANH).build(), "in")
.addLayer("1", new OutputLayerSpace.Builder().nIn(10).nOut(3).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "0")
.setOutputs("1")
.build();
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
runner.execute();
}
}