MKLDNNConvHelper: don't permute; pass in weight format arg (#412)

* MKLDNNConvHelper: don't permute; pass in weight format arg

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix flaky tests unrelated to main branch change

Signed-off-by: Alex Black <blacka101@gmail.com>

* Don't log every extracted file for LFW and other auto-downloaded image loaders using BaseImageLoader.downloadAndUntar

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-04-26 01:04:31 +10:00 committed by GitHub
parent 58b11bfecc
commit 9e7395667f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 32 additions and 25 deletions

View File

@ -98,7 +98,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 90000L; return 120_000L;
} }
@Test @Test
@ -156,7 +156,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
.dataSource(ds, dsP) .dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave)) .modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction()) .scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(3)) new MaxCandidatesCondition(3))
.build(); .build();

View File

@ -87,7 +87,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 45000L; return 120_000L;
} }
@Test @Test
@ -154,8 +154,8 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
.dataSource(ds, dsP) .dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave)) .modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction()) .scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(10)) new MaxCandidatesCondition(3))
.build(); .build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator())); IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator()));

View File

@ -81,7 +81,7 @@ public abstract class BaseImageLoader implements Serializable {
String fileName = file.toString(); String fileName = file.toString();
if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz") if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz")
|| fileName.endsWith(".zip")) || fileName.endsWith(".zip"))
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath()); ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath(), false);
} catch (IOException e) { } catch (IOException e) {
throw new IllegalStateException("Unable to fetch images", e); throw new IllegalStateException("Unable to fetch images", e);
} }

View File

@ -66,10 +66,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
if(input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT) if(input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT)
return null; //MKL-DNN only supports floating point dtype return null; //MKL-DNN only supports floating point dtype
//Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW]
INDArray weightsPermute = weights.permute(2,3,1,0);
INDArray weightGradViewPermute = weightGradView.permute(2,3,1,0);
int hDim = 2; int hDim = 2;
int wDim = 3; int wDim = 3;
if(format == CNN2DFormat.NHWC){ if(format == CNN2DFormat.NHWC){
@ -89,14 +85,15 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
pad[0], pad[1], pad[0], pad[1],
dilation[0], dilation[1], dilation[0], dilation[1],
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same), ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
format == CNN2DFormat.NCHW ? 0 : 1 //0=NCHW, 1=NHWC format == CNN2DFormat.NCHW ? 0 : 1, //0=NCHW, 1=NHWC
1 //Weight format: 1 - [oC, iC, kH, kW]
); );
}; };
INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta}; INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weights, delta} : new INDArray[]{input, weights, bias, delta};
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView}; INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradView} : new INDArray[]{gradAtInput, weightGradView, biasGradView};
contextBwd.purge(); contextBwd.purge();
for( int i=0; i<inputsArr.length; i++ ){ for( int i=0; i<inputsArr.length; i++ ){
contextBwd.setInputArray(i, inputsArr[i]); contextBwd.setInputArray(i, inputsArr[i]);
@ -149,7 +146,8 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
pad[0], pad[1], pad[0], pad[1],
dilation[0], dilation[1], dilation[0], dilation[1],
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same), ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
format == CNN2DFormat.NCHW ? 0 : 1 //0=NCHW, 1=NHWC format == CNN2DFormat.NCHW ? 0 : 1, //0=NCHW, 1=NHWC
1 //Weight format: 1 - [oC, iC, kH, kW]
); );
}; };
@ -157,9 +155,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
long[] outShape = (format == CNN2DFormat.NCHW) ? new long[]{input.size(0), outDepth, outSize[0], outSize[1]} : new long[]{input.size(0), outSize[0], outSize[1], outDepth}; long[] outShape = (format == CNN2DFormat.NCHW) ? new long[]{input.size(0), outDepth, outSize[0], outSize[1]} : new long[]{input.size(0), outSize[0], outSize[1], outDepth};
INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape); INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape);
//Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW]
weights = weights.permute(2,3,1,0);
INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias}; INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias};
context.purge(); context.purge();
for( int i=0; i<inputsArr.length; i++ ){ for( int i=0; i<inputsArr.length; i++ ){

View File

@ -62,13 +62,11 @@ public class TestTransferStatsCollection extends BaseDL4JTest {
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build()) new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
.setFeatureExtractor(0).build(); .setFeatureExtractor(0).build();
File f = testDir.newFile("dl4jTestTransferStatsCollection.bin"); File dir = testDir.newFolder();
f.delete(); File f = new File(dir, "dl4jTestTransferStatsCollection.bin");
net2.setListeners(new StatsListener(new FileStatsStorage(f))); net2.setListeners(new StatsListener(new FileStatsStorage(f)));
//Previosuly: failed on frozen layers //Previosuly: failed on frozen layers
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10))); net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));
f.deleteOnExit();
} }
} }

View File

@ -44,16 +44,30 @@ public class ArchiveUtils {
} }
/** /**
* Extracts files to the specified destination * Extracts all files from the archive to the specified destination.<br>
* Note: Logs the path of all extracted files by default. Use {@link #unzipFileTo(String, String, boolean)} if
* logging is not desired.<br>
* Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats.
* Format is interpreted from the filename
* *
* @param file the file to extract to * @param file the file to extract the files from
* @param dest the destination directory * @param dest the destination directory. Will be created if it does not exist
* @throws IOException * @throws IOException If an error occurs accessing the files or extracting
*/ */
public static void unzipFileTo(String file, String dest) throws IOException { public static void unzipFileTo(String file, String dest) throws IOException {
unzipFileTo(file, dest, true); unzipFileTo(file, dest, true);
} }
/**
* Extracts all files from the archive to the specified destination, optionally logging the extracted file path.<br>
* Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats.
* Format is interpreted from the filename
*
* @param file the file to extract the files from
* @param dest the destination directory. Will be created if it does not exist
* @param logFiles If true: log the path of every extracted file; if false do not log
* @throws IOException If an error occurs accessing the files or extracting
*/
public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException { public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException {
File target = new File(file); File target = new File(file);
if (!target.exists()) if (!target.exists())