MKLDNNConvHelper: don't permute; pass in weight format arg (#412)
* MKLDNNConvHelper: don't permute; pass in weight format arg Signed-off-by: Alex Black <blacka101@gmail.com> * Fix flaky tests unrelated to main branch change Signed-off-by: Alex Black <blacka101@gmail.com> * Don't log every extracted file for LFW and other auto-downloaded image loaders using BaseImageLoader.downloadAndUntar Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
58b11bfecc
commit
9e7395667f
|
@ -98,7 +98,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L;
|
||||
return 120_000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -156,7 +156,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
|||
.dataSource(ds, dsP)
|
||||
.modelSaver(new FileModelSaver(modelSave))
|
||||
.scoreFunction(new TestSetLossScoreFunction())
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.build();
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 45000L;
|
||||
return 120_000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -154,8 +154,8 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
|||
.dataSource(ds, dsP)
|
||||
.modelSaver(new FileModelSaver(modelSave))
|
||||
.scoreFunction(new TestSetLossScoreFunction())
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(10))
|
||||
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.build();
|
||||
|
||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator()));
|
||||
|
|
|
@ -81,7 +81,7 @@ public abstract class BaseImageLoader implements Serializable {
|
|||
String fileName = file.toString();
|
||||
if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz")
|
||||
|| fileName.endsWith(".zip"))
|
||||
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath());
|
||||
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath(), false);
|
||||
} catch (IOException e) {
|
||||
throw new IllegalStateException("Unable to fetch images", e);
|
||||
}
|
||||
|
|
|
@ -66,10 +66,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
|||
if(input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT)
|
||||
return null; //MKL-DNN only supports floating point dtype
|
||||
|
||||
//Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW]
|
||||
INDArray weightsPermute = weights.permute(2,3,1,0);
|
||||
INDArray weightGradViewPermute = weightGradView.permute(2,3,1,0);
|
||||
|
||||
int hDim = 2;
|
||||
int wDim = 3;
|
||||
if(format == CNN2DFormat.NHWC){
|
||||
|
@ -89,14 +85,15 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
|||
pad[0], pad[1],
|
||||
dilation[0], dilation[1],
|
||||
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
|
||||
format == CNN2DFormat.NCHW ? 0 : 1 //0=NCHW, 1=NHWC
|
||||
format == CNN2DFormat.NCHW ? 0 : 1, //0=NCHW, 1=NHWC
|
||||
1 //Weight format: 1 - [oC, iC, kH, kW]
|
||||
);
|
||||
};
|
||||
|
||||
INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
|
||||
|
||||
INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta};
|
||||
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView};
|
||||
INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weights, delta} : new INDArray[]{input, weights, bias, delta};
|
||||
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradView} : new INDArray[]{gradAtInput, weightGradView, biasGradView};
|
||||
contextBwd.purge();
|
||||
for( int i=0; i<inputsArr.length; i++ ){
|
||||
contextBwd.setInputArray(i, inputsArr[i]);
|
||||
|
@ -149,7 +146,8 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
|||
pad[0], pad[1],
|
||||
dilation[0], dilation[1],
|
||||
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
|
||||
format == CNN2DFormat.NCHW ? 0 : 1 //0=NCHW, 1=NHWC
|
||||
format == CNN2DFormat.NCHW ? 0 : 1, //0=NCHW, 1=NHWC
|
||||
1 //Weight format: 1 - [oC, iC, kH, kW]
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -157,9 +155,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
|||
long[] outShape = (format == CNN2DFormat.NCHW) ? new long[]{input.size(0), outDepth, outSize[0], outSize[1]} : new long[]{input.size(0), outSize[0], outSize[1], outDepth};
|
||||
INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape);
|
||||
|
||||
//Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW]
|
||||
weights = weights.permute(2,3,1,0);
|
||||
|
||||
INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias};
|
||||
context.purge();
|
||||
for( int i=0; i<inputsArr.length; i++ ){
|
||||
|
|
|
@ -62,13 +62,11 @@ public class TestTransferStatsCollection extends BaseDL4JTest {
|
|||
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
|
||||
.setFeatureExtractor(0).build();
|
||||
|
||||
File f = testDir.newFile("dl4jTestTransferStatsCollection.bin");
|
||||
f.delete();
|
||||
File dir = testDir.newFolder();
|
||||
File f = new File(dir, "dl4jTestTransferStatsCollection.bin");
|
||||
net2.setListeners(new StatsListener(new FileStatsStorage(f)));
|
||||
|
||||
//Previosuly: failed on frozen layers
|
||||
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));
|
||||
|
||||
f.deleteOnExit();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,16 +44,30 @@ public class ArchiveUtils {
|
|||
}
|
||||
|
||||
/**
|
||||
* Extracts files to the specified destination
|
||||
* Extracts all files from the archive to the specified destination.<br>
|
||||
* Note: Logs the path of all extracted files by default. Use {@link #unzipFileTo(String, String, boolean)} if
|
||||
* logging is not desired.<br>
|
||||
* Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats.
|
||||
* Format is interpreted from the filename
|
||||
*
|
||||
* @param file the file to extract to
|
||||
* @param dest the destination directory
|
||||
* @throws IOException
|
||||
* @param file the file to extract the files from
|
||||
* @param dest the destination directory. Will be created if it does not exist
|
||||
* @throws IOException If an error occurs accessing the files or extracting
|
||||
*/
|
||||
public static void unzipFileTo(String file, String dest) throws IOException {
|
||||
unzipFileTo(file, dest, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts all files from the archive to the specified destination, optionally logging the extracted file path.<br>
|
||||
* Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats.
|
||||
* Format is interpreted from the filename
|
||||
*
|
||||
* @param file the file to extract the files from
|
||||
* @param dest the destination directory. Will be created if it does not exist
|
||||
* @param logFiles If true: log the path of every extracted file; if false do not log
|
||||
* @throws IOException If an error occurs accessing the files or extracting
|
||||
*/
|
||||
public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException {
|
||||
File target = new File(file);
|
||||
if (!target.exists())
|
||||
|
|
Loading…
Reference in New Issue