Merge remote-tracking branch 'eclipse/master'
commit
133223e865
|
@ -54,7 +54,11 @@ public class ModelGuesser {
|
||||||
* @return the loaded normalizer
|
* @return the loaded normalizer
|
||||||
*/
|
*/
|
||||||
public static Normalizer<?> loadNormalizer(String path) {
|
public static Normalizer<?> loadNormalizer(String path) {
|
||||||
|
try {
|
||||||
return ModelSerializer.restoreNormalizerFromFile(new File(path));
|
return ModelSerializer.restoreNormalizerFromFile(new File(path));
|
||||||
|
} catch (IOException e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.util;
|
package org.deeplearning4j.util;
|
||||||
|
|
||||||
|
import org.apache.commons.io.input.CloseShieldInputStream;
|
||||||
import org.nd4j.shade.guava.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -43,9 +44,12 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Enumeration;
|
import java.util.Enumeration;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.zip.ZipEntry;
|
import java.util.zip.ZipEntry;
|
||||||
import java.util.zip.ZipFile;
|
import java.util.zip.ZipFile;
|
||||||
|
import java.util.zip.ZipInputStream;
|
||||||
import java.util.zip.ZipOutputStream;
|
import java.util.zip.ZipOutputStream;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -215,7 +219,31 @@ public class ModelSerializer {
|
||||||
*/
|
*/
|
||||||
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater)
|
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
ZipFile zipFile = new ZipFile(file);
|
try(InputStream is = new BufferedInputStream(new FileInputStream(file))){
|
||||||
|
return restoreMultiLayerNetwork(is, loadUpdater);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a MultiLayerNetwork from InputStream from an input stream<br>
|
||||||
|
* Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
|
||||||
|
*
|
||||||
|
* @param is the inputstream to load from
|
||||||
|
* @return the loaded multi layer network
|
||||||
|
* @throws IOException
|
||||||
|
* @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
|
||||||
|
*/
|
||||||
|
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
|
||||||
|
throws IOException {
|
||||||
|
return restoreMultiLayerNetworkHelper(is, loadUpdater).getFirst();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Pair<MultiLayerNetwork, Map<String,byte[]>> restoreMultiLayerNetworkHelper(@NonNull InputStream is, boolean loadUpdater)
|
||||||
|
throws IOException {
|
||||||
|
checkInputStream(is);
|
||||||
|
|
||||||
|
Map<String, byte[]> zipFile = loadZipData(is);
|
||||||
|
|
||||||
boolean gotConfig = false;
|
boolean gotConfig = false;
|
||||||
boolean gotCoefficients = false;
|
boolean gotCoefficients = false;
|
||||||
|
@ -229,11 +257,11 @@ public class ModelSerializer {
|
||||||
DataSetPreProcessor preProcessor = null;
|
DataSetPreProcessor preProcessor = null;
|
||||||
|
|
||||||
|
|
||||||
ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON);
|
byte[] config = zipFile.get(CONFIGURATION_JSON);
|
||||||
if (config != null) {
|
if (config != null) {
|
||||||
//restoring configuration
|
//restoring configuration
|
||||||
|
|
||||||
InputStream stream = zipFile.getInputStream(config);
|
InputStream stream = new ByteArrayInputStream(config);
|
||||||
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
|
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
|
||||||
String line = "";
|
String line = "";
|
||||||
StringBuilder js = new StringBuilder();
|
StringBuilder js = new StringBuilder();
|
||||||
|
@ -248,25 +276,25 @@ public class ModelSerializer {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN);
|
byte[] coefficients = zipFile.get(COEFFICIENTS_BIN);
|
||||||
if (coefficients != null ) {
|
if (coefficients != null ) {
|
||||||
if(coefficients.getSize() > 0) {
|
if(coefficients.length > 0) {
|
||||||
InputStream stream = zipFile.getInputStream(coefficients);
|
InputStream stream = new ByteArrayInputStream(coefficients);
|
||||||
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
|
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
|
||||||
params = Nd4j.read(dis);
|
params = Nd4j.read(dis);
|
||||||
|
|
||||||
dis.close();
|
dis.close();
|
||||||
gotCoefficients = true;
|
gotCoefficients = true;
|
||||||
} else {
|
} else {
|
||||||
ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER);
|
byte[] noParamsMarker = zipFile.get(NO_PARAMS_MARKER);
|
||||||
gotCoefficients = (noParamsMarker != null);
|
gotCoefficients = (noParamsMarker != null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (loadUpdater) {
|
if (loadUpdater) {
|
||||||
ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
|
byte[] updaterStateEntry = zipFile.get(UPDATER_BIN);
|
||||||
if (updaterStateEntry != null) {
|
if (updaterStateEntry != null) {
|
||||||
InputStream stream = zipFile.getInputStream(updaterStateEntry);
|
InputStream stream = new ByteArrayInputStream(updaterStateEntry);
|
||||||
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
|
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
|
||||||
updaterState = Nd4j.read(dis);
|
updaterState = Nd4j.read(dis);
|
||||||
|
|
||||||
|
@ -275,9 +303,9 @@ public class ModelSerializer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN);
|
byte[] prep = zipFile.get(PREPROCESSOR_BIN);
|
||||||
if (prep != null) {
|
if (prep != null) {
|
||||||
InputStream stream = zipFile.getInputStream(prep);
|
InputStream stream = new ByteArrayInputStream(prep);
|
||||||
ObjectInputStream ois = new ObjectInputStream(stream);
|
ObjectInputStream ois = new ObjectInputStream(stream);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
@ -290,7 +318,6 @@ public class ModelSerializer {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
zipFile.close();
|
|
||||||
|
|
||||||
if (gotConfig && gotCoefficients) {
|
if (gotConfig && gotCoefficients) {
|
||||||
MultiLayerConfiguration confFromJson;
|
MultiLayerConfiguration confFromJson;
|
||||||
|
@ -322,37 +349,12 @@ public class ModelSerializer {
|
||||||
if (gotUpdaterState && updaterState != null) {
|
if (gotUpdaterState && updaterState != null) {
|
||||||
network.getUpdater().setStateViewArray(network, updaterState, false);
|
network.getUpdater().setStateViewArray(network, updaterState, false);
|
||||||
}
|
}
|
||||||
return network;
|
return new Pair<>(network, zipFile);
|
||||||
} else
|
} else
|
||||||
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
|
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
|
||||||
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
|
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load a MultiLayerNetwork from InputStream from an input stream<br>
|
|
||||||
* Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
|
|
||||||
*
|
|
||||||
* @param is the inputstream to load from
|
|
||||||
* @return the loaded multi layer network
|
|
||||||
* @throws IOException
|
|
||||||
* @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
|
|
||||||
*/
|
|
||||||
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
|
|
||||||
throws IOException {
|
|
||||||
checkInputStream(is);
|
|
||||||
|
|
||||||
File tmpFile = null;
|
|
||||||
try{
|
|
||||||
tmpFile = tempFileFromStream(is);
|
|
||||||
return restoreMultiLayerNetwork(tmpFile, loadUpdater);
|
|
||||||
} finally {
|
|
||||||
if(tmpFile != null){
|
|
||||||
tmpFile.delete();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Restore a multi layer network from an input stream<br>
|
* Restore a multi layer network from an input stream<br>
|
||||||
* * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
|
* * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
|
||||||
|
@ -403,16 +405,12 @@ public class ModelSerializer {
|
||||||
public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(
|
public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(
|
||||||
@NonNull InputStream is, boolean loadUpdater) throws IOException {
|
@NonNull InputStream is, boolean loadUpdater) throws IOException {
|
||||||
checkInputStream(is);
|
checkInputStream(is);
|
||||||
|
is = new CloseShieldInputStream(is);
|
||||||
|
|
||||||
File tmpFile = null;
|
Pair<MultiLayerNetwork,Map<String,byte[]>> p = restoreMultiLayerNetworkHelper(is, loadUpdater);
|
||||||
try {
|
MultiLayerNetwork net = p.getFirst();
|
||||||
tmpFile = tempFileFromStream(is);
|
Normalizer norm = restoreNormalizerFromMap(p.getSecond());
|
||||||
return restoreMultiLayerNetworkAndNormalizer(tmpFile, loadUpdater);
|
return new Pair<>(net, norm);
|
||||||
} finally {
|
|
||||||
if (tmpFile != null) {
|
|
||||||
tmpFile.delete();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -425,9 +423,9 @@ public class ModelSerializer {
|
||||||
*/
|
*/
|
||||||
public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(@NonNull File file, boolean loadUpdater)
|
public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(@NonNull File file, boolean loadUpdater)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
MultiLayerNetwork net = restoreMultiLayerNetwork(file, loadUpdater);
|
try(InputStream is = new BufferedInputStream(new FileInputStream(file))){
|
||||||
Normalizer norm = restoreNormalizerFromFile(file);
|
return restoreMultiLayerNetworkAndNormalizer(is, loadUpdater);
|
||||||
return new Pair<>(net, norm);
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -463,17 +461,126 @@ public class ModelSerializer {
|
||||||
*/
|
*/
|
||||||
public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater)
|
public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
return restoreComputationGraphHelper(is, loadUpdater).getFirst();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Pair<ComputationGraph,Map<String,byte[]>> restoreComputationGraphHelper(@NonNull InputStream is, boolean loadUpdater)
|
||||||
|
throws IOException {
|
||||||
checkInputStream(is);
|
checkInputStream(is);
|
||||||
|
|
||||||
File tmpFile = null;
|
Map<String, byte[]> files = loadZipData(is);
|
||||||
|
|
||||||
|
boolean gotConfig = false;
|
||||||
|
boolean gotCoefficients = false;
|
||||||
|
boolean gotUpdaterState = false;
|
||||||
|
boolean gotPreProcessor = false;
|
||||||
|
|
||||||
|
String json = "";
|
||||||
|
INDArray params = null;
|
||||||
|
INDArray updaterState = null;
|
||||||
|
DataSetPreProcessor preProcessor = null;
|
||||||
|
|
||||||
|
|
||||||
|
byte[] config = files.get(CONFIGURATION_JSON);
|
||||||
|
if (config != null) {
|
||||||
|
//restoring configuration
|
||||||
|
|
||||||
|
InputStream stream = new ByteArrayInputStream(config);
|
||||||
|
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
|
||||||
|
String line = "";
|
||||||
|
StringBuilder js = new StringBuilder();
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
js.append(line).append("\n");
|
||||||
|
}
|
||||||
|
json = js.toString();
|
||||||
|
|
||||||
|
reader.close();
|
||||||
|
stream.close();
|
||||||
|
gotConfig = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
byte[] coefficients = files.get(COEFFICIENTS_BIN);
|
||||||
|
if (coefficients != null) {
|
||||||
|
if(coefficients.length > 0) {
|
||||||
|
InputStream stream = new ByteArrayInputStream(coefficients);
|
||||||
|
DataInputStream dis = new DataInputStream(stream);
|
||||||
|
params = Nd4j.read(dis);
|
||||||
|
|
||||||
|
dis.close();
|
||||||
|
gotCoefficients = true;
|
||||||
|
} else {
|
||||||
|
byte[] noParamsMarker = files.get(NO_PARAMS_MARKER);
|
||||||
|
gotCoefficients = (noParamsMarker != null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if (loadUpdater) {
|
||||||
|
byte[] updaterStateEntry = files.get(UPDATER_BIN);
|
||||||
|
if (updaterStateEntry != null) {
|
||||||
|
InputStream stream = new ByteArrayInputStream(updaterStateEntry);
|
||||||
|
DataInputStream dis = new DataInputStream(stream);
|
||||||
|
updaterState = Nd4j.read(dis);
|
||||||
|
|
||||||
|
dis.close();
|
||||||
|
gotUpdaterState = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
byte[] prep = files.get(PREPROCESSOR_BIN);
|
||||||
|
if (prep != null) {
|
||||||
|
InputStream stream = new ByteArrayInputStream(prep);
|
||||||
|
ObjectInputStream ois = new ObjectInputStream(stream);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
tmpFile = tempFileFromStream(is);
|
preProcessor = (DataSetPreProcessor) ois.readObject();
|
||||||
return restoreComputationGraph(tmpFile, loadUpdater);
|
} catch (ClassNotFoundException e) {
|
||||||
} finally {
|
throw new RuntimeException(e);
|
||||||
if(tmpFile != null){
|
|
||||||
tmpFile.delete();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gotPreProcessor = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if (gotConfig && gotCoefficients) {
|
||||||
|
ComputationGraphConfiguration confFromJson;
|
||||||
|
try{
|
||||||
|
confFromJson = ComputationGraphConfiguration.fromJson(json);
|
||||||
|
if(confFromJson.getNetworkInputs() == null && (confFromJson.getVertices() == null || confFromJson.getVertices().size() == 0)){
|
||||||
|
//May be deserialized correctly, but mostly with null fields
|
||||||
|
throw new RuntimeException("Invalid JSON - not a ComputationGraphConfiguration");
|
||||||
|
}
|
||||||
|
} catch (Exception e){
|
||||||
|
if(e.getMessage() != null && e.getMessage().contains("registerLegacyCustomClassesForJSON")){
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
try{
|
||||||
|
MultiLayerConfiguration.fromJson(json);
|
||||||
|
} catch (Exception e2){
|
||||||
|
//Invalid, and not a compgraph
|
||||||
|
throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is" +
|
||||||
|
" not a valid ComputationGraphConfiguration", e);
|
||||||
|
}
|
||||||
|
throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model appears to be " +
|
||||||
|
"a MultiLayerNetwork - use ModelSerializer.restoreMultiLayerNetwork instead");
|
||||||
|
}
|
||||||
|
|
||||||
|
//Handle legacy config - no network DataType in config, in beta3 or earlier
|
||||||
|
if(params != null)
|
||||||
|
confFromJson.setDataType(params.dataType());
|
||||||
|
|
||||||
|
ComputationGraph cg = new ComputationGraph(confFromJson);
|
||||||
|
cg.init(params, false);
|
||||||
|
|
||||||
|
|
||||||
|
if (gotUpdaterState && updaterState != null) {
|
||||||
|
cg.getUpdater().setStateViewArray(updaterState);
|
||||||
|
}
|
||||||
|
return new Pair<>(cg, files);
|
||||||
|
} else
|
||||||
|
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
|
||||||
|
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -511,15 +618,11 @@ public class ModelSerializer {
|
||||||
@NonNull InputStream is, boolean loadUpdater) throws IOException {
|
@NonNull InputStream is, boolean loadUpdater) throws IOException {
|
||||||
checkInputStream(is);
|
checkInputStream(is);
|
||||||
|
|
||||||
File tmpFile = null;
|
|
||||||
try {
|
Pair<ComputationGraph,Map<String,byte[]>> p = restoreComputationGraphHelper(is, loadUpdater);
|
||||||
tmpFile = tempFileFromStream(is);
|
ComputationGraph net = p.getFirst();
|
||||||
return restoreComputationGraphAndNormalizer(tmpFile, loadUpdater);
|
Normalizer norm = restoreNormalizerFromMap(p.getSecond());
|
||||||
} finally {
|
return new Pair<>(net, norm);
|
||||||
if (tmpFile != null) {
|
|
||||||
tmpFile.delete();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -532,9 +635,7 @@ public class ModelSerializer {
|
||||||
*/
|
*/
|
||||||
public static Pair<ComputationGraph, Normalizer> restoreComputationGraphAndNormalizer(@NonNull File file, boolean loadUpdater)
|
public static Pair<ComputationGraph, Normalizer> restoreComputationGraphAndNormalizer(@NonNull File file, boolean loadUpdater)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
ComputationGraph net = restoreComputationGraph(file, loadUpdater);
|
return restoreComputationGraphAndNormalizer(new FileInputStream(file), loadUpdater);
|
||||||
Normalizer norm = restoreNormalizerFromFile(file);
|
|
||||||
return new Pair<>(net, norm);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -545,121 +646,7 @@ public class ModelSerializer {
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
*/
|
*/
|
||||||
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
|
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
|
||||||
ZipFile zipFile = new ZipFile(file);
|
return restoreComputationGraph(new FileInputStream(file), loadUpdater);
|
||||||
|
|
||||||
boolean gotConfig = false;
|
|
||||||
boolean gotCoefficients = false;
|
|
||||||
boolean gotUpdaterState = false;
|
|
||||||
boolean gotPreProcessor = false;
|
|
||||||
|
|
||||||
String json = "";
|
|
||||||
INDArray params = null;
|
|
||||||
INDArray updaterState = null;
|
|
||||||
DataSetPreProcessor preProcessor = null;
|
|
||||||
|
|
||||||
|
|
||||||
ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON);
|
|
||||||
if (config != null) {
|
|
||||||
//restoring configuration
|
|
||||||
|
|
||||||
InputStream stream = zipFile.getInputStream(config);
|
|
||||||
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
|
|
||||||
String line = "";
|
|
||||||
StringBuilder js = new StringBuilder();
|
|
||||||
while ((line = reader.readLine()) != null) {
|
|
||||||
js.append(line).append("\n");
|
|
||||||
}
|
|
||||||
json = js.toString();
|
|
||||||
|
|
||||||
reader.close();
|
|
||||||
stream.close();
|
|
||||||
gotConfig = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN);
|
|
||||||
if (coefficients != null) {
|
|
||||||
if(coefficients.getSize() > 0) {
|
|
||||||
InputStream stream = zipFile.getInputStream(coefficients);
|
|
||||||
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
|
|
||||||
params = Nd4j.read(dis);
|
|
||||||
|
|
||||||
dis.close();
|
|
||||||
gotCoefficients = true;
|
|
||||||
} else {
|
|
||||||
ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER);
|
|
||||||
gotCoefficients = (noParamsMarker != null);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (loadUpdater) {
|
|
||||||
ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
|
|
||||||
if (updaterStateEntry != null) {
|
|
||||||
InputStream stream = zipFile.getInputStream(updaterStateEntry);
|
|
||||||
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
|
|
||||||
updaterState = Nd4j.read(dis);
|
|
||||||
|
|
||||||
dis.close();
|
|
||||||
gotUpdaterState = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN);
|
|
||||||
if (prep != null) {
|
|
||||||
InputStream stream = zipFile.getInputStream(prep);
|
|
||||||
ObjectInputStream ois = new ObjectInputStream(stream);
|
|
||||||
|
|
||||||
try {
|
|
||||||
preProcessor = (DataSetPreProcessor) ois.readObject();
|
|
||||||
} catch (ClassNotFoundException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
gotPreProcessor = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
zipFile.close();
|
|
||||||
|
|
||||||
if (gotConfig && gotCoefficients) {
|
|
||||||
ComputationGraphConfiguration confFromJson;
|
|
||||||
try{
|
|
||||||
confFromJson = ComputationGraphConfiguration.fromJson(json);
|
|
||||||
if(confFromJson.getNetworkInputs() == null && (confFromJson.getVertices() == null || confFromJson.getVertices().size() == 0)){
|
|
||||||
//May be deserialized correctly, but mostly with null fields
|
|
||||||
throw new RuntimeException("Invalid JSON - not a ComputationGraphConfiguration");
|
|
||||||
}
|
|
||||||
} catch (Exception e){
|
|
||||||
if(e.getMessage() != null && e.getMessage().contains("registerLegacyCustomClassesForJSON")){
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
try{
|
|
||||||
MultiLayerConfiguration.fromJson(json);
|
|
||||||
} catch (Exception e2){
|
|
||||||
//Invalid, and not a compgraph
|
|
||||||
throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is" +
|
|
||||||
" not a valid ComputationGraphConfiguration", e);
|
|
||||||
}
|
|
||||||
throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model appears to be " +
|
|
||||||
"a MultiLayerNetwork - use ModelSerializer.restoreMultiLayerNetwork instead");
|
|
||||||
}
|
|
||||||
|
|
||||||
//Handle legacy config - no network DataType in config, in beta3 or earlier
|
|
||||||
if(params != null)
|
|
||||||
confFromJson.setDataType(params.dataType());
|
|
||||||
|
|
||||||
ComputationGraph cg = new ComputationGraph(confFromJson);
|
|
||||||
cg.init(params, false);
|
|
||||||
|
|
||||||
|
|
||||||
if (gotUpdaterState && updaterState != null) {
|
|
||||||
cg.getUpdater().setStateViewArray(updaterState);
|
|
||||||
}
|
|
||||||
return cg;
|
|
||||||
} else
|
|
||||||
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
|
|
||||||
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -811,15 +798,16 @@ public class ModelSerializer {
|
||||||
}
|
}
|
||||||
|
|
||||||
//Add new object:
|
//Add new object:
|
||||||
ZipEntry entry = new ZipEntry("objects/" + key);
|
|
||||||
writeFile.putNextEntry(entry);
|
|
||||||
|
|
||||||
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
|
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
|
||||||
oos.writeObject(o);
|
oos.writeObject(o);
|
||||||
byte[] bytes = baos.toByteArray();
|
byte[] bytes = baos.toByteArray();
|
||||||
|
ZipEntry entry = new ZipEntry("objects/" + key);
|
||||||
|
entry.setSize(bytes.length);
|
||||||
|
writeFile.putNextEntry(entry);
|
||||||
writeFile.write(bytes);
|
writeFile.write(bytes);
|
||||||
}
|
|
||||||
writeFile.closeEntry();
|
writeFile.closeEntry();
|
||||||
|
}
|
||||||
|
|
||||||
writeFile.close();
|
writeFile.close();
|
||||||
zipFile.close();
|
zipFile.close();
|
||||||
|
@ -904,18 +892,12 @@ public class ModelSerializer {
|
||||||
* @param file
|
* @param file
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static <T extends Normalizer> T restoreNormalizerFromFile(File file) {
|
public static <T extends Normalizer> T restoreNormalizerFromFile(File file) throws IOException {
|
||||||
try (ZipFile zipFile = new ZipFile(file)) {
|
try (InputStream is = new BufferedInputStream(new FileInputStream(file))) {
|
||||||
ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN);
|
return restoreNormalizerFromInputStream(is);
|
||||||
|
|
||||||
// checking for file existence
|
|
||||||
if (norm == null)
|
|
||||||
return null;
|
|
||||||
|
|
||||||
return NormalizerSerializer.getDefault().restore(zipFile.getInputStream(norm));
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.warn("Error while restoring normalizer, trying to restore assuming deprecated format...");
|
log.warn("Error while restoring normalizer, trying to restore assuming deprecated format...");
|
||||||
DataNormalization restoredDeprecated = restoreNormalizerFromFileDeprecated(file);
|
DataNormalization restoredDeprecated = restoreNormalizerFromInputStreamDeprecated(new FileInputStream(file));
|
||||||
|
|
||||||
log.warn("Recovered using deprecated method. Will now re-save the normalizer to fix this issue.");
|
log.warn("Recovered using deprecated method. Will now re-save the normalizer to fix this issue.");
|
||||||
addNormalizerToModel(file, restoredDeprecated);
|
addNormalizerToModel(file, restoredDeprecated);
|
||||||
|
@ -933,15 +915,21 @@ public class ModelSerializer {
|
||||||
*/
|
*/
|
||||||
public static <T extends Normalizer> T restoreNormalizerFromInputStream(InputStream is) throws IOException {
|
public static <T extends Normalizer> T restoreNormalizerFromInputStream(InputStream is) throws IOException {
|
||||||
checkInputStream(is);
|
checkInputStream(is);
|
||||||
|
Map<String, byte[]> files = loadZipData(is);
|
||||||
File tmpFile = null;
|
return restoreNormalizerFromMap(files);
|
||||||
try {
|
|
||||||
tmpFile = tempFileFromStream(is);
|
|
||||||
return restoreNormalizerFromFile(tmpFile);
|
|
||||||
} finally {
|
|
||||||
if(tmpFile != null){
|
|
||||||
tmpFile.delete();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static <T extends Normalizer> T restoreNormalizerFromMap(Map<String, byte[]> files) throws IOException {
|
||||||
|
byte[] norm = files.get(NORMALIZER_BIN);
|
||||||
|
|
||||||
|
// checking for file existence
|
||||||
|
if (norm == null)
|
||||||
|
return null;
|
||||||
|
try {
|
||||||
|
return NormalizerSerializer.getDefault().restore(new ByteArrayInputStream(norm));
|
||||||
|
}
|
||||||
|
catch (Exception e) {
|
||||||
|
throw new IOException("Error loading normalizer", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -950,20 +938,10 @@ public class ModelSerializer {
|
||||||
*
|
*
|
||||||
* This method restores normalizer from a given persisted model file serialized with Java object serialization
|
* This method restores normalizer from a given persisted model file serialized with Java object serialization
|
||||||
*
|
*
|
||||||
* @param file
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
private static DataNormalization restoreNormalizerFromFileDeprecated(File file) {
|
private static DataNormalization restoreNormalizerFromInputStreamDeprecated(InputStream stream) {
|
||||||
try (ZipFile zipFile = new ZipFile(file)) {
|
try {
|
||||||
ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN);
|
|
||||||
|
|
||||||
// checking for file existence
|
|
||||||
if (norm == null)
|
|
||||||
return null;
|
|
||||||
|
|
||||||
InputStream stream = zipFile.getInputStream(norm);
|
|
||||||
ObjectInputStream ois = new ObjectInputStream(stream);
|
ObjectInputStream ois = new ObjectInputStream(stream);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
DataNormalization normalizer = (DataNormalization) ois.readObject();
|
DataNormalization normalizer = (DataNormalization) ois.readObject();
|
||||||
return normalizer;
|
return normalizer;
|
||||||
|
@ -977,8 +955,6 @@ public class ModelSerializer {
|
||||||
|
|
||||||
|
|
||||||
private static void checkInputStream(InputStream inputStream) throws IOException {
|
private static void checkInputStream(InputStream inputStream) throws IOException {
|
||||||
|
|
||||||
/*
|
|
||||||
//available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887
|
//available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887
|
||||||
int available;
|
int available;
|
||||||
try{
|
try{
|
||||||
|
@ -993,34 +969,32 @@ public class ModelSerializer {
|
||||||
throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" +
|
throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" +
|
||||||
"multiple times?");
|
"multiple times?");
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void checkTempFileFromInputStream(File f) throws IOException {
|
private static Map<String, byte[]> loadZipData(InputStream is) throws IOException {
|
||||||
if (f.length() <= 0) {
|
Map<String, byte[]> result = new HashMap<>();
|
||||||
throw new IOException("Error reading from input stream: temporary file is empty after copying entire stream." +
|
try (final ZipInputStream zis = new ZipInputStream(is)) {
|
||||||
" Stream may have been closed before reading, is attempting to be used multiple times, or does not" +
|
while (true) {
|
||||||
" point to a model file?");
|
final ZipEntry zipEntry = zis.getNextEntry();
|
||||||
|
if (zipEntry == null)
|
||||||
|
break;
|
||||||
|
if(zipEntry.isDirectory() || zipEntry.getSize() > Integer.MAX_VALUE)
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
|
||||||
|
final int size = (int) (zipEntry.getSize());
|
||||||
|
final byte[] data;
|
||||||
|
if (size >= 0) { // known size
|
||||||
|
data = IOUtils.readFully(zis, size);
|
||||||
}
|
}
|
||||||
|
else { // unknown size
|
||||||
|
final ByteArrayOutputStream bout = new ByteArrayOutputStream();
|
||||||
|
IOUtils.copy(zis, bout);
|
||||||
|
data = bout.toByteArray();
|
||||||
|
}
|
||||||
|
result.put(zipEntry.getName(), data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static File tempFileFromStream(InputStream is) throws IOException{
|
|
||||||
checkInputStream(is);
|
|
||||||
String p = System.getProperty(DL4JSystemProperties.DL4J_TEMP_DIR_PROPERTY);
|
|
||||||
File tmpFile = DL4JFileUtils.createTempFile("dl4jModelSerializer", "bin");
|
|
||||||
try {
|
|
||||||
tmpFile.deleteOnExit();
|
|
||||||
BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(tmpFile));
|
|
||||||
IOUtils.copy(is, bufferedOutputStream);
|
|
||||||
bufferedOutputStream.flush();
|
|
||||||
IOUtils.closeQuietly(bufferedOutputStream);
|
|
||||||
checkTempFileFromInputStream(tmpFile);
|
|
||||||
return tmpFile;
|
|
||||||
} catch (IOException e){
|
|
||||||
if(tmpFile != null){
|
|
||||||
tmpFile.delete();
|
|
||||||
}
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue