diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java b/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java index 14401d691..16c78923f 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java @@ -16,6 +16,9 @@ package org.nd4j.config; +import java.io.File; +import java.net.URL; + public class ND4JSystemProperties { /** @@ -125,6 +128,22 @@ public class ND4JSystemProperties { */ public static final String RESOURCES_CACHE_DIR = "org.nd4j.test.resources.cache.dir"; + /** + * Applicability: nd4j-common {@link org.nd4j.resources.Resources} class (and hence {@link org.nd4j.resources.strumpf.StrumpfResolver})
+ * Description: When resolving resources from a Strumpf resource file (Example: {@code Resources.asFile("myFile.txt")} + * what should be the connection timeout, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)}
+ * Default: {@link org.nd4j.resources.strumpf.ResourceFile#DEFAULT_CONNECTION_TIMEOUT} + */ + public static final String RESOURCES_CONNECTION_TIMEOUT = "org.nd4j.resources.download.connectiontimeout"; + + /** + * Applicability: nd4j-common {@link org.nd4j.resources.Resources} class (and hence {@link org.nd4j.resources.strumpf.StrumpfResolver})
+ * Description: When resolving resources from a Strumpf resource file (Example: {@code Resources.asFile("myFile.txt")} + * what should be the connection timeout, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)}
+ * Default: {@link org.nd4j.resources.strumpf.ResourceFile#DEFAULT_READ_TIMEOUT} + */ + public static final String RESOURCES_READ_TIMEOUT = "org.nd4j.resources.download.readtimeout"; + /** * Applicability: nd4j-common {@link org.nd4j.resources.Resources} class (and hence {@link org.nd4j.resources.strumpf.StrumpfResolver})
* Description: When resolving resources, what local directories should be checked (in addition to the classpath) for files? diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java index 05c44c29e..19352fc7c 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java @@ -34,30 +34,49 @@ import java.net.URL; */ @Slf4j public class Downloader { + /** + * Default connection timeout in milliseconds when using {@link FileUtils#copyURLToFile(URL, File, int, int)} + */ + public static final int DEFAULT_CONNECTION_TIMEOUT = 60000; + /** + * Default read timeout in milliseconds when using {@link FileUtils#copyURLToFile(URL, File, int, int)} + */ + public static final int DEFAULT_READ_TIMEOUT = 60000; private Downloader(){ } /** - * Download the specified URL to the specified file, and verify that the target MD5 matches - * @param name Name (mainly for providing useful exceptions) - * @param url URL to download - * @param f Destination file - * @param targetMD5 Expected MD5 for file - * @param maxTries Maximum number of download attempts before failing and throwing an exception - * @throws IOException If an error occurs during downloading + * As per {@link #download(String, URL, File, String, int, int, int)} with the connection and read timeouts + * set to their default values - {@link #DEFAULT_CONNECTION_TIMEOUT} and {@link #DEFAULT_READ_TIMEOUT} respectively */ public static void download(String name, URL url, File f, String targetMD5, int maxTries) throws IOException { - download(name, url, f, targetMD5, maxTries, 0); + download(name, url, f, targetMD5, maxTries, DEFAULT_CONNECTION_TIMEOUT, DEFAULT_READ_TIMEOUT); } - private static void download(String name, URL url, File f, String targetMD5, int maxTries, int attempt) throws IOException { + /** + * Download the specified URL to the specified file, and verify that the target MD5 matches + * + * @param name Name (mainly for providing useful exceptions) + * @param url URL to download + * @param f Destination file + * @param targetMD5 Expected MD5 for file + * @param maxTries Maximum number of download attempts before failing and throwing an exception + * @param connectionTimeout connection timeout in milliseconds, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)} + * @param readTimeout read timeout in milliseconds, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)} + * @throws IOException If an error occurs during downloading + */ + public static void download(String name, URL url, File f, String targetMD5, int maxTries, int connectionTimeout, int readTimeout) throws IOException { + download(name, url, f, targetMD5, maxTries, 0, connectionTimeout, readTimeout); + } + + private static void download(String name, URL url, File f, String targetMD5, int maxTries, int attempt, int connectionTimeout, int readTimeout) throws IOException { boolean isCorrectFile = f.exists() && f.isFile() && checkMD5OfFile(targetMD5, f); if (attempt < maxTries) { if(!isCorrectFile) { - FileUtils.copyURLToFile(url, f); + FileUtils.copyURLToFile(url, f, connectionTimeout, readTimeout); if (!checkMD5OfFile(targetMD5, f)) { f.delete(); - download(name, url, f, targetMD5, maxTries, attempt + 1); + download(name, url, f, targetMD5, maxTries, attempt + 1, connectionTimeout, readTimeout); } } } else if (!isCorrectFile) { @@ -67,6 +86,14 @@ public class Downloader { } } + /** + * As per {@link #downloadAndExtract(String, URL, File, File, String, int, int, int)} with the connection and read timeouts + * * set to their default values - {@link #DEFAULT_CONNECTION_TIMEOUT} and {@link #DEFAULT_READ_TIMEOUT} respectively + */ + public static void downloadAndExtract(String name, URL url, File f, File extractToDir, String targetMD5, int maxTries) throws IOException { + downloadAndExtract(name, url, f, extractToDir, targetMD5, maxTries, DEFAULT_CONNECTION_TIMEOUT, DEFAULT_READ_TIMEOUT); + } + /** * Download the specified URL to the specified file, verify that the MD5 matches, and then extract it to the specified directory.
* Note that the file must be an archive, with the correct file extension: .zip, .jar, .tar.gz, .tgz or .gz @@ -77,20 +104,24 @@ public class Downloader { * @param extractToDir Destination directory to extract all files * @param targetMD5 Expected MD5 for file * @param maxTries Maximum number of download attempts before failing and throwing an exception + * @param connectionTimeout connection timeout in milliseconds, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)} + * @param readTimeout read timeout in milliseconds, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)} * @throws IOException If an error occurs during downloading */ - public static void downloadAndExtract(String name, URL url, File f, File extractToDir, String targetMD5, int maxTries) throws IOException { - downloadAndExtract(0, maxTries, name, url, f, extractToDir, targetMD5); + public static void downloadAndExtract(String name, URL url, File f, File extractToDir, String targetMD5, int maxTries, + int connectionTimeout, int readTimeout) throws IOException { + downloadAndExtract(0, maxTries, name, url, f, extractToDir, targetMD5, connectionTimeout, readTimeout); } - private static void downloadAndExtract(int attempt, int maxTries, String name, URL url, File f, File extractToDir, String targetMD5) throws IOException { + private static void downloadAndExtract(int attempt, int maxTries, String name, URL url, File f, File extractToDir, + String targetMD5, int connectionTimeout, int readTimeout) throws IOException { boolean isCorrectFile = f.exists() && f.isFile() && checkMD5OfFile(targetMD5, f); if (attempt < maxTries) { if(!isCorrectFile) { - FileUtils.copyURLToFile(url, f); + FileUtils.copyURLToFile(url, f, connectionTimeout, readTimeout); if (!checkMD5OfFile(targetMD5, f)) { f.delete(); - downloadAndExtract(attempt + 1, maxTries, name, url, f, extractToDir, targetMD5); + downloadAndExtract(attempt + 1, maxTries, name, url, f, extractToDir, targetMD5, connectionTimeout, readTimeout); } } // try extracting @@ -99,7 +130,7 @@ public class Downloader { } catch (Throwable t){ log.warn("Error extracting {} files from file {} - retrying...", name, f.getAbsolutePath(), t); f.delete(); - downloadAndExtract(attempt + 1, maxTries, name, url, f, extractToDir, targetMD5); + downloadAndExtract(attempt + 1, maxTries, name, url, f, extractToDir, targetMD5, connectionTimeout, readTimeout); } } else if (!isCorrectFile) { //Too many attempts diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java index f8fca14b5..6a69bd3b9 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java @@ -1,5 +1,6 @@ package org.nd4j.resources.strumpf; +import org.nd4j.config.ND4JSystemProperties; import org.nd4j.shade.guava.io.Files; import lombok.AllArgsConstructor; import lombok.Data; @@ -32,6 +33,14 @@ import java.util.Map; @JsonIgnoreProperties("filePath") @Slf4j public class ResourceFile { + /** + * Default value for resource downloading connection timeout - see {@link ND4JSystemProperties#RESOURCES_CONNECTION_TIMEOUT} + */ + public static final int DEFAULT_CONNECTION_TIMEOUT = 60000; //Timeout for connections to be established + /** + * Default value for resource downloading read timeout - see {@link ND4JSystemProperties#RESOURCES_READ_TIMEOUT} + */ + public static final int DEFAULT_READ_TIMEOUT = 60000; //Timeout for amount of time between connection established and data is available protected static final String PATH_KEY = "full_remote_path"; protected static final String HASH = "_hash"; protected static final String COMPRESSED_HASH = "_compressed_hash"; @@ -146,15 +155,20 @@ public class ResourceFile { String sha256PropertyCompressed = relativePath() + COMPRESSED_HASH; - //TODO NEXT LINE IN TEMPORARY UNTIL FIXED IN STRUMPF 0.3.2 -// sha256PropertyCompressed = sha256PropertyCompressed.replaceAll("/", "\\\\"); - String sha256Compressed = v1.get(sha256PropertyCompressed); Preconditions.checkState(sha256Compressed != null, "Expected JSON property %s was not found in resource reference file %s", sha256PropertyCompressed, filePath); String sha256Property = relativePath() + HASH; String sha256Uncompressed = v1.get(sha256Property); + String connTimeoutStr = System.getProperty(ND4JSystemProperties.RESOURCES_CONNECTION_TIMEOUT); + String readTimeoutStr = System.getProperty(ND4JSystemProperties.RESOURCES_READ_TIMEOUT); + boolean validCTimeout = connTimeoutStr != null && connTimeoutStr.matches("\\d+"); + boolean validRTimeout = readTimeoutStr != null && readTimeoutStr.matches("\\d+"); + + int connectTimeout = validCTimeout ? Integer.parseInt(connTimeoutStr) : DEFAULT_CONNECTION_TIMEOUT; + int readTimeout = validRTimeout ? Integer.parseInt(readTimeoutStr) : DEFAULT_READ_TIMEOUT; + try { boolean correctHash = false; for (int tryCount = 0; tryCount < MAX_DOWNLOAD_ATTEMPTS; tryCount++) { @@ -162,7 +176,7 @@ public class ResourceFile { if (tempFile.exists()) tempFile.delete(); log.info("Downloading remote resource {} to {}", remotePath, tempFile); - FileUtils.copyURLToFile(new URL(remotePath), tempFile); + FileUtils.copyURLToFile(new URL(remotePath), tempFile, connectTimeout, readTimeout); //Now: check if downloaded archive hash is OK String hash = sha256(tempFile); correctHash = sha256Compressed.equals(hash);