159 lines
7.1 KiB
Java
159 lines
7.1 KiB
Java
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
package org.deeplearning4j.arbiter.multilayernetwork;
|
|
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
|
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
|
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
|
|
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
|
|
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
|
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
|
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
|
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
|
|
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
|
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
|
|
import org.deeplearning4j.arbiter.saver.local.FileModelSaver;
|
|
import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction;
|
|
import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator;
|
|
import org.deeplearning4j.arbiter.util.TestDataProviderMnist;
|
|
import org.junit.jupiter.api.Test;
|
|
import org.junit.jupiter.api.Timeout;
|
|
import org.junit.jupiter.api.io.TempDir;
|
|
import org.nd4j.linalg.activations.Activation;
|
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
|
|
import java.io.File;
|
|
|
|
@Timeout(20)
|
|
public class TestErrors extends BaseDL4JTest {
|
|
|
|
@TempDir
|
|
public File temp;
|
|
|
|
@Test
|
|
public void testAllInvalidConfig() throws Exception {
|
|
//Invalid config - basically check that this actually terminates
|
|
|
|
File f = temp;
|
|
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
|
.addLayer(new DenseLayerSpace.Builder().nIn(4).nOut(new FixedValue<>(0)) //INVALID: nOut of 0
|
|
.activation(Activation.TANH)
|
|
.build())
|
|
.addLayer(new OutputLayerSpace.Builder().nOut(3).activation(Activation.SOFTMAX)
|
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
|
.build();
|
|
|
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
|
|
|
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
|
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
|
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
|
.terminationConditions(
|
|
new MaxCandidatesCondition(5))
|
|
.build();
|
|
|
|
IOptimizationRunner runner = new LocalOptimizationRunner(configuration);
|
|
runner.execute();
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testAllInvalidDataConfigMismatch() throws Exception {
|
|
//Valid config - but mismatched with provided data
|
|
|
|
File f = temp;
|
|
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
|
.addLayer(new DenseLayerSpace.Builder().nIn(4).nOut(10) //INVALID: nOut of 0
|
|
.activation(Activation.TANH)
|
|
.build())
|
|
.addLayer(new OutputLayerSpace.Builder().nIn(10).nOut(3).activation(Activation.SOFTMAX)
|
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
|
.build();
|
|
|
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
|
|
|
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
|
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
|
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
|
.terminationConditions(
|
|
new MaxCandidatesCondition(5))
|
|
.build();
|
|
|
|
IOptimizationRunner runner = new LocalOptimizationRunner(configuration);
|
|
runner.execute();
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testAllInvalidConfigCG() throws Exception {
|
|
//Invalid config - basically check that this actually terminates
|
|
|
|
File f = temp;
|
|
ComputationGraphSpace mls = new ComputationGraphSpace.Builder()
|
|
.addInputs("in")
|
|
.layer("0", new DenseLayerSpace.Builder().nIn(4).nOut(new FixedValue<>(0)) //INVALID: nOut of 0
|
|
.activation(Activation.TANH)
|
|
.build(), "in")
|
|
.layer("1", new OutputLayerSpace.Builder().nOut(3).activation(Activation.SOFTMAX)
|
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "0")
|
|
.setOutputs("1")
|
|
.build();
|
|
|
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
|
|
|
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
|
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
|
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
|
.terminationConditions(new MaxCandidatesCondition(5))
|
|
.build();
|
|
|
|
IOptimizationRunner runner = new LocalOptimizationRunner(configuration);
|
|
runner.execute();
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testAllInvalidDataConfigMismatchCG() throws Exception {
|
|
//Valid config - but mismatched with provided data
|
|
|
|
File f = temp;
|
|
ComputationGraphSpace mls = new ComputationGraphSpace.Builder()
|
|
.addInputs("in")
|
|
.layer("0", new DenseLayerSpace.Builder().nIn(4).nOut(10)
|
|
.activation(Activation.TANH).build(), "in")
|
|
.addLayer("1", new OutputLayerSpace.Builder().nIn(10).nOut(3).activation(Activation.SOFTMAX)
|
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "0")
|
|
.setOutputs("1")
|
|
.build();
|
|
|
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
|
|
|
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
|
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
|
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
|
.terminationConditions(
|
|
new MaxCandidatesCondition(5))
|
|
.build();
|
|
|
|
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
|
|
runner.execute();
|
|
}
|
|
|
|
}
|