[WIP] Some optimizations in dl4j (#11)
* Minor optimization. * Reduce number of objects. * Extend arrays when limit reached * Test * Some fixes. * Small fix * Wrong condition fixed. * Fixes of reallocation. * Small fix. * Tests * Clean up * Test added. * Tests and some fixes. * Test * Test fixed. * Conflict fixed. * UX improvedmaster
parent
730442ae21
commit
d98784197a
|
@ -36,6 +36,7 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
|||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
|
||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
|
@ -242,7 +243,10 @@ public class DataSetIteratorTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
|
||||
model.init();
|
||||
|
||||
model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
|
||||
//model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
|
||||
|
||||
CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
|
||||
model.setListeners(listener);
|
||||
|
||||
model.fit(cifar);
|
||||
|
||||
|
@ -254,6 +258,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
|
|||
eval.eval(testDS.getLabels(), output);
|
||||
}
|
||||
System.out.println(eval.stats(true));
|
||||
listener.exportScores(System.out);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.io.FileOutputStream;
|
|||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
|
@ -37,7 +38,83 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
|
|||
|
||||
private int frequency;
|
||||
private int iterationCount = 0;
|
||||
private List<Pair<Integer, Double>> scoreVsIter = new ArrayList<>();
|
||||
//private List<Pair<Integer, Double>> scoreVsIter = new ArrayList<>();
|
||||
|
||||
public static class ScoreStat {
|
||||
public static final int BUCKET_LENGTH = 10000;
|
||||
|
||||
private int position = 0;
|
||||
private int bucketNumber = 1;
|
||||
private List<long[]> indexes;
|
||||
private List<double[]> scores;
|
||||
|
||||
public ScoreStat() {
|
||||
indexes = new ArrayList<>(1);
|
||||
indexes.add(new long[BUCKET_LENGTH]);
|
||||
scores = new ArrayList<>(1);
|
||||
scores.add(new double[BUCKET_LENGTH]);
|
||||
}
|
||||
|
||||
public List<long[]> getIndexes() {
|
||||
return indexes;
|
||||
}
|
||||
|
||||
public List<double[]> getScores() {
|
||||
return scores;
|
||||
}
|
||||
|
||||
public long[] getEffectiveIndexes() {
|
||||
return Arrays.copyOfRange(indexes.get(0), 0, position);
|
||||
}
|
||||
|
||||
public double[] getEffectiveScores() {
|
||||
return Arrays.copyOfRange(scores.get(0), 0, position);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
Originally scores array is initialized with BUCKET_LENGTH size.
|
||||
When data doesn't fit there - arrays size is increased for BUCKET_LENGTH,
|
||||
old data is copied and bucketNumber (counter of reallocations) being incremented.
|
||||
|
||||
If we got more score points than MAX_VALUE - they are put to another item of scores list.
|
||||
*/
|
||||
private void reallocateGuard() {
|
||||
if (position >= BUCKET_LENGTH * bucketNumber) {
|
||||
|
||||
long fullLength = (long)BUCKET_LENGTH * bucketNumber;
|
||||
|
||||
if (position == Integer.MAX_VALUE || fullLength >= Integer.MAX_VALUE) {
|
||||
position = 0;
|
||||
long[] newIndexes = new long[BUCKET_LENGTH];
|
||||
double[] newScores = new double[BUCKET_LENGTH];
|
||||
indexes.add(newIndexes);
|
||||
scores.add(newScores);
|
||||
}
|
||||
else {
|
||||
long[] newIndexes = new long[(int)fullLength + BUCKET_LENGTH];
|
||||
double[] newScores = new double[(int)fullLength + BUCKET_LENGTH];
|
||||
System.arraycopy(indexes.get(indexes.size()-1), 0, newIndexes, 0, (int)fullLength);
|
||||
System.arraycopy(scores.get(scores.size()-1), 0, newScores, 0, (int)fullLength);
|
||||
scores.remove(scores.size()-1);
|
||||
indexes.remove(indexes.size()-1);
|
||||
int lastIndex = scores.size() == 0 ? 0 : scores.size()-1;
|
||||
scores.add(lastIndex, newScores);
|
||||
indexes.add(lastIndex, newIndexes);
|
||||
}
|
||||
bucketNumber += 1;
|
||||
}
|
||||
}
|
||||
|
||||
public void addScore(long index, double score) {
|
||||
reallocateGuard();
|
||||
scores.get(scores.size() - 1)[position] = score;
|
||||
indexes.get(scores.size() - 1)[position] = index;
|
||||
position += 1;
|
||||
}
|
||||
}
|
||||
|
||||
ScoreStat scoreVsIter = new ScoreStat();
|
||||
|
||||
/**
|
||||
* Constructor for collecting scores with default saving frequency of 1
|
||||
|
@ -60,11 +137,12 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
|
|||
public void iterationDone(Model model, int iteration, int epoch) {
|
||||
if (++iterationCount % frequency == 0) {
|
||||
double score = model.score();
|
||||
scoreVsIter.add(new Pair<>(iterationCount, score));
|
||||
scoreVsIter.reallocateGuard();
|
||||
scoreVsIter.addScore(iteration, score);
|
||||
}
|
||||
}
|
||||
|
||||
public List<Pair<Integer, Double>> getScoreVsIter() {
|
||||
public ScoreStat getScoreVsIter() {
|
||||
return scoreVsIter;
|
||||
}
|
||||
|
||||
|
@ -84,8 +162,16 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
|
|||
public void exportScores(OutputStream outputStream, String delimiter) throws IOException {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("Iteration").append(delimiter).append("Score");
|
||||
for (Pair<Integer, Double> p : scoreVsIter) {
|
||||
sb.append("\n").append(p.getFirst()).append(delimiter).append(p.getSecond());
|
||||
int largeBuckets = scoreVsIter.indexes.size();
|
||||
for (int j = 0; j < largeBuckets; ++j) {
|
||||
long[] indexes = scoreVsIter.indexes.get(j);
|
||||
double[] scores = scoreVsIter.scores.get(j);
|
||||
|
||||
int effectiveLength = (j < largeBuckets -1) ? indexes.length : scoreVsIter.position;
|
||||
|
||||
for (int i = 0; i < effectiveLength; ++i) {
|
||||
sb.append("\n").append(indexes[i]).append(delimiter).append(scores[i]);
|
||||
}
|
||||
}
|
||||
outputStream.write(sb.toString().getBytes("UTF-8"));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
package org.deeplearning4j.optimize.listeners;
|
||||
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.List;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class ScoreStatTest {
|
||||
@Test
|
||||
public void testScoreStatSmall() {
|
||||
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
|
||||
for (int i = 0; i < CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH; ++i) {
|
||||
double score = (double)i;
|
||||
statTest.addScore(i, score);
|
||||
}
|
||||
|
||||
List<long[]> indexes = statTest.getIndexes();
|
||||
List<double[]> scores = statTest.getScores();
|
||||
|
||||
assertTrue(indexes.size() == 1);
|
||||
assertTrue(scores.size() == 1);
|
||||
|
||||
assertTrue(indexes.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH);
|
||||
assertTrue(scores.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH);
|
||||
assertEquals(indexes.get(0)[indexes.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1);
|
||||
assertEquals(scores.get(0)[scores.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1, 1e-4);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScoreStatAverage() {
|
||||
int dataSize = 1000000;
|
||||
long[] indexes = new long[dataSize];
|
||||
double[] scores = new double[dataSize];
|
||||
|
||||
for (int i = 0; i < dataSize; ++i) {
|
||||
indexes[i] = i;
|
||||
scores[i] = i;
|
||||
}
|
||||
|
||||
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
|
||||
for (int i = 0; i < dataSize; ++i) {
|
||||
statTest.addScore(indexes[i], scores[i]);
|
||||
}
|
||||
|
||||
long[] indexesStored = statTest.getIndexes().get(0);
|
||||
double[] scoresStored = statTest.getScores().get(0);
|
||||
|
||||
assertArrayEquals(indexes, indexesStored);
|
||||
assertArrayEquals(scores, scoresStored, 1e-4);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScoresClean() {
|
||||
int dataSize = 10256; // expected to be placed in 2 buckets of 10k elements size
|
||||
long[] indexes = new long[dataSize];
|
||||
double[] scores = new double[dataSize];
|
||||
|
||||
for (int i = 0; i < dataSize; ++i) {
|
||||
indexes[i] = i;
|
||||
scores[i] = i;
|
||||
}
|
||||
|
||||
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
|
||||
for (int i = 0; i < dataSize; ++i) {
|
||||
statTest.addScore(indexes[i], scores[i]);
|
||||
}
|
||||
|
||||
long[] indexesEffective = statTest.getEffectiveIndexes();
|
||||
double[] scoresEffective = statTest.getEffectiveScores();
|
||||
|
||||
assertArrayEquals(indexes, indexesEffective);
|
||||
assertArrayEquals(scores, scoresEffective, 1e-4);
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testScoreStatBig() {
|
||||
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
|
||||
long bigLength = (long)Integer.MAX_VALUE + 5;
|
||||
for (long i = 0; i < bigLength; ++i) {
|
||||
double score = (double)i;
|
||||
statTest.addScore(i, score);
|
||||
}
|
||||
|
||||
List<long[]> indexes = statTest.getIndexes();
|
||||
List<double[]> scores = statTest.getScores();
|
||||
|
||||
assertTrue(indexes.size() == 2);
|
||||
assertTrue(scores.size() == 2);
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
assertTrue(indexes.get(1)[i] == Integer.MAX_VALUE + i);
|
||||
assertTrue(scores.get(1)[i] == Integer.MAX_VALUE + i);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue