diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java index 5c48b4c18..ad8601000 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java @@ -98,7 +98,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 90000L; + return 120_000L; } @Test @@ -156,7 +156,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest { .dataSource(ds, dsP) .modelSaver(new FileModelSaver(modelSave)) .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), + .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS), new MaxCandidatesCondition(3)) .build(); diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java index 91daa027f..caeffaaa7 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java @@ -87,7 +87,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 45000L; + return 120_000L; } @Test @@ -154,8 +154,8 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest { .dataSource(ds, dsP) .modelSaver(new FileModelSaver(modelSave)) .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), - new MaxCandidatesCondition(10)) + .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS), + new MaxCandidatesCondition(3)) .build(); IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator())); diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java index 52a2eb9b9..6ec7b0935 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java @@ -81,7 +81,7 @@ public abstract class BaseImageLoader implements Serializable { String fileName = file.toString(); if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz") || fileName.endsWith(".zip")) - ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath()); + ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath(), false); } catch (IOException e) { throw new IllegalStateException("Unable to fetch images", e); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java index c7fed81ed..5b360b349 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java @@ -66,10 +66,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper { if(input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT) return null; //MKL-DNN only supports floating point dtype - //Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW] - INDArray weightsPermute = weights.permute(2,3,1,0); - INDArray weightGradViewPermute = weightGradView.permute(2,3,1,0); - int hDim = 2; int wDim = 3; if(format == CNN2DFormat.NHWC){ @@ -89,14 +85,15 @@ public class MKLDNNConvHelper implements ConvolutionHelper { pad[0], pad[1], dilation[0], dilation[1], ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same), - format == CNN2DFormat.NCHW ? 0 : 1 //0=NCHW, 1=NHWC + format == CNN2DFormat.NCHW ? 0 : 1, //0=NCHW, 1=NHWC + 1 //Weight format: 1 - [oC, iC, kH, kW] ); }; INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); - INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta}; - INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView}; + INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weights, delta} : new INDArray[]{input, weights, bias, delta}; + INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradView} : new INDArray[]{gradAtInput, weightGradView, biasGradView}; contextBwd.purge(); for( int i=0; i + * Note: Logs the path of all extracted files by default. Use {@link #unzipFileTo(String, String, boolean)} if + * logging is not desired.
+ * Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats. + * Format is interpreted from the filename * - * @param file the file to extract to - * @param dest the destination directory - * @throws IOException + * @param file the file to extract the files from + * @param dest the destination directory. Will be created if it does not exist + * @throws IOException If an error occurs accessing the files or extracting */ public static void unzipFileTo(String file, String dest) throws IOException { unzipFileTo(file, dest, true); } + /** + * Extracts all files from the archive to the specified destination, optionally logging the extracted file path.
+ * Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats. + * Format is interpreted from the filename + * + * @param file the file to extract the files from + * @param dest the destination directory. Will be created if it does not exist + * @param logFiles If true: log the path of every extracted file; if false do not log + * @throws IOException If an error occurs accessing the files or extracting + */ public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException { File target = new File(file); if (!target.exists())