[WIP] Handling binary data in DL4J servlet (#135)

* Binary deser

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Binary mode for servlet

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Added test

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* -sRandom image generation copied from datavec

* -sRandom image generation copied from datavec

* Remove serialization constraints

* Fix:

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Removed unused code

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Resources usage

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Async inference

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Cleanup

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* -sTest corrected

* Cleanup

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Mutually eclusive serializers/deserializers

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Binary output supported

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Binary out test

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* - types hardcoded
- increased payload size limit

Signed-off-by: raver119 <raver119@gmail.com>

* change types constant

Signed-off-by: raver119 <raver119@gmail.com>
master
Alexander Stoyakin 2019-08-23 17:00:55 +03:00 committed by raver119
parent 8e3d569f18
commit 2e99bc2dee
10 changed files with 767 additions and 92 deletions

View File

@ -18,18 +18,17 @@ package org.deeplearning4j.remote;
import lombok.*; import lombok.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelInference; import org.deeplearning4j.parallelism.ParallelInference;
import org.nd4j.adapters.InferenceAdapter; import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.remote.clients.serde.BinaryDeserializer;
import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer; import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer; import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.remote.serving.SameDiffServlet; import org.nd4j.remote.serving.SameDiffServlet;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
@ -38,6 +37,8 @@ import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStreamReader; import java.io.InputStreamReader;
/** /**
* *
* @author astoyakin * @author astoyakin
@ -51,7 +52,7 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
protected boolean parallelEnabled = true; protected boolean parallelEnabled = true;
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter, public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) { JsonSerializer<O> serializer, JsonDeserializer<I> deserializer) {
super(inferenceAdapter, serializer, deserializer); super(inferenceAdapter, serializer, deserializer);
this.parallelInference = parallelInference; this.parallelInference = parallelInference;
this.model = null; this.model = null;
@ -59,20 +60,68 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
} }
public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter, public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter,
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) { JsonSerializer<O> serializer, JsonDeserializer<I> deserializer) {
super(inferenceAdapter, serializer, deserializer); super(inferenceAdapter, serializer, deserializer);
this.model = model; this.model = model;
this.parallelInference = null; this.parallelInference = null;
this.parallelEnabled = false; this.parallelEnabled = false;
} }
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
BinarySerializer<O> serializer, BinaryDeserializer<I> deserializer) {
super(inferenceAdapter, serializer, deserializer);
this.parallelInference = parallelInference;
this.model = null;
this.parallelEnabled = true;
}
public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter,
JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer,
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer) {
super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer);
this.model = model;
this.parallelInference = null;
this.parallelEnabled = false;
}
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer,
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer) {
super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer);
this.parallelInference = parallelInference;
this.model = null;
this.parallelEnabled = true;
}
private O process(MultiDataSet mds) {
O result = null;
if (parallelEnabled) {
// process result
result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays()));
} else {
synchronized (this) {
if (model instanceof ComputationGraph)
result = inferenceAdapter.apply(((ComputationGraph) model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
else if (model instanceof MultiLayerNetwork) {
Preconditions.checkArgument(mds.getFeatures().length > 0 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 0),
"Input data for MultilayerNetwork is invalid!");
result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
}
}
}
return result;
}
@Override @Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException { protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
String processorReturned = ""; String processorReturned = "";
MultiDataSet mds = null;
String path = request.getPathInfo(); String path = request.getPathInfo();
if (path.equals(SERVING_ENDPOINT)) { if (path.equals(SERVING_ENDPOINT)) {
val contentType = request.getContentType(); val contentType = request.getContentType();
if (validateRequest(request,response)) { if (contentType.equals(typeJson)) {
if (validateRequest(request, response)) {
val stream = request.getInputStream(); val stream = request.getInputStream();
val bufferedReader = new BufferedReader(new InputStreamReader(stream)); val bufferedReader = new BufferedReader(new InputStreamReader(stream));
char[] charBuffer = new char[128]; char[] charBuffer = new char[128];
@ -83,31 +132,35 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
} }
val requestString = buffer.toString(); val requestString = buffer.toString();
val mds = inferenceAdapter.apply(deserializer.deserialize(requestString)); mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
}
O result = null; }
if (parallelEnabled) { else if (contentType.equals(typeBinary)) {
// process result val stream = request.getInputStream();
result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays())); int available = request.getContentLength();
if (available <= 0) {
response.sendError(411, "Content length is unavailable");
} }
else { else {
synchronized(this) { byte[] data = new byte[available];
if (model instanceof ComputationGraph) stream.read(data, 0, available);
result = inferenceAdapter.apply(((ComputationGraph)model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
else if (model instanceof MultiLayerNetwork) { mds = inferenceAdapter.apply(binaryDeserializer.deserialize(data));
Preconditions.checkArgument(mds.getFeatures().length > 1 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 1),
"Input data for MultilayerNetwork is invalid!");
result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
} }
} }
if (mds == null)
log.error("InferenceAdapter failed");
else {
val result = process(mds);
if (binarySerializer != null) {
byte[] serialized = binarySerializer.serialize(result);
response.setContentType(typeBinary);
response.setContentLength(serialized.length);
val out = response.getOutputStream();
out.write(serialized);
} }
else {
processorReturned = serializer.serialize(result); processorReturned = serializer.serialize(result);
}
} else {
// we return error otherwise
sendError(request.getRequestURI(), response);
}
try { try {
val out = response.getWriter(); val out = response.getWriter();
out.write(processorReturned); out.write(processorReturned);
@ -115,6 +168,12 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
log.error(e.getMessage()); log.error(e.getMessage());
} }
} }
}
} else {
// we return error otherwise
sendError(request.getRequestURI(), response);
}
}
/** /**
* Creates servlet to serve models * Creates servlet to serve models
@ -133,6 +192,8 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
private InferenceAdapter<I, O> inferenceAdapter; private InferenceAdapter<I, O> inferenceAdapter;
private JsonSerializer<O> serializer; private JsonSerializer<O> serializer;
private JsonDeserializer<I> deserializer; private JsonDeserializer<I> deserializer;
private BinarySerializer<O> binarySerializer;
private BinaryDeserializer<I> binaryDeserializer;
private int port; private int port;
private boolean parallelEnabled = true; private boolean parallelEnabled = true;
@ -155,7 +216,7 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
* @param serializer * @param serializer
* @return * @return
*/ */
public Builder<I,O> serializer(@NonNull JsonSerializer<O> serializer) { public Builder<I,O> serializer(JsonSerializer<O> serializer) {
this.serializer = serializer; this.serializer = serializer;
return this; return this;
} }
@ -166,11 +227,33 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
* @param deserializer * @param deserializer
* @return * @return
*/ */
public Builder<I,O> deserializer(@NonNull JsonDeserializer<I> deserializer) { public Builder<I,O> deserializer(JsonDeserializer<I> deserializer) {
this.deserializer = deserializer; this.deserializer = deserializer;
return this; return this;
} }
/**
* This method is required to specify serializer
*
* @param serializer
* @return
*/
public Builder<I,O> binarySerializer(BinarySerializer<O> serializer) {
this.binarySerializer = serializer;
return this;
}
/**
* This method allows to specify deserializer
*
* @param deserializer
* @return
*/
public Builder<I,O> binaryDeserializer(BinaryDeserializer<I> deserializer) {
this.binaryDeserializer = deserializer;
return this;
}
/** /**
* This method allows to specify port * This method allows to specify port
* *
@ -194,8 +277,8 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
} }
public DL4jServlet<I,O> build() { public DL4jServlet<I,O> build() {
return parallelEnabled ? new DL4jServlet<I, O>(pi, inferenceAdapter, serializer, deserializer) : return parallelEnabled ? new DL4jServlet<I, O>(pi, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer) :
new DL4jServlet<I, O>(model, inferenceAdapter, serializer, deserializer); new DL4jServlet<I, O>(model, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer);
} }
} }
} }

View File

@ -34,6 +34,8 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.remote.SameDiffJsonModelServer; import org.nd4j.remote.SameDiffJsonModelServer;
import org.nd4j.remote.clients.serde.BinaryDeserializer;
import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer; import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer; import org.nd4j.remote.clients.serde.JsonSerializer;
@ -70,28 +72,40 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
protected boolean enabledParallel = true; protected boolean enabledParallel = true;
protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, String[] orderedOutputNodes) { protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter<I, O> inferenceAdapter,
super(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes); JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
int port, String[] orderedInputNodes, String[] orderedOutputNodes) {
super(sdModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, orderedInputNodes, orderedOutputNodes);
} }
protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) { protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter<I, O> inferenceAdapter,
super(inferenceAdapter, serializer, deserializer, port); JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
super(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
this.cgModel = cgModel; this.cgModel = cgModel;
this.inferenceMode = inferenceMode; this.inferenceMode = inferenceMode;
this.numWorkers = numWorkers; this.numWorkers = numWorkers;
} }
protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) { protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter<I, O> inferenceAdapter,
super(inferenceAdapter, serializer, deserializer, port); JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
super(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
this.mlnModel = mlnModel; this.mlnModel = mlnModel;
this.inferenceMode = inferenceMode; this.inferenceMode = inferenceMode;
this.numWorkers = numWorkers; this.numWorkers = numWorkers;
} }
protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port) { protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter<I, O> inferenceAdapter,
super(inferenceAdapter, serializer, deserializer, port); JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
int port) {
super(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
this.parallelInference = pi; this.parallelInference = pi;
} }
@ -139,6 +153,8 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
.parallelEnabled(true) .parallelEnabled(true)
.serializer(serializer) .serializer(serializer)
.deserializer(deserializer) .deserializer(deserializer)
.binarySerializer(binarySerializer)
.binaryDeserializer(binaryDeserializer)
.inferenceAdapter(inferenceAdapter) .inferenceAdapter(inferenceAdapter)
.build(); .build();
} }
@ -147,6 +163,8 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
.parallelEnabled(false) .parallelEnabled(false)
.serializer(serializer) .serializer(serializer)
.deserializer(deserializer) .deserializer(deserializer)
.binarySerializer(binarySerializer)
.binaryDeserializer(binaryDeserializer)
.inferenceAdapter(inferenceAdapter) .inferenceAdapter(inferenceAdapter)
.build(); .build();
} }
@ -175,6 +193,8 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
private InferenceAdapter<I, O> inferenceAdapter; private InferenceAdapter<I, O> inferenceAdapter;
private JsonSerializer<O> serializer; private JsonSerializer<O> serializer;
private JsonDeserializer<I> deserializer; private JsonDeserializer<I> deserializer;
private BinarySerializer<O> binarySerializer;
private BinaryDeserializer<I> binaryDeserializer;
private InputAdapter<I> inputAdapter; private InputAdapter<I> inputAdapter;
private OutputAdapter<O> outputAdapter; private OutputAdapter<O> outputAdapter;
@ -238,7 +258,9 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
} }
/** /**
* This method allows you to specify serializer * This method allows you to specify JSON serializer.
* Incompatible with {@link #outputBinarySerializer(BinarySerializer)}
* Only one serializer - deserializer pair can be used by client and server.
* *
* @param serializer * @param serializer
* @return * @return
@ -249,7 +271,9 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
} }
/** /**
* This method allows you to specify deserializer * This method allows you to specify JSON deserializer.
* Incompatible with {@link #inputBinaryDeserializer(BinaryDeserializer)}
* Only one serializer - deserializer pair can be used by client and server.
* *
* @param deserializer * @param deserializer
* @return * @return
@ -259,6 +283,32 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
return this; return this;
} }
/**
* This method allows you to specify binary serializer.
* Incompatible with {@link #outputSerializer(JsonSerializer)}
* Only one serializer - deserializer pair can be used by client and server.
*
* @param serializer
* @return
*/
public Builder<I,O> outputBinarySerializer(@NonNull BinarySerializer<O> serializer) {
this.binarySerializer = serializer;
return this;
}
/**
* This method allows you to specify binary deserializer
* Incompatible with {@link #inputDeserializer(JsonDeserializer)}
* Only one serializer - deserializer pair can be used by client and server.
*
* @param deserializer
* @return
*/
public Builder<I,O> inputBinaryDeserializer(@NonNull BinaryDeserializer<I> deserializer) {
this.binaryDeserializer = deserializer;
return this;
}
/** /**
* This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details * This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details
* *
@ -375,17 +425,24 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured"); throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
} }
JsonModelServer server = null;
if (sdModel != null) { if (sdModel != null) {
Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "For SameDiff model serving OutputNodes should be defined"); server = new JsonModelServer<I, O>(sdModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, orderedInputNodes, orderedOutputNodes);
return new JsonModelServer<I, O>(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes); }
} else if (cgModel != null) else if (cgModel != null) {
return new JsonModelServer<I,O>(cgModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers); server = new JsonModelServer<I, O>(cgModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, inferenceMode, numWorkers);
else if (mlnModel != null) }
return new JsonModelServer<I,O>(mlnModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers); else if (mlnModel != null) {
else if (pi != null) server = new JsonModelServer<I, O>(mlnModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, inferenceMode, numWorkers);
return new JsonModelServer<I,O>(pi, inferenceAdapter, serializer, deserializer, port); }
else if (pi != null) {
server = new JsonModelServer<I, O>(pi, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
}
else else
throw new IllegalStateException("No models were defined for JsonModelServer"); throw new IllegalStateException("No models were defined for JsonModelServer");
server.enabledParallel = parallelMode;
return server;
} }
} }

View File

@ -0,0 +1,276 @@
package org.deeplearning4j.remote;
import lombok.val;
import org.datavec.image.loader.Java2DNativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.remote.helpers.ImageConversionUtils;
import org.deeplearning4j.util.ModelSerializer;
import org.junit.After;
import org.junit.Test;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.remote.clients.JsonRemoteInference;
import org.nd4j.remote.clients.serde.BinaryDeserializer;
import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.remote.clients.serde.impl.IntegerSerde;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.*;
import java.nio.Buffer;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
import static org.junit.Assert.*;
public class BinaryModelServerTest {
private final int PORT = 18080;
@After
public void pause() throws Exception {
// TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
TimeUnit.SECONDS.sleep(2);
}
// Internal test for locally defined serializers
@Test
public void testBufferedImageSerde() {
BinarySerializer<BufferedImage> serde = new BinaryModelServerTest.BufferedImageSerde();
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
byte[] serialized = serde.serialize(image);
BufferedImage deserialized = ((BufferedImageSerde) serde).deserialize(serialized);
int originalSize = image.getData().getDataBuffer().getSize();
assertEquals(originalSize, deserialized.getData().getDataBuffer().getSize());
for (int i = 0; i < originalSize; ++i) {
assertEquals(deserialized.getData().getDataBuffer().getElem(i),
image.getData().getDataBuffer().getElem(i));
}
}
@Test
public void testImageToINDArray() {
INDArray data = ImageConversionUtils.makeRandomImageAsINDArray(28,28,1);
assertNotNull(data);
}
@Test
public void testMlnMnist_ImageInput() throws Exception {
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
.outputSerializer(new IntegerSerde())
.inputBinaryDeserializer(new BufferedImageSerde())
.inferenceAdapter(new InferenceAdapter<BufferedImage, Integer>() {
@Override
public MultiDataSet apply(BufferedImage input) {
INDArray data = null;
try {
data = new Java2DNativeImageLoader().asMatrix(input);
data = data.reshape(1, 784);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return new MultiDataSet(data, null);
}
@Override
public Integer apply(INDArray... nnOutput) {
return nnOutput[0].argMax().getInt(0);
}
})
.port(PORT)
.inferenceMode(SEQUENTIAL)
.numWorkers(1)
.parallelMode(false)
.build();
val client = JsonRemoteInference.<BufferedImage, Integer>builder()
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.inputBinarySerializer(new BufferedImageSerde())
.outputDeserializer(new IntegerSerde())
.build();
try {
server.start();
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
Integer result = client.predict(image);
assertNotNull(result);
File file = new ClassPathResource("datavec-local/imagetest/0/b.bmp").getFile();
image = ImageIO.read(new FileInputStream(file));
result = client.predict(image);
assertEquals(new Integer(0), result);
file = new ClassPathResource("datavec-local/imagetest/1/a.bmp").getFile();
image = ImageIO.read(new FileInputStream(file));
result = client.predict(image);
assertEquals(new Integer(1), result);
} catch (Exception e){
e.printStackTrace();
throw e;
} finally {
server.stop();
}
}
@Test
public void testMlnMnist_ImageInput_Async() throws Exception {
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
.outputSerializer(new IntegerSerde())
.inputBinaryDeserializer(new BufferedImageSerde())
.inferenceAdapter(new InferenceAdapter<BufferedImage, Integer>() {
@Override
public MultiDataSet apply(BufferedImage input) {
INDArray data = null;
try {
data = new Java2DNativeImageLoader().asMatrix(input);
data = data.reshape(1, 784);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return new MultiDataSet(data, null);
}
@Override
public Integer apply(INDArray... nnOutput) {
return nnOutput[0].argMax().getInt(0);
}
})
.port(PORT)
.inferenceMode(SEQUENTIAL)
.numWorkers(1)
.parallelMode(false)
.build();
val client = JsonRemoteInference.<BufferedImage, Integer>builder()
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.inputBinarySerializer(new BufferedImageSerde())
.outputDeserializer(new IntegerSerde())
.build();
try {
server.start();
BufferedImage[] images = new BufferedImage[3];
images[0] = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
File file = new ClassPathResource("datavec-local/imagetest/0/b.bmp").getFile();
images[1] = ImageIO.read(new FileInputStream(file));
file = new ClassPathResource("datavec-local/imagetest/1/a.bmp").getFile();
images[2] = ImageIO.read(new FileInputStream(file));
Future<Integer>[] results = new Future[3];
for (int i = 0; i < images.length; ++i) {
results[i] = client.predictAsync(images[i]);
assertNotNull(results[i]);
}
assertNotNull(results[0].get());
assertEquals(new Integer(0), results[1].get());
assertEquals(new Integer(1), results[2].get());
} catch (Exception e){
e.printStackTrace();
throw e;
} finally {
server.stop();
}
}
@Test
public void testBinaryIn_BinaryOut() throws Exception {
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net)
.outputBinarySerializer(new BufferedImageSerde())
.inputBinaryDeserializer(new BufferedImageSerde())
.inferenceAdapter(new InferenceAdapter<BufferedImage, BufferedImage>() {
@Override
public MultiDataSet apply(BufferedImage input) {
INDArray data = null;
try {
data = new Java2DNativeImageLoader().asMatrix(input);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return new MultiDataSet(data, null);
}
@Override
public BufferedImage apply(INDArray... nnOutput) {
return ImageConversionUtils.makeRandomBufferedImage(28,28,3);
}
})
.port(PORT)
.inferenceMode(SEQUENTIAL)
.numWorkers(1)
.parallelMode(false)
.build();
val client = JsonRemoteInference.<BufferedImage, BufferedImage>builder()
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.inputBinarySerializer(new BufferedImageSerde())
.outputBinaryDeserializer(new BufferedImageSerde())
.build();
try {
server.start();
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
BufferedImage result = client.predict(image);
assertNotNull(result);
assertEquals(28, result.getHeight());
assertEquals(28, result.getWidth());
} catch (Exception e){
e.printStackTrace();
throw e;
} finally {
server.stop();
}
}
private static class BufferedImageSerde implements BinarySerializer<BufferedImage>, BinaryDeserializer<BufferedImage> {
@Override
public BufferedImage deserialize(byte[] buffer) {
try {
BufferedImage img = ImageIO.read(new ByteArrayInputStream(buffer));
return img;
} catch (IOException e){
throw new RuntimeException(e);
}
}
@Override
public byte[] serialize(BufferedImage image) {
try{
val baos = new ByteArrayOutputStream();
ImageIO.write(image, "bmp", baos);
byte[] bytes = baos.toByteArray();
return bytes;
} catch (IOException e){
throw new RuntimeException(e);
}
}
}
}

View File

@ -0,0 +1,82 @@
package org.deeplearning4j.remote.helpers;
import lombok.val;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.bytedeco.opencv.opencv_core.Mat;
import org.datavec.image.loader.Java2DNativeImageLoader;
import org.datavec.image.loader.NativeImageLoader;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.Random;
import static org.bytedeco.opencv.global.opencv_core.CV_8UC;
public class ImageConversionUtils {
public static Mat makeRandomImage(int height, int width, int channels) {
if (height <= 0) {
height = new Random().nextInt() % 100 + 100;
}
if (width <= 0) {
width = new Random().nextInt() % 100 + 100;
}
Mat img = new Mat(height, width, CV_8UC(channels));
UByteIndexer idx = img.createIndexer();
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
for (int k = 0; k < channels; k++) {
idx.put(i, j, k, new Random().nextInt());
}
}
}
return img;
}
public static BufferedImage makeRandomBufferedImage(int height, int width, int channels) {
Mat img = makeRandomImage(height, width, channels);
OpenCVFrameConverter.ToMat c = new OpenCVFrameConverter.ToMat();
Java2DFrameConverter c2 = new Java2DFrameConverter();
return c2.convert(c.convert(img));
}
public static INDArray convert(BufferedImage image) {
INDArray retVal = null;
try {
retVal = new Java2DNativeImageLoader(image.getHeight(), image.getWidth(), image.getRaster().getNumBands()).
asRowVector(image);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return retVal;
}
public static INDArray convert(Mat image) {
INDArray retVal = null;
try {
new NativeImageLoader().asRowVector(image);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return retVal;
}
public static BufferedImage convert(INDArray input) {
return new Java2DNativeImageLoader(input.rows(),input.columns()).asBufferedImage(input);
}
public static INDArray makeRandomImageAsINDArray(int height, int width, int channels) {
val image = makeRandomBufferedImage(height, width, channels);
INDArray retVal = convert(image);
return retVal;
}
}

View File

@ -69,7 +69,7 @@
</developer> </developer>
<developer> <developer>
<id>raver119</id> <id>raver119</id>
<name>raver119</name> <name>Vyacheslav Kokorin</name>
</developer> </developer>
<developer> <developer>
<id>saudet</id> <id>saudet</id>

View File

@ -24,10 +24,12 @@ import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.json.JSONObject; import org.json.JSONObject;
import org.nd4j.remote.clients.serde.JsonDeserializer; import org.nd4j.remote.clients.serde.*;
import org.nd4j.remote.clients.serde.JsonSerializer;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -49,21 +51,57 @@ import java.util.concurrent.TimeoutException;
@Slf4j @Slf4j
public class JsonRemoteInference<I, O> { public class JsonRemoteInference<I, O> {
private String endpointAddress; private String endpointAddress;
// JSON serializer/deserializer and binary serializer/deserializer are mutually exclusive.
private JsonSerializer<I> serializer; private JsonSerializer<I> serializer;
private JsonDeserializer<O> deserializer; private JsonDeserializer<O> deserializer;
private BinarySerializer<I> binarySerializer;
private BinaryDeserializer<O> binaryDeserializer;
private final static String APPLICATION_JSON = "application/json";
private final static String APPLICATION_OCTET_STREAM = "application/octet-stream";
@Builder @Builder
public JsonRemoteInference(@NonNull String endpointAddress, @NonNull JsonSerializer<I> inputSerializer, @NonNull JsonDeserializer<O> outputDeserializer) { public JsonRemoteInference(@NonNull String endpointAddress,
JsonSerializer<I> inputSerializer, JsonDeserializer<O> outputDeserializer,
BinarySerializer<I> inputBinarySerializer, BinaryDeserializer<O> outputBinaryDeserializer) {
this.endpointAddress = endpointAddress; this.endpointAddress = endpointAddress;
this.serializer = inputSerializer; this.serializer = inputSerializer;
this.deserializer = outputDeserializer; this.deserializer = outputDeserializer;
this.binarySerializer = inputBinarySerializer;
this.binaryDeserializer = outputBinaryDeserializer;
if (serializer != null && binarySerializer != null || serializer == null && binarySerializer == null)
throw new IllegalStateException("Binary and JSON serializers/deserializers are mutually exclusive and mandatory.");
} }
private O processResponse(HttpResponse<String> response) throws IOException { private O processResponse(HttpResponse<String> response) throws IOException {
if (response.getStatus() != 200) if (response.getStatus() != 200)
throw new IOException("Inference request returned bad error code: " + response.getStatus()); throw new IOException("Inference request returned bad error code: " + response.getStatus());
O result = deserializer.deserialize(response.getBody()); O result = deserializer.deserialize(response.getBody());
if (result == null) {
throw new IOException("Deserialization failed!");
}
return result;
}
private O processResponseBinary(HttpResponse<InputStream> response) throws IOException {
if (response.getStatus() != 200)
throw new IOException("Inference request returned bad error code: " + response.getStatus());
List<String> values = response.getHeaders().get("Content-Length");
if (values == null || values.size() < 1) {
throw new IOException("Content-Length is required for binary data");
}
String strLength = values.get(0);
byte[] bytes = new byte[Integer.parseInt(strLength)];
response.getBody().read(bytes);
O result = binaryDeserializer.deserialize(bytes);
if (result == null) { if (result == null) {
throw new IOException("Deserialization failed!"); throw new IOException("Deserialization failed!");
} }
@ -79,12 +117,30 @@ public class JsonRemoteInference<I, O> {
*/ */
public O predict(I input) throws IOException { public O predict(I input) throws IOException {
try { try {
val stringResult = Unirest.post(endpointAddress) if (binarySerializer != null && binaryDeserializer != null) {
.header("Content-Type", "application/json") HttpResponse<InputStream> response =
.header("Accept", "application/json") Unirest.post(endpointAddress)
.header("Content-Type", APPLICATION_OCTET_STREAM)
.header("Accept", APPLICATION_OCTET_STREAM)
.body(binarySerializer.serialize(input)).asBinary();
return processResponseBinary(response);
}
else if (binarySerializer != null && binaryDeserializer == null) {
HttpResponse<String> response =
Unirest.post(endpointAddress)
.header("Content-Type", APPLICATION_OCTET_STREAM)
.header("Accept", APPLICATION_OCTET_STREAM)
.body(binarySerializer.serialize(input)).asString();
return processResponse(response);
}
else {
HttpResponse<String> response = Unirest.post(endpointAddress)
.header("Content-Type", APPLICATION_JSON)
.header("Accept", APPLICATION_JSON)
.body(new JSONObject(serializer.serialize(input))).asString(); .body(new JSONObject(serializer.serialize(input))).asString();
return processResponse(response);
}
return processResponse(stringResult);
} catch (UnirestException e) { } catch (UnirestException e) {
throw new IOException(e); throw new IOException(e);
} }
@ -96,11 +152,19 @@ public class JsonRemoteInference<I, O> {
* @return * @return
*/ */
public Future<O> predictAsync(I input) { public Future<O> predictAsync(I input) {
val stringResult = Unirest.post(endpointAddress)
Future<HttpResponse<String>> response = binarySerializer != null ?
Unirest.post(endpointAddress)
.header("Content-Type", "application/octet-stream")
.header("Accept", "application/octet-stream")
.body(binarySerializer.serialize(input)).asStringAsync() :
Unirest.post(endpointAddress)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Accept", "application/json") .header("Accept", "application/json")
.body(new JSONObject(serializer.serialize(input))).asStringAsync(); .body(new JSONObject(serializer.serialize(input))).asStringAsync();
return new InferenceFuture(stringResult);
return new InferenceFuture(response);
} }
/** /**
@ -151,3 +215,4 @@ public class JsonRemoteInference<I, O> {
} }
} }
} }

View File

@ -0,0 +1,32 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde;
/**
* This interface describes basic binary deserializer interface used for remote inference
* @param <T> type of the deserializable class
*
* @author Alexander Stoyakin
*/
public interface BinaryDeserializer<T> {
/**
* This method deserializes binary data to arbitrary object.
* @param byte buffer
* @return deserialized object
*/
T deserialize(byte[] buffer);
}

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde;
/**
* This interface describes basic binary serializer interface used for remote inference
* @param <T> type of the serializable class
*
* @author Alexander Stoyakin
*/
public interface BinarySerializer<T> {
/**
* This method serializes given object into byte buffer
*
* @param o object to be serialized
* @return
*/
byte[] serialize(T o);
}

View File

@ -26,6 +26,8 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.remote.clients.serde.BinaryDeserializer;
import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer; import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer; import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.adapters.InferenceAdapter; import org.nd4j.adapters.InferenceAdapter;
@ -51,6 +53,8 @@ public class SameDiffJsonModelServer<I, O> {
protected SameDiff sdModel; protected SameDiff sdModel;
protected final JsonSerializer<O> serializer; protected final JsonSerializer<O> serializer;
protected final JsonDeserializer<I> deserializer; protected final JsonDeserializer<I> deserializer;
protected final BinarySerializer<O> binarySerializer;
protected final BinaryDeserializer<I> binaryDeserializer;
protected final InferenceAdapter<I, O> inferenceAdapter; protected final InferenceAdapter<I, O> inferenceAdapter;
protected final int port; protected final int port;
@ -64,9 +68,18 @@ public class SameDiffJsonModelServer<I, O> {
protected String[] orderedInputNodes; protected String[] orderedInputNodes;
protected String[] orderedOutputNodes; protected String[] orderedOutputNodes;
protected SameDiffJsonModelServer(@NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer, int port) { protected SameDiffJsonModelServer(@NonNull InferenceAdapter<I, O> inferenceAdapter,
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
int port) {
Preconditions.checkArgument(port > 0 && port < 65535, "TCP port must be in range of 0..65535"); Preconditions.checkArgument(port > 0 && port < 65535, "TCP port must be in range of 0..65535");
Preconditions.checkArgument(serializer == null && binarySerializer == null ||
serializer != null && binarySerializer == null ||
serializer == null && binarySerializer != null,
"JSON and binary serializers/deserializers are mutually exclusive and mandatory.");
this.binarySerializer = binarySerializer;
this.binaryDeserializer = binaryDeserializer;
this.inferenceAdapter = inferenceAdapter; this.inferenceAdapter = inferenceAdapter;
this.serializer = serializer; this.serializer = serializer;
this.deserializer = deserializer; this.deserializer = deserializer;
@ -74,8 +87,11 @@ public class SameDiffJsonModelServer<I, O> {
} }
//@Builder //@Builder
public SameDiffJsonModelServer(SameDiff sdModel, @NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, @NonNull String[] orderedOutputNodes) { public SameDiffJsonModelServer(SameDiff sdModel, @NonNull InferenceAdapter<I, O> inferenceAdapter,
this(inferenceAdapter, serializer, deserializer, port); JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
int port, String[] orderedInputNodes, @NonNull String[] orderedOutputNodes) {
this(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
this.sdModel = sdModel; this.sdModel = sdModel;
this.orderedInputNodes = orderedInputNodes; this.orderedInputNodes = orderedInputNodes;
this.orderedOutputNodes = orderedOutputNodes; this.orderedOutputNodes = orderedOutputNodes;
@ -282,7 +298,7 @@ public class SameDiffJsonModelServer<I, O> {
} else } else
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured"); throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
} }
return new SameDiffJsonModelServer<I,O>(sameDiff, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes); return new SameDiffJsonModelServer<I,O>(sameDiff, inferenceAdapter, serializer, deserializer, null, null, port, orderedInputNodes, orderedOutputNodes);
} }
} }
} }

View File

@ -21,6 +21,8 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.remote.clients.serde.BinaryDeserializer;
import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer; import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer; import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.adapters.InferenceAdapter; import org.nd4j.adapters.InferenceAdapter;
@ -35,6 +37,7 @@ import java.io.InputStreamReader;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import static javax.ws.rs.core.MediaType.APPLICATION_JSON; import static javax.ws.rs.core.MediaType.APPLICATION_JSON;
import static javax.ws.rs.core.MediaType.APPLICATION_OCTET_STREAM;
/** /**
* This servlet provides SameDiff model serving capabilities * This servlet provides SameDiff model serving capabilities
@ -50,9 +53,14 @@ import static javax.ws.rs.core.MediaType.APPLICATION_JSON;
@Builder @Builder
public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> { public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
protected static final String typeJson = APPLICATION_JSON;
protected static final String typeBinary = APPLICATION_OCTET_STREAM;
protected SameDiff sdModel; protected SameDiff sdModel;
protected JsonSerializer<O> serializer; protected JsonSerializer<O> serializer;
protected JsonDeserializer<I> deserializer; protected JsonDeserializer<I> deserializer;
protected BinarySerializer<O> binarySerializer;
protected BinaryDeserializer<I> binaryDeserializer;
protected InferenceAdapter<I, O> inferenceAdapter; protected InferenceAdapter<I, O> inferenceAdapter;
protected String[] orderedInputNodes; protected String[] orderedInputNodes;
@ -60,14 +68,36 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
protected final static String SERVING_ENDPOINT = "/v1/serving"; protected final static String SERVING_ENDPOINT = "/v1/serving";
protected final static String LISTING_ENDPOINT = "/v1"; protected final static String LISTING_ENDPOINT = "/v1";
protected final static int PAYLOAD_SIZE_LIMIT = 10 * 1024; // TODO: should be customizable protected final static int PAYLOAD_SIZE_LIMIT = 10 * 1024 * 1024; // TODO: should be customizable
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer){ protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer){
this.serializer = serializer; this.serializer = serializer;
this.deserializer = deserializer; this.deserializer = deserializer;
this.inferenceAdapter = inferenceAdapter; this.inferenceAdapter = inferenceAdapter;
} }
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter,
BinarySerializer<O> serializer, BinaryDeserializer<I> deserializer){
this.binarySerializer = serializer;
this.binaryDeserializer = deserializer;
this.inferenceAdapter = inferenceAdapter;
}
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter,
JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer,
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer){
this.serializer = jsonSerializer;
this.deserializer = jsonDeserializer;
this.binarySerializer = binarySerializer;
this.binaryDeserializer = binaryDeserializer;
this.inferenceAdapter = inferenceAdapter;
if (serializer != null && binarySerializer != null || serializer == null && binarySerializer == null)
throw new IllegalStateException("Binary and JSON serializers/deserializers are mutually exclusive and mandatory.");
}
@Override @Override
public void init(ServletConfig servletConfig) throws ServletException { public void init(ServletConfig servletConfig) throws ServletException {
// //
@ -108,7 +138,7 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
protected boolean validateRequest(HttpServletRequest request, HttpServletResponse response) protected boolean validateRequest(HttpServletRequest request, HttpServletResponse response)
throws IOException{ throws IOException{
val contentType = request.getContentType(); val contentType = request.getContentType();
if (!StringUtils.equals(contentType, APPLICATION_JSON)) { if (!StringUtils.equals(contentType, typeJson)) {
sendBadContentType(contentType, response); sendBadContentType(contentType, response);
int contentLength = request.getContentLength(); int contentLength = request.getContentLength();
if (contentLength > PAYLOAD_SIZE_LIMIT) { if (contentLength > PAYLOAD_SIZE_LIMIT) {
@ -125,7 +155,7 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
String path = request.getPathInfo(); String path = request.getPathInfo();
if (path.equals(LISTING_ENDPOINT)) { if (path.equals(LISTING_ENDPOINT)) {
val contentType = request.getContentType(); val contentType = request.getContentType();
if (!StringUtils.equals(contentType, APPLICATION_JSON)) { if (!StringUtils.equals(contentType, typeJson)) {
sendBadContentType(contentType, response); sendBadContentType(contentType, response);
} }
processorReturned = processor.listEndpoints(); processorReturned = processor.listEndpoints();
@ -147,7 +177,7 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
String path = request.getPathInfo(); String path = request.getPathInfo();
if (path.equals(SERVING_ENDPOINT)) { if (path.equals(SERVING_ENDPOINT)) {
val contentType = request.getContentType(); val contentType = request.getContentType();
/*Preconditions.checkArgument(StringUtils.equals(contentType, APPLICATION_JSON), /*Preconditions.checkArgument(StringUtils.equals(contentType, typeJson),
"Content type is " + contentType);*/ "Content type is " + contentType);*/
if (validateRequest(request,response)) { if (validateRequest(request,response)) {
val stream = request.getInputStream(); val stream = request.getInputStream();