Fix loading both model and serializer at once from stream + re-add checks
Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
e3a8629214
commit
55a3d9bb2c
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.util;
|
package org.deeplearning4j.util;
|
||||||
|
|
||||||
|
import org.apache.commons.io.input.CloseShieldInputStream;
|
||||||
import org.nd4j.shade.guava.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -235,6 +236,11 @@ public class ModelSerializer {
|
||||||
*/
|
*/
|
||||||
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
|
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
return restoreMultiLayerNetworkHelper(is, loadUpdater).getFirst();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Pair<MultiLayerNetwork, Map<String,byte[]>> restoreMultiLayerNetworkHelper(@NonNull InputStream is, boolean loadUpdater)
|
||||||
|
throws IOException {
|
||||||
checkInputStream(is);
|
checkInputStream(is);
|
||||||
|
|
||||||
Map<String, byte[]> zipFile = loadZipData(is);
|
Map<String, byte[]> zipFile = loadZipData(is);
|
||||||
|
@ -343,7 +349,7 @@ public class ModelSerializer {
|
||||||
if (gotUpdaterState && updaterState != null) {
|
if (gotUpdaterState && updaterState != null) {
|
||||||
network.getUpdater().setStateViewArray(network, updaterState, false);
|
network.getUpdater().setStateViewArray(network, updaterState, false);
|
||||||
}
|
}
|
||||||
return network;
|
return new Pair<>(network, zipFile);
|
||||||
} else
|
} else
|
||||||
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
|
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
|
||||||
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
|
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
|
||||||
|
@ -399,9 +405,11 @@ public class ModelSerializer {
|
||||||
public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(
|
public static Pair<MultiLayerNetwork, Normalizer> restoreMultiLayerNetworkAndNormalizer(
|
||||||
@NonNull InputStream is, boolean loadUpdater) throws IOException {
|
@NonNull InputStream is, boolean loadUpdater) throws IOException {
|
||||||
checkInputStream(is);
|
checkInputStream(is);
|
||||||
|
is = new CloseShieldInputStream(is);
|
||||||
|
|
||||||
MultiLayerNetwork net = restoreMultiLayerNetwork(is, loadUpdater);
|
Pair<MultiLayerNetwork,Map<String,byte[]>> p = restoreMultiLayerNetworkHelper(is, loadUpdater);
|
||||||
Normalizer norm = restoreNormalizerFromInputStream(is);
|
MultiLayerNetwork net = p.getFirst();
|
||||||
|
Normalizer norm = restoreNormalizerFromMap(p.getSecond());
|
||||||
return new Pair<>(net, norm);
|
return new Pair<>(net, norm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -453,6 +461,11 @@ public class ModelSerializer {
|
||||||
*/
|
*/
|
||||||
public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater)
|
public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
return restoreComputationGraphHelper(is, loadUpdater).getFirst();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Pair<ComputationGraph,Map<String,byte[]>> restoreComputationGraphHelper(@NonNull InputStream is, boolean loadUpdater)
|
||||||
|
throws IOException {
|
||||||
checkInputStream(is);
|
checkInputStream(is);
|
||||||
|
|
||||||
Map<String, byte[]> files = loadZipData(is);
|
Map<String, byte[]> files = loadZipData(is);
|
||||||
|
@ -564,7 +577,7 @@ public class ModelSerializer {
|
||||||
if (gotUpdaterState && updaterState != null) {
|
if (gotUpdaterState && updaterState != null) {
|
||||||
cg.getUpdater().setStateViewArray(updaterState);
|
cg.getUpdater().setStateViewArray(updaterState);
|
||||||
}
|
}
|
||||||
return cg;
|
return new Pair<>(cg, files);
|
||||||
} else
|
} else
|
||||||
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
|
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig
|
||||||
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
|
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
|
||||||
|
@ -605,8 +618,10 @@ public class ModelSerializer {
|
||||||
@NonNull InputStream is, boolean loadUpdater) throws IOException {
|
@NonNull InputStream is, boolean loadUpdater) throws IOException {
|
||||||
checkInputStream(is);
|
checkInputStream(is);
|
||||||
|
|
||||||
ComputationGraph net = restoreComputationGraph(is, loadUpdater);
|
|
||||||
Normalizer norm = restoreNormalizerFromInputStream(is);
|
Pair<ComputationGraph,Map<String,byte[]>> p = restoreComputationGraphHelper(is, loadUpdater);
|
||||||
|
ComputationGraph net = p.getFirst();
|
||||||
|
Normalizer norm = restoreNormalizerFromMap(p.getSecond());
|
||||||
return new Pair<>(net, norm);
|
return new Pair<>(net, norm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -900,8 +915,11 @@ public class ModelSerializer {
|
||||||
*/
|
*/
|
||||||
public static <T extends Normalizer> T restoreNormalizerFromInputStream(InputStream is) throws IOException {
|
public static <T extends Normalizer> T restoreNormalizerFromInputStream(InputStream is) throws IOException {
|
||||||
checkInputStream(is);
|
checkInputStream(is);
|
||||||
|
|
||||||
Map<String, byte[]> files = loadZipData(is);
|
Map<String, byte[]> files = loadZipData(is);
|
||||||
|
return restoreNormalizerFromMap(files);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static <T extends Normalizer> T restoreNormalizerFromMap(Map<String, byte[]> files) throws IOException {
|
||||||
byte[] norm = files.get(NORMALIZER_BIN);
|
byte[] norm = files.get(NORMALIZER_BIN);
|
||||||
|
|
||||||
// checking for file existence
|
// checking for file existence
|
||||||
|
@ -937,8 +955,6 @@ public class ModelSerializer {
|
||||||
|
|
||||||
|
|
||||||
private static void checkInputStream(InputStream inputStream) throws IOException {
|
private static void checkInputStream(InputStream inputStream) throws IOException {
|
||||||
|
|
||||||
/*
|
|
||||||
//available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887
|
//available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887
|
||||||
int available;
|
int available;
|
||||||
try{
|
try{
|
||||||
|
@ -953,7 +969,6 @@ public class ModelSerializer {
|
||||||
throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" +
|
throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" +
|
||||||
"multiple times?");
|
"multiple times?");
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, byte[]> loadZipData(InputStream is) throws IOException {
|
private static Map<String, byte[]> loadZipData(InputStream is) throws IOException {
|
||||||
|
|
Loading…
Reference in New Issue