Fix loading both model and serializer at once from stream + re-add checks

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-31 11:56:56 +11:00
parent e3a8629214
commit 55a3d9bb2c
1 changed files with 26 additions and 11 deletions

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.util;
import org.apache.commons.io.input.CloseShieldInputStream;
import org.nd4j.shade.guava.io.Files;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
@ -235,6 +236,11 @@ public class ModelSerializer {
*/
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
throws IOException {
return restoreMultiLayerNetworkHelper(is, loadUpdater).getFirst();
}
private static Pair<MultiLayerNetwork, Map<String,byte[]>> restoreMultiLayerNetworkHelper(@NonNull InputStream is, boolean loadUpdater)
throws IOException {
checkInputStream(is);
Map<String, byte[]> zipFile = loadZipData(is);
@ -343,7 +349,7 @@ public class ModelSerializer {
if (gotUpdaterState && updaterState != null) {
network.getUpdater().setStateViewArray(network, updaterState, false);
}
return network;
return new Pair<>(network, zipFile);
} else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
@ -399,9 +405,11 @@ public class ModelSerializer {
public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(
@NonNull InputStream is, boolean loadUpdater) throws IOException {
checkInputStream(is);
is = new CloseShieldInputStream(is);
MultiLayerNetwork net = restoreMultiLayerNetwork(is, loadUpdater);
Normalizer norm = restoreNormalizerFromInputStream(is);
Pair<MultiLayerNetwork,Map<String,byte[]>> p = restoreMultiLayerNetworkHelper(is, loadUpdater);
MultiLayerNetwork net = p.getFirst();
Normalizer norm = restoreNormalizerFromMap(p.getSecond());
return new Pair<>(net, norm);
}
@ -453,6 +461,11 @@ public class ModelSerializer {
*/
public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater)
throws IOException {
return restoreComputationGraphHelper(is, loadUpdater).getFirst();
}
private static Pair<ComputationGraph,Map<String,byte[]>> restoreComputationGraphHelper(@NonNull InputStream is, boolean loadUpdater)
throws IOException {
checkInputStream(is);
Map<String, byte[]> files = loadZipData(is);
@ -564,7 +577,7 @@ public class ModelSerializer {
if (gotUpdaterState && updaterState != null) {
cg.getUpdater().setStateViewArray(updaterState);
}
return cg;
return new Pair<>(cg, files);
} else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
@ -604,9 +617,11 @@ public class ModelSerializer {
public static Pair<ComputationGraph, Normalizer> restoreComputationGraphAndNormalizer(
@NonNull InputStream is, boolean loadUpdater) throws IOException {
checkInputStream(is);
ComputationGraph net = restoreComputationGraph(is, loadUpdater);
Normalizer norm = restoreNormalizerFromInputStream(is);
Pair<ComputationGraph,Map<String,byte[]>> p = restoreComputationGraphHelper(is, loadUpdater);
ComputationGraph net = p.getFirst();
Normalizer norm = restoreNormalizerFromMap(p.getSecond());
return new Pair<>(net, norm);
}
@ -900,8 +915,11 @@ public class ModelSerializer {
*/
public static <T extends Normalizer> T restoreNormalizerFromInputStream(InputStream is) throws IOException {
checkInputStream(is);
Map<String, byte[]> files = loadZipData(is);
return restoreNormalizerFromMap(files);
}
private static <T extends Normalizer> T restoreNormalizerFromMap(Map<String, byte[]> files) throws IOException {
byte[] norm = files.get(NORMALIZER_BIN);
// checking for file existence
@ -937,8 +955,6 @@ public class ModelSerializer {
private static void checkInputStream(InputStream inputStream) throws IOException {
/*
//available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887
int available;
try{
@ -953,7 +969,6 @@ public class ModelSerializer {
throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" +
"multiple times?");
}
*/
}
private static Map<String, byte[]> loadZipData(InputStream is) throws IOException {