diff --git a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java
index 5e75bd3ce..96e29d1ac 100644
--- a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java
+++ b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java
@@ -54,7 +54,11 @@ public class ModelGuesser {
* @return the loaded normalizer
*/
public static Normalizer> loadNormalizer(String path) {
- return ModelSerializer.restoreNormalizerFromFile(new File(path));
+ try {
+ return ModelSerializer.restoreNormalizerFromFile(new File(path));
+ } catch (IOException e){
+ throw new RuntimeException(e);
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java
index 70945b06c..30dbe7b4b 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.util;
+import org.apache.commons.io.input.CloseShieldInputStream;
import org.nd4j.shade.guava.io.Files;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
@@ -43,9 +44,12 @@ import org.nd4j.linalg.primitives.Pair;
import java.io.*;
import java.util.ArrayList;
import java.util.Enumeration;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
+import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
/**
@@ -215,7 +219,31 @@ public class ModelSerializer {
*/
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater)
throws IOException {
- ZipFile zipFile = new ZipFile(file);
+ try(InputStream is = new BufferedInputStream(new FileInputStream(file))){
+ return restoreMultiLayerNetwork(is, loadUpdater);
+ }
+ }
+
+
+ /**
+ * Load a MultiLayerNetwork from InputStream from an input stream
+ * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
+ *
+ * @param is the inputstream to load from
+ * @return the loaded multi layer network
+ * @throws IOException
+ * @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
+ */
+ public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
+ throws IOException {
+ return restoreMultiLayerNetworkHelper(is, loadUpdater).getFirst();
+ }
+
+ private static Pair> restoreMultiLayerNetworkHelper(@NonNull InputStream is, boolean loadUpdater)
+ throws IOException {
+ checkInputStream(is);
+
+ Map zipFile = loadZipData(is);
boolean gotConfig = false;
boolean gotCoefficients = false;
@@ -229,11 +257,11 @@ public class ModelSerializer {
DataSetPreProcessor preProcessor = null;
- ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON);
+ byte[] config = zipFile.get(CONFIGURATION_JSON);
if (config != null) {
//restoring configuration
- InputStream stream = zipFile.getInputStream(config);
+ InputStream stream = new ByteArrayInputStream(config);
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
StringBuilder js = new StringBuilder();
@@ -248,25 +276,25 @@ public class ModelSerializer {
}
- ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN);
+ byte[] coefficients = zipFile.get(COEFFICIENTS_BIN);
if (coefficients != null ) {
- if(coefficients.getSize() > 0) {
- InputStream stream = zipFile.getInputStream(coefficients);
+ if(coefficients.length > 0) {
+ InputStream stream = new ByteArrayInputStream(coefficients);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
params = Nd4j.read(dis);
dis.close();
gotCoefficients = true;
} else {
- ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER);
+ byte[] noParamsMarker = zipFile.get(NO_PARAMS_MARKER);
gotCoefficients = (noParamsMarker != null);
}
}
if (loadUpdater) {
- ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
+ byte[] updaterStateEntry = zipFile.get(UPDATER_BIN);
if (updaterStateEntry != null) {
- InputStream stream = zipFile.getInputStream(updaterStateEntry);
+ InputStream stream = new ByteArrayInputStream(updaterStateEntry);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
updaterState = Nd4j.read(dis);
@@ -275,9 +303,9 @@ public class ModelSerializer {
}
}
- ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN);
+ byte[] prep = zipFile.get(PREPROCESSOR_BIN);
if (prep != null) {
- InputStream stream = zipFile.getInputStream(prep);
+ InputStream stream = new ByteArrayInputStream(prep);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
@@ -290,7 +318,6 @@ public class ModelSerializer {
}
- zipFile.close();
if (gotConfig && gotCoefficients) {
MultiLayerConfiguration confFromJson;
@@ -322,37 +349,12 @@ public class ModelSerializer {
if (gotUpdaterState && updaterState != null) {
network.getUpdater().setStateViewArray(network, updaterState, false);
}
- return network;
+ return new Pair<>(network, zipFile);
} else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
-
- /**
- * Load a MultiLayerNetwork from InputStream from an input stream
- * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
- *
- * @param is the inputstream to load from
- * @return the loaded multi layer network
- * @throws IOException
- * @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
- */
- public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
- throws IOException {
- checkInputStream(is);
-
- File tmpFile = null;
- try{
- tmpFile = tempFileFromStream(is);
- return restoreMultiLayerNetwork(tmpFile, loadUpdater);
- } finally {
- if(tmpFile != null){
- tmpFile.delete();
- }
- }
- }
-
/**
* Restore a multi layer network from an input stream
* * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
@@ -403,16 +405,12 @@ public class ModelSerializer {
public static Pair restoreMultiLayerNetworkAndNormalizer(
@NonNull InputStream is, boolean loadUpdater) throws IOException {
checkInputStream(is);
+ is = new CloseShieldInputStream(is);
- File tmpFile = null;
- try {
- tmpFile = tempFileFromStream(is);
- return restoreMultiLayerNetworkAndNormalizer(tmpFile, loadUpdater);
- } finally {
- if (tmpFile != null) {
- tmpFile.delete();
- }
- }
+ Pair> p = restoreMultiLayerNetworkHelper(is, loadUpdater);
+ MultiLayerNetwork net = p.getFirst();
+ Normalizer norm = restoreNormalizerFromMap(p.getSecond());
+ return new Pair<>(net, norm);
}
/**
@@ -425,9 +423,9 @@ public class ModelSerializer {
*/
public static Pair restoreMultiLayerNetworkAndNormalizer(@NonNull File file, boolean loadUpdater)
throws IOException {
- MultiLayerNetwork net = restoreMultiLayerNetwork(file, loadUpdater);
- Normalizer norm = restoreNormalizerFromFile(file);
- return new Pair<>(net, norm);
+ try(InputStream is = new BufferedInputStream(new FileInputStream(file))){
+ return restoreMultiLayerNetworkAndNormalizer(is, loadUpdater);
+ }
}
/**
@@ -463,17 +461,126 @@ public class ModelSerializer {
*/
public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater)
throws IOException {
+ return restoreComputationGraphHelper(is, loadUpdater).getFirst();
+ }
+
+ private static Pair> restoreComputationGraphHelper(@NonNull InputStream is, boolean loadUpdater)
+ throws IOException {
checkInputStream(is);
- File tmpFile = null;
- try{
- tmpFile = tempFileFromStream(is);
- return restoreComputationGraph(tmpFile, loadUpdater);
- } finally {
- if(tmpFile != null){
- tmpFile.delete();
+ Map files = loadZipData(is);
+
+ boolean gotConfig = false;
+ boolean gotCoefficients = false;
+ boolean gotUpdaterState = false;
+ boolean gotPreProcessor = false;
+
+ String json = "";
+ INDArray params = null;
+ INDArray updaterState = null;
+ DataSetPreProcessor preProcessor = null;
+
+
+ byte[] config = files.get(CONFIGURATION_JSON);
+ if (config != null) {
+ //restoring configuration
+
+ InputStream stream = new ByteArrayInputStream(config);
+ BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
+ String line = "";
+ StringBuilder js = new StringBuilder();
+ while ((line = reader.readLine()) != null) {
+ js.append(line).append("\n");
+ }
+ json = js.toString();
+
+ reader.close();
+ stream.close();
+ gotConfig = true;
+ }
+
+
+ byte[] coefficients = files.get(COEFFICIENTS_BIN);
+ if (coefficients != null) {
+ if(coefficients.length > 0) {
+ InputStream stream = new ByteArrayInputStream(coefficients);
+ DataInputStream dis = new DataInputStream(stream);
+ params = Nd4j.read(dis);
+
+ dis.close();
+ gotCoefficients = true;
+ } else {
+ byte[] noParamsMarker = files.get(NO_PARAMS_MARKER);
+ gotCoefficients = (noParamsMarker != null);
}
}
+
+
+ if (loadUpdater) {
+ byte[] updaterStateEntry = files.get(UPDATER_BIN);
+ if (updaterStateEntry != null) {
+ InputStream stream = new ByteArrayInputStream(updaterStateEntry);
+ DataInputStream dis = new DataInputStream(stream);
+ updaterState = Nd4j.read(dis);
+
+ dis.close();
+ gotUpdaterState = true;
+ }
+ }
+
+ byte[] prep = files.get(PREPROCESSOR_BIN);
+ if (prep != null) {
+ InputStream stream = new ByteArrayInputStream(prep);
+ ObjectInputStream ois = new ObjectInputStream(stream);
+
+ try {
+ preProcessor = (DataSetPreProcessor) ois.readObject();
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+
+ gotPreProcessor = true;
+ }
+
+
+ if (gotConfig && gotCoefficients) {
+ ComputationGraphConfiguration confFromJson;
+ try{
+ confFromJson = ComputationGraphConfiguration.fromJson(json);
+ if(confFromJson.getNetworkInputs() == null && (confFromJson.getVertices() == null || confFromJson.getVertices().size() == 0)){
+ //May be deserialized correctly, but mostly with null fields
+ throw new RuntimeException("Invalid JSON - not a ComputationGraphConfiguration");
+ }
+ } catch (Exception e){
+ if(e.getMessage() != null && e.getMessage().contains("registerLegacyCustomClassesForJSON")){
+ throw e;
+ }
+ try{
+ MultiLayerConfiguration.fromJson(json);
+ } catch (Exception e2){
+ //Invalid, and not a compgraph
+ throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is" +
+ " not a valid ComputationGraphConfiguration", e);
+ }
+ throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model appears to be " +
+ "a MultiLayerNetwork - use ModelSerializer.restoreMultiLayerNetwork instead");
+ }
+
+ //Handle legacy config - no network DataType in config, in beta3 or earlier
+ if(params != null)
+ confFromJson.setDataType(params.dataType());
+
+ ComputationGraph cg = new ComputationGraph(confFromJson);
+ cg.init(params, false);
+
+
+ if (gotUpdaterState && updaterState != null) {
+ cg.getUpdater().setStateViewArray(updaterState);
+ }
+ return new Pair<>(cg, files);
+ } else
+ throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
+ + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
/**
@@ -511,15 +618,11 @@ public class ModelSerializer {
@NonNull InputStream is, boolean loadUpdater) throws IOException {
checkInputStream(is);
- File tmpFile = null;
- try {
- tmpFile = tempFileFromStream(is);
- return restoreComputationGraphAndNormalizer(tmpFile, loadUpdater);
- } finally {
- if (tmpFile != null) {
- tmpFile.delete();
- }
- }
+
+ Pair> p = restoreComputationGraphHelper(is, loadUpdater);
+ ComputationGraph net = p.getFirst();
+ Normalizer norm = restoreNormalizerFromMap(p.getSecond());
+ return new Pair<>(net, norm);
}
/**
@@ -532,9 +635,7 @@ public class ModelSerializer {
*/
public static Pair restoreComputationGraphAndNormalizer(@NonNull File file, boolean loadUpdater)
throws IOException {
- ComputationGraph net = restoreComputationGraph(file, loadUpdater);
- Normalizer norm = restoreNormalizerFromFile(file);
- return new Pair<>(net, norm);
+ return restoreComputationGraphAndNormalizer(new FileInputStream(file), loadUpdater);
}
/**
@@ -545,121 +646,7 @@ public class ModelSerializer {
* @throws IOException
*/
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
- ZipFile zipFile = new ZipFile(file);
-
- boolean gotConfig = false;
- boolean gotCoefficients = false;
- boolean gotUpdaterState = false;
- boolean gotPreProcessor = false;
-
- String json = "";
- INDArray params = null;
- INDArray updaterState = null;
- DataSetPreProcessor preProcessor = null;
-
-
- ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON);
- if (config != null) {
- //restoring configuration
-
- InputStream stream = zipFile.getInputStream(config);
- BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
- String line = "";
- StringBuilder js = new StringBuilder();
- while ((line = reader.readLine()) != null) {
- js.append(line).append("\n");
- }
- json = js.toString();
-
- reader.close();
- stream.close();
- gotConfig = true;
- }
-
-
- ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN);
- if (coefficients != null) {
- if(coefficients.getSize() > 0) {
- InputStream stream = zipFile.getInputStream(coefficients);
- DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
- params = Nd4j.read(dis);
-
- dis.close();
- gotCoefficients = true;
- } else {
- ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER);
- gotCoefficients = (noParamsMarker != null);
- }
- }
-
-
- if (loadUpdater) {
- ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
- if (updaterStateEntry != null) {
- InputStream stream = zipFile.getInputStream(updaterStateEntry);
- DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
- updaterState = Nd4j.read(dis);
-
- dis.close();
- gotUpdaterState = true;
- }
- }
-
- ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN);
- if (prep != null) {
- InputStream stream = zipFile.getInputStream(prep);
- ObjectInputStream ois = new ObjectInputStream(stream);
-
- try {
- preProcessor = (DataSetPreProcessor) ois.readObject();
- } catch (ClassNotFoundException e) {
- throw new RuntimeException(e);
- }
-
- gotPreProcessor = true;
- }
-
-
- zipFile.close();
-
- if (gotConfig && gotCoefficients) {
- ComputationGraphConfiguration confFromJson;
- try{
- confFromJson = ComputationGraphConfiguration.fromJson(json);
- if(confFromJson.getNetworkInputs() == null && (confFromJson.getVertices() == null || confFromJson.getVertices().size() == 0)){
- //May be deserialized correctly, but mostly with null fields
- throw new RuntimeException("Invalid JSON - not a ComputationGraphConfiguration");
- }
- } catch (Exception e){
- if(e.getMessage() != null && e.getMessage().contains("registerLegacyCustomClassesForJSON")){
- throw e;
- }
- try{
- MultiLayerConfiguration.fromJson(json);
- } catch (Exception e2){
- //Invalid, and not a compgraph
- throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is" +
- " not a valid ComputationGraphConfiguration", e);
- }
- throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model appears to be " +
- "a MultiLayerNetwork - use ModelSerializer.restoreMultiLayerNetwork instead");
- }
-
- //Handle legacy config - no network DataType in config, in beta3 or earlier
- if(params != null)
- confFromJson.setDataType(params.dataType());
-
- ComputationGraph cg = new ComputationGraph(confFromJson);
- cg.init(params, false);
-
-
- if (gotUpdaterState && updaterState != null) {
- cg.getUpdater().setStateViewArray(updaterState);
- }
- return cg;
- } else
- throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
- + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
+ return restoreComputationGraph(new FileInputStream(file), loadUpdater);
}
/**
@@ -811,15 +798,16 @@ public class ModelSerializer {
}
//Add new object:
- ZipEntry entry = new ZipEntry("objects/" + key);
- writeFile.putNextEntry(entry);
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
oos.writeObject(o);
byte[] bytes = baos.toByteArray();
+ ZipEntry entry = new ZipEntry("objects/" + key);
+ entry.setSize(bytes.length);
+ writeFile.putNextEntry(entry);
writeFile.write(bytes);
+ writeFile.closeEntry();
}
- writeFile.closeEntry();
writeFile.close();
zipFile.close();
@@ -904,18 +892,12 @@ public class ModelSerializer {
* @param file
* @return
*/
- public static T restoreNormalizerFromFile(File file) {
- try (ZipFile zipFile = new ZipFile(file)) {
- ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN);
-
- // checking for file existence
- if (norm == null)
- return null;
-
- return NormalizerSerializer.getDefault().restore(zipFile.getInputStream(norm));
+ public static T restoreNormalizerFromFile(File file) throws IOException {
+ try (InputStream is = new BufferedInputStream(new FileInputStream(file))) {
+ return restoreNormalizerFromInputStream(is);
} catch (Exception e) {
log.warn("Error while restoring normalizer, trying to restore assuming deprecated format...");
- DataNormalization restoredDeprecated = restoreNormalizerFromFileDeprecated(file);
+ DataNormalization restoredDeprecated = restoreNormalizerFromInputStreamDeprecated(new FileInputStream(file));
log.warn("Recovered using deprecated method. Will now re-save the normalizer to fix this issue.");
addNormalizerToModel(file, restoredDeprecated);
@@ -933,16 +915,22 @@ public class ModelSerializer {
*/
public static T restoreNormalizerFromInputStream(InputStream is) throws IOException {
checkInputStream(is);
+ Map files = loadZipData(is);
+ return restoreNormalizerFromMap(files);
+ }
- File tmpFile = null;
+ private static T restoreNormalizerFromMap(Map files) throws IOException {
+ byte[] norm = files.get(NORMALIZER_BIN);
+
+ // checking for file existence
+ if (norm == null)
+ return null;
try {
- tmpFile = tempFileFromStream(is);
- return restoreNormalizerFromFile(tmpFile);
- } finally {
- if(tmpFile != null){
- tmpFile.delete();
- }
+ return NormalizerSerializer.getDefault().restore(new ByteArrayInputStream(norm));
}
+ catch (Exception e) {
+ throw new IOException("Error loading normalizer", e);
+ }
}
/**
@@ -950,20 +938,10 @@ public class ModelSerializer {
*
* This method restores normalizer from a given persisted model file serialized with Java object serialization
*
- * @param file
- * @return
*/
- private static DataNormalization restoreNormalizerFromFileDeprecated(File file) {
- try (ZipFile zipFile = new ZipFile(file)) {
- ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN);
-
- // checking for file existence
- if (norm == null)
- return null;
-
- InputStream stream = zipFile.getInputStream(norm);
+ private static DataNormalization restoreNormalizerFromInputStreamDeprecated(InputStream stream) {
+ try {
ObjectInputStream ois = new ObjectInputStream(stream);
-
try {
DataNormalization normalizer = (DataNormalization) ois.readObject();
return normalizer;
@@ -977,8 +955,6 @@ public class ModelSerializer {
private static void checkInputStream(InputStream inputStream) throws IOException {
-
- /*
//available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887
int available;
try{
@@ -993,34 +969,32 @@ public class ModelSerializer {
throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" +
"multiple times?");
}
- */
}
- private static void checkTempFileFromInputStream(File f) throws IOException {
- if (f.length() <= 0) {
- throw new IOException("Error reading from input stream: temporary file is empty after copying entire stream." +
- " Stream may have been closed before reading, is attempting to be used multiple times, or does not" +
- " point to a model file?");
- }
+ private static Map loadZipData(InputStream is) throws IOException {
+ Map result = new HashMap<>();
+ try (final ZipInputStream zis = new ZipInputStream(is)) {
+ while (true) {
+ final ZipEntry zipEntry = zis.getNextEntry();
+ if (zipEntry == null)
+ break;
+ if(zipEntry.isDirectory() || zipEntry.getSize() > Integer.MAX_VALUE)
+ throw new IllegalArgumentException();
+
+ final int size = (int) (zipEntry.getSize());
+ final byte[] data;
+ if (size >= 0) { // known size
+ data = IOUtils.readFully(zis, size);
+ }
+ else { // unknown size
+ final ByteArrayOutputStream bout = new ByteArrayOutputStream();
+ IOUtils.copy(zis, bout);
+ data = bout.toByteArray();
+ }
+ result.put(zipEntry.getName(), data);
+ }
+ }
+ return result;
}
- private static File tempFileFromStream(InputStream is) throws IOException{
- checkInputStream(is);
- String p = System.getProperty(DL4JSystemProperties.DL4J_TEMP_DIR_PROPERTY);
- File tmpFile = DL4JFileUtils.createTempFile("dl4jModelSerializer", "bin");
- try {
- tmpFile.deleteOnExit();
- BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(tmpFile));
- IOUtils.copy(is, bufferedOutputStream);
- bufferedOutputStream.flush();
- IOUtils.closeQuietly(bufferedOutputStream);
- checkTempFileFromInputStream(tmpFile);
- return tmpFile;
- } catch (IOException e){
- if(tmpFile != null){
- tmpFile.delete();
- }
- throw e;
- }
- }
}