MKLDNNConvHelper: don't permute; pass in weight format arg (#412)
* MKLDNNConvHelper: don't permute; pass in weight format arg Signed-off-by: Alex Black <blacka101@gmail.com> * Fix flaky tests unrelated to main branch change Signed-off-by: Alex Black <blacka101@gmail.com> * Don't log every extracted file for LFW and other auto-downloaded image loaders using BaseImageLoader.downloadAndUntar Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
58b11bfecc
commit
9e7395667f
|
@ -98,7 +98,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 90000L;
|
return 120_000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -156,7 +156,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||||
.dataSource(ds, dsP)
|
.dataSource(ds, dsP)
|
||||||
.modelSaver(new FileModelSaver(modelSave))
|
.modelSaver(new FileModelSaver(modelSave))
|
||||||
.scoreFunction(new TestSetLossScoreFunction())
|
.scoreFunction(new TestSetLossScoreFunction())
|
||||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(3))
|
new MaxCandidatesCondition(3))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 45000L;
|
return 120_000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -154,8 +154,8 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
||||||
.dataSource(ds, dsP)
|
.dataSource(ds, dsP)
|
||||||
.modelSaver(new FileModelSaver(modelSave))
|
.modelSaver(new FileModelSaver(modelSave))
|
||||||
.scoreFunction(new TestSetLossScoreFunction())
|
.scoreFunction(new TestSetLossScoreFunction())
|
||||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(10))
|
new MaxCandidatesCondition(3))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator()));
|
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator()));
|
||||||
|
|
|
@ -81,7 +81,7 @@ public abstract class BaseImageLoader implements Serializable {
|
||||||
String fileName = file.toString();
|
String fileName = file.toString();
|
||||||
if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz")
|
if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz")
|
||||||
|| fileName.endsWith(".zip"))
|
|| fileName.endsWith(".zip"))
|
||||||
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath());
|
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath(), false);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new IllegalStateException("Unable to fetch images", e);
|
throw new IllegalStateException("Unable to fetch images", e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,10 +66,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
||||||
if(input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT)
|
if(input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT)
|
||||||
return null; //MKL-DNN only supports floating point dtype
|
return null; //MKL-DNN only supports floating point dtype
|
||||||
|
|
||||||
//Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW]
|
|
||||||
INDArray weightsPermute = weights.permute(2,3,1,0);
|
|
||||||
INDArray weightGradViewPermute = weightGradView.permute(2,3,1,0);
|
|
||||||
|
|
||||||
int hDim = 2;
|
int hDim = 2;
|
||||||
int wDim = 3;
|
int wDim = 3;
|
||||||
if(format == CNN2DFormat.NHWC){
|
if(format == CNN2DFormat.NHWC){
|
||||||
|
@ -89,14 +85,15 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
||||||
pad[0], pad[1],
|
pad[0], pad[1],
|
||||||
dilation[0], dilation[1],
|
dilation[0], dilation[1],
|
||||||
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
|
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
|
||||||
format == CNN2DFormat.NCHW ? 0 : 1 //0=NCHW, 1=NHWC
|
format == CNN2DFormat.NCHW ? 0 : 1, //0=NCHW, 1=NHWC
|
||||||
|
1 //Weight format: 1 - [oC, iC, kH, kW]
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
|
INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
|
||||||
|
|
||||||
INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta};
|
INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weights, delta} : new INDArray[]{input, weights, bias, delta};
|
||||||
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView};
|
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradView} : new INDArray[]{gradAtInput, weightGradView, biasGradView};
|
||||||
contextBwd.purge();
|
contextBwd.purge();
|
||||||
for( int i=0; i<inputsArr.length; i++ ){
|
for( int i=0; i<inputsArr.length; i++ ){
|
||||||
contextBwd.setInputArray(i, inputsArr[i]);
|
contextBwd.setInputArray(i, inputsArr[i]);
|
||||||
|
@ -149,7 +146,8 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
||||||
pad[0], pad[1],
|
pad[0], pad[1],
|
||||||
dilation[0], dilation[1],
|
dilation[0], dilation[1],
|
||||||
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
|
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
|
||||||
format == CNN2DFormat.NCHW ? 0 : 1 //0=NCHW, 1=NHWC
|
format == CNN2DFormat.NCHW ? 0 : 1, //0=NCHW, 1=NHWC
|
||||||
|
1 //Weight format: 1 - [oC, iC, kH, kW]
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -157,9 +155,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
||||||
long[] outShape = (format == CNN2DFormat.NCHW) ? new long[]{input.size(0), outDepth, outSize[0], outSize[1]} : new long[]{input.size(0), outSize[0], outSize[1], outDepth};
|
long[] outShape = (format == CNN2DFormat.NCHW) ? new long[]{input.size(0), outDepth, outSize[0], outSize[1]} : new long[]{input.size(0), outSize[0], outSize[1], outDepth};
|
||||||
INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape);
|
INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape);
|
||||||
|
|
||||||
//Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW]
|
|
||||||
weights = weights.permute(2,3,1,0);
|
|
||||||
|
|
||||||
INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias};
|
INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias};
|
||||||
context.purge();
|
context.purge();
|
||||||
for( int i=0; i<inputsArr.length; i++ ){
|
for( int i=0; i<inputsArr.length; i++ ){
|
||||||
|
|
|
@ -62,13 +62,11 @@ public class TestTransferStatsCollection extends BaseDL4JTest {
|
||||||
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
|
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
|
||||||
.setFeatureExtractor(0).build();
|
.setFeatureExtractor(0).build();
|
||||||
|
|
||||||
File f = testDir.newFile("dl4jTestTransferStatsCollection.bin");
|
File dir = testDir.newFolder();
|
||||||
f.delete();
|
File f = new File(dir, "dl4jTestTransferStatsCollection.bin");
|
||||||
net2.setListeners(new StatsListener(new FileStatsStorage(f)));
|
net2.setListeners(new StatsListener(new FileStatsStorage(f)));
|
||||||
|
|
||||||
//Previosuly: failed on frozen layers
|
//Previosuly: failed on frozen layers
|
||||||
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));
|
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));
|
||||||
|
|
||||||
f.deleteOnExit();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,16 +44,30 @@ public class ArchiveUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extracts files to the specified destination
|
* Extracts all files from the archive to the specified destination.<br>
|
||||||
|
* Note: Logs the path of all extracted files by default. Use {@link #unzipFileTo(String, String, boolean)} if
|
||||||
|
* logging is not desired.<br>
|
||||||
|
* Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats.
|
||||||
|
* Format is interpreted from the filename
|
||||||
*
|
*
|
||||||
* @param file the file to extract to
|
* @param file the file to extract the files from
|
||||||
* @param dest the destination directory
|
* @param dest the destination directory. Will be created if it does not exist
|
||||||
* @throws IOException
|
* @throws IOException If an error occurs accessing the files or extracting
|
||||||
*/
|
*/
|
||||||
public static void unzipFileTo(String file, String dest) throws IOException {
|
public static void unzipFileTo(String file, String dest) throws IOException {
|
||||||
unzipFileTo(file, dest, true);
|
unzipFileTo(file, dest, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts all files from the archive to the specified destination, optionally logging the extracted file path.<br>
|
||||||
|
* Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats.
|
||||||
|
* Format is interpreted from the filename
|
||||||
|
*
|
||||||
|
* @param file the file to extract the files from
|
||||||
|
* @param dest the destination directory. Will be created if it does not exist
|
||||||
|
* @param logFiles If true: log the path of every extracted file; if false do not log
|
||||||
|
* @throws IOException If an error occurs accessing the files or extracting
|
||||||
|
*/
|
||||||
public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException {
|
public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException {
|
||||||
File target = new File(file);
|
File target = new File(file);
|
||||||
if (!target.exists())
|
if (!target.exists())
|
||||||
|
|
Loading…
Reference in New Issue