Fix loading both model and serializer at once from stream + re-add checks
Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
e3a8629214
commit
55a3d9bb2c
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.util;
|
||||
|
||||
import org.apache.commons.io.input.CloseShieldInputStream;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -235,6 +236,11 @@ public class ModelSerializer {
|
|||
*/
|
||||
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
|
||||
throws IOException {
|
||||
return restoreMultiLayerNetworkHelper(is, loadUpdater).getFirst();
|
||||
}
|
||||
|
||||
private static Pair<MultiLayerNetwork, Map<String,byte[]>> restoreMultiLayerNetworkHelper(@NonNull InputStream is, boolean loadUpdater)
|
||||
throws IOException {
|
||||
checkInputStream(is);
|
||||
|
||||
Map<String, byte[]> zipFile = loadZipData(is);
|
||||
|
@ -343,7 +349,7 @@ public class ModelSerializer {
|
|||
if (gotUpdaterState && updaterState != null) {
|
||||
network.getUpdater().setStateViewArray(network, updaterState, false);
|
||||
}
|
||||
return network;
|
||||
return new Pair<>(network, zipFile);
|
||||
} else
|
||||
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
|
||||
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
|
||||
|
@ -399,9 +405,11 @@ public class ModelSerializer {
|
|||
public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(
|
||||
@NonNull InputStream is, boolean loadUpdater) throws IOException {
|
||||
checkInputStream(is);
|
||||
is = new CloseShieldInputStream(is);
|
||||
|
||||
MultiLayerNetwork net = restoreMultiLayerNetwork(is, loadUpdater);
|
||||
Normalizer norm = restoreNormalizerFromInputStream(is);
|
||||
Pair<MultiLayerNetwork,Map<String,byte[]>> p = restoreMultiLayerNetworkHelper(is, loadUpdater);
|
||||
MultiLayerNetwork net = p.getFirst();
|
||||
Normalizer norm = restoreNormalizerFromMap(p.getSecond());
|
||||
return new Pair<>(net, norm);
|
||||
}
|
||||
|
||||
|
@ -453,6 +461,11 @@ public class ModelSerializer {
|
|||
*/
|
||||
public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater)
|
||||
throws IOException {
|
||||
return restoreComputationGraphHelper(is, loadUpdater).getFirst();
|
||||
}
|
||||
|
||||
private static Pair<ComputationGraph,Map<String,byte[]>> restoreComputationGraphHelper(@NonNull InputStream is, boolean loadUpdater)
|
||||
throws IOException {
|
||||
checkInputStream(is);
|
||||
|
||||
Map<String, byte[]> files = loadZipData(is);
|
||||
|
@ -564,7 +577,7 @@ public class ModelSerializer {
|
|||
if (gotUpdaterState && updaterState != null) {
|
||||
cg.getUpdater().setStateViewArray(updaterState);
|
||||
}
|
||||
return cg;
|
||||
return new Pair<>(cg, files);
|
||||
} else
|
||||
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
|
||||
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
|
||||
|
@ -604,9 +617,11 @@ public class ModelSerializer {
|
|||
public static Pair<ComputationGraph, Normalizer> restoreComputationGraphAndNormalizer(
|
||||
@NonNull InputStream is, boolean loadUpdater) throws IOException {
|
||||
checkInputStream(is);
|
||||
|
||||
ComputationGraph net = restoreComputationGraph(is, loadUpdater);
|
||||
Normalizer norm = restoreNormalizerFromInputStream(is);
|
||||
|
||||
|
||||
Pair<ComputationGraph,Map<String,byte[]>> p = restoreComputationGraphHelper(is, loadUpdater);
|
||||
ComputationGraph net = p.getFirst();
|
||||
Normalizer norm = restoreNormalizerFromMap(p.getSecond());
|
||||
return new Pair<>(net, norm);
|
||||
}
|
||||
|
||||
|
@ -900,8 +915,11 @@ public class ModelSerializer {
|
|||
*/
|
||||
public static <T extends Normalizer> T restoreNormalizerFromInputStream(InputStream is) throws IOException {
|
||||
checkInputStream(is);
|
||||
|
||||
Map<String, byte[]> files = loadZipData(is);
|
||||
return restoreNormalizerFromMap(files);
|
||||
}
|
||||
|
||||
private static <T extends Normalizer> T restoreNormalizerFromMap(Map<String, byte[]> files) throws IOException {
|
||||
byte[] norm = files.get(NORMALIZER_BIN);
|
||||
|
||||
// checking for file existence
|
||||
|
@ -937,8 +955,6 @@ public class ModelSerializer {
|
|||
|
||||
|
||||
private static void checkInputStream(InputStream inputStream) throws IOException {
|
||||
|
||||
/*
|
||||
//available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887
|
||||
int available;
|
||||
try{
|
||||
|
@ -953,7 +969,6 @@ public class ModelSerializer {
|
|||
throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" +
|
||||
"multiple times?");
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
private static Map<String, byte[]> loadZipData(InputStream is) throws IOException {
|
||||
|
|
Loading…
Reference in New Issue