Merge remote-tracking branch 'eclipse/master'

master
Alex Black 2020-04-14 22:24:17 +10:00
commit 133223e865
2 changed files with 223 additions and 245 deletions

View File

@ -54,7 +54,11 @@ public class ModelGuesser {
* @return the loaded normalizer * @return the loaded normalizer
*/ */
public static Normalizer<?> loadNormalizer(String path) { public static Normalizer<?> loadNormalizer(String path) {
try {
return ModelSerializer.restoreNormalizerFromFile(new File(path)); return ModelSerializer.restoreNormalizerFromFile(new File(path));
} catch (IOException e){
throw new RuntimeException(e);
}
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.util; package org.deeplearning4j.util;
import org.apache.commons.io.input.CloseShieldInputStream;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -43,9 +44,12 @@ import org.nd4j.linalg.primitives.Pair;
import java.io.*; import java.io.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.zip.ZipEntry; import java.util.zip.ZipEntry;
import java.util.zip.ZipFile; import java.util.zip.ZipFile;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream; import java.util.zip.ZipOutputStream;
/** /**
@ -215,7 +219,31 @@ public class ModelSerializer {
*/ */
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater) public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater)
throws IOException { throws IOException {
ZipFile zipFile = new ZipFile(file); try(InputStream is = new BufferedInputStream(new FileInputStream(file))){
return restoreMultiLayerNetwork(is, loadUpdater);
}
}
/**
* Load a MultiLayerNetwork from InputStream from an input stream<br>
* Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
*
* @param is the inputstream to load from
* @return the loaded multi layer network
* @throws IOException
* @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
*/
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
throws IOException {
return restoreMultiLayerNetworkHelper(is, loadUpdater).getFirst();
}
private static Pair<MultiLayerNetwork, Map<String,byte[]>> restoreMultiLayerNetworkHelper(@NonNull InputStream is, boolean loadUpdater)
throws IOException {
checkInputStream(is);
Map<String, byte[]> zipFile = loadZipData(is);
boolean gotConfig = false; boolean gotConfig = false;
boolean gotCoefficients = false; boolean gotCoefficients = false;
@ -229,11 +257,11 @@ public class ModelSerializer {
DataSetPreProcessor preProcessor = null; DataSetPreProcessor preProcessor = null;
ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON); byte[] config = zipFile.get(CONFIGURATION_JSON);
if (config != null) { if (config != null) {
//restoring configuration //restoring configuration
InputStream stream = zipFile.getInputStream(config); InputStream stream = new ByteArrayInputStream(config);
BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = ""; String line = "";
StringBuilder js = new StringBuilder(); StringBuilder js = new StringBuilder();
@ -248,25 +276,25 @@ public class ModelSerializer {
} }
ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN); byte[] coefficients = zipFile.get(COEFFICIENTS_BIN);
if (coefficients != null ) { if (coefficients != null ) {
if(coefficients.getSize() > 0) { if(coefficients.length > 0) {
InputStream stream = zipFile.getInputStream(coefficients); InputStream stream = new ByteArrayInputStream(coefficients);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
params = Nd4j.read(dis); params = Nd4j.read(dis);
dis.close(); dis.close();
gotCoefficients = true; gotCoefficients = true;
} else { } else {
ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER); byte[] noParamsMarker = zipFile.get(NO_PARAMS_MARKER);
gotCoefficients = (noParamsMarker != null); gotCoefficients = (noParamsMarker != null);
} }
} }
if (loadUpdater) { if (loadUpdater) {
ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN); byte[] updaterStateEntry = zipFile.get(UPDATER_BIN);
if (updaterStateEntry != null) { if (updaterStateEntry != null) {
InputStream stream = zipFile.getInputStream(updaterStateEntry); InputStream stream = new ByteArrayInputStream(updaterStateEntry);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream)); DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
updaterState = Nd4j.read(dis); updaterState = Nd4j.read(dis);
@ -275,9 +303,9 @@ public class ModelSerializer {
} }
} }
ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN); byte[] prep = zipFile.get(PREPROCESSOR_BIN);
if (prep != null) { if (prep != null) {
InputStream stream = zipFile.getInputStream(prep); InputStream stream = new ByteArrayInputStream(prep);
ObjectInputStream ois = new ObjectInputStream(stream); ObjectInputStream ois = new ObjectInputStream(stream);
try { try {
@ -290,7 +318,6 @@ public class ModelSerializer {
} }
zipFile.close();
if (gotConfig && gotCoefficients) { if (gotConfig && gotCoefficients) {
MultiLayerConfiguration confFromJson; MultiLayerConfiguration confFromJson;
@ -322,37 +349,12 @@ public class ModelSerializer {
if (gotUpdaterState && updaterState != null) { if (gotUpdaterState && updaterState != null) {
network.getUpdater().setStateViewArray(network, updaterState, false); network.getUpdater().setStateViewArray(network, updaterState, false);
} }
return network; return new Pair<>(network, zipFile);
} else } else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
} }
/**
* Load a MultiLayerNetwork from InputStream from an input stream<br>
* Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
*
* @param is the inputstream to load from
* @return the loaded multi layer network
* @throws IOException
* @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
*/
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
throws IOException {
checkInputStream(is);
File tmpFile = null;
try{
tmpFile = tempFileFromStream(is);
return restoreMultiLayerNetwork(tmpFile, loadUpdater);
} finally {
if(tmpFile != null){
tmpFile.delete();
}
}
}
/** /**
* Restore a multi layer network from an input stream<br> * Restore a multi layer network from an input stream<br>
* * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used. * * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
@ -403,16 +405,12 @@ public class ModelSerializer {
public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer( public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(
@NonNull InputStream is, boolean loadUpdater) throws IOException { @NonNull InputStream is, boolean loadUpdater) throws IOException {
checkInputStream(is); checkInputStream(is);
is = new CloseShieldInputStream(is);
File tmpFile = null; Pair<MultiLayerNetwork,Map<String,byte[]>> p = restoreMultiLayerNetworkHelper(is, loadUpdater);
try { MultiLayerNetwork net = p.getFirst();
tmpFile = tempFileFromStream(is); Normalizer norm = restoreNormalizerFromMap(p.getSecond());
return restoreMultiLayerNetworkAndNormalizer(tmpFile, loadUpdater); return new Pair<>(net, norm);
} finally {
if (tmpFile != null) {
tmpFile.delete();
}
}
} }
/** /**
@ -425,9 +423,9 @@ public class ModelSerializer {
*/ */
public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(@NonNull File file, boolean loadUpdater) public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(@NonNull File file, boolean loadUpdater)
throws IOException { throws IOException {
MultiLayerNetwork net = restoreMultiLayerNetwork(file, loadUpdater); try(InputStream is = new BufferedInputStream(new FileInputStream(file))){
Normalizer norm = restoreNormalizerFromFile(file); return restoreMultiLayerNetworkAndNormalizer(is, loadUpdater);
return new Pair<>(net, norm); }
} }
/** /**
@ -463,17 +461,126 @@ public class ModelSerializer {
*/ */
public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater) public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater)
throws IOException { throws IOException {
return restoreComputationGraphHelper(is, loadUpdater).getFirst();
}
private static Pair<ComputationGraph,Map<String,byte[]>> restoreComputationGraphHelper(@NonNull InputStream is, boolean loadUpdater)
throws IOException {
checkInputStream(is); checkInputStream(is);
File tmpFile = null; Map<String, byte[]> files = loadZipData(is);
boolean gotConfig = false;
boolean gotCoefficients = false;
boolean gotUpdaterState = false;
boolean gotPreProcessor = false;
String json = "";
INDArray params = null;
INDArray updaterState = null;
DataSetPreProcessor preProcessor = null;
byte[] config = files.get(CONFIGURATION_JSON);
if (config != null) {
//restoring configuration
InputStream stream = new ByteArrayInputStream(config);
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
StringBuilder js = new StringBuilder();
while ((line = reader.readLine()) != null) {
js.append(line).append("\n");
}
json = js.toString();
reader.close();
stream.close();
gotConfig = true;
}
byte[] coefficients = files.get(COEFFICIENTS_BIN);
if (coefficients != null) {
if(coefficients.length > 0) {
InputStream stream = new ByteArrayInputStream(coefficients);
DataInputStream dis = new DataInputStream(stream);
params = Nd4j.read(dis);
dis.close();
gotCoefficients = true;
} else {
byte[] noParamsMarker = files.get(NO_PARAMS_MARKER);
gotCoefficients = (noParamsMarker != null);
}
}
if (loadUpdater) {
byte[] updaterStateEntry = files.get(UPDATER_BIN);
if (updaterStateEntry != null) {
InputStream stream = new ByteArrayInputStream(updaterStateEntry);
DataInputStream dis = new DataInputStream(stream);
updaterState = Nd4j.read(dis);
dis.close();
gotUpdaterState = true;
}
}
byte[] prep = files.get(PREPROCESSOR_BIN);
if (prep != null) {
InputStream stream = new ByteArrayInputStream(prep);
ObjectInputStream ois = new ObjectInputStream(stream);
try { try {
tmpFile = tempFileFromStream(is); preProcessor = (DataSetPreProcessor) ois.readObject();
return restoreComputationGraph(tmpFile, loadUpdater); } catch (ClassNotFoundException e) {
} finally { throw new RuntimeException(e);
if(tmpFile != null){
tmpFile.delete();
} }
gotPreProcessor = true;
} }
if (gotConfig && gotCoefficients) {
ComputationGraphConfiguration confFromJson;
try{
confFromJson = ComputationGraphConfiguration.fromJson(json);
if(confFromJson.getNetworkInputs() == null && (confFromJson.getVertices() == null || confFromJson.getVertices().size() == 0)){
//May be deserialized correctly, but mostly with null fields
throw new RuntimeException("Invalid JSON - not a ComputationGraphConfiguration");
}
} catch (Exception e){
if(e.getMessage() != null && e.getMessage().contains("registerLegacyCustomClassesForJSON")){
throw e;
}
try{
MultiLayerConfiguration.fromJson(json);
} catch (Exception e2){
//Invalid, and not a compgraph
throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is" +
" not a valid ComputationGraphConfiguration", e);
}
throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model appears to be " +
"a MultiLayerNetwork - use ModelSerializer.restoreMultiLayerNetwork instead");
}
//Handle legacy config - no network DataType in config, in beta3 or earlier
if(params != null)
confFromJson.setDataType(params.dataType());
ComputationGraph cg = new ComputationGraph(confFromJson);
cg.init(params, false);
if (gotUpdaterState && updaterState != null) {
cg.getUpdater().setStateViewArray(updaterState);
}
return new Pair<>(cg, files);
} else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
} }
/** /**
@ -511,15 +618,11 @@ public class ModelSerializer {
@NonNull InputStream is, boolean loadUpdater) throws IOException { @NonNull InputStream is, boolean loadUpdater) throws IOException {
checkInputStream(is); checkInputStream(is);
File tmpFile = null;
try { Pair<ComputationGraph,Map<String,byte[]>> p = restoreComputationGraphHelper(is, loadUpdater);
tmpFile = tempFileFromStream(is); ComputationGraph net = p.getFirst();
return restoreComputationGraphAndNormalizer(tmpFile, loadUpdater); Normalizer norm = restoreNormalizerFromMap(p.getSecond());
} finally { return new Pair<>(net, norm);
if (tmpFile != null) {
tmpFile.delete();
}
}
} }
/** /**
@ -532,9 +635,7 @@ public class ModelSerializer {
*/ */
public static Pair<ComputationGraph, Normalizer> restoreComputationGraphAndNormalizer(@NonNull File file, boolean loadUpdater) public static Pair<ComputationGraph, Normalizer> restoreComputationGraphAndNormalizer(@NonNull File file, boolean loadUpdater)
throws IOException { throws IOException {
ComputationGraph net = restoreComputationGraph(file, loadUpdater); return restoreComputationGraphAndNormalizer(new FileInputStream(file), loadUpdater);
Normalizer norm = restoreNormalizerFromFile(file);
return new Pair<>(net, norm);
} }
/** /**
@ -545,121 +646,7 @@ public class ModelSerializer {
* @throws IOException * @throws IOException
*/ */
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException { public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
ZipFile zipFile = new ZipFile(file); return restoreComputationGraph(new FileInputStream(file), loadUpdater);
boolean gotConfig = false;
boolean gotCoefficients = false;
boolean gotUpdaterState = false;
boolean gotPreProcessor = false;
String json = "";
INDArray params = null;
INDArray updaterState = null;
DataSetPreProcessor preProcessor = null;
ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON);
if (config != null) {
//restoring configuration
InputStream stream = zipFile.getInputStream(config);
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
StringBuilder js = new StringBuilder();
while ((line = reader.readLine()) != null) {
js.append(line).append("\n");
}
json = js.toString();
reader.close();
stream.close();
gotConfig = true;
}
ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN);
if (coefficients != null) {
if(coefficients.getSize() > 0) {
InputStream stream = zipFile.getInputStream(coefficients);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
params = Nd4j.read(dis);
dis.close();
gotCoefficients = true;
} else {
ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER);
gotCoefficients = (noParamsMarker != null);
}
}
if (loadUpdater) {
ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
if (updaterStateEntry != null) {
InputStream stream = zipFile.getInputStream(updaterStateEntry);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
updaterState = Nd4j.read(dis);
dis.close();
gotUpdaterState = true;
}
}
ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN);
if (prep != null) {
InputStream stream = zipFile.getInputStream(prep);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
preProcessor = (DataSetPreProcessor) ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
gotPreProcessor = true;
}
zipFile.close();
if (gotConfig && gotCoefficients) {
ComputationGraphConfiguration confFromJson;
try{
confFromJson = ComputationGraphConfiguration.fromJson(json);
if(confFromJson.getNetworkInputs() == null && (confFromJson.getVertices() == null || confFromJson.getVertices().size() == 0)){
//May be deserialized correctly, but mostly with null fields
throw new RuntimeException("Invalid JSON - not a ComputationGraphConfiguration");
}
} catch (Exception e){
if(e.getMessage() != null && e.getMessage().contains("registerLegacyCustomClassesForJSON")){
throw e;
}
try{
MultiLayerConfiguration.fromJson(json);
} catch (Exception e2){
//Invalid, and not a compgraph
throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is" +
" not a valid ComputationGraphConfiguration", e);
}
throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model appears to be " +
"a MultiLayerNetwork - use ModelSerializer.restoreMultiLayerNetwork instead");
}
//Handle legacy config - no network DataType in config, in beta3 or earlier
if(params != null)
confFromJson.setDataType(params.dataType());
ComputationGraph cg = new ComputationGraph(confFromJson);
cg.init(params, false);
if (gotUpdaterState && updaterState != null) {
cg.getUpdater().setStateViewArray(updaterState);
}
return cg;
} else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
} }
/** /**
@ -811,15 +798,16 @@ public class ModelSerializer {
} }
//Add new object: //Add new object:
ZipEntry entry = new ZipEntry("objects/" + key);
writeFile.putNextEntry(entry);
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){ try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
oos.writeObject(o); oos.writeObject(o);
byte[] bytes = baos.toByteArray(); byte[] bytes = baos.toByteArray();
ZipEntry entry = new ZipEntry("objects/" + key);
entry.setSize(bytes.length);
writeFile.putNextEntry(entry);
writeFile.write(bytes); writeFile.write(bytes);
}
writeFile.closeEntry(); writeFile.closeEntry();
}
writeFile.close(); writeFile.close();
zipFile.close(); zipFile.close();
@ -904,18 +892,12 @@ public class ModelSerializer {
* @param file * @param file
* @return * @return
*/ */
public static <T extends Normalizer> T restoreNormalizerFromFile(File file) { public static <T extends Normalizer> T restoreNormalizerFromFile(File file) throws IOException {
try (ZipFile zipFile = new ZipFile(file)) { try (InputStream is = new BufferedInputStream(new FileInputStream(file))) {
ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN); return restoreNormalizerFromInputStream(is);
// checking for file existence
if (norm == null)
return null;
return NormalizerSerializer.getDefault().restore(zipFile.getInputStream(norm));
} catch (Exception e) { } catch (Exception e) {
log.warn("Error while restoring normalizer, trying to restore assuming deprecated format..."); log.warn("Error while restoring normalizer, trying to restore assuming deprecated format...");
DataNormalization restoredDeprecated = restoreNormalizerFromFileDeprecated(file); DataNormalization restoredDeprecated = restoreNormalizerFromInputStreamDeprecated(new FileInputStream(file));
log.warn("Recovered using deprecated method. Will now re-save the normalizer to fix this issue."); log.warn("Recovered using deprecated method. Will now re-save the normalizer to fix this issue.");
addNormalizerToModel(file, restoredDeprecated); addNormalizerToModel(file, restoredDeprecated);
@ -933,15 +915,21 @@ public class ModelSerializer {
*/ */
public static <T extends Normalizer> T restoreNormalizerFromInputStream(InputStream is) throws IOException { public static <T extends Normalizer> T restoreNormalizerFromInputStream(InputStream is) throws IOException {
checkInputStream(is); checkInputStream(is);
Map<String, byte[]> files = loadZipData(is);
File tmpFile = null; return restoreNormalizerFromMap(files);
try {
tmpFile = tempFileFromStream(is);
return restoreNormalizerFromFile(tmpFile);
} finally {
if(tmpFile != null){
tmpFile.delete();
} }
private static <T extends Normalizer> T restoreNormalizerFromMap(Map<String, byte[]> files) throws IOException {
byte[] norm = files.get(NORMALIZER_BIN);
// checking for file existence
if (norm == null)
return null;
try {
return NormalizerSerializer.getDefault().restore(new ByteArrayInputStream(norm));
}
catch (Exception e) {
throw new IOException("Error loading normalizer", e);
} }
} }
@ -950,20 +938,10 @@ public class ModelSerializer {
* *
* This method restores normalizer from a given persisted model file serialized with Java object serialization * This method restores normalizer from a given persisted model file serialized with Java object serialization
* *
* @param file
* @return
*/ */
private static DataNormalization restoreNormalizerFromFileDeprecated(File file) { private static DataNormalization restoreNormalizerFromInputStreamDeprecated(InputStream stream) {
try (ZipFile zipFile = new ZipFile(file)) { try {
ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN);
// checking for file existence
if (norm == null)
return null;
InputStream stream = zipFile.getInputStream(norm);
ObjectInputStream ois = new ObjectInputStream(stream); ObjectInputStream ois = new ObjectInputStream(stream);
try { try {
DataNormalization normalizer = (DataNormalization) ois.readObject(); DataNormalization normalizer = (DataNormalization) ois.readObject();
return normalizer; return normalizer;
@ -977,8 +955,6 @@ public class ModelSerializer {
private static void checkInputStream(InputStream inputStream) throws IOException { private static void checkInputStream(InputStream inputStream) throws IOException {
/*
//available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887 //available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887
int available; int available;
try{ try{
@ -993,34 +969,32 @@ public class ModelSerializer {
throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" + throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" +
"multiple times?"); "multiple times?");
} }
*/
} }
private static void checkTempFileFromInputStream(File f) throws IOException { private static Map<String, byte[]> loadZipData(InputStream is) throws IOException {
if (f.length() <= 0) { Map<String, byte[]> result = new HashMap<>();
throw new IOException("Error reading from input stream: temporary file is empty after copying entire stream." + try (final ZipInputStream zis = new ZipInputStream(is)) {
" Stream may have been closed before reading, is attempting to be used multiple times, or does not" + while (true) {
" point to a model file?"); final ZipEntry zipEntry = zis.getNextEntry();
if (zipEntry == null)
break;
if(zipEntry.isDirectory() || zipEntry.getSize() > Integer.MAX_VALUE)
throw new IllegalArgumentException();
final int size = (int) (zipEntry.getSize());
final byte[] data;
if (size >= 0) { // known size
data = IOUtils.readFully(zis, size);
} }
else { // unknown size
final ByteArrayOutputStream bout = new ByteArrayOutputStream();
IOUtils.copy(zis, bout);
data = bout.toByteArray();
}
result.put(zipEntry.getName(), data);
}
}
return result;
} }
private static File tempFileFromStream(InputStream is) throws IOException{
checkInputStream(is);
String p = System.getProperty(DL4JSystemProperties.DL4J_TEMP_DIR_PROPERTY);
File tmpFile = DL4JFileUtils.createTempFile("dl4jModelSerializer", "bin");
try {
tmpFile.deleteOnExit();
BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(tmpFile));
IOUtils.copy(is, bufferedOutputStream);
bufferedOutputStream.flush();
IOUtils.closeQuietly(bufferedOutputStream);
checkTempFileFromInputStream(tmpFile);
return tmpFile;
} catch (IOException e){
if(tmpFile != null){
tmpFile.delete();
}
throw e;
}
}
} }