Small fixes (#355)

* #8787 DataVec test fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* New nd4j test + fix bad datavec test

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8745 Fix flaky arbiter test

Signed-off-by: Alex Black <blacka101@gmail.com>

* One more test

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-04-01 15:11:39 +11:00 committed by GitHub
parent 0a27e9f41d
commit 81ebfeead1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 77 additions and 13 deletions

View File

@ -96,6 +96,11 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testLocalExecutionDataSources() throws Exception {
@ -204,7 +209,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
.terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();
@ -251,7 +256,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
.candidateGenerator(candidateGenerator)
.dataProvider(new TestMdsDataProvider(1, 32))
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
.terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.scoreFunction(ScoreFunctions.testSetAccuracy())
.build();

View File

@ -72,11 +72,11 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
protected File fullDir;
protected boolean useSubset = false;
InputSplit[] inputSplit;
protected InputSplit[] inputSplit;
public static Map<String, String> lfwData = new HashMap<>();
public static Map<String, String> lfwLabel = new HashMap<>();
public static Map<String, String> lfwSubsetData = new HashMap<>();
public Map<String, String> lfwData = new HashMap<>();
public Map<String, String> lfwLabel = new HashMap<>();
public Map<String, String> lfwSubsetData = new HashMap<>();
public LFWLoader() {
this(false);

View File

@ -45,15 +45,23 @@ import static org.junit.Assert.assertTrue;
*/
public class LoaderTests {
private static void ensureDataAvailable(){
//Ensure test resources available by initializing CifarLoader and relying on auto download
boolean preProcessCifar = false;
int numExamples = 10;
int row = 28;
int col = 28;
int channels = 1;
for( boolean train : new boolean[]{true, false}){
CifarLoader loader = new CifarLoader(row, col, channels, train, preProcessCifar);
loader.next(numExamples);
}
new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42)).next();
}
@Test
public void testLfwReader() throws Exception {
String subDir = "lfw-a/lfw";
File path = new File(FilenameUtils.concat(System.getProperty("user.home"), subDir));
FileSplit fileSplit = new FileSplit(path, LFWLoader.ALLOWED_FORMATS, new Random(42));
BalancedPathFilter pathFilter = new BalancedPathFilter(new Random(42), LFWLoader.LABEL_PATTERN, 1, 1, 1);
InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1);
RecordReader rr = new ImageRecordReader(250, 250, 3, LFWLoader.LABEL_PATTERN);
rr.initialize(inputSplit[0]);
RecordReader rr = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42));
List<String> exptedLabel = rr.getLabels();
RecordReader rr2 = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42));
@ -63,6 +71,7 @@ public class LoaderTests {
@Test
public void testCifarLoader() {
ensureDataAvailable();
File dir = new File(FilenameUtils.concat(System.getProperty("user.home"), "cifar/cifar-10-batches-bin"));
CifarLoader cifar = new CifarLoader(false, dir);
assertTrue(dir.exists());
@ -71,6 +80,7 @@ public class LoaderTests {
@Test
public void testCifarInputStream() throws Exception {
ensureDataAvailable();
// check train
String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin";
String path = FilenameUtils.concat(System.getProperty("user.home"), subDir);

View File

@ -2122,4 +2122,26 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err);
}
@Test
public void testSeqMask(){
INDArray arr = Nd4j.createFromArray(1,2,3);
INDArray maxLen = Nd4j.scalar(4);
INDArray out = Nd4j.create(DataType.INT32, 3, 4);
out.assign(Integer.MAX_VALUE);
Nd4j.exec(DynamicCustomOp.builder("sequence_mask")
.addInputs(arr, maxLen)
.addOutputs(out)
.build()
);
INDArray exp = Nd4j.createFromArray(new int[][]{
{1, 0, 0, 0},
{1, 1, 0, 0},
{1, 1, 1, 0}});
assertEquals(exp, out);
}
}

View File

@ -22,6 +22,7 @@ import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.math3.linear.LUDecomposition;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable;
@ -2498,4 +2499,30 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(Nd4j.createFromArray(2, 2), out);
}
@Test @Ignore //AB 2020/04/01 - https://github.com/eclipse/deeplearning4j/issues/8592
public void testReshapeZeros(){
int[][] shapes = new int[][]{{2,0}, {10,0}, {10, 0}, {2,0,0,10}, {10, 0}, {0, 0, 10}, {0,2,10}, {1,2,0}};
int[][] reshape = new int[][]{{2,-1}, {2,0,-1}, {5,2,-1}, {2,0,-1}, {-1, 2, 0}, {2, -1, 0}, {2, 0, 0, 0, -1}, {2,0,-1}};
int[][] expected = new int[][]{{2,0}, {2,0,5}, {5,2,0}, {2,0,10}, {5,2,0}, {2,5,0}, {2,0,0,0,10}, {2,0,1}};
for( int i=0; i<shapes.length; i++ ){
System.out.println(i);
long[] orig = ArrayUtil.toLongArray(shapes[i]);
int[] r = reshape[i];
long[] exp = ArrayUtil.toLongArray(expected[i]);
SameDiff sd = SameDiff.create();
SDVariable v = sd.placeHolder("orig", DataType.FLOAT, orig);
SDVariable rs = v.reshape(r);
SDVariable rs2 = v.reshape(sd.constant(Nd4j.createFromArray(r)));
INDArray out = rs.eval(Collections.singletonMap("orig", Nd4j.create(DataType.FLOAT, orig)));
assertArrayEquals(exp, out.shape());
out = rs2.eval(Collections.singletonMap("orig", Nd4j.create(DataType.FLOAT, orig)));
assertArrayEquals(exp, out.shape());
}
}
}