Small fixes (#355)
* #8787 DataVec test fix Signed-off-by: Alex Black <blacka101@gmail.com> * New nd4j test + fix bad datavec test Signed-off-by: Alex Black <blacka101@gmail.com> * #8745 Fix flaky arbiter test Signed-off-by: Alex Black <blacka101@gmail.com> * One more test Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
0a27e9f41d
commit
81ebfeead1
|
@ -96,6 +96,11 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
|||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLocalExecutionDataSources() throws Exception {
|
||||
|
||||
|
@ -204,7 +209,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
|||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
||||
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
.terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.build();
|
||||
|
||||
|
@ -251,7 +256,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
|||
.candidateGenerator(candidateGenerator)
|
||||
.dataProvider(new TestMdsDataProvider(1, 32))
|
||||
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
.terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.scoreFunction(ScoreFunctions.testSetAccuracy())
|
||||
.build();
|
||||
|
|
|
@ -72,11 +72,11 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
|
|||
protected File fullDir;
|
||||
|
||||
protected boolean useSubset = false;
|
||||
InputSplit[] inputSplit;
|
||||
protected InputSplit[] inputSplit;
|
||||
|
||||
public static Map<String, String> lfwData = new HashMap<>();
|
||||
public static Map<String, String> lfwLabel = new HashMap<>();
|
||||
public static Map<String, String> lfwSubsetData = new HashMap<>();
|
||||
public Map<String, String> lfwData = new HashMap<>();
|
||||
public Map<String, String> lfwLabel = new HashMap<>();
|
||||
public Map<String, String> lfwSubsetData = new HashMap<>();
|
||||
|
||||
public LFWLoader() {
|
||||
this(false);
|
||||
|
|
|
@ -45,15 +45,23 @@ import static org.junit.Assert.assertTrue;
|
|||
*/
|
||||
public class LoaderTests {
|
||||
|
||||
private static void ensureDataAvailable(){
|
||||
//Ensure test resources available by initializing CifarLoader and relying on auto download
|
||||
boolean preProcessCifar = false;
|
||||
int numExamples = 10;
|
||||
int row = 28;
|
||||
int col = 28;
|
||||
int channels = 1;
|
||||
for( boolean train : new boolean[]{true, false}){
|
||||
CifarLoader loader = new CifarLoader(row, col, channels, train, preProcessCifar);
|
||||
loader.next(numExamples);
|
||||
}
|
||||
new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42)).next();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLfwReader() throws Exception {
|
||||
String subDir = "lfw-a/lfw";
|
||||
File path = new File(FilenameUtils.concat(System.getProperty("user.home"), subDir));
|
||||
FileSplit fileSplit = new FileSplit(path, LFWLoader.ALLOWED_FORMATS, new Random(42));
|
||||
BalancedPathFilter pathFilter = new BalancedPathFilter(new Random(42), LFWLoader.LABEL_PATTERN, 1, 1, 1);
|
||||
InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1);
|
||||
RecordReader rr = new ImageRecordReader(250, 250, 3, LFWLoader.LABEL_PATTERN);
|
||||
rr.initialize(inputSplit[0]);
|
||||
RecordReader rr = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42));
|
||||
List<String> exptedLabel = rr.getLabels();
|
||||
|
||||
RecordReader rr2 = new LFWLoader(new long[] {250, 250, 3}, true).getRecordReader(1, 1, 1, new Random(42));
|
||||
|
@ -63,6 +71,7 @@ public class LoaderTests {
|
|||
|
||||
@Test
|
||||
public void testCifarLoader() {
|
||||
ensureDataAvailable();
|
||||
File dir = new File(FilenameUtils.concat(System.getProperty("user.home"), "cifar/cifar-10-batches-bin"));
|
||||
CifarLoader cifar = new CifarLoader(false, dir);
|
||||
assertTrue(dir.exists());
|
||||
|
@ -71,6 +80,7 @@ public class LoaderTests {
|
|||
|
||||
@Test
|
||||
public void testCifarInputStream() throws Exception {
|
||||
ensureDataAvailable();
|
||||
// check train
|
||||
String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin";
|
||||
String path = FilenameUtils.concat(System.getProperty("user.home"), subDir);
|
||||
|
|
|
@ -2122,4 +2122,26 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSeqMask(){
|
||||
INDArray arr = Nd4j.createFromArray(1,2,3);
|
||||
INDArray maxLen = Nd4j.scalar(4);
|
||||
|
||||
INDArray out = Nd4j.create(DataType.INT32, 3, 4);
|
||||
out.assign(Integer.MAX_VALUE);
|
||||
|
||||
Nd4j.exec(DynamicCustomOp.builder("sequence_mask")
|
||||
.addInputs(arr, maxLen)
|
||||
.addOutputs(out)
|
||||
.build()
|
||||
);
|
||||
|
||||
INDArray exp = Nd4j.createFromArray(new int[][]{
|
||||
{1, 0, 0, 0},
|
||||
{1, 1, 0, 0},
|
||||
{1, 1, 1, 0}});
|
||||
|
||||
assertEquals(exp, out);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import lombok.Data;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.commons.math3.linear.LUDecomposition;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.OpValidationSuite;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -2498,4 +2499,30 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
assertEquals(Nd4j.createFromArray(2, 2), out);
|
||||
}
|
||||
|
||||
|
||||
@Test @Ignore //AB 2020/04/01 - https://github.com/eclipse/deeplearning4j/issues/8592
|
||||
public void testReshapeZeros(){
|
||||
int[][] shapes = new int[][]{{2,0}, {10,0}, {10, 0}, {2,0,0,10}, {10, 0}, {0, 0, 10}, {0,2,10}, {1,2,0}};
|
||||
int[][] reshape = new int[][]{{2,-1}, {2,0,-1}, {5,2,-1}, {2,0,-1}, {-1, 2, 0}, {2, -1, 0}, {2, 0, 0, 0, -1}, {2,0,-1}};
|
||||
int[][] expected = new int[][]{{2,0}, {2,0,5}, {5,2,0}, {2,0,10}, {5,2,0}, {2,5,0}, {2,0,0,0,10}, {2,0,1}};
|
||||
|
||||
for( int i=0; i<shapes.length; i++ ){
|
||||
System.out.println(i);
|
||||
long[] orig = ArrayUtil.toLongArray(shapes[i]);
|
||||
int[] r = reshape[i];
|
||||
long[] exp = ArrayUtil.toLongArray(expected[i]);
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable v = sd.placeHolder("orig", DataType.FLOAT, orig);
|
||||
SDVariable rs = v.reshape(r);
|
||||
SDVariable rs2 = v.reshape(sd.constant(Nd4j.createFromArray(r)));
|
||||
|
||||
INDArray out = rs.eval(Collections.singletonMap("orig", Nd4j.create(DataType.FLOAT, orig)));
|
||||
assertArrayEquals(exp, out.shape());
|
||||
|
||||
out = rs2.eval(Collections.singletonMap("orig", Nd4j.create(DataType.FLOAT, orig)));
|
||||
assertArrayEquals(exp, out.shape());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue