#8751 Arbiter grid search candidate generator fix [WIP] (#292)

* #8751 Arbiter grid search candidate generator fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Timeout

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-06 12:01:21 +11:00 committed by GitHub
parent 19d5a8d49d
commit 7494117e90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 7 deletions

View File

@ -196,6 +196,11 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
// 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc // 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc
//Based on: Nd4j Shape.ind2sub //Based on: Nd4j Shape.ind2sub
int countNon1 = 0;
for( int i : numValuesPerParam)
if(i > 1)
countNon1++;
int denom = product; int denom = product;
int num = candidateIdx; int num = candidateIdx;
int[] index = new int[numValuesPerParam.length]; int[] index = new int[numValuesPerParam.length];
@ -209,12 +214,11 @@ public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
//Now: convert indexes to values in range [0,1] //Now: convert indexes to values in range [0,1]
//min value -> 0 //min value -> 0
//max value -> 1 //max value -> 1
double[] out = new double[numValuesPerParam.length]; double[] out = new double[countNon1];
for (int i = 0; i < out.length; i++) { int outIdx = 0;
if (numValuesPerParam[i] <= 1) for (int i = 0; i < numValuesPerParam.length; i++) {
out[i] = 0.0; if (numValuesPerParam[i] > 1){
else { out[outIdx++] = index[i] / ((double) (numValuesPerParam[i] - 1));
out[i] = index[i] / ((double) (numValuesPerParam[i] - 1));
} }
} }

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.arbiter.DL4JConfiguration;
import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.TestUtils; import org.deeplearning4j.arbiter.TestUtils;
import org.deeplearning4j.arbiter.conf.updater.AdamSpace; import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.conf.updater.NesterovsSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.*; import org.deeplearning4j.arbiter.layers.*;
import org.deeplearning4j.arbiter.optimize.api.Candidate; import org.deeplearning4j.arbiter.optimize.api.Candidate;
@ -80,6 +81,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMSE; import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.primitives.Pair;
import java.io.File; import java.io.File;
import java.lang.reflect.Field; import java.lang.reflect.Field;
@ -767,4 +769,52 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
assertEquals(expCandidates, count); assertEquals(expCandidates, count);
} }
} }
@Test
public void testGridCandidateGenerator(){
ParameterSpace<Integer> layerSizeParam = new DiscreteParameterSpace<>(32, 48, 64);
ParameterSpace<Double> learningRateParam = new DiscreteParameterSpace<>(0.005, 0.007, 0.01);
MultiLayerSpace hyperParamaterSpace = new MultiLayerSpace.Builder()
.seed(12345)
.biasInit(1)
.l2(1e-4)
.updater(new NesterovsSpace(learningRateParam))
.addLayer(new DenseLayerSpace.Builder().nIn(10).nOut(layerSizeParam)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.build())
.addLayer(new DenseLayerSpace.Builder().nIn(layerSizeParam).nOut(layerSizeParam)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.build())
.addLayer(new OutputLayerSpace.Builder()
.lossFunction(LossFunctions.LossFunction.MSE)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX)
.nIn(layerSizeParam).nOut(10).build())
.build();
CandidateGenerator candidateGenerator = new GridSearchCandidateGenerator(hyperParamaterSpace, 30, GridSearchCandidateGenerator.Mode.Sequential, null);
// CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperParamaterSpace);
Set<Pair<Double,Integer>> expCandidates = new HashSet<>();
for(Double d : new double[]{0.005, 0.007, 0.01}){
for(int i : new int[]{32, 48, 64}){
expCandidates.add(new Pair<>(d, i));
}
}
Set<Pair<Double,Integer>> actCandidates = new HashSet<>();
while(candidateGenerator.hasMoreCandidates()) {
Candidate<DL4JConfiguration> conf = candidateGenerator.getCandidate();
MultiLayerConfiguration mlc = conf.getValue().getMultiLayerConfiguration();
FeedForwardLayer ffl = ((FeedForwardLayer) mlc.getConf(0).getLayer());
// System.out.println(ffl.getIUpdater() + ", " + ffl.getNOut());
actCandidates.add(new Pair<>(ffl.getIUpdater().getLearningRate(0,0), (int)ffl.getNOut()));
}
assertEquals(expCandidates, actCandidates);
}
} }

View File

@ -55,6 +55,10 @@ import static org.junit.Assert.assertEquals;
@Slf4j @Slf4j
public class ArbiterCLIRunnerTest extends BaseDL4JTest { public class ArbiterCLIRunnerTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000;
}
@Test @Test
public void testCliRunner() throws Exception { public void testCliRunner() throws Exception {
@ -67,7 +71,7 @@ public class ArbiterCLIRunnerTest extends BaseDL4JTest {
.l2(new ContinuousParameterSpace(0.0001, 0.01)) .l2(new ContinuousParameterSpace(0.0001, 0.01))
.addLayer(new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2,10)) .addLayer(new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2,10))
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
.build(),new IntegerParameterSpace(1,2),true) //1-2 identical layers (except nIn) .build())
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.numEpochs(3).build(); .numEpochs(3).build();