[WIP] Some optimizations in dl4j (#11)

* Minor optimization.

* Reduce number of objects.

* Extend arrays when limit reached

* Test

* Some fixes.

* Small fix

* Wrong condition fixed.

* Fixes of reallocation.

* Small fix.

* Tests

* Clean up

* Test added.

* Tests and some fixes.

* Test

* Test fixed.

* Conflict fixed.

* UX improved
master
Alexander Stoyakin 2019-10-24 14:15:50 +03:00 committed by GitHub
parent 730442ae21
commit d98784197a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 195 additions and 6 deletions

View File

@ -36,6 +36,7 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Ignore;
import org.junit.Test;
@ -242,7 +243,10 @@ public class DataSetIteratorTest extends BaseDL4JTest {
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
model.init();
model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
//model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
model.setListeners(listener);
model.fit(cifar);
@ -254,6 +258,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
eval.eval(testDS.getLabels(), output);
}
System.out.println(eval.stats(true));
listener.exportScores(System.out);
}

View File

@ -25,6 +25,7 @@ import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
@ -37,7 +38,83 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
private int frequency;
private int iterationCount = 0;
private List<Pair<Integer, Double>> scoreVsIter = new ArrayList<>();
//private List<Pair<Integer, Double>> scoreVsIter = new ArrayList<>();
public static class ScoreStat {
public static final int BUCKET_LENGTH = 10000;
private int position = 0;
private int bucketNumber = 1;
private List<long[]> indexes;
private List<double[]> scores;
public ScoreStat() {
indexes = new ArrayList<>(1);
indexes.add(new long[BUCKET_LENGTH]);
scores = new ArrayList<>(1);
scores.add(new double[BUCKET_LENGTH]);
}
public List<long[]> getIndexes() {
return indexes;
}
public List<double[]> getScores() {
return scores;
}
public long[] getEffectiveIndexes() {
return Arrays.copyOfRange(indexes.get(0), 0, position);
}
public double[] getEffectiveScores() {
return Arrays.copyOfRange(scores.get(0), 0, position);
}
/*
Originally scores array is initialized with BUCKET_LENGTH size.
When data doesn't fit there - arrays size is increased for BUCKET_LENGTH,
old data is copied and bucketNumber (counter of reallocations) being incremented.
If we got more score points than MAX_VALUE - they are put to another item of scores list.
*/
private void reallocateGuard() {
if (position >= BUCKET_LENGTH * bucketNumber) {
long fullLength = (long)BUCKET_LENGTH * bucketNumber;
if (position == Integer.MAX_VALUE || fullLength >= Integer.MAX_VALUE) {
position = 0;
long[] newIndexes = new long[BUCKET_LENGTH];
double[] newScores = new double[BUCKET_LENGTH];
indexes.add(newIndexes);
scores.add(newScores);
}
else {
long[] newIndexes = new long[(int)fullLength + BUCKET_LENGTH];
double[] newScores = new double[(int)fullLength + BUCKET_LENGTH];
System.arraycopy(indexes.get(indexes.size()-1), 0, newIndexes, 0, (int)fullLength);
System.arraycopy(scores.get(scores.size()-1), 0, newScores, 0, (int)fullLength);
scores.remove(scores.size()-1);
indexes.remove(indexes.size()-1);
int lastIndex = scores.size() == 0 ? 0 : scores.size()-1;
scores.add(lastIndex, newScores);
indexes.add(lastIndex, newIndexes);
}
bucketNumber += 1;
}
}
public void addScore(long index, double score) {
reallocateGuard();
scores.get(scores.size() - 1)[position] = score;
indexes.get(scores.size() - 1)[position] = index;
position += 1;
}
}
ScoreStat scoreVsIter = new ScoreStat();
/**
* Constructor for collecting scores with default saving frequency of 1
@ -60,11 +137,12 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
public void iterationDone(Model model, int iteration, int epoch) {
if (++iterationCount % frequency == 0) {
double score = model.score();
scoreVsIter.add(new Pair<>(iterationCount, score));
scoreVsIter.reallocateGuard();
scoreVsIter.addScore(iteration, score);
}
}
public List<Pair<Integer, Double>> getScoreVsIter() {
public ScoreStat getScoreVsIter() {
return scoreVsIter;
}
@ -84,8 +162,16 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
public void exportScores(OutputStream outputStream, String delimiter) throws IOException {
StringBuilder sb = new StringBuilder();
sb.append("Iteration").append(delimiter).append("Score");
for (Pair<Integer, Double> p : scoreVsIter) {
sb.append("\n").append(p.getFirst()).append(delimiter).append(p.getSecond());
int largeBuckets = scoreVsIter.indexes.size();
for (int j = 0; j < largeBuckets; ++j) {
long[] indexes = scoreVsIter.indexes.get(j);
double[] scores = scoreVsIter.scores.get(j);
int effectiveLength = (j < largeBuckets -1) ? indexes.length : scoreVsIter.position;
for (int i = 0; i < effectiveLength; ++i) {
sb.append("\n").append(indexes[i]).append(delimiter).append(scores[i]);
}
}
outputStream.write(sb.toString().getBytes("UTF-8"));
}

View File

@ -0,0 +1,98 @@
package org.deeplearning4j.optimize.listeners;
import org.junit.Ignore;
import org.junit.Test;
import java.util.List;
import static org.junit.Assert.*;
public class ScoreStatTest {
@Test
public void testScoreStatSmall() {
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
for (int i = 0; i < CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH; ++i) {
double score = (double)i;
statTest.addScore(i, score);
}
List<long[]> indexes = statTest.getIndexes();
List<double[]> scores = statTest.getScores();
assertTrue(indexes.size() == 1);
assertTrue(scores.size() == 1);
assertTrue(indexes.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH);
assertTrue(scores.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH);
assertEquals(indexes.get(0)[indexes.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1);
assertEquals(scores.get(0)[scores.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1, 1e-4);
}
@Test
public void testScoreStatAverage() {
int dataSize = 1000000;
long[] indexes = new long[dataSize];
double[] scores = new double[dataSize];
for (int i = 0; i < dataSize; ++i) {
indexes[i] = i;
scores[i] = i;
}
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
for (int i = 0; i < dataSize; ++i) {
statTest.addScore(indexes[i], scores[i]);
}
long[] indexesStored = statTest.getIndexes().get(0);
double[] scoresStored = statTest.getScores().get(0);
assertArrayEquals(indexes, indexesStored);
assertArrayEquals(scores, scoresStored, 1e-4);
}
@Test
public void testScoresClean() {
int dataSize = 10256; // expected to be placed in 2 buckets of 10k elements size
long[] indexes = new long[dataSize];
double[] scores = new double[dataSize];
for (int i = 0; i < dataSize; ++i) {
indexes[i] = i;
scores[i] = i;
}
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
for (int i = 0; i < dataSize; ++i) {
statTest.addScore(indexes[i], scores[i]);
}
long[] indexesEffective = statTest.getEffectiveIndexes();
double[] scoresEffective = statTest.getEffectiveScores();
assertArrayEquals(indexes, indexesEffective);
assertArrayEquals(scores, scoresEffective, 1e-4);
}
@Ignore
@Test
public void testScoreStatBig() {
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
long bigLength = (long)Integer.MAX_VALUE + 5;
for (long i = 0; i < bigLength; ++i) {
double score = (double)i;
statTest.addScore(i, score);
}
List<long[]> indexes = statTest.getIndexes();
List<double[]> scores = statTest.getScores();
assertTrue(indexes.size() == 2);
assertTrue(scores.size() == 2);
for (int i = 0; i < 5; ++i) {
assertTrue(indexes.get(1)[i] == Integer.MAX_VALUE + i);
assertTrue(scores.get(1)[i] == Integer.MAX_VALUE + i);
}
}
}