2019-06-06 15:21:15 +03:00

168 lines
7.3 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.computationgraph;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.TestUtils;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
import org.deeplearning4j.arbiter.util.LeafUtils;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class TestComputationGraphSpace {
@Test
public void testBasic() {
ComputationGraphConfiguration expected = new NeuralNetConfiguration.Builder()
.updater(new Sgd(0.005))
.seed(12345)
.graphBuilder().addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in")
.addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).build(), "0").addLayer("2",
new OutputLayer.Builder().lossFunction(LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nIn(10).nOut(5)
.build(),
"1")
.setOutputs("2").build();
ComputationGraphSpace cgs = new ComputationGraphSpace.Builder()
.updater(new Sgd(0.005))
.seed(12345).addInputs("in")
.addLayer("0", new DenseLayerSpace.Builder().nIn(10).nOut(10).build(), "in")
.addLayer("1", new DenseLayerSpace.Builder().nIn(10).nOut(10).build(), "0")
.addLayer("2", new OutputLayerSpace.Builder().lossFunction(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(5)
.build(), "1")
.setOutputs("2").setInputTypes(InputType.feedForward(10))
.build();
int nParams = cgs.numParameters();
assertEquals(0, nParams);
ComputationGraphConfiguration conf = cgs.getValue(new double[0]).getConfiguration();
assertEquals(expected, conf);
}
@Test
public void testBasic2() {
ComputationGraphSpace mls = new ComputationGraphSpace.Builder()
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1)))
.l2(new ContinuousParameterSpace(0.2, 0.5))
.addInputs("in").addLayer("0",
new DenseLayerSpace.Builder().nIn(10).nOut(10)
.activation(new DiscreteParameterSpace<>(Activation.RELU,
Activation.TANH))
.build(),
"in")
.addLayer("1", new OutputLayerSpace.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX)
.build(), "0")
.setOutputs("1").setInputTypes(InputType.feedForward(10)).build();
int nParams = mls.numParameters();
assertEquals(3, nParams);
//Assign numbers to each leaf ParameterSpace object (normally done by candidate generator)
List<ParameterSpace> noDuplicatesList = LeafUtils.getUniqueObjects(mls.collectLeaves());
//Second: assign each a number
int c = 0;
for (ParameterSpace ps : noDuplicatesList) {
int np = ps.numParameters();
if (np == 1) {
ps.setIndices(c++);
} else {
int[] values = new int[np];
for (int j = 0; j < np; j++)
values[c++] = j;
ps.setIndices(values);
}
}
int reluCount = 0;
int tanhCount = 0;
Random r = new Random(12345);
for (int i = 0; i < 50; i++) {
double[] rvs = new double[nParams];
for (int j = 0; j < rvs.length; j++)
rvs[j] = r.nextDouble();
ComputationGraphConfiguration conf = mls.getValue(rvs).getConfiguration();
int nLayers = conf.getVertexInputs().size();
assertEquals(2, nLayers);
for (int j = 0; j < nLayers; j++) {
NeuralNetConfiguration layerConf =
((LayerVertex) conf.getVertices().get(String.valueOf(j))).getLayerConf();
double lr = ((Sgd)((BaseLayer) layerConf.getLayer()).getIUpdater()).getLearningRate();
assertTrue(lr >= 0.0001 && lr <= 0.1);
double l2 = TestUtils.getL2(((BaseLayer) layerConf.getLayer()));
assertTrue(l2 >= 0.2 && l2 <= 0.5);
if (j == nLayers - 1) { //Output layer
assertEquals(Activation.SOFTMAX.getActivationFunction(),
((BaseLayer) layerConf.getLayer()).getActivationFn());
} else {
IActivation actFn = ((BaseLayer) layerConf.getLayer()).getActivationFn();
assertTrue(Activation.RELU.getActivationFunction().equals(actFn) ||
Activation.TANH.getActivationFunction().equals(actFn));
if (Activation.RELU.getActivationFunction().equals(actFn))
reluCount++;
else
tanhCount++;
}
}
}
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
assertTrue(reluCount > 0);
assertTrue(tanhCount > 0);
}
}