From 1df8899fb1ff6a4d55f8ce41889ea6d6d3e6b757 Mon Sep 17 00:00:00 2001 From: cspriegel Date: Wed, 26 Feb 2020 16:24:32 +0100 Subject: [PATCH] restoreMultiLayerNetwork() needlessly writes temp-file #8735 Signed-off-by: cspriegel --- .../deeplearning4j/util/ModelSerializer.java | 355 ++++++++---------- 1 file changed, 156 insertions(+), 199 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java index 70945b06c..aa0cab9a0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java @@ -43,9 +43,12 @@ import org.nd4j.linalg.primitives.Pair; import java.io.*; import java.util.ArrayList; import java.util.Enumeration; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.zip.ZipEntry; import java.util.zip.ZipFile; +import java.util.zip.ZipInputStream; import java.util.zip.ZipOutputStream; /** @@ -215,7 +218,24 @@ public class ModelSerializer { */ public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater) throws IOException { - ZipFile zipFile = new ZipFile(file); + return restoreMultiLayerNetwork(new FileInputStream(file), loadUpdater); + } + + + /** + * Load a MultiLayerNetwork from InputStream from an input stream
+ * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used. + * + * @param is the inputstream to load from + * @return the loaded multi layer network + * @throws IOException + * @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean) + */ + public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater) + throws IOException { + checkInputStream(is); + + Map zipFile = loadZipData(is); boolean gotConfig = false; boolean gotCoefficients = false; @@ -229,11 +249,11 @@ public class ModelSerializer { DataSetPreProcessor preProcessor = null; - ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON); + byte[] config = zipFile.get(CONFIGURATION_JSON); if (config != null) { //restoring configuration - InputStream stream = zipFile.getInputStream(config); + InputStream stream = new ByteArrayInputStream(config); BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); String line = ""; StringBuilder js = new StringBuilder(); @@ -248,25 +268,25 @@ public class ModelSerializer { } - ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN); + byte[] coefficients = zipFile.get(COEFFICIENTS_BIN); if (coefficients != null ) { - if(coefficients.getSize() > 0) { - InputStream stream = zipFile.getInputStream(coefficients); + if(coefficients.length > 0) { + InputStream stream = new ByteArrayInputStream(coefficients); DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); params = Nd4j.read(dis); dis.close(); gotCoefficients = true; } else { - ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER); + byte[] noParamsMarker = zipFile.get(NO_PARAMS_MARKER); gotCoefficients = (noParamsMarker != null); } } if (loadUpdater) { - ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN); + byte[] updaterStateEntry = zipFile.get(UPDATER_BIN); if (updaterStateEntry != null) { - InputStream stream = zipFile.getInputStream(updaterStateEntry); + InputStream stream = new ByteArrayInputStream(updaterStateEntry); DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); updaterState = Nd4j.read(dis); @@ -275,9 +295,9 @@ public class ModelSerializer { } } - ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN); + byte[] prep = zipFile.get(PREPROCESSOR_BIN); if (prep != null) { - InputStream stream = zipFile.getInputStream(prep); + InputStream stream = new ByteArrayInputStream(prep); ObjectInputStream ois = new ObjectInputStream(stream); try { @@ -290,7 +310,6 @@ public class ModelSerializer { } - zipFile.close(); if (gotConfig && gotCoefficients) { MultiLayerConfiguration confFromJson; @@ -328,31 +347,6 @@ public class ModelSerializer { + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); } - - /** - * Load a MultiLayerNetwork from InputStream from an input stream
- * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used. - * - * @param is the inputstream to load from - * @return the loaded multi layer network - * @throws IOException - * @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean) - */ - public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater) - throws IOException { - checkInputStream(is); - - File tmpFile = null; - try{ - tmpFile = tempFileFromStream(is); - return restoreMultiLayerNetwork(tmpFile, loadUpdater); - } finally { - if(tmpFile != null){ - tmpFile.delete(); - } - } - } - /** * Restore a multi layer network from an input stream
* * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used. @@ -404,15 +398,9 @@ public class ModelSerializer { @NonNull InputStream is, boolean loadUpdater) throws IOException { checkInputStream(is); - File tmpFile = null; - try { - tmpFile = tempFileFromStream(is); - return restoreMultiLayerNetworkAndNormalizer(tmpFile, loadUpdater); - } finally { - if (tmpFile != null) { - tmpFile.delete(); - } - } + MultiLayerNetwork net = restoreMultiLayerNetwork(is, loadUpdater); + Normalizer norm = restoreNormalizerFromInputStream(is); + return new Pair<>(net, norm); } /** @@ -425,9 +413,7 @@ public class ModelSerializer { */ public static Pair restoreMultiLayerNetworkAndNormalizer(@NonNull File file, boolean loadUpdater) throws IOException { - MultiLayerNetwork net = restoreMultiLayerNetwork(file, loadUpdater); - Normalizer norm = restoreNormalizerFromFile(file); - return new Pair<>(net, norm); + return restoreMultiLayerNetworkAndNormalizer(new FileInputStream(file), loadUpdater); } /** @@ -465,87 +451,7 @@ public class ModelSerializer { throws IOException { checkInputStream(is); - File tmpFile = null; - try{ - tmpFile = tempFileFromStream(is); - return restoreComputationGraph(tmpFile, loadUpdater); - } finally { - if(tmpFile != null){ - tmpFile.delete(); - } - } - } - - /** - * Load a computation graph from a InputStream - * @param is the inputstream to get the computation graph from - * @return the loaded computation graph - * - * @throws IOException - */ - public static ComputationGraph restoreComputationGraph(@NonNull InputStream is) throws IOException { - return restoreComputationGraph(is, true); - } - - /** - * Load a computation graph from a file - * @param file the file to get the computation graph from - * @return the loaded computation graph - * - * @throws IOException - */ - public static ComputationGraph restoreComputationGraph(@NonNull File file) throws IOException { - return restoreComputationGraph(file, true); - } - - /** - * Restore a ComputationGraph and Normalizer (if present - null if not) from the InputStream. - * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used. - * - * @param is Input stream to read from - * @param loadUpdater Whether to load the updater from the model or not - * @return Model and normalizer, if present - * @throws IOException If an error occurs when reading from the stream - */ - public static Pair restoreComputationGraphAndNormalizer( - @NonNull InputStream is, boolean loadUpdater) throws IOException { - checkInputStream(is); - - File tmpFile = null; - try { - tmpFile = tempFileFromStream(is); - return restoreComputationGraphAndNormalizer(tmpFile, loadUpdater); - } finally { - if (tmpFile != null) { - tmpFile.delete(); - } - } - } - - /** - * Restore a ComputationGraph and Normalizer (if present - null if not) from a File - * - * @param file File to read the model and normalizer from - * @param loadUpdater Whether to load the updater from the model or not - * @return Model and normalizer, if present - * @throws IOException If an error occurs when reading from the File - */ - public static Pair restoreComputationGraphAndNormalizer(@NonNull File file, boolean loadUpdater) - throws IOException { - ComputationGraph net = restoreComputationGraph(file, loadUpdater); - Normalizer norm = restoreNormalizerFromFile(file); - return new Pair<>(net, norm); - } - - /** - * Load a computation graph from a file - * @param file the file to get the computation graph from - * @return the loaded computation graph - * - * @throws IOException - */ - public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException { - ZipFile zipFile = new ZipFile(file); + Map files = loadZipData(is); boolean gotConfig = false; boolean gotCoefficients = false; @@ -558,11 +464,11 @@ public class ModelSerializer { DataSetPreProcessor preProcessor = null; - ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON); + byte[] config = files.get(CONFIGURATION_JSON); if (config != null) { //restoring configuration - InputStream stream = zipFile.getInputStream(config); + InputStream stream = new ByteArrayInputStream(config); BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); String line = ""; StringBuilder js = new StringBuilder(); @@ -577,27 +483,27 @@ public class ModelSerializer { } - ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN); + byte[] coefficients = files.get(COEFFICIENTS_BIN); if (coefficients != null) { - if(coefficients.getSize() > 0) { - InputStream stream = zipFile.getInputStream(coefficients); - DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); + if(coefficients.length > 0) { + InputStream stream = new ByteArrayInputStream(coefficients); + DataInputStream dis = new DataInputStream(stream); params = Nd4j.read(dis); dis.close(); gotCoefficients = true; } else { - ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER); + byte[] noParamsMarker = files.get(NO_PARAMS_MARKER); gotCoefficients = (noParamsMarker != null); } } if (loadUpdater) { - ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN); + byte[] updaterStateEntry = files.get(UPDATER_BIN); if (updaterStateEntry != null) { - InputStream stream = zipFile.getInputStream(updaterStateEntry); - DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); + InputStream stream = new ByteArrayInputStream(updaterStateEntry); + DataInputStream dis = new DataInputStream(stream); updaterState = Nd4j.read(dis); dis.close(); @@ -605,9 +511,9 @@ public class ModelSerializer { } } - ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN); + byte[] prep = files.get(PREPROCESSOR_BIN); if (prep != null) { - InputStream stream = zipFile.getInputStream(prep); + InputStream stream = new ByteArrayInputStream(prep); ObjectInputStream ois = new ObjectInputStream(stream); try { @@ -620,8 +526,6 @@ public class ModelSerializer { } - zipFile.close(); - if (gotConfig && gotCoefficients) { ComputationGraphConfiguration confFromJson; try{ @@ -662,6 +566,70 @@ public class ModelSerializer { + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); } + /** + * Load a computation graph from a InputStream + * @param is the inputstream to get the computation graph from + * @return the loaded computation graph + * + * @throws IOException + */ + public static ComputationGraph restoreComputationGraph(@NonNull InputStream is) throws IOException { + return restoreComputationGraph(is, true); + } + + /** + * Load a computation graph from a file + * @param file the file to get the computation graph from + * @return the loaded computation graph + * + * @throws IOException + */ + public static ComputationGraph restoreComputationGraph(@NonNull File file) throws IOException { + return restoreComputationGraph(file, true); + } + + /** + * Restore a ComputationGraph and Normalizer (if present - null if not) from the InputStream. + * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used. + * + * @param is Input stream to read from + * @param loadUpdater Whether to load the updater from the model or not + * @return Model and normalizer, if present + * @throws IOException If an error occurs when reading from the stream + */ + public static Pair restoreComputationGraphAndNormalizer( + @NonNull InputStream is, boolean loadUpdater) throws IOException { + checkInputStream(is); + + ComputationGraph net = restoreComputationGraph(is, loadUpdater); + Normalizer norm = restoreNormalizerFromInputStream(is); + return new Pair<>(net, norm); + } + + /** + * Restore a ComputationGraph and Normalizer (if present - null if not) from a File + * + * @param file File to read the model and normalizer from + * @param loadUpdater Whether to load the updater from the model or not + * @return Model and normalizer, if present + * @throws IOException If an error occurs when reading from the File + */ + public static Pair restoreComputationGraphAndNormalizer(@NonNull File file, boolean loadUpdater) + throws IOException { + return restoreComputationGraphAndNormalizer(new FileInputStream(file), loadUpdater); + } + + /** + * Load a computation graph from a file + * @param file the file to get the computation graph from + * @return the loaded computation graph + * + * @throws IOException + */ + public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException { + return restoreComputationGraph(new FileInputStream(file), loadUpdater); + } + /** * * @param model @@ -811,15 +779,16 @@ public class ModelSerializer { } //Add new object: - ZipEntry entry = new ZipEntry("objects/" + key); - writeFile.putNextEntry(entry); try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){ oos.writeObject(o); byte[] bytes = baos.toByteArray(); + ZipEntry entry = new ZipEntry("objects/" + key); + entry.setSize(bytes.length); + writeFile.putNextEntry(entry); writeFile.write(bytes); + writeFile.closeEntry(); } - writeFile.closeEntry(); writeFile.close(); zipFile.close(); @@ -904,18 +873,12 @@ public class ModelSerializer { * @param file * @return */ - public static T restoreNormalizerFromFile(File file) { - try (ZipFile zipFile = new ZipFile(file)) { - ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN); - - // checking for file existence - if (norm == null) - return null; - - return NormalizerSerializer.getDefault().restore(zipFile.getInputStream(norm)); + public static T restoreNormalizerFromFile(File file) throws IOException { + try { + return restoreNormalizerFromInputStream(new FileInputStream(file)); } catch (Exception e) { log.warn("Error while restoring normalizer, trying to restore assuming deprecated format..."); - DataNormalization restoredDeprecated = restoreNormalizerFromFileDeprecated(file); + DataNormalization restoredDeprecated = restoreNormalizerFromInputStreamDeprecated(new FileInputStream(file)); log.warn("Recovered using deprecated method. Will now re-save the normalizer to fix this issue."); addNormalizerToModel(file, restoredDeprecated); @@ -934,15 +897,18 @@ public class ModelSerializer { public static T restoreNormalizerFromInputStream(InputStream is) throws IOException { checkInputStream(is); - File tmpFile = null; + Map files = loadZipData(is); + byte[] norm = files.get(NORMALIZER_BIN); + + // checking for file existence + if (norm == null) + return null; try { - tmpFile = tempFileFromStream(is); - return restoreNormalizerFromFile(tmpFile); - } finally { - if(tmpFile != null){ - tmpFile.delete(); - } + return NormalizerSerializer.getDefault().restore(new ByteArrayInputStream(norm)); } + catch (Exception e) { + throw new IOException("Error loading normalizer", e); + } } /** @@ -953,17 +919,9 @@ public class ModelSerializer { * @param file * @return */ - private static DataNormalization restoreNormalizerFromFileDeprecated(File file) { - try (ZipFile zipFile = new ZipFile(file)) { - ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN); - - // checking for file existence - if (norm == null) - return null; - - InputStream stream = zipFile.getInputStream(norm); + private static DataNormalization restoreNormalizerFromInputStreamDeprecated(InputStream stream) { + try { ObjectInputStream ois = new ObjectInputStream(stream); - try { DataNormalization normalizer = (DataNormalization) ois.readObject(); return normalizer; @@ -996,31 +954,30 @@ public class ModelSerializer { */ } - private static void checkTempFileFromInputStream(File f) throws IOException { - if (f.length() <= 0) { - throw new IOException("Error reading from input stream: temporary file is empty after copying entire stream." + - " Stream may have been closed before reading, is attempting to be used multiple times, or does not" + - " point to a model file?"); - } + private static Map loadZipData(InputStream is) throws IOException { + Map result = new HashMap<>(); + try (final ZipInputStream zis = new ZipInputStream(is)) { + while (true) { + final ZipEntry zipEntry = zis.getNextEntry(); + if (zipEntry == null) + break; + if(zipEntry.isDirectory() || zipEntry.getSize() > Integer.MAX_VALUE) + throw new IllegalArgumentException(); + + final int size = (int) (zipEntry.getSize()); + final byte[] data; + if (size >= 0) { // known size + data = IOUtils.readFully(zis, size); + } + else { // unknown size + final ByteArrayOutputStream bout = new ByteArrayOutputStream(); + IOUtils.copy(zis, bout); + data = bout.toByteArray(); + } + result.put(zipEntry.getName(), data); + } + } + return result; } - private static File tempFileFromStream(InputStream is) throws IOException{ - checkInputStream(is); - String p = System.getProperty(DL4JSystemProperties.DL4J_TEMP_DIR_PROPERTY); - File tmpFile = DL4JFileUtils.createTempFile("dl4jModelSerializer", "bin"); - try { - tmpFile.deleteOnExit(); - BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(tmpFile)); - IOUtils.copy(is, bufferedOutputStream); - bufferedOutputStream.flush(); - IOUtils.closeQuietly(bufferedOutputStream); - checkTempFileFromInputStream(tmpFile); - return tmpFile; - } catch (IOException e){ - if(tmpFile != null){ - tmpFile.delete(); - } - throw e; - } - } }