Fix loading both model and serializer at once from stream + re-add checks

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-31 11:56:56 +11:00
parent e3a8629214
commit 55a3d9bb2c
1 changed files with 26 additions and 11 deletions

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.util; package org.deeplearning4j.util;
import org.apache.commons.io.input.CloseShieldInputStream;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -235,6 +236,11 @@ public class ModelSerializer {
*/ */
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater) public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
throws IOException { throws IOException {
return restoreMultiLayerNetworkHelper(is, loadUpdater).getFirst();
}
private static Pair<MultiLayerNetwork, Map<String,byte[]>> restoreMultiLayerNetworkHelper(@NonNull InputStream is, boolean loadUpdater)
throws IOException {
checkInputStream(is); checkInputStream(is);
Map<String, byte[]> zipFile = loadZipData(is); Map<String, byte[]> zipFile = loadZipData(is);
@ -343,7 +349,7 @@ public class ModelSerializer {
if (gotUpdaterState && updaterState != null) { if (gotUpdaterState && updaterState != null) {
network.getUpdater().setStateViewArray(network, updaterState, false); network.getUpdater().setStateViewArray(network, updaterState, false);
} }
return network; return new Pair<>(network, zipFile);
} else } else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
@ -399,9 +405,11 @@ public class ModelSerializer {
public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer( public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(
@NonNull InputStream is, boolean loadUpdater) throws IOException { @NonNull InputStream is, boolean loadUpdater) throws IOException {
checkInputStream(is); checkInputStream(is);
is = new CloseShieldInputStream(is);
MultiLayerNetwork net = restoreMultiLayerNetwork(is, loadUpdater); Pair<MultiLayerNetwork,Map<String,byte[]>> p = restoreMultiLayerNetworkHelper(is, loadUpdater);
Normalizer norm = restoreNormalizerFromInputStream(is); MultiLayerNetwork net = p.getFirst();
Normalizer norm = restoreNormalizerFromMap(p.getSecond());
return new Pair<>(net, norm); return new Pair<>(net, norm);
} }
@ -453,6 +461,11 @@ public class ModelSerializer {
*/ */
public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater) public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater)
throws IOException { throws IOException {
return restoreComputationGraphHelper(is, loadUpdater).getFirst();
}
private static Pair<ComputationGraph,Map<String,byte[]>> restoreComputationGraphHelper(@NonNull InputStream is, boolean loadUpdater)
throws IOException {
checkInputStream(is); checkInputStream(is);
Map<String, byte[]> files = loadZipData(is); Map<String, byte[]> files = loadZipData(is);
@ -564,7 +577,7 @@ public class ModelSerializer {
if (gotUpdaterState && updaterState != null) { if (gotUpdaterState && updaterState != null) {
cg.getUpdater().setStateViewArray(updaterState); cg.getUpdater().setStateViewArray(updaterState);
} }
return cg; return new Pair<>(cg, files);
} else } else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
@ -605,8 +618,10 @@ public class ModelSerializer {
@NonNull InputStream is, boolean loadUpdater) throws IOException { @NonNull InputStream is, boolean loadUpdater) throws IOException {
checkInputStream(is); checkInputStream(is);
ComputationGraph net = restoreComputationGraph(is, loadUpdater);
Normalizer norm = restoreNormalizerFromInputStream(is); Pair<ComputationGraph,Map<String,byte[]>> p = restoreComputationGraphHelper(is, loadUpdater);
ComputationGraph net = p.getFirst();
Normalizer norm = restoreNormalizerFromMap(p.getSecond());
return new Pair<>(net, norm); return new Pair<>(net, norm);
} }
@ -900,8 +915,11 @@ public class ModelSerializer {
*/ */
public static <T extends Normalizer> T restoreNormalizerFromInputStream(InputStream is) throws IOException { public static <T extends Normalizer> T restoreNormalizerFromInputStream(InputStream is) throws IOException {
checkInputStream(is); checkInputStream(is);
Map<String, byte[]> files = loadZipData(is); Map<String, byte[]> files = loadZipData(is);
return restoreNormalizerFromMap(files);
}
private static <T extends Normalizer> T restoreNormalizerFromMap(Map<String, byte[]> files) throws IOException {
byte[] norm = files.get(NORMALIZER_BIN); byte[] norm = files.get(NORMALIZER_BIN);
// checking for file existence // checking for file existence
@ -937,8 +955,6 @@ public class ModelSerializer {
private static void checkInputStream(InputStream inputStream) throws IOException { private static void checkInputStream(InputStream inputStream) throws IOException {
/*
//available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887 //available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887
int available; int available;
try{ try{
@ -953,7 +969,6 @@ public class ModelSerializer {
throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" + throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" +
"multiple times?"); "multiple times?");
} }
*/
} }
private static Map<String, byte[]> loadZipData(InputStream is) throws IOException { private static Map<String, byte[]> loadZipData(InputStream is) throws IOException {