diff --git a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java index 5e75bd3ce..96e29d1ac 100644 --- a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java +++ b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java @@ -54,7 +54,11 @@ public class ModelGuesser { * @return the loaded normalizer */ public static Normalizer loadNormalizer(String path) { - return ModelSerializer.restoreNormalizerFromFile(new File(path)); + try { + return ModelSerializer.restoreNormalizerFromFile(new File(path)); + } catch (IOException e){ + throw new RuntimeException(e); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java index 70945b06c..30dbe7b4b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java @@ -16,6 +16,7 @@ package org.deeplearning4j.util; +import org.apache.commons.io.input.CloseShieldInputStream; import org.nd4j.shade.guava.io.Files; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; @@ -43,9 +44,12 @@ import org.nd4j.linalg.primitives.Pair; import java.io.*; import java.util.ArrayList; import java.util.Enumeration; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.zip.ZipEntry; import java.util.zip.ZipFile; +import java.util.zip.ZipInputStream; import java.util.zip.ZipOutputStream; /** @@ -215,7 +219,31 @@ public class ModelSerializer { */ public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater) throws IOException { - ZipFile zipFile = new ZipFile(file); + try(InputStream is = new BufferedInputStream(new FileInputStream(file))){ + return restoreMultiLayerNetwork(is, loadUpdater); + } + } + + + /** + * Load a MultiLayerNetwork from InputStream from an input stream
+ * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used. + * + * @param is the inputstream to load from + * @return the loaded multi layer network + * @throws IOException + * @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean) + */ + public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater) + throws IOException { + return restoreMultiLayerNetworkHelper(is, loadUpdater).getFirst(); + } + + private static Pair> restoreMultiLayerNetworkHelper(@NonNull InputStream is, boolean loadUpdater) + throws IOException { + checkInputStream(is); + + Map zipFile = loadZipData(is); boolean gotConfig = false; boolean gotCoefficients = false; @@ -229,11 +257,11 @@ public class ModelSerializer { DataSetPreProcessor preProcessor = null; - ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON); + byte[] config = zipFile.get(CONFIGURATION_JSON); if (config != null) { //restoring configuration - InputStream stream = zipFile.getInputStream(config); + InputStream stream = new ByteArrayInputStream(config); BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); String line = ""; StringBuilder js = new StringBuilder(); @@ -248,25 +276,25 @@ public class ModelSerializer { } - ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN); + byte[] coefficients = zipFile.get(COEFFICIENTS_BIN); if (coefficients != null ) { - if(coefficients.getSize() > 0) { - InputStream stream = zipFile.getInputStream(coefficients); + if(coefficients.length > 0) { + InputStream stream = new ByteArrayInputStream(coefficients); DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); params = Nd4j.read(dis); dis.close(); gotCoefficients = true; } else { - ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER); + byte[] noParamsMarker = zipFile.get(NO_PARAMS_MARKER); gotCoefficients = (noParamsMarker != null); } } if (loadUpdater) { - ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN); + byte[] updaterStateEntry = zipFile.get(UPDATER_BIN); if (updaterStateEntry != null) { - InputStream stream = zipFile.getInputStream(updaterStateEntry); + InputStream stream = new ByteArrayInputStream(updaterStateEntry); DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); updaterState = Nd4j.read(dis); @@ -275,9 +303,9 @@ public class ModelSerializer { } } - ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN); + byte[] prep = zipFile.get(PREPROCESSOR_BIN); if (prep != null) { - InputStream stream = zipFile.getInputStream(prep); + InputStream stream = new ByteArrayInputStream(prep); ObjectInputStream ois = new ObjectInputStream(stream); try { @@ -290,7 +318,6 @@ public class ModelSerializer { } - zipFile.close(); if (gotConfig && gotCoefficients) { MultiLayerConfiguration confFromJson; @@ -322,37 +349,12 @@ public class ModelSerializer { if (gotUpdaterState && updaterState != null) { network.getUpdater().setStateViewArray(network, updaterState, false); } - return network; + return new Pair<>(network, zipFile); } else throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); } - - /** - * Load a MultiLayerNetwork from InputStream from an input stream
- * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used. - * - * @param is the inputstream to load from - * @return the loaded multi layer network - * @throws IOException - * @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean) - */ - public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater) - throws IOException { - checkInputStream(is); - - File tmpFile = null; - try{ - tmpFile = tempFileFromStream(is); - return restoreMultiLayerNetwork(tmpFile, loadUpdater); - } finally { - if(tmpFile != null){ - tmpFile.delete(); - } - } - } - /** * Restore a multi layer network from an input stream
* * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used. @@ -403,16 +405,12 @@ public class ModelSerializer { public static Pair restoreMultiLayerNetworkAndNormalizer( @NonNull InputStream is, boolean loadUpdater) throws IOException { checkInputStream(is); + is = new CloseShieldInputStream(is); - File tmpFile = null; - try { - tmpFile = tempFileFromStream(is); - return restoreMultiLayerNetworkAndNormalizer(tmpFile, loadUpdater); - } finally { - if (tmpFile != null) { - tmpFile.delete(); - } - } + Pair> p = restoreMultiLayerNetworkHelper(is, loadUpdater); + MultiLayerNetwork net = p.getFirst(); + Normalizer norm = restoreNormalizerFromMap(p.getSecond()); + return new Pair<>(net, norm); } /** @@ -425,9 +423,9 @@ public class ModelSerializer { */ public static Pair restoreMultiLayerNetworkAndNormalizer(@NonNull File file, boolean loadUpdater) throws IOException { - MultiLayerNetwork net = restoreMultiLayerNetwork(file, loadUpdater); - Normalizer norm = restoreNormalizerFromFile(file); - return new Pair<>(net, norm); + try(InputStream is = new BufferedInputStream(new FileInputStream(file))){ + return restoreMultiLayerNetworkAndNormalizer(is, loadUpdater); + } } /** @@ -463,17 +461,126 @@ public class ModelSerializer { */ public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater) throws IOException { + return restoreComputationGraphHelper(is, loadUpdater).getFirst(); + } + + private static Pair> restoreComputationGraphHelper(@NonNull InputStream is, boolean loadUpdater) + throws IOException { checkInputStream(is); - File tmpFile = null; - try{ - tmpFile = tempFileFromStream(is); - return restoreComputationGraph(tmpFile, loadUpdater); - } finally { - if(tmpFile != null){ - tmpFile.delete(); + Map files = loadZipData(is); + + boolean gotConfig = false; + boolean gotCoefficients = false; + boolean gotUpdaterState = false; + boolean gotPreProcessor = false; + + String json = ""; + INDArray params = null; + INDArray updaterState = null; + DataSetPreProcessor preProcessor = null; + + + byte[] config = files.get(CONFIGURATION_JSON); + if (config != null) { + //restoring configuration + + InputStream stream = new ByteArrayInputStream(config); + BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); + String line = ""; + StringBuilder js = new StringBuilder(); + while ((line = reader.readLine()) != null) { + js.append(line).append("\n"); + } + json = js.toString(); + + reader.close(); + stream.close(); + gotConfig = true; + } + + + byte[] coefficients = files.get(COEFFICIENTS_BIN); + if (coefficients != null) { + if(coefficients.length > 0) { + InputStream stream = new ByteArrayInputStream(coefficients); + DataInputStream dis = new DataInputStream(stream); + params = Nd4j.read(dis); + + dis.close(); + gotCoefficients = true; + } else { + byte[] noParamsMarker = files.get(NO_PARAMS_MARKER); + gotCoefficients = (noParamsMarker != null); } } + + + if (loadUpdater) { + byte[] updaterStateEntry = files.get(UPDATER_BIN); + if (updaterStateEntry != null) { + InputStream stream = new ByteArrayInputStream(updaterStateEntry); + DataInputStream dis = new DataInputStream(stream); + updaterState = Nd4j.read(dis); + + dis.close(); + gotUpdaterState = true; + } + } + + byte[] prep = files.get(PREPROCESSOR_BIN); + if (prep != null) { + InputStream stream = new ByteArrayInputStream(prep); + ObjectInputStream ois = new ObjectInputStream(stream); + + try { + preProcessor = (DataSetPreProcessor) ois.readObject(); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + + gotPreProcessor = true; + } + + + if (gotConfig && gotCoefficients) { + ComputationGraphConfiguration confFromJson; + try{ + confFromJson = ComputationGraphConfiguration.fromJson(json); + if(confFromJson.getNetworkInputs() == null && (confFromJson.getVertices() == null || confFromJson.getVertices().size() == 0)){ + //May be deserialized correctly, but mostly with null fields + throw new RuntimeException("Invalid JSON - not a ComputationGraphConfiguration"); + } + } catch (Exception e){ + if(e.getMessage() != null && e.getMessage().contains("registerLegacyCustomClassesForJSON")){ + throw e; + } + try{ + MultiLayerConfiguration.fromJson(json); + } catch (Exception e2){ + //Invalid, and not a compgraph + throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is" + + " not a valid ComputationGraphConfiguration", e); + } + throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model appears to be " + + "a MultiLayerNetwork - use ModelSerializer.restoreMultiLayerNetwork instead"); + } + + //Handle legacy config - no network DataType in config, in beta3 or earlier + if(params != null) + confFromJson.setDataType(params.dataType()); + + ComputationGraph cg = new ComputationGraph(confFromJson); + cg.init(params, false); + + + if (gotUpdaterState && updaterState != null) { + cg.getUpdater().setStateViewArray(updaterState); + } + return new Pair<>(cg, files); + } else + throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); } /** @@ -511,15 +618,11 @@ public class ModelSerializer { @NonNull InputStream is, boolean loadUpdater) throws IOException { checkInputStream(is); - File tmpFile = null; - try { - tmpFile = tempFileFromStream(is); - return restoreComputationGraphAndNormalizer(tmpFile, loadUpdater); - } finally { - if (tmpFile != null) { - tmpFile.delete(); - } - } + + Pair> p = restoreComputationGraphHelper(is, loadUpdater); + ComputationGraph net = p.getFirst(); + Normalizer norm = restoreNormalizerFromMap(p.getSecond()); + return new Pair<>(net, norm); } /** @@ -532,9 +635,7 @@ public class ModelSerializer { */ public static Pair restoreComputationGraphAndNormalizer(@NonNull File file, boolean loadUpdater) throws IOException { - ComputationGraph net = restoreComputationGraph(file, loadUpdater); - Normalizer norm = restoreNormalizerFromFile(file); - return new Pair<>(net, norm); + return restoreComputationGraphAndNormalizer(new FileInputStream(file), loadUpdater); } /** @@ -545,121 +646,7 @@ public class ModelSerializer { * @throws IOException */ public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException { - ZipFile zipFile = new ZipFile(file); - - boolean gotConfig = false; - boolean gotCoefficients = false; - boolean gotUpdaterState = false; - boolean gotPreProcessor = false; - - String json = ""; - INDArray params = null; - INDArray updaterState = null; - DataSetPreProcessor preProcessor = null; - - - ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON); - if (config != null) { - //restoring configuration - - InputStream stream = zipFile.getInputStream(config); - BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); - String line = ""; - StringBuilder js = new StringBuilder(); - while ((line = reader.readLine()) != null) { - js.append(line).append("\n"); - } - json = js.toString(); - - reader.close(); - stream.close(); - gotConfig = true; - } - - - ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN); - if (coefficients != null) { - if(coefficients.getSize() > 0) { - InputStream stream = zipFile.getInputStream(coefficients); - DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); - params = Nd4j.read(dis); - - dis.close(); - gotCoefficients = true; - } else { - ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER); - gotCoefficients = (noParamsMarker != null); - } - } - - - if (loadUpdater) { - ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN); - if (updaterStateEntry != null) { - InputStream stream = zipFile.getInputStream(updaterStateEntry); - DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); - updaterState = Nd4j.read(dis); - - dis.close(); - gotUpdaterState = true; - } - } - - ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN); - if (prep != null) { - InputStream stream = zipFile.getInputStream(prep); - ObjectInputStream ois = new ObjectInputStream(stream); - - try { - preProcessor = (DataSetPreProcessor) ois.readObject(); - } catch (ClassNotFoundException e) { - throw new RuntimeException(e); - } - - gotPreProcessor = true; - } - - - zipFile.close(); - - if (gotConfig && gotCoefficients) { - ComputationGraphConfiguration confFromJson; - try{ - confFromJson = ComputationGraphConfiguration.fromJson(json); - if(confFromJson.getNetworkInputs() == null && (confFromJson.getVertices() == null || confFromJson.getVertices().size() == 0)){ - //May be deserialized correctly, but mostly with null fields - throw new RuntimeException("Invalid JSON - not a ComputationGraphConfiguration"); - } - } catch (Exception e){ - if(e.getMessage() != null && e.getMessage().contains("registerLegacyCustomClassesForJSON")){ - throw e; - } - try{ - MultiLayerConfiguration.fromJson(json); - } catch (Exception e2){ - //Invalid, and not a compgraph - throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is" + - " not a valid ComputationGraphConfiguration", e); - } - throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model appears to be " + - "a MultiLayerNetwork - use ModelSerializer.restoreMultiLayerNetwork instead"); - } - - //Handle legacy config - no network DataType in config, in beta3 or earlier - if(params != null) - confFromJson.setDataType(params.dataType()); - - ComputationGraph cg = new ComputationGraph(confFromJson); - cg.init(params, false); - - - if (gotUpdaterState && updaterState != null) { - cg.getUpdater().setStateViewArray(updaterState); - } - return cg; - } else - throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig - + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); + return restoreComputationGraph(new FileInputStream(file), loadUpdater); } /** @@ -811,15 +798,16 @@ public class ModelSerializer { } //Add new object: - ZipEntry entry = new ZipEntry("objects/" + key); - writeFile.putNextEntry(entry); try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){ oos.writeObject(o); byte[] bytes = baos.toByteArray(); + ZipEntry entry = new ZipEntry("objects/" + key); + entry.setSize(bytes.length); + writeFile.putNextEntry(entry); writeFile.write(bytes); + writeFile.closeEntry(); } - writeFile.closeEntry(); writeFile.close(); zipFile.close(); @@ -904,18 +892,12 @@ public class ModelSerializer { * @param file * @return */ - public static T restoreNormalizerFromFile(File file) { - try (ZipFile zipFile = new ZipFile(file)) { - ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN); - - // checking for file existence - if (norm == null) - return null; - - return NormalizerSerializer.getDefault().restore(zipFile.getInputStream(norm)); + public static T restoreNormalizerFromFile(File file) throws IOException { + try (InputStream is = new BufferedInputStream(new FileInputStream(file))) { + return restoreNormalizerFromInputStream(is); } catch (Exception e) { log.warn("Error while restoring normalizer, trying to restore assuming deprecated format..."); - DataNormalization restoredDeprecated = restoreNormalizerFromFileDeprecated(file); + DataNormalization restoredDeprecated = restoreNormalizerFromInputStreamDeprecated(new FileInputStream(file)); log.warn("Recovered using deprecated method. Will now re-save the normalizer to fix this issue."); addNormalizerToModel(file, restoredDeprecated); @@ -933,16 +915,22 @@ public class ModelSerializer { */ public static T restoreNormalizerFromInputStream(InputStream is) throws IOException { checkInputStream(is); + Map files = loadZipData(is); + return restoreNormalizerFromMap(files); + } - File tmpFile = null; + private static T restoreNormalizerFromMap(Map files) throws IOException { + byte[] norm = files.get(NORMALIZER_BIN); + + // checking for file existence + if (norm == null) + return null; try { - tmpFile = tempFileFromStream(is); - return restoreNormalizerFromFile(tmpFile); - } finally { - if(tmpFile != null){ - tmpFile.delete(); - } + return NormalizerSerializer.getDefault().restore(new ByteArrayInputStream(norm)); } + catch (Exception e) { + throw new IOException("Error loading normalizer", e); + } } /** @@ -950,20 +938,10 @@ public class ModelSerializer { * * This method restores normalizer from a given persisted model file serialized with Java object serialization * - * @param file - * @return */ - private static DataNormalization restoreNormalizerFromFileDeprecated(File file) { - try (ZipFile zipFile = new ZipFile(file)) { - ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN); - - // checking for file existence - if (norm == null) - return null; - - InputStream stream = zipFile.getInputStream(norm); + private static DataNormalization restoreNormalizerFromInputStreamDeprecated(InputStream stream) { + try { ObjectInputStream ois = new ObjectInputStream(stream); - try { DataNormalization normalizer = (DataNormalization) ois.readObject(); return normalizer; @@ -977,8 +955,6 @@ public class ModelSerializer { private static void checkInputStream(InputStream inputStream) throws IOException { - - /* //available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887 int available; try{ @@ -993,34 +969,32 @@ public class ModelSerializer { throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" + "multiple times?"); } - */ } - private static void checkTempFileFromInputStream(File f) throws IOException { - if (f.length() <= 0) { - throw new IOException("Error reading from input stream: temporary file is empty after copying entire stream." + - " Stream may have been closed before reading, is attempting to be used multiple times, or does not" + - " point to a model file?"); - } + private static Map loadZipData(InputStream is) throws IOException { + Map result = new HashMap<>(); + try (final ZipInputStream zis = new ZipInputStream(is)) { + while (true) { + final ZipEntry zipEntry = zis.getNextEntry(); + if (zipEntry == null) + break; + if(zipEntry.isDirectory() || zipEntry.getSize() > Integer.MAX_VALUE) + throw new IllegalArgumentException(); + + final int size = (int) (zipEntry.getSize()); + final byte[] data; + if (size >= 0) { // known size + data = IOUtils.readFully(zis, size); + } + else { // unknown size + final ByteArrayOutputStream bout = new ByteArrayOutputStream(); + IOUtils.copy(zis, bout); + data = bout.toByteArray(); + } + result.put(zipEntry.getName(), data); + } + } + return result; } - private static File tempFileFromStream(InputStream is) throws IOException{ - checkInputStream(is); - String p = System.getProperty(DL4JSystemProperties.DL4J_TEMP_DIR_PROPERTY); - File tmpFile = DL4JFileUtils.createTempFile("dl4jModelSerializer", "bin"); - try { - tmpFile.deleteOnExit(); - BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(tmpFile)); - IOUtils.copy(is, bufferedOutputStream); - bufferedOutputStream.flush(); - IOUtils.closeQuietly(bufferedOutputStream); - checkTempFileFromInputStream(tmpFile); - return tmpFile; - } catch (IOException e){ - if(tmpFile != null){ - tmpFile.delete(); - } - throw e; - } - } }