* #8751 Arbiter grid search candidate generator fix Signed-off-by: Alex Black <blacka101@gmail.com> * Small fix Signed-off-by: Alex Black <blacka101@gmail.com> * Timeout Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
19d5a8d49d
commit
7494117e90
|
@ -196,6 +196,11 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||||
// 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc
|
// 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc
|
||||||
//Based on: Nd4j Shape.ind2sub
|
//Based on: Nd4j Shape.ind2sub
|
||||||
|
|
||||||
|
int countNon1 = 0;
|
||||||
|
for( int i : numValuesPerParam)
|
||||||
|
if(i > 1)
|
||||||
|
countNon1++;
|
||||||
|
|
||||||
int denom = product;
|
int denom = product;
|
||||||
int num = candidateIdx;
|
int num = candidateIdx;
|
||||||
int[] index = new int[numValuesPerParam.length];
|
int[] index = new int[numValuesPerParam.length];
|
||||||
|
@ -209,12 +214,11 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||||
//Now: convert indexes to values in range [0,1]
|
//Now: convert indexes to values in range [0,1]
|
||||||
//min value -> 0
|
//min value -> 0
|
||||||
//max value -> 1
|
//max value -> 1
|
||||||
double[] out = new double[numValuesPerParam.length];
|
double[] out = new double[countNon1];
|
||||||
for (int i = 0; i < out.length; i++) {
|
int outIdx = 0;
|
||||||
if (numValuesPerParam[i] <= 1)
|
for (int i = 0; i < numValuesPerParam.length; i++) {
|
||||||
out[i] = 0.0;
|
if (numValuesPerParam[i] > 1){
|
||||||
else {
|
out[outIdx++] = index[i] / ((double) (numValuesPerParam[i] - 1));
|
||||||
out[i] = index[i] / ((double) (numValuesPerParam[i] - 1));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.arbiter.DL4JConfiguration;
|
||||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||||
import org.deeplearning4j.arbiter.TestUtils;
|
import org.deeplearning4j.arbiter.TestUtils;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
||||||
|
import org.deeplearning4j.arbiter.conf.updater.NesterovsSpace;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||||
import org.deeplearning4j.arbiter.layers.*;
|
import org.deeplearning4j.arbiter.layers.*;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||||
|
@ -80,6 +81,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
|
@ -767,4 +769,52 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
|
||||||
assertEquals(expCandidates, count);
|
assertEquals(expCandidates, count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGridCandidateGenerator(){
|
||||||
|
ParameterSpace<Integer> layerSizeParam = new DiscreteParameterSpace<>(32, 48, 64);
|
||||||
|
ParameterSpace<Double> learningRateParam = new DiscreteParameterSpace<>(0.005, 0.007, 0.01);
|
||||||
|
|
||||||
|
MultiLayerSpace hyperParamaterSpace = new MultiLayerSpace.Builder()
|
||||||
|
.seed(12345)
|
||||||
|
.biasInit(1)
|
||||||
|
.l2(1e-4)
|
||||||
|
.updater(new NesterovsSpace(learningRateParam))
|
||||||
|
.addLayer(new DenseLayerSpace.Builder().nIn(10).nOut(layerSizeParam)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.build())
|
||||||
|
.addLayer(new DenseLayerSpace.Builder().nIn(layerSizeParam).nOut(layerSizeParam)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.build())
|
||||||
|
.addLayer(new OutputLayerSpace.Builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MSE)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nIn(layerSizeParam).nOut(10).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
CandidateGenerator candidateGenerator = new GridSearchCandidateGenerator(hyperParamaterSpace, 30, GridSearchCandidateGenerator.Mode.Sequential, null);
|
||||||
|
// CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperParamaterSpace);
|
||||||
|
|
||||||
|
Set<Pair<Double,Integer>> expCandidates = new HashSet<>();
|
||||||
|
for(Double d : new double[]{0.005, 0.007, 0.01}){
|
||||||
|
for(int i : new int[]{32, 48, 64}){
|
||||||
|
expCandidates.add(new Pair<>(d, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Set<Pair<Double,Integer>> actCandidates = new HashSet<>();
|
||||||
|
while(candidateGenerator.hasMoreCandidates()) {
|
||||||
|
Candidate<DL4JConfiguration> conf = candidateGenerator.getCandidate();
|
||||||
|
MultiLayerConfiguration mlc = conf.getValue().getMultiLayerConfiguration();
|
||||||
|
FeedForwardLayer ffl = ((FeedForwardLayer) mlc.getConf(0).getLayer());
|
||||||
|
// System.out.println(ffl.getIUpdater() + ", " + ffl.getNOut());
|
||||||
|
actCandidates.add(new Pair<>(ffl.getIUpdater().getLearningRate(0,0), (int)ffl.getNOut()));
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(expCandidates, actCandidates);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,6 +55,10 @@ import static org.junit.Assert.assertEquals;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCliRunner() throws Exception {
|
public void testCliRunner() throws Exception {
|
||||||
|
@ -67,7 +71,7 @@ public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
||||||
.l2(new ContinuousParameterSpace(0.0001, 0.01))
|
.l2(new ContinuousParameterSpace(0.0001, 0.01))
|
||||||
.addLayer(new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2,10))
|
.addLayer(new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2,10))
|
||||||
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
||||||
.build(),new IntegerParameterSpace(1,2),true) //1-2 identical layers (except nIn)
|
.build())
|
||||||
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.numEpochs(3).build();
|
.numEpochs(3).build();
|
||||||
|
|
Loading…
Reference in New Issue