[WIP] Handling binary data in DL4J servlet (#135)
* Binary deser Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Binary mode for servlet Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added test Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * -sRandom image generation copied from datavec * -sRandom image generation copied from datavec * Remove serialization constraints * Fix: Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Removed unused code Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Resources usage Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Async inference Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Cleanup Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * -sTest corrected * Cleanup Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Mutually eclusive serializers/deserializers Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Binary output supported Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Binary out test Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * - types hardcoded - increased payload size limit Signed-off-by: raver119 <raver119@gmail.com> * change types constant Signed-off-by: raver119 <raver119@gmail.com>master
parent
8e3d569f18
commit
2e99bc2dee
|
@ -18,18 +18,17 @@ package org.deeplearning4j.remote;
|
|||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.parallelism.ParallelInference;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.remote.clients.serde.BinaryDeserializer;
|
||||
import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.remote.serving.SameDiffServlet;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
@ -38,6 +37,8 @@ import java.io.BufferedReader;
|
|||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @author astoyakin
|
||||
|
@ -51,7 +52,7 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
|||
protected boolean parallelEnabled = true;
|
||||
|
||||
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
|
||||
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer) {
|
||||
super(inferenceAdapter, serializer, deserializer);
|
||||
this.parallelInference = parallelInference;
|
||||
this.model = null;
|
||||
|
@ -59,19 +60,67 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
|||
}
|
||||
|
||||
public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
|
||||
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer) {
|
||||
super(inferenceAdapter, serializer, deserializer);
|
||||
this.model = model;
|
||||
this.parallelInference = null;
|
||||
this.parallelEnabled = false;
|
||||
}
|
||||
|
||||
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
BinarySerializer<O> serializer, BinaryDeserializer<I> deserializer) {
|
||||
super(inferenceAdapter, serializer, deserializer);
|
||||
this.parallelInference = parallelInference;
|
||||
this.model = null;
|
||||
this.parallelEnabled = true;
|
||||
}
|
||||
|
||||
public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer,
|
||||
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer) {
|
||||
super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer);
|
||||
this.model = model;
|
||||
this.parallelInference = null;
|
||||
this.parallelEnabled = false;
|
||||
}
|
||||
|
||||
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer,
|
||||
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer) {
|
||||
super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer);
|
||||
this.parallelInference = parallelInference;
|
||||
this.model = null;
|
||||
this.parallelEnabled = true;
|
||||
}
|
||||
|
||||
private O process(MultiDataSet mds) {
|
||||
O result = null;
|
||||
if (parallelEnabled) {
|
||||
// process result
|
||||
result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
||||
} else {
|
||||
synchronized (this) {
|
||||
if (model instanceof ComputationGraph)
|
||||
result = inferenceAdapter.apply(((ComputationGraph) model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
||||
else if (model instanceof MultiLayerNetwork) {
|
||||
Preconditions.checkArgument(mds.getFeatures().length > 0 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 0),
|
||||
"Input data for MultilayerNetwork is invalid!");
|
||||
result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
|
||||
mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
|
||||
String processorReturned = "";
|
||||
MultiDataSet mds = null;
|
||||
String path = request.getPathInfo();
|
||||
if (path.equals(SERVING_ENDPOINT)) {
|
||||
val contentType = request.getContentType();
|
||||
if (contentType.equals(typeJson)) {
|
||||
if (validateRequest(request, response)) {
|
||||
val stream = request.getInputStream();
|
||||
val bufferedReader = new BufferedReader(new InputStreamReader(stream));
|
||||
|
@ -83,31 +132,35 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
|||
}
|
||||
val requestString = buffer.toString();
|
||||
|
||||
val mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
|
||||
|
||||
O result = null;
|
||||
if (parallelEnabled) {
|
||||
// process result
|
||||
result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
||||
mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
|
||||
}
|
||||
}
|
||||
else if (contentType.equals(typeBinary)) {
|
||||
val stream = request.getInputStream();
|
||||
int available = request.getContentLength();
|
||||
if (available <= 0) {
|
||||
response.sendError(411, "Content length is unavailable");
|
||||
}
|
||||
else {
|
||||
synchronized(this) {
|
||||
if (model instanceof ComputationGraph)
|
||||
result = inferenceAdapter.apply(((ComputationGraph)model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
||||
else if (model instanceof MultiLayerNetwork) {
|
||||
Preconditions.checkArgument(mds.getFeatures().length > 1 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 1),
|
||||
"Input data for MultilayerNetwork is invalid!");
|
||||
result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
|
||||
mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
|
||||
byte[] data = new byte[available];
|
||||
stream.read(data, 0, available);
|
||||
|
||||
mds = inferenceAdapter.apply(binaryDeserializer.deserialize(data));
|
||||
}
|
||||
}
|
||||
if (mds == null)
|
||||
log.error("InferenceAdapter failed");
|
||||
else {
|
||||
val result = process(mds);
|
||||
if (binarySerializer != null) {
|
||||
byte[] serialized = binarySerializer.serialize(result);
|
||||
response.setContentType(typeBinary);
|
||||
response.setContentLength(serialized.length);
|
||||
val out = response.getOutputStream();
|
||||
out.write(serialized);
|
||||
}
|
||||
else {
|
||||
processorReturned = serializer.serialize(result);
|
||||
}
|
||||
} else {
|
||||
// we return error otherwise
|
||||
sendError(request.getRequestURI(), response);
|
||||
}
|
||||
try {
|
||||
val out = response.getWriter();
|
||||
out.write(processorReturned);
|
||||
|
@ -115,6 +168,12 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
|||
log.error(e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// we return error otherwise
|
||||
sendError(request.getRequestURI(), response);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates servlet to serve models
|
||||
|
@ -133,6 +192,8 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
|||
private InferenceAdapter<I, O> inferenceAdapter;
|
||||
private JsonSerializer<O> serializer;
|
||||
private JsonDeserializer<I> deserializer;
|
||||
private BinarySerializer<O> binarySerializer;
|
||||
private BinaryDeserializer<I> binaryDeserializer;
|
||||
private int port;
|
||||
private boolean parallelEnabled = true;
|
||||
|
||||
|
@ -155,7 +216,7 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
|||
* @param serializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> serializer(@NonNull JsonSerializer<O> serializer) {
|
||||
public Builder<I,O> serializer(JsonSerializer<O> serializer) {
|
||||
this.serializer = serializer;
|
||||
return this;
|
||||
}
|
||||
|
@ -166,11 +227,33 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
|||
* @param deserializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> deserializer(@NonNull JsonDeserializer<I> deserializer) {
|
||||
public Builder<I,O> deserializer(JsonDeserializer<I> deserializer) {
|
||||
this.deserializer = deserializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is required to specify serializer
|
||||
*
|
||||
* @param serializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> binarySerializer(BinarySerializer<O> serializer) {
|
||||
this.binarySerializer = serializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows to specify deserializer
|
||||
*
|
||||
* @param deserializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> binaryDeserializer(BinaryDeserializer<I> deserializer) {
|
||||
this.binaryDeserializer = deserializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows to specify port
|
||||
*
|
||||
|
@ -194,8 +277,8 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
|||
}
|
||||
|
||||
public DL4jServlet<I,O> build() {
|
||||
return parallelEnabled ? new DL4jServlet<I, O>(pi, inferenceAdapter, serializer, deserializer) :
|
||||
new DL4jServlet<I, O>(model, inferenceAdapter, serializer, deserializer);
|
||||
return parallelEnabled ? new DL4jServlet<I, O>(pi, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer) :
|
||||
new DL4jServlet<I, O>(model, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,8 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.remote.SameDiffJsonModelServer;
|
||||
import org.nd4j.remote.clients.serde.BinaryDeserializer;
|
||||
import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
|
@ -70,28 +72,40 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
|||
|
||||
protected boolean enabledParallel = true;
|
||||
|
||||
protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, String[] orderedOutputNodes) {
|
||||
super(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||
protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter<I, O> inferenceAdapter,
|
||||
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||
int port, String[] orderedInputNodes, String[] orderedOutputNodes) {
|
||||
super(sdModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||
}
|
||||
|
||||
protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
||||
super(inferenceAdapter, serializer, deserializer, port);
|
||||
protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter<I, O> inferenceAdapter,
|
||||
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||
int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
||||
super(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
|
||||
|
||||
this.cgModel = cgModel;
|
||||
this.inferenceMode = inferenceMode;
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
||||
super(inferenceAdapter, serializer, deserializer, port);
|
||||
protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter<I, O> inferenceAdapter,
|
||||
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||
int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
||||
super(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
|
||||
|
||||
this.mlnModel = mlnModel;
|
||||
this.inferenceMode = inferenceMode;
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port) {
|
||||
super(inferenceAdapter, serializer, deserializer, port);
|
||||
protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter<I, O> inferenceAdapter,
|
||||
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||
int port) {
|
||||
super(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
|
||||
|
||||
this.parallelInference = pi;
|
||||
}
|
||||
|
@ -139,6 +153,8 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
|||
.parallelEnabled(true)
|
||||
.serializer(serializer)
|
||||
.deserializer(deserializer)
|
||||
.binarySerializer(binarySerializer)
|
||||
.binaryDeserializer(binaryDeserializer)
|
||||
.inferenceAdapter(inferenceAdapter)
|
||||
.build();
|
||||
}
|
||||
|
@ -147,6 +163,8 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
|||
.parallelEnabled(false)
|
||||
.serializer(serializer)
|
||||
.deserializer(deserializer)
|
||||
.binarySerializer(binarySerializer)
|
||||
.binaryDeserializer(binaryDeserializer)
|
||||
.inferenceAdapter(inferenceAdapter)
|
||||
.build();
|
||||
}
|
||||
|
@ -175,6 +193,8 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
|||
private InferenceAdapter<I, O> inferenceAdapter;
|
||||
private JsonSerializer<O> serializer;
|
||||
private JsonDeserializer<I> deserializer;
|
||||
private BinarySerializer<O> binarySerializer;
|
||||
private BinaryDeserializer<I> binaryDeserializer;
|
||||
|
||||
private InputAdapter<I> inputAdapter;
|
||||
private OutputAdapter<O> outputAdapter;
|
||||
|
@ -238,7 +258,9 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
|||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify serializer
|
||||
* This method allows you to specify JSON serializer.
|
||||
* Incompatible with {@link #outputBinarySerializer(BinarySerializer)}
|
||||
* Only one serializer - deserializer pair can be used by client and server.
|
||||
*
|
||||
* @param serializer
|
||||
* @return
|
||||
|
@ -249,7 +271,9 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
|||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify deserializer
|
||||
* This method allows you to specify JSON deserializer.
|
||||
* Incompatible with {@link #inputBinaryDeserializer(BinaryDeserializer)}
|
||||
* Only one serializer - deserializer pair can be used by client and server.
|
||||
*
|
||||
* @param deserializer
|
||||
* @return
|
||||
|
@ -259,6 +283,32 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify binary serializer.
|
||||
* Incompatible with {@link #outputSerializer(JsonSerializer)}
|
||||
* Only one serializer - deserializer pair can be used by client and server.
|
||||
*
|
||||
* @param serializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> outputBinarySerializer(@NonNull BinarySerializer<O> serializer) {
|
||||
this.binarySerializer = serializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify binary deserializer
|
||||
* Incompatible with {@link #inputDeserializer(JsonDeserializer)}
|
||||
* Only one serializer - deserializer pair can be used by client and server.
|
||||
*
|
||||
* @param deserializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inputBinaryDeserializer(@NonNull BinaryDeserializer<I> deserializer) {
|
||||
this.binaryDeserializer = deserializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details
|
||||
*
|
||||
|
@ -375,17 +425,24 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
|||
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
|
||||
}
|
||||
|
||||
JsonModelServer server = null;
|
||||
if (sdModel != null) {
|
||||
Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "For SameDiff model serving OutputNodes should be defined");
|
||||
return new JsonModelServer<I, O>(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||
} else if (cgModel != null)
|
||||
return new JsonModelServer<I,O>(cgModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
|
||||
else if (mlnModel != null)
|
||||
return new JsonModelServer<I,O>(mlnModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
|
||||
else if (pi != null)
|
||||
return new JsonModelServer<I,O>(pi, inferenceAdapter, serializer, deserializer, port);
|
||||
server = new JsonModelServer<I, O>(sdModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||
}
|
||||
else if (cgModel != null) {
|
||||
server = new JsonModelServer<I, O>(cgModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, inferenceMode, numWorkers);
|
||||
}
|
||||
else if (mlnModel != null) {
|
||||
server = new JsonModelServer<I, O>(mlnModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, inferenceMode, numWorkers);
|
||||
}
|
||||
else if (pi != null) {
|
||||
server = new JsonModelServer<I, O>(pi, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
|
||||
}
|
||||
else
|
||||
throw new IllegalStateException("No models were defined for JsonModelServer");
|
||||
|
||||
server.enabledParallel = parallelMode;
|
||||
return server;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,276 @@
|
|||
package org.deeplearning4j.remote;
|
||||
|
||||
import lombok.val;
|
||||
import org.datavec.image.loader.Java2DNativeImageLoader;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.remote.helpers.ImageConversionUtils;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.junit.After;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.remote.clients.JsonRemoteInference;
|
||||
import org.nd4j.remote.clients.serde.BinaryDeserializer;
|
||||
import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.remote.clients.serde.impl.IntegerSerde;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.*;
|
||||
import java.nio.Buffer;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class BinaryModelServerTest {
|
||||
private final int PORT = 18080;
|
||||
|
||||
@After
|
||||
public void pause() throws Exception {
|
||||
// TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
|
||||
TimeUnit.SECONDS.sleep(2);
|
||||
}
|
||||
|
||||
// Internal test for locally defined serializers
|
||||
@Test
|
||||
public void testBufferedImageSerde() {
|
||||
BinarySerializer<BufferedImage> serde = new BinaryModelServerTest.BufferedImageSerde();
|
||||
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||||
byte[] serialized = serde.serialize(image);
|
||||
|
||||
BufferedImage deserialized = ((BufferedImageSerde) serde).deserialize(serialized);
|
||||
int originalSize = image.getData().getDataBuffer().getSize();
|
||||
assertEquals(originalSize, deserialized.getData().getDataBuffer().getSize());
|
||||
for (int i = 0; i < originalSize; ++i) {
|
||||
assertEquals(deserialized.getData().getDataBuffer().getElem(i),
|
||||
image.getData().getDataBuffer().getElem(i));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testImageToINDArray() {
|
||||
INDArray data = ImageConversionUtils.makeRandomImageAsINDArray(28,28,1);
|
||||
assertNotNull(data);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMlnMnist_ImageInput() throws Exception {
|
||||
|
||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||
|
||||
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||||
.outputSerializer(new IntegerSerde())
|
||||
.inputBinaryDeserializer(new BufferedImageSerde())
|
||||
.inferenceAdapter(new InferenceAdapter<BufferedImage, Integer>() {
|
||||
@Override
|
||||
public MultiDataSet apply(BufferedImage input) {
|
||||
INDArray data = null;
|
||||
try {
|
||||
data = new Java2DNativeImageLoader().asMatrix(input);
|
||||
data = data.reshape(1, 784);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return new MultiDataSet(data, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer apply(INDArray... nnOutput) {
|
||||
return nnOutput[0].argMax().getInt(0);
|
||||
}
|
||||
})
|
||||
.port(PORT)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(1)
|
||||
.parallelMode(false)
|
||||
.build();
|
||||
|
||||
val client = JsonRemoteInference.<BufferedImage, Integer>builder()
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.inputBinarySerializer(new BufferedImageSerde())
|
||||
.outputDeserializer(new IntegerSerde())
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||||
Integer result = client.predict(image);
|
||||
assertNotNull(result);
|
||||
|
||||
File file = new ClassPathResource("datavec-local/imagetest/0/b.bmp").getFile();
|
||||
image = ImageIO.read(new FileInputStream(file));
|
||||
result = client.predict(image);
|
||||
assertEquals(new Integer(0), result);
|
||||
|
||||
file = new ClassPathResource("datavec-local/imagetest/1/a.bmp").getFile();
|
||||
image = ImageIO.read(new FileInputStream(file));
|
||||
result = client.predict(image);
|
||||
assertEquals(new Integer(1), result);
|
||||
|
||||
} catch (Exception e){
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMlnMnist_ImageInput_Async() throws Exception {
|
||||
|
||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||
|
||||
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||||
.outputSerializer(new IntegerSerde())
|
||||
.inputBinaryDeserializer(new BufferedImageSerde())
|
||||
.inferenceAdapter(new InferenceAdapter<BufferedImage, Integer>() {
|
||||
@Override
|
||||
public MultiDataSet apply(BufferedImage input) {
|
||||
INDArray data = null;
|
||||
try {
|
||||
data = new Java2DNativeImageLoader().asMatrix(input);
|
||||
data = data.reshape(1, 784);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return new MultiDataSet(data, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer apply(INDArray... nnOutput) {
|
||||
return nnOutput[0].argMax().getInt(0);
|
||||
}
|
||||
})
|
||||
.port(PORT)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(1)
|
||||
.parallelMode(false)
|
||||
.build();
|
||||
|
||||
val client = JsonRemoteInference.<BufferedImage, Integer>builder()
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.inputBinarySerializer(new BufferedImageSerde())
|
||||
.outputDeserializer(new IntegerSerde())
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
BufferedImage[] images = new BufferedImage[3];
|
||||
images[0] = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||||
|
||||
File file = new ClassPathResource("datavec-local/imagetest/0/b.bmp").getFile();
|
||||
images[1] = ImageIO.read(new FileInputStream(file));
|
||||
|
||||
file = new ClassPathResource("datavec-local/imagetest/1/a.bmp").getFile();
|
||||
images[2] = ImageIO.read(new FileInputStream(file));
|
||||
|
||||
Future<Integer>[] results = new Future[3];
|
||||
for (int i = 0; i < images.length; ++i) {
|
||||
results[i] = client.predictAsync(images[i]);
|
||||
assertNotNull(results[i]);
|
||||
}
|
||||
|
||||
assertNotNull(results[0].get());
|
||||
assertEquals(new Integer(0), results[1].get());
|
||||
assertEquals(new Integer(1), results[2].get());
|
||||
|
||||
} catch (Exception e){
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBinaryIn_BinaryOut() throws Exception {
|
||||
|
||||
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||||
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||
|
||||
val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net)
|
||||
.outputBinarySerializer(new BufferedImageSerde())
|
||||
.inputBinaryDeserializer(new BufferedImageSerde())
|
||||
.inferenceAdapter(new InferenceAdapter<BufferedImage, BufferedImage>() {
|
||||
@Override
|
||||
public MultiDataSet apply(BufferedImage input) {
|
||||
INDArray data = null;
|
||||
try {
|
||||
data = new Java2DNativeImageLoader().asMatrix(input);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return new MultiDataSet(data, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BufferedImage apply(INDArray... nnOutput) {
|
||||
return ImageConversionUtils.makeRandomBufferedImage(28,28,3);
|
||||
}
|
||||
})
|
||||
.port(PORT)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(1)
|
||||
.parallelMode(false)
|
||||
.build();
|
||||
|
||||
val client = JsonRemoteInference.<BufferedImage, BufferedImage>builder()
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.inputBinarySerializer(new BufferedImageSerde())
|
||||
.outputBinaryDeserializer(new BufferedImageSerde())
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||||
BufferedImage result = client.predict(image);
|
||||
assertNotNull(result);
|
||||
assertEquals(28, result.getHeight());
|
||||
assertEquals(28, result.getWidth());
|
||||
|
||||
} catch (Exception e){
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
private static class BufferedImageSerde implements BinarySerializer<BufferedImage>, BinaryDeserializer<BufferedImage> {
|
||||
|
||||
@Override
|
||||
public BufferedImage deserialize(byte[] buffer) {
|
||||
try {
|
||||
BufferedImage img = ImageIO.read(new ByteArrayInputStream(buffer));
|
||||
return img;
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] serialize(BufferedImage image) {
|
||||
try{
|
||||
val baos = new ByteArrayOutputStream();
|
||||
ImageIO.write(image, "bmp", baos);
|
||||
byte[] bytes = baos.toByteArray();
|
||||
return bytes;
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package org.deeplearning4j.remote.helpers;
|
||||
|
||||
import lombok.val;
|
||||
import org.bytedeco.javacpp.indexer.UByteIndexer;
|
||||
import org.bytedeco.javacv.Java2DFrameConverter;
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
import org.bytedeco.opencv.opencv_core.Mat;
|
||||
import org.datavec.image.loader.Java2DNativeImageLoader;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.IOException;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.bytedeco.opencv.global.opencv_core.CV_8UC;
|
||||
|
||||
public class ImageConversionUtils {
|
||||
|
||||
public static Mat makeRandomImage(int height, int width, int channels) {
|
||||
if (height <= 0) {
|
||||
|
||||
height = new Random().nextInt() % 100 + 100;
|
||||
}
|
||||
if (width <= 0) {
|
||||
width = new Random().nextInt() % 100 + 100;
|
||||
}
|
||||
|
||||
Mat img = new Mat(height, width, CV_8UC(channels));
|
||||
UByteIndexer idx = img.createIndexer();
|
||||
for (int i = 0; i < height; i++) {
|
||||
for (int j = 0; j < width; j++) {
|
||||
for (int k = 0; k < channels; k++) {
|
||||
idx.put(i, j, k, new Random().nextInt());
|
||||
}
|
||||
}
|
||||
}
|
||||
return img;
|
||||
}
|
||||
|
||||
public static BufferedImage makeRandomBufferedImage(int height, int width, int channels) {
|
||||
Mat img = makeRandomImage(height, width, channels);
|
||||
|
||||
OpenCVFrameConverter.ToMat c = new OpenCVFrameConverter.ToMat();
|
||||
Java2DFrameConverter c2 = new Java2DFrameConverter();
|
||||
|
||||
return c2.convert(c.convert(img));
|
||||
}
|
||||
|
||||
public static INDArray convert(BufferedImage image) {
|
||||
INDArray retVal = null;
|
||||
try {
|
||||
retVal = new Java2DNativeImageLoader(image.getHeight(), image.getWidth(), image.getRaster().getNumBands()).
|
||||
asRowVector(image);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return retVal;
|
||||
}
|
||||
|
||||
public static INDArray convert(Mat image) {
|
||||
INDArray retVal = null;
|
||||
try {
|
||||
new NativeImageLoader().asRowVector(image);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return retVal;
|
||||
}
|
||||
|
||||
public static BufferedImage convert(INDArray input) {
|
||||
return new Java2DNativeImageLoader(input.rows(),input.columns()).asBufferedImage(input);
|
||||
}
|
||||
|
||||
public static INDArray makeRandomImageAsINDArray(int height, int width, int channels) {
|
||||
val image = makeRandomBufferedImage(height, width, channels);
|
||||
INDArray retVal = convert(image);
|
||||
return retVal;
|
||||
}
|
||||
}
|
|
@ -69,7 +69,7 @@
|
|||
</developer>
|
||||
<developer>
|
||||
<id>raver119</id>
|
||||
<name>raver119</name>
|
||||
<name>Vyacheslav Kokorin</name>
|
||||
</developer>
|
||||
<developer>
|
||||
<id>saudet</id>
|
||||
|
|
|
@ -24,10 +24,12 @@ import lombok.NonNull;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.json.JSONObject;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.remote.clients.serde.*;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
@ -49,21 +51,57 @@ import java.util.concurrent.TimeoutException;
|
|||
@Slf4j
|
||||
public class JsonRemoteInference<I, O> {
|
||||
private String endpointAddress;
|
||||
// JSON serializer/deserializer and binary serializer/deserializer are mutually exclusive.
|
||||
private JsonSerializer<I> serializer;
|
||||
private JsonDeserializer<O> deserializer;
|
||||
private BinarySerializer<I> binarySerializer;
|
||||
private BinaryDeserializer<O> binaryDeserializer;
|
||||
|
||||
private final static String APPLICATION_JSON = "application/json";
|
||||
private final static String APPLICATION_OCTET_STREAM = "application/octet-stream";
|
||||
|
||||
@Builder
|
||||
public JsonRemoteInference(@NonNull String endpointAddress, @NonNull JsonSerializer<I> inputSerializer, @NonNull JsonDeserializer<O> outputDeserializer) {
|
||||
public JsonRemoteInference(@NonNull String endpointAddress,
|
||||
JsonSerializer<I> inputSerializer, JsonDeserializer<O> outputDeserializer,
|
||||
BinarySerializer<I> inputBinarySerializer, BinaryDeserializer<O> outputBinaryDeserializer) {
|
||||
|
||||
this.endpointAddress = endpointAddress;
|
||||
this.serializer = inputSerializer;
|
||||
this.deserializer = outputDeserializer;
|
||||
this.binarySerializer = inputBinarySerializer;
|
||||
this.binaryDeserializer = outputBinaryDeserializer;
|
||||
|
||||
if (serializer != null && binarySerializer != null || serializer == null && binarySerializer == null)
|
||||
throw new IllegalStateException("Binary and JSON serializers/deserializers are mutually exclusive and mandatory.");
|
||||
}
|
||||
|
||||
|
||||
private O processResponse(HttpResponse<String> response) throws IOException {
|
||||
if (response.getStatus() != 200)
|
||||
throw new IOException("Inference request returned bad error code: " + response.getStatus());
|
||||
|
||||
O result = deserializer.deserialize(response.getBody());
|
||||
|
||||
if (result == null) {
|
||||
throw new IOException("Deserialization failed!");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private O processResponseBinary(HttpResponse<InputStream> response) throws IOException {
|
||||
if (response.getStatus() != 200)
|
||||
throw new IOException("Inference request returned bad error code: " + response.getStatus());
|
||||
|
||||
List<String> values = response.getHeaders().get("Content-Length");
|
||||
if (values == null || values.size() < 1) {
|
||||
throw new IOException("Content-Length is required for binary data");
|
||||
}
|
||||
|
||||
String strLength = values.get(0);
|
||||
byte[] bytes = new byte[Integer.parseInt(strLength)];
|
||||
response.getBody().read(bytes);
|
||||
O result = binaryDeserializer.deserialize(bytes);
|
||||
|
||||
if (result == null) {
|
||||
throw new IOException("Deserialization failed!");
|
||||
}
|
||||
|
@ -79,12 +117,30 @@ public class JsonRemoteInference<I, O> {
|
|||
*/
|
||||
public O predict(I input) throws IOException {
|
||||
try {
|
||||
val stringResult = Unirest.post(endpointAddress)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
if (binarySerializer != null && binaryDeserializer != null) {
|
||||
HttpResponse<InputStream> response =
|
||||
Unirest.post(endpointAddress)
|
||||
.header("Content-Type", APPLICATION_OCTET_STREAM)
|
||||
.header("Accept", APPLICATION_OCTET_STREAM)
|
||||
.body(binarySerializer.serialize(input)).asBinary();
|
||||
return processResponseBinary(response);
|
||||
}
|
||||
else if (binarySerializer != null && binaryDeserializer == null) {
|
||||
HttpResponse<String> response =
|
||||
Unirest.post(endpointAddress)
|
||||
.header("Content-Type", APPLICATION_OCTET_STREAM)
|
||||
.header("Accept", APPLICATION_OCTET_STREAM)
|
||||
.body(binarySerializer.serialize(input)).asString();
|
||||
return processResponse(response);
|
||||
}
|
||||
else {
|
||||
HttpResponse<String> response = Unirest.post(endpointAddress)
|
||||
.header("Content-Type", APPLICATION_JSON)
|
||||
.header("Accept", APPLICATION_JSON)
|
||||
.body(new JSONObject(serializer.serialize(input))).asString();
|
||||
return processResponse(response);
|
||||
}
|
||||
|
||||
return processResponse(stringResult);
|
||||
} catch (UnirestException e) {
|
||||
throw new IOException(e);
|
||||
}
|
||||
|
@ -96,11 +152,19 @@ public class JsonRemoteInference<I, O> {
|
|||
* @return
|
||||
*/
|
||||
public Future<O> predictAsync(I input) {
|
||||
val stringResult = Unirest.post(endpointAddress)
|
||||
|
||||
Future<HttpResponse<String>> response = binarySerializer != null ?
|
||||
Unirest.post(endpointAddress)
|
||||
.header("Content-Type", "application/octet-stream")
|
||||
.header("Accept", "application/octet-stream")
|
||||
.body(binarySerializer.serialize(input)).asStringAsync() :
|
||||
|
||||
Unirest.post(endpointAddress)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
.body(new JSONObject(serializer.serialize(input))).asStringAsync();
|
||||
return new InferenceFuture(stringResult);
|
||||
|
||||
return new InferenceFuture(response);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -151,3 +215,4 @@ public class JsonRemoteInference<I, O> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.remote.clients.serde;
|
||||
|
||||
/**
|
||||
* This interface describes basic binary deserializer interface used for remote inference
|
||||
* @param <T> type of the deserializable class
|
||||
*
|
||||
* @author Alexander Stoyakin
|
||||
*/
|
||||
public interface BinaryDeserializer<T> {
|
||||
|
||||
/**
|
||||
* This method deserializes binary data to arbitrary object.
|
||||
* @param byte buffer
|
||||
* @return deserialized object
|
||||
*/
|
||||
T deserialize(byte[] buffer);
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.clients.serde;
|
||||
|
||||
/**
|
||||
* This interface describes basic binary serializer interface used for remote inference
|
||||
* @param <T> type of the serializable class
|
||||
*
|
||||
* @author Alexander Stoyakin
|
||||
*/
|
||||
public interface BinarySerializer<T> {
|
||||
|
||||
/**
|
||||
* This method serializes given object into byte buffer
|
||||
*
|
||||
* @param o object to be serialized
|
||||
* @return
|
||||
*/
|
||||
byte[] serialize(T o);
|
||||
}
|
|
@ -26,6 +26,8 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
|||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.remote.clients.serde.BinaryDeserializer;
|
||||
import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
|
@ -51,6 +53,8 @@ public class SameDiffJsonModelServer<I, O> {
|
|||
protected SameDiff sdModel;
|
||||
protected final JsonSerializer<O> serializer;
|
||||
protected final JsonDeserializer<I> deserializer;
|
||||
protected final BinarySerializer<O> binarySerializer;
|
||||
protected final BinaryDeserializer<I> binaryDeserializer;
|
||||
protected final InferenceAdapter<I, O> inferenceAdapter;
|
||||
protected final int port;
|
||||
|
||||
|
@ -64,9 +68,18 @@ public class SameDiffJsonModelServer<I, O> {
|
|||
protected String[] orderedInputNodes;
|
||||
protected String[] orderedOutputNodes;
|
||||
|
||||
protected SameDiffJsonModelServer(@NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer, int port) {
|
||||
protected SameDiffJsonModelServer(@NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||
int port) {
|
||||
Preconditions.checkArgument(port > 0 && port < 65535, "TCP port must be in range of 0..65535");
|
||||
Preconditions.checkArgument(serializer == null && binarySerializer == null ||
|
||||
serializer != null && binarySerializer == null ||
|
||||
serializer == null && binarySerializer != null,
|
||||
"JSON and binary serializers/deserializers are mutually exclusive and mandatory.");
|
||||
|
||||
this.binarySerializer = binarySerializer;
|
||||
this.binaryDeserializer = binaryDeserializer;
|
||||
this.inferenceAdapter = inferenceAdapter;
|
||||
this.serializer = serializer;
|
||||
this.deserializer = deserializer;
|
||||
|
@ -74,8 +87,11 @@ public class SameDiffJsonModelServer<I, O> {
|
|||
}
|
||||
|
||||
//@Builder
|
||||
public SameDiffJsonModelServer(SameDiff sdModel, @NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, @NonNull String[] orderedOutputNodes) {
|
||||
this(inferenceAdapter, serializer, deserializer, port);
|
||||
public SameDiffJsonModelServer(SameDiff sdModel, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||
int port, String[] orderedInputNodes, @NonNull String[] orderedOutputNodes) {
|
||||
this(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
|
||||
this.sdModel = sdModel;
|
||||
this.orderedInputNodes = orderedInputNodes;
|
||||
this.orderedOutputNodes = orderedOutputNodes;
|
||||
|
@ -282,7 +298,7 @@ public class SameDiffJsonModelServer<I, O> {
|
|||
} else
|
||||
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
|
||||
}
|
||||
return new SameDiffJsonModelServer<I,O>(sameDiff, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||
return new SameDiffJsonModelServer<I,O>(sameDiff, inferenceAdapter, serializer, deserializer, null, null, port, orderedInputNodes, orderedOutputNodes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,8 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.remote.clients.serde.BinaryDeserializer;
|
||||
import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
|
@ -35,6 +37,7 @@ import java.io.InputStreamReader;
|
|||
import java.util.LinkedHashMap;
|
||||
|
||||
import static javax.ws.rs.core.MediaType.APPLICATION_JSON;
|
||||
import static javax.ws.rs.core.MediaType.APPLICATION_OCTET_STREAM;
|
||||
|
||||
/**
|
||||
* This servlet provides SameDiff model serving capabilities
|
||||
|
@ -50,9 +53,14 @@ import static javax.ws.rs.core.MediaType.APPLICATION_JSON;
|
|||
@Builder
|
||||
public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
||||
|
||||
protected static final String typeJson = APPLICATION_JSON;
|
||||
protected static final String typeBinary = APPLICATION_OCTET_STREAM;
|
||||
|
||||
protected SameDiff sdModel;
|
||||
protected JsonSerializer<O> serializer;
|
||||
protected JsonDeserializer<I> deserializer;
|
||||
protected BinarySerializer<O> binarySerializer;
|
||||
protected BinaryDeserializer<I> binaryDeserializer;
|
||||
protected InferenceAdapter<I, O> inferenceAdapter;
|
||||
|
||||
protected String[] orderedInputNodes;
|
||||
|
@ -60,14 +68,36 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
|||
|
||||
protected final static String SERVING_ENDPOINT = "/v1/serving";
|
||||
protected final static String LISTING_ENDPOINT = "/v1";
|
||||
protected final static int PAYLOAD_SIZE_LIMIT = 10 * 1024; // TODO: should be customizable
|
||||
protected final static int PAYLOAD_SIZE_LIMIT = 10 * 1024 * 1024; // TODO: should be customizable
|
||||
|
||||
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer){
|
||||
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer){
|
||||
this.serializer = serializer;
|
||||
this.deserializer = deserializer;
|
||||
this.inferenceAdapter = inferenceAdapter;
|
||||
}
|
||||
|
||||
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
BinarySerializer<O> serializer, BinaryDeserializer<I> deserializer){
|
||||
this.binarySerializer = serializer;
|
||||
this.binaryDeserializer = deserializer;
|
||||
this.inferenceAdapter = inferenceAdapter;
|
||||
}
|
||||
|
||||
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer,
|
||||
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer){
|
||||
|
||||
this.serializer = jsonSerializer;
|
||||
this.deserializer = jsonDeserializer;
|
||||
this.binarySerializer = binarySerializer;
|
||||
this.binaryDeserializer = binaryDeserializer;
|
||||
this.inferenceAdapter = inferenceAdapter;
|
||||
|
||||
if (serializer != null && binarySerializer != null || serializer == null && binarySerializer == null)
|
||||
throw new IllegalStateException("Binary and JSON serializers/deserializers are mutually exclusive and mandatory.");
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void init(ServletConfig servletConfig) throws ServletException {
|
||||
//
|
||||
|
@ -108,7 +138,7 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
|||
protected boolean validateRequest(HttpServletRequest request, HttpServletResponse response)
|
||||
throws IOException{
|
||||
val contentType = request.getContentType();
|
||||
if (!StringUtils.equals(contentType, APPLICATION_JSON)) {
|
||||
if (!StringUtils.equals(contentType, typeJson)) {
|
||||
sendBadContentType(contentType, response);
|
||||
int contentLength = request.getContentLength();
|
||||
if (contentLength > PAYLOAD_SIZE_LIMIT) {
|
||||
|
@ -125,7 +155,7 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
|||
String path = request.getPathInfo();
|
||||
if (path.equals(LISTING_ENDPOINT)) {
|
||||
val contentType = request.getContentType();
|
||||
if (!StringUtils.equals(contentType, APPLICATION_JSON)) {
|
||||
if (!StringUtils.equals(contentType, typeJson)) {
|
||||
sendBadContentType(contentType, response);
|
||||
}
|
||||
processorReturned = processor.listEndpoints();
|
||||
|
@ -147,7 +177,7 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
|||
String path = request.getPathInfo();
|
||||
if (path.equals(SERVING_ENDPOINT)) {
|
||||
val contentType = request.getContentType();
|
||||
/*Preconditions.checkArgument(StringUtils.equals(contentType, APPLICATION_JSON),
|
||||
/*Preconditions.checkArgument(StringUtils.equals(contentType, typeJson),
|
||||
"Content type is " + contentType);*/
|
||||
if (validateRequest(request,response)) {
|
||||
val stream = request.getInputStream();
|
||||
|
|
Loading…
Reference in New Issue