MKLDNNConvHelper: don't permute; pass in weight format arg (#412)

* MKLDNNConvHelper: don't permute; pass in weight format arg

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix flaky tests unrelated to main branch change

Signed-off-by: Alex Black <blacka101@gmail.com>

* Don't log every extracted file for LFW and other auto-downloaded image loaders using BaseImageLoader.downloadAndUntar

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-04-26 01:04:31 +10:00 committed by GitHub
parent 58b11bfecc
commit 9e7395667f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 32 additions and 25 deletions

View File

@ -98,7 +98,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
return 120_000L;
}
@Test
@ -156,7 +156,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();

View File

@ -87,7 +87,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 45000L;
return 120_000L;
}
@Test
@ -154,8 +154,8 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator()));

View File

@ -81,7 +81,7 @@ public abstract class BaseImageLoader implements Serializable {
String fileName = file.toString();
if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz")
|| fileName.endsWith(".zip"))
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath());
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath(), false);
} catch (IOException e) {
throw new IllegalStateException("Unable to fetch images", e);
}

View File

@ -66,10 +66,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
if(input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT)
return null; //MKL-DNN only supports floating point dtype
//Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW]
INDArray weightsPermute = weights.permute(2,3,1,0);
INDArray weightGradViewPermute = weightGradView.permute(2,3,1,0);
int hDim = 2;
int wDim = 3;
if(format == CNN2DFormat.NHWC){
@ -89,14 +85,15 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
pad[0], pad[1],
dilation[0], dilation[1],
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
format == CNN2DFormat.NCHW ? 0 : 1 //0=NCHW, 1=NHWC
format == CNN2DFormat.NCHW ? 0 : 1, //0=NCHW, 1=NHWC
1 //Weight format: 1 - [oC, iC, kH, kW]
);
};
INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta};
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView};
INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weights, delta} : new INDArray[]{input, weights, bias, delta};
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradView} : new INDArray[]{gradAtInput, weightGradView, biasGradView};
contextBwd.purge();
for( int i=0; i<inputsArr.length; i++ ){
contextBwd.setInputArray(i, inputsArr[i]);
@ -149,7 +146,8 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
pad[0], pad[1],
dilation[0], dilation[1],
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
format == CNN2DFormat.NCHW ? 0 : 1 //0=NCHW, 1=NHWC
format == CNN2DFormat.NCHW ? 0 : 1, //0=NCHW, 1=NHWC
1 //Weight format: 1 - [oC, iC, kH, kW]
);
};
@ -157,9 +155,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
long[] outShape = (format == CNN2DFormat.NCHW) ? new long[]{input.size(0), outDepth, outSize[0], outSize[1]} : new long[]{input.size(0), outSize[0], outSize[1], outDepth};
INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape);
//Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW]
weights = weights.permute(2,3,1,0);
INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias};
context.purge();
for( int i=0; i<inputsArr.length; i++ ){

View File

@ -62,13 +62,11 @@ public class TestTransferStatsCollection extends BaseDL4JTest {
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
.setFeatureExtractor(0).build();
File f = testDir.newFile("dl4jTestTransferStatsCollection.bin");
f.delete();
File dir = testDir.newFolder();
File f = new File(dir, "dl4jTestTransferStatsCollection.bin");
net2.setListeners(new StatsListener(new FileStatsStorage(f)));
//Previosuly: failed on frozen layers
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));
f.deleteOnExit();
}
}

View File

@ -44,16 +44,30 @@ public class ArchiveUtils {
}
/**
* Extracts files to the specified destination
* Extracts all files from the archive to the specified destination.<br>
* Note: Logs the path of all extracted files by default. Use {@link #unzipFileTo(String, String, boolean)} if
* logging is not desired.<br>
* Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats.
* Format is interpreted from the filename
*
* @param file the file to extract to
* @param dest the destination directory
* @throws IOException
* @param file the file to extract the files from
* @param dest the destination directory. Will be created if it does not exist
* @throws IOException If an error occurs accessing the files or extracting
*/
public static void unzipFileTo(String file, String dest) throws IOException {
unzipFileTo(file, dest, true);
}
/**
* Extracts all files from the archive to the specified destination, optionally logging the extracted file path.<br>
* Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats.
* Format is interpreted from the filename
*
* @param file the file to extract the files from
* @param dest the destination directory. Will be created if it does not exist
* @param logFiles If true: log the path of every extracted file; if false do not log
* @throws IOException If an error occurs accessing the files or extracting
*/
public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException {
File target = new File(file);
if (!target.exists())