[WIP] Handling binary data in DL4J servlet (#135)
* Binary deser Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Binary mode for servlet Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added test Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * -sRandom image generation copied from datavec * -sRandom image generation copied from datavec * Remove serialization constraints * Fix: Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Removed unused code Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Resources usage Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Async inference Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Cleanup Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * -sTest corrected * Cleanup Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Mutually eclusive serializers/deserializers Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Binary output supported Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Binary out test Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * - types hardcoded - increased payload size limit Signed-off-by: raver119 <raver119@gmail.com> * change types constant Signed-off-by: raver119 <raver119@gmail.com>master
parent
8e3d569f18
commit
2e99bc2dee
|
@ -18,18 +18,17 @@ package org.deeplearning4j.remote;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.parallelism.ParallelInference;
|
import org.deeplearning4j.parallelism.ParallelInference;
|
||||||
import org.nd4j.adapters.InferenceAdapter;
|
import org.nd4j.adapters.InferenceAdapter;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
import org.nd4j.remote.clients.serde.BinaryDeserializer;
|
||||||
|
import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||||
import org.nd4j.adapters.InferenceAdapter;
|
|
||||||
import org.nd4j.remote.serving.SameDiffServlet;
|
import org.nd4j.remote.serving.SameDiffServlet;
|
||||||
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
@ -38,6 +37,8 @@ import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStreamReader;
|
import java.io.InputStreamReader;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @author astoyakin
|
* @author astoyakin
|
||||||
|
@ -51,7 +52,7 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
||||||
protected boolean parallelEnabled = true;
|
protected boolean parallelEnabled = true;
|
||||||
|
|
||||||
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||||
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
|
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer) {
|
||||||
super(inferenceAdapter, serializer, deserializer);
|
super(inferenceAdapter, serializer, deserializer);
|
||||||
this.parallelInference = parallelInference;
|
this.parallelInference = parallelInference;
|
||||||
this.model = null;
|
this.model = null;
|
||||||
|
@ -59,61 +60,119 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
||||||
}
|
}
|
||||||
|
|
||||||
public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||||
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
|
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer) {
|
||||||
super(inferenceAdapter, serializer, deserializer);
|
super(inferenceAdapter, serializer, deserializer);
|
||||||
this.model = model;
|
this.model = model;
|
||||||
this.parallelInference = null;
|
this.parallelInference = null;
|
||||||
this.parallelEnabled = false;
|
this.parallelEnabled = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||||
|
BinarySerializer<O> serializer, BinaryDeserializer<I> deserializer) {
|
||||||
|
super(inferenceAdapter, serializer, deserializer);
|
||||||
|
this.parallelInference = parallelInference;
|
||||||
|
this.model = null;
|
||||||
|
this.parallelEnabled = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||||
|
JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer,
|
||||||
|
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer) {
|
||||||
|
super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer);
|
||||||
|
this.model = model;
|
||||||
|
this.parallelInference = null;
|
||||||
|
this.parallelEnabled = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||||
|
JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer,
|
||||||
|
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer) {
|
||||||
|
super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer);
|
||||||
|
this.parallelInference = parallelInference;
|
||||||
|
this.model = null;
|
||||||
|
this.parallelEnabled = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private O process(MultiDataSet mds) {
|
||||||
|
O result = null;
|
||||||
|
if (parallelEnabled) {
|
||||||
|
// process result
|
||||||
|
result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
||||||
|
} else {
|
||||||
|
synchronized (this) {
|
||||||
|
if (model instanceof ComputationGraph)
|
||||||
|
result = inferenceAdapter.apply(((ComputationGraph) model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
||||||
|
else if (model instanceof MultiLayerNetwork) {
|
||||||
|
Preconditions.checkArgument(mds.getFeatures().length > 0 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 0),
|
||||||
|
"Input data for MultilayerNetwork is invalid!");
|
||||||
|
result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
|
||||||
|
mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
|
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
|
||||||
String processorReturned = "";
|
String processorReturned = "";
|
||||||
|
MultiDataSet mds = null;
|
||||||
String path = request.getPathInfo();
|
String path = request.getPathInfo();
|
||||||
if (path.equals(SERVING_ENDPOINT)) {
|
if (path.equals(SERVING_ENDPOINT)) {
|
||||||
val contentType = request.getContentType();
|
val contentType = request.getContentType();
|
||||||
if (validateRequest(request,response)) {
|
if (contentType.equals(typeJson)) {
|
||||||
val stream = request.getInputStream();
|
if (validateRequest(request, response)) {
|
||||||
val bufferedReader = new BufferedReader(new InputStreamReader(stream));
|
val stream = request.getInputStream();
|
||||||
char[] charBuffer = new char[128];
|
val bufferedReader = new BufferedReader(new InputStreamReader(stream));
|
||||||
int bytesRead = -1;
|
char[] charBuffer = new char[128];
|
||||||
val buffer = new StringBuilder();
|
int bytesRead = -1;
|
||||||
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
|
val buffer = new StringBuilder();
|
||||||
buffer.append(charBuffer, 0, bytesRead);
|
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
|
||||||
|
buffer.append(charBuffer, 0, bytesRead);
|
||||||
|
}
|
||||||
|
val requestString = buffer.toString();
|
||||||
|
|
||||||
|
mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
|
||||||
}
|
}
|
||||||
val requestString = buffer.toString();
|
}
|
||||||
|
else if (contentType.equals(typeBinary)) {
|
||||||
val mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
|
val stream = request.getInputStream();
|
||||||
|
int available = request.getContentLength();
|
||||||
O result = null;
|
if (available <= 0) {
|
||||||
if (parallelEnabled) {
|
response.sendError(411, "Content length is unavailable");
|
||||||
// process result
|
|
||||||
result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
synchronized(this) {
|
byte[] data = new byte[available];
|
||||||
if (model instanceof ComputationGraph)
|
stream.read(data, 0, available);
|
||||||
result = inferenceAdapter.apply(((ComputationGraph)model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
|
||||||
else if (model instanceof MultiLayerNetwork) {
|
mds = inferenceAdapter.apply(binaryDeserializer.deserialize(data));
|
||||||
Preconditions.checkArgument(mds.getFeatures().length > 1 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 1),
|
}
|
||||||
"Input data for MultilayerNetwork is invalid!");
|
}
|
||||||
result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
|
if (mds == null)
|
||||||
mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
|
log.error("InferenceAdapter failed");
|
||||||
}
|
else {
|
||||||
|
val result = process(mds);
|
||||||
|
if (binarySerializer != null) {
|
||||||
|
byte[] serialized = binarySerializer.serialize(result);
|
||||||
|
response.setContentType(typeBinary);
|
||||||
|
response.setContentLength(serialized.length);
|
||||||
|
val out = response.getOutputStream();
|
||||||
|
out.write(serialized);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
processorReturned = serializer.serialize(result);
|
||||||
|
try {
|
||||||
|
val out = response.getWriter();
|
||||||
|
out.write(processorReturned);
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error(e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
processorReturned = serializer.serialize(result);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// we return error otherwise
|
// we return error otherwise
|
||||||
sendError(request.getRequestURI(), response);
|
sendError(request.getRequestURI(), response);
|
||||||
}
|
}
|
||||||
try {
|
|
||||||
val out = response.getWriter();
|
|
||||||
out.write(processorReturned);
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error(e.getMessage());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -133,6 +192,8 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
||||||
private InferenceAdapter<I, O> inferenceAdapter;
|
private InferenceAdapter<I, O> inferenceAdapter;
|
||||||
private JsonSerializer<O> serializer;
|
private JsonSerializer<O> serializer;
|
||||||
private JsonDeserializer<I> deserializer;
|
private JsonDeserializer<I> deserializer;
|
||||||
|
private BinarySerializer<O> binarySerializer;
|
||||||
|
private BinaryDeserializer<I> binaryDeserializer;
|
||||||
private int port;
|
private int port;
|
||||||
private boolean parallelEnabled = true;
|
private boolean parallelEnabled = true;
|
||||||
|
|
||||||
|
@ -155,7 +216,7 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
||||||
* @param serializer
|
* @param serializer
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Builder<I,O> serializer(@NonNull JsonSerializer<O> serializer) {
|
public Builder<I,O> serializer(JsonSerializer<O> serializer) {
|
||||||
this.serializer = serializer;
|
this.serializer = serializer;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -166,11 +227,33 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
||||||
* @param deserializer
|
* @param deserializer
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Builder<I,O> deserializer(@NonNull JsonDeserializer<I> deserializer) {
|
public Builder<I,O> deserializer(JsonDeserializer<I> deserializer) {
|
||||||
this.deserializer = deserializer;
|
this.deserializer = deserializer;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method is required to specify serializer
|
||||||
|
*
|
||||||
|
* @param serializer
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Builder<I,O> binarySerializer(BinarySerializer<O> serializer) {
|
||||||
|
this.binarySerializer = serializer;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method allows to specify deserializer
|
||||||
|
*
|
||||||
|
* @param deserializer
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Builder<I,O> binaryDeserializer(BinaryDeserializer<I> deserializer) {
|
||||||
|
this.binaryDeserializer = deserializer;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method allows to specify port
|
* This method allows to specify port
|
||||||
*
|
*
|
||||||
|
@ -194,8 +277,8 @@ public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
||||||
}
|
}
|
||||||
|
|
||||||
public DL4jServlet<I,O> build() {
|
public DL4jServlet<I,O> build() {
|
||||||
return parallelEnabled ? new DL4jServlet<I, O>(pi, inferenceAdapter, serializer, deserializer) :
|
return parallelEnabled ? new DL4jServlet<I, O>(pi, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer) :
|
||||||
new DL4jServlet<I, O>(model, inferenceAdapter, serializer, deserializer);
|
new DL4jServlet<I, O>(model, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,6 +34,8 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.remote.SameDiffJsonModelServer;
|
import org.nd4j.remote.SameDiffJsonModelServer;
|
||||||
|
import org.nd4j.remote.clients.serde.BinaryDeserializer;
|
||||||
|
import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||||
|
|
||||||
|
@ -70,28 +72,40 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
||||||
|
|
||||||
protected boolean enabledParallel = true;
|
protected boolean enabledParallel = true;
|
||||||
|
|
||||||
protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, String[] orderedOutputNodes) {
|
protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter<I, O> inferenceAdapter,
|
||||||
super(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
|
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||||
|
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||||
|
int port, String[] orderedInputNodes, String[] orderedOutputNodes) {
|
||||||
|
super(sdModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter<I, O> inferenceAdapter,
|
||||||
super(inferenceAdapter, serializer, deserializer, port);
|
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||||
|
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||||
|
int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
||||||
|
super(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
|
||||||
|
|
||||||
this.cgModel = cgModel;
|
this.cgModel = cgModel;
|
||||||
this.inferenceMode = inferenceMode;
|
this.inferenceMode = inferenceMode;
|
||||||
this.numWorkers = numWorkers;
|
this.numWorkers = numWorkers;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter<I, O> inferenceAdapter,
|
||||||
super(inferenceAdapter, serializer, deserializer, port);
|
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||||
|
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||||
|
int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
||||||
|
super(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
|
||||||
|
|
||||||
this.mlnModel = mlnModel;
|
this.mlnModel = mlnModel;
|
||||||
this.inferenceMode = inferenceMode;
|
this.inferenceMode = inferenceMode;
|
||||||
this.numWorkers = numWorkers;
|
this.numWorkers = numWorkers;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port) {
|
protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter<I, O> inferenceAdapter,
|
||||||
super(inferenceAdapter, serializer, deserializer, port);
|
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||||
|
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||||
|
int port) {
|
||||||
|
super(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
|
||||||
|
|
||||||
this.parallelInference = pi;
|
this.parallelInference = pi;
|
||||||
}
|
}
|
||||||
|
@ -136,19 +150,23 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
servingServlet = new DL4jServlet.Builder<I, O>(parallelInference)
|
servingServlet = new DL4jServlet.Builder<I, O>(parallelInference)
|
||||||
.parallelEnabled(true)
|
.parallelEnabled(true)
|
||||||
.serializer(serializer)
|
.serializer(serializer)
|
||||||
.deserializer(deserializer)
|
.deserializer(deserializer)
|
||||||
.inferenceAdapter(inferenceAdapter)
|
.binarySerializer(binarySerializer)
|
||||||
.build();
|
.binaryDeserializer(binaryDeserializer)
|
||||||
|
.inferenceAdapter(inferenceAdapter)
|
||||||
|
.build();
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
servingServlet = new DL4jServlet.Builder<I, O>(model)
|
servingServlet = new DL4jServlet.Builder<I, O>(model)
|
||||||
.parallelEnabled(false)
|
.parallelEnabled(false)
|
||||||
.serializer(serializer)
|
.serializer(serializer)
|
||||||
.deserializer(deserializer)
|
.deserializer(deserializer)
|
||||||
.inferenceAdapter(inferenceAdapter)
|
.binarySerializer(binarySerializer)
|
||||||
.build();
|
.binaryDeserializer(binaryDeserializer)
|
||||||
|
.inferenceAdapter(inferenceAdapter)
|
||||||
|
.build();
|
||||||
}
|
}
|
||||||
start(port, servingServlet);
|
start(port, servingServlet);
|
||||||
}
|
}
|
||||||
|
@ -175,6 +193,8 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
||||||
private InferenceAdapter<I, O> inferenceAdapter;
|
private InferenceAdapter<I, O> inferenceAdapter;
|
||||||
private JsonSerializer<O> serializer;
|
private JsonSerializer<O> serializer;
|
||||||
private JsonDeserializer<I> deserializer;
|
private JsonDeserializer<I> deserializer;
|
||||||
|
private BinarySerializer<O> binarySerializer;
|
||||||
|
private BinaryDeserializer<I> binaryDeserializer;
|
||||||
|
|
||||||
private InputAdapter<I> inputAdapter;
|
private InputAdapter<I> inputAdapter;
|
||||||
private OutputAdapter<O> outputAdapter;
|
private OutputAdapter<O> outputAdapter;
|
||||||
|
@ -238,7 +258,9 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method allows you to specify serializer
|
* This method allows you to specify JSON serializer.
|
||||||
|
* Incompatible with {@link #outputBinarySerializer(BinarySerializer)}
|
||||||
|
* Only one serializer - deserializer pair can be used by client and server.
|
||||||
*
|
*
|
||||||
* @param serializer
|
* @param serializer
|
||||||
* @return
|
* @return
|
||||||
|
@ -249,7 +271,9 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method allows you to specify deserializer
|
* This method allows you to specify JSON deserializer.
|
||||||
|
* Incompatible with {@link #inputBinaryDeserializer(BinaryDeserializer)}
|
||||||
|
* Only one serializer - deserializer pair can be used by client and server.
|
||||||
*
|
*
|
||||||
* @param deserializer
|
* @param deserializer
|
||||||
* @return
|
* @return
|
||||||
|
@ -259,6 +283,32 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method allows you to specify binary serializer.
|
||||||
|
* Incompatible with {@link #outputSerializer(JsonSerializer)}
|
||||||
|
* Only one serializer - deserializer pair can be used by client and server.
|
||||||
|
*
|
||||||
|
* @param serializer
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Builder<I,O> outputBinarySerializer(@NonNull BinarySerializer<O> serializer) {
|
||||||
|
this.binarySerializer = serializer;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method allows you to specify binary deserializer
|
||||||
|
* Incompatible with {@link #inputDeserializer(JsonDeserializer)}
|
||||||
|
* Only one serializer - deserializer pair can be used by client and server.
|
||||||
|
*
|
||||||
|
* @param deserializer
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Builder<I,O> inputBinaryDeserializer(@NonNull BinaryDeserializer<I> deserializer) {
|
||||||
|
this.binaryDeserializer = deserializer;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details
|
* This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details
|
||||||
*
|
*
|
||||||
|
@ -375,17 +425,24 @@ public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
||||||
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
|
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
JsonModelServer server = null;
|
||||||
if (sdModel != null) {
|
if (sdModel != null) {
|
||||||
Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "For SameDiff model serving OutputNodes should be defined");
|
server = new JsonModelServer<I, O>(sdModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||||
return new JsonModelServer<I, O>(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
|
}
|
||||||
} else if (cgModel != null)
|
else if (cgModel != null) {
|
||||||
return new JsonModelServer<I,O>(cgModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
|
server = new JsonModelServer<I, O>(cgModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, inferenceMode, numWorkers);
|
||||||
else if (mlnModel != null)
|
}
|
||||||
return new JsonModelServer<I,O>(mlnModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
|
else if (mlnModel != null) {
|
||||||
else if (pi != null)
|
server = new JsonModelServer<I, O>(mlnModel, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port, inferenceMode, numWorkers);
|
||||||
return new JsonModelServer<I,O>(pi, inferenceAdapter, serializer, deserializer, port);
|
}
|
||||||
else
|
else if (pi != null) {
|
||||||
throw new IllegalStateException("No models were defined for JsonModelServer");
|
server = new JsonModelServer<I, O>(pi, inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
throw new IllegalStateException("No models were defined for JsonModelServer");
|
||||||
|
|
||||||
|
server.enabledParallel = parallelMode;
|
||||||
|
return server;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,276 @@
|
||||||
|
package org.deeplearning4j.remote;
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
|
import org.datavec.image.loader.Java2DNativeImageLoader;
|
||||||
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.remote.helpers.ImageConversionUtils;
|
||||||
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.adapters.InferenceAdapter;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
import org.nd4j.remote.clients.JsonRemoteInference;
|
||||||
|
import org.nd4j.remote.clients.serde.BinaryDeserializer;
|
||||||
|
import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||||
|
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||||
|
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||||
|
import org.nd4j.remote.clients.serde.impl.IntegerSerde;
|
||||||
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
import javax.imageio.ImageIO;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.*;
|
||||||
|
import java.nio.Buffer;
|
||||||
|
import java.util.concurrent.Future;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
public class BinaryModelServerTest {
|
||||||
|
private final int PORT = 18080;
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void pause() throws Exception {
|
||||||
|
// TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
|
||||||
|
TimeUnit.SECONDS.sleep(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Internal test for locally defined serializers
|
||||||
|
@Test
|
||||||
|
public void testBufferedImageSerde() {
|
||||||
|
BinarySerializer<BufferedImage> serde = new BinaryModelServerTest.BufferedImageSerde();
|
||||||
|
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||||||
|
byte[] serialized = serde.serialize(image);
|
||||||
|
|
||||||
|
BufferedImage deserialized = ((BufferedImageSerde) serde).deserialize(serialized);
|
||||||
|
int originalSize = image.getData().getDataBuffer().getSize();
|
||||||
|
assertEquals(originalSize, deserialized.getData().getDataBuffer().getSize());
|
||||||
|
for (int i = 0; i < originalSize; ++i) {
|
||||||
|
assertEquals(deserialized.getData().getDataBuffer().getElem(i),
|
||||||
|
image.getData().getDataBuffer().getElem(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testImageToINDArray() {
|
||||||
|
INDArray data = ImageConversionUtils.makeRandomImageAsINDArray(28,28,1);
|
||||||
|
assertNotNull(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMlnMnist_ImageInput() throws Exception {
|
||||||
|
|
||||||
|
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||||||
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||||
|
|
||||||
|
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||||||
|
.outputSerializer(new IntegerSerde())
|
||||||
|
.inputBinaryDeserializer(new BufferedImageSerde())
|
||||||
|
.inferenceAdapter(new InferenceAdapter<BufferedImage, Integer>() {
|
||||||
|
@Override
|
||||||
|
public MultiDataSet apply(BufferedImage input) {
|
||||||
|
INDArray data = null;
|
||||||
|
try {
|
||||||
|
data = new Java2DNativeImageLoader().asMatrix(input);
|
||||||
|
data = data.reshape(1, 784);
|
||||||
|
}
|
||||||
|
catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
return new MultiDataSet(data, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer apply(INDArray... nnOutput) {
|
||||||
|
return nnOutput[0].argMax().getInt(0);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.port(PORT)
|
||||||
|
.inferenceMode(SEQUENTIAL)
|
||||||
|
.numWorkers(1)
|
||||||
|
.parallelMode(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
val client = JsonRemoteInference.<BufferedImage, Integer>builder()
|
||||||
|
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||||
|
.inputBinarySerializer(new BufferedImageSerde())
|
||||||
|
.outputDeserializer(new IntegerSerde())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
try {
|
||||||
|
server.start();
|
||||||
|
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||||||
|
Integer result = client.predict(image);
|
||||||
|
assertNotNull(result);
|
||||||
|
|
||||||
|
File file = new ClassPathResource("datavec-local/imagetest/0/b.bmp").getFile();
|
||||||
|
image = ImageIO.read(new FileInputStream(file));
|
||||||
|
result = client.predict(image);
|
||||||
|
assertEquals(new Integer(0), result);
|
||||||
|
|
||||||
|
file = new ClassPathResource("datavec-local/imagetest/1/a.bmp").getFile();
|
||||||
|
image = ImageIO.read(new FileInputStream(file));
|
||||||
|
result = client.predict(image);
|
||||||
|
assertEquals(new Integer(1), result);
|
||||||
|
|
||||||
|
} catch (Exception e){
|
||||||
|
e.printStackTrace();
|
||||||
|
throw e;
|
||||||
|
} finally {
|
||||||
|
server.stop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMlnMnist_ImageInput_Async() throws Exception {
|
||||||
|
|
||||||
|
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||||||
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||||
|
|
||||||
|
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||||||
|
.outputSerializer(new IntegerSerde())
|
||||||
|
.inputBinaryDeserializer(new BufferedImageSerde())
|
||||||
|
.inferenceAdapter(new InferenceAdapter<BufferedImage, Integer>() {
|
||||||
|
@Override
|
||||||
|
public MultiDataSet apply(BufferedImage input) {
|
||||||
|
INDArray data = null;
|
||||||
|
try {
|
||||||
|
data = new Java2DNativeImageLoader().asMatrix(input);
|
||||||
|
data = data.reshape(1, 784);
|
||||||
|
}
|
||||||
|
catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
return new MultiDataSet(data, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer apply(INDArray... nnOutput) {
|
||||||
|
return nnOutput[0].argMax().getInt(0);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.port(PORT)
|
||||||
|
.inferenceMode(SEQUENTIAL)
|
||||||
|
.numWorkers(1)
|
||||||
|
.parallelMode(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
val client = JsonRemoteInference.<BufferedImage, Integer>builder()
|
||||||
|
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||||
|
.inputBinarySerializer(new BufferedImageSerde())
|
||||||
|
.outputDeserializer(new IntegerSerde())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
try {
|
||||||
|
server.start();
|
||||||
|
BufferedImage[] images = new BufferedImage[3];
|
||||||
|
images[0] = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||||||
|
|
||||||
|
File file = new ClassPathResource("datavec-local/imagetest/0/b.bmp").getFile();
|
||||||
|
images[1] = ImageIO.read(new FileInputStream(file));
|
||||||
|
|
||||||
|
file = new ClassPathResource("datavec-local/imagetest/1/a.bmp").getFile();
|
||||||
|
images[2] = ImageIO.read(new FileInputStream(file));
|
||||||
|
|
||||||
|
Future<Integer>[] results = new Future[3];
|
||||||
|
for (int i = 0; i < images.length; ++i) {
|
||||||
|
results[i] = client.predictAsync(images[i]);
|
||||||
|
assertNotNull(results[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertNotNull(results[0].get());
|
||||||
|
assertEquals(new Integer(0), results[1].get());
|
||||||
|
assertEquals(new Integer(1), results[2].get());
|
||||||
|
|
||||||
|
} catch (Exception e){
|
||||||
|
e.printStackTrace();
|
||||||
|
throw e;
|
||||||
|
} finally {
|
||||||
|
server.stop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBinaryIn_BinaryOut() throws Exception {
|
||||||
|
|
||||||
|
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||||||
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||||
|
|
||||||
|
val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net)
|
||||||
|
.outputBinarySerializer(new BufferedImageSerde())
|
||||||
|
.inputBinaryDeserializer(new BufferedImageSerde())
|
||||||
|
.inferenceAdapter(new InferenceAdapter<BufferedImage, BufferedImage>() {
|
||||||
|
@Override
|
||||||
|
public MultiDataSet apply(BufferedImage input) {
|
||||||
|
INDArray data = null;
|
||||||
|
try {
|
||||||
|
data = new Java2DNativeImageLoader().asMatrix(input);
|
||||||
|
}
|
||||||
|
catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
return new MultiDataSet(data, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BufferedImage apply(INDArray... nnOutput) {
|
||||||
|
return ImageConversionUtils.makeRandomBufferedImage(28,28,3);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.port(PORT)
|
||||||
|
.inferenceMode(SEQUENTIAL)
|
||||||
|
.numWorkers(1)
|
||||||
|
.parallelMode(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
val client = JsonRemoteInference.<BufferedImage, BufferedImage>builder()
|
||||||
|
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||||
|
.inputBinarySerializer(new BufferedImageSerde())
|
||||||
|
.outputBinaryDeserializer(new BufferedImageSerde())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
try {
|
||||||
|
server.start();
|
||||||
|
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||||||
|
BufferedImage result = client.predict(image);
|
||||||
|
assertNotNull(result);
|
||||||
|
assertEquals(28, result.getHeight());
|
||||||
|
assertEquals(28, result.getWidth());
|
||||||
|
|
||||||
|
} catch (Exception e){
|
||||||
|
e.printStackTrace();
|
||||||
|
throw e;
|
||||||
|
} finally {
|
||||||
|
server.stop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class BufferedImageSerde implements BinarySerializer<BufferedImage>, BinaryDeserializer<BufferedImage> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BufferedImage deserialize(byte[] buffer) {
|
||||||
|
try {
|
||||||
|
BufferedImage img = ImageIO.read(new ByteArrayInputStream(buffer));
|
||||||
|
return img;
|
||||||
|
} catch (IOException e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public byte[] serialize(BufferedImage image) {
|
||||||
|
try{
|
||||||
|
val baos = new ByteArrayOutputStream();
|
||||||
|
ImageIO.write(image, "bmp", baos);
|
||||||
|
byte[] bytes = baos.toByteArray();
|
||||||
|
return bytes;
|
||||||
|
} catch (IOException e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,82 @@
|
||||||
|
package org.deeplearning4j.remote.helpers;
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
|
import org.bytedeco.javacpp.indexer.UByteIndexer;
|
||||||
|
import org.bytedeco.javacv.Java2DFrameConverter;
|
||||||
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
|
import org.bytedeco.opencv.opencv_core.Mat;
|
||||||
|
import org.datavec.image.loader.Java2DNativeImageLoader;
|
||||||
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
import static org.bytedeco.opencv.global.opencv_core.CV_8UC;
|
||||||
|
|
||||||
|
public class ImageConversionUtils {
|
||||||
|
|
||||||
|
public static Mat makeRandomImage(int height, int width, int channels) {
|
||||||
|
if (height <= 0) {
|
||||||
|
|
||||||
|
height = new Random().nextInt() % 100 + 100;
|
||||||
|
}
|
||||||
|
if (width <= 0) {
|
||||||
|
width = new Random().nextInt() % 100 + 100;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat img = new Mat(height, width, CV_8UC(channels));
|
||||||
|
UByteIndexer idx = img.createIndexer();
|
||||||
|
for (int i = 0; i < height; i++) {
|
||||||
|
for (int j = 0; j < width; j++) {
|
||||||
|
for (int k = 0; k < channels; k++) {
|
||||||
|
idx.put(i, j, k, new Random().nextInt());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return img;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static BufferedImage makeRandomBufferedImage(int height, int width, int channels) {
|
||||||
|
Mat img = makeRandomImage(height, width, channels);
|
||||||
|
|
||||||
|
OpenCVFrameConverter.ToMat c = new OpenCVFrameConverter.ToMat();
|
||||||
|
Java2DFrameConverter c2 = new Java2DFrameConverter();
|
||||||
|
|
||||||
|
return c2.convert(c.convert(img));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static INDArray convert(BufferedImage image) {
|
||||||
|
INDArray retVal = null;
|
||||||
|
try {
|
||||||
|
retVal = new Java2DNativeImageLoader(image.getHeight(), image.getWidth(), image.getRaster().getNumBands()).
|
||||||
|
asRowVector(image);
|
||||||
|
}
|
||||||
|
catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
return retVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static INDArray convert(Mat image) {
|
||||||
|
INDArray retVal = null;
|
||||||
|
try {
|
||||||
|
new NativeImageLoader().asRowVector(image);
|
||||||
|
}
|
||||||
|
catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
return retVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static BufferedImage convert(INDArray input) {
|
||||||
|
return new Java2DNativeImageLoader(input.rows(),input.columns()).asBufferedImage(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static INDArray makeRandomImageAsINDArray(int height, int width, int channels) {
|
||||||
|
val image = makeRandomBufferedImage(height, width, channels);
|
||||||
|
INDArray retVal = convert(image);
|
||||||
|
return retVal;
|
||||||
|
}
|
||||||
|
}
|
|
@ -69,7 +69,7 @@
|
||||||
</developer>
|
</developer>
|
||||||
<developer>
|
<developer>
|
||||||
<id>raver119</id>
|
<id>raver119</id>
|
||||||
<name>raver119</name>
|
<name>Vyacheslav Kokorin</name>
|
||||||
</developer>
|
</developer>
|
||||||
<developer>
|
<developer>
|
||||||
<id>saudet</id>
|
<id>saudet</id>
|
||||||
|
|
|
@ -24,10 +24,12 @@ import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.json.JSONObject;
|
import org.json.JSONObject;
|
||||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
import org.nd4j.remote.clients.serde.*;
|
||||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
import java.util.concurrent.Future;
|
import java.util.concurrent.Future;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
@ -49,21 +51,57 @@ import java.util.concurrent.TimeoutException;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class JsonRemoteInference<I, O> {
|
public class JsonRemoteInference<I, O> {
|
||||||
private String endpointAddress;
|
private String endpointAddress;
|
||||||
|
// JSON serializer/deserializer and binary serializer/deserializer are mutually exclusive.
|
||||||
private JsonSerializer<I> serializer;
|
private JsonSerializer<I> serializer;
|
||||||
private JsonDeserializer<O> deserializer;
|
private JsonDeserializer<O> deserializer;
|
||||||
|
private BinarySerializer<I> binarySerializer;
|
||||||
|
private BinaryDeserializer<O> binaryDeserializer;
|
||||||
|
|
||||||
|
private final static String APPLICATION_JSON = "application/json";
|
||||||
|
private final static String APPLICATION_OCTET_STREAM = "application/octet-stream";
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public JsonRemoteInference(@NonNull String endpointAddress, @NonNull JsonSerializer<I> inputSerializer, @NonNull JsonDeserializer<O> outputDeserializer) {
|
public JsonRemoteInference(@NonNull String endpointAddress,
|
||||||
|
JsonSerializer<I> inputSerializer, JsonDeserializer<O> outputDeserializer,
|
||||||
|
BinarySerializer<I> inputBinarySerializer, BinaryDeserializer<O> outputBinaryDeserializer) {
|
||||||
|
|
||||||
this.endpointAddress = endpointAddress;
|
this.endpointAddress = endpointAddress;
|
||||||
this.serializer = inputSerializer;
|
this.serializer = inputSerializer;
|
||||||
this.deserializer = outputDeserializer;
|
this.deserializer = outputDeserializer;
|
||||||
|
this.binarySerializer = inputBinarySerializer;
|
||||||
|
this.binaryDeserializer = outputBinaryDeserializer;
|
||||||
|
|
||||||
|
if (serializer != null && binarySerializer != null || serializer == null && binarySerializer == null)
|
||||||
|
throw new IllegalStateException("Binary and JSON serializers/deserializers are mutually exclusive and mandatory.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private O processResponse(HttpResponse<String> response) throws IOException {
|
private O processResponse(HttpResponse<String> response) throws IOException {
|
||||||
if (response.getStatus() != 200)
|
if (response.getStatus() != 200)
|
||||||
throw new IOException("Inference request returned bad error code: " + response.getStatus());
|
throw new IOException("Inference request returned bad error code: " + response.getStatus());
|
||||||
|
|
||||||
O result = deserializer.deserialize(response.getBody());
|
O result = deserializer.deserialize(response.getBody());
|
||||||
|
|
||||||
|
if (result == null) {
|
||||||
|
throw new IOException("Deserialization failed!");
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private O processResponseBinary(HttpResponse<InputStream> response) throws IOException {
|
||||||
|
if (response.getStatus() != 200)
|
||||||
|
throw new IOException("Inference request returned bad error code: " + response.getStatus());
|
||||||
|
|
||||||
|
List<String> values = response.getHeaders().get("Content-Length");
|
||||||
|
if (values == null || values.size() < 1) {
|
||||||
|
throw new IOException("Content-Length is required for binary data");
|
||||||
|
}
|
||||||
|
|
||||||
|
String strLength = values.get(0);
|
||||||
|
byte[] bytes = new byte[Integer.parseInt(strLength)];
|
||||||
|
response.getBody().read(bytes);
|
||||||
|
O result = binaryDeserializer.deserialize(bytes);
|
||||||
|
|
||||||
if (result == null) {
|
if (result == null) {
|
||||||
throw new IOException("Deserialization failed!");
|
throw new IOException("Deserialization failed!");
|
||||||
}
|
}
|
||||||
|
@ -79,12 +117,30 @@ public class JsonRemoteInference<I, O> {
|
||||||
*/
|
*/
|
||||||
public O predict(I input) throws IOException {
|
public O predict(I input) throws IOException {
|
||||||
try {
|
try {
|
||||||
val stringResult = Unirest.post(endpointAddress)
|
if (binarySerializer != null && binaryDeserializer != null) {
|
||||||
.header("Content-Type", "application/json")
|
HttpResponse<InputStream> response =
|
||||||
.header("Accept", "application/json")
|
Unirest.post(endpointAddress)
|
||||||
.body(new JSONObject(serializer.serialize(input))).asString();
|
.header("Content-Type", APPLICATION_OCTET_STREAM)
|
||||||
|
.header("Accept", APPLICATION_OCTET_STREAM)
|
||||||
|
.body(binarySerializer.serialize(input)).asBinary();
|
||||||
|
return processResponseBinary(response);
|
||||||
|
}
|
||||||
|
else if (binarySerializer != null && binaryDeserializer == null) {
|
||||||
|
HttpResponse<String> response =
|
||||||
|
Unirest.post(endpointAddress)
|
||||||
|
.header("Content-Type", APPLICATION_OCTET_STREAM)
|
||||||
|
.header("Accept", APPLICATION_OCTET_STREAM)
|
||||||
|
.body(binarySerializer.serialize(input)).asString();
|
||||||
|
return processResponse(response);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
HttpResponse<String> response = Unirest.post(endpointAddress)
|
||||||
|
.header("Content-Type", APPLICATION_JSON)
|
||||||
|
.header("Accept", APPLICATION_JSON)
|
||||||
|
.body(new JSONObject(serializer.serialize(input))).asString();
|
||||||
|
return processResponse(response);
|
||||||
|
}
|
||||||
|
|
||||||
return processResponse(stringResult);
|
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
throw new IOException(e);
|
throw new IOException(e);
|
||||||
}
|
}
|
||||||
|
@ -96,11 +152,19 @@ public class JsonRemoteInference<I, O> {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Future<O> predictAsync(I input) {
|
public Future<O> predictAsync(I input) {
|
||||||
val stringResult = Unirest.post(endpointAddress)
|
|
||||||
|
Future<HttpResponse<String>> response = binarySerializer != null ?
|
||||||
|
Unirest.post(endpointAddress)
|
||||||
|
.header("Content-Type", "application/octet-stream")
|
||||||
|
.header("Accept", "application/octet-stream")
|
||||||
|
.body(binarySerializer.serialize(input)).asStringAsync() :
|
||||||
|
|
||||||
|
Unirest.post(endpointAddress)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Accept", "application/json")
|
.header("Accept", "application/json")
|
||||||
.body(new JSONObject(serializer.serialize(input))).asStringAsync();
|
.body(new JSONObject(serializer.serialize(input))).asStringAsync();
|
||||||
return new InferenceFuture(stringResult);
|
|
||||||
|
return new InferenceFuture(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -151,3 +215,4 @@ public class JsonRemoteInference<I, O> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.remote.clients.serde;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This interface describes basic binary deserializer interface used for remote inference
|
||||||
|
* @param <T> type of the deserializable class
|
||||||
|
*
|
||||||
|
* @author Alexander Stoyakin
|
||||||
|
*/
|
||||||
|
public interface BinaryDeserializer<T> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method deserializes binary data to arbitrary object.
|
||||||
|
* @param byte buffer
|
||||||
|
* @return deserialized object
|
||||||
|
*/
|
||||||
|
T deserialize(byte[] buffer);
|
||||||
|
}
|
|
@ -0,0 +1,34 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.remote.clients.serde;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This interface describes basic binary serializer interface used for remote inference
|
||||||
|
* @param <T> type of the serializable class
|
||||||
|
*
|
||||||
|
* @author Alexander Stoyakin
|
||||||
|
*/
|
||||||
|
public interface BinarySerializer<T> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method serializes given object into byte buffer
|
||||||
|
*
|
||||||
|
* @param o object to be serialized
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
byte[] serialize(T o);
|
||||||
|
}
|
|
@ -26,6 +26,8 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
import org.nd4j.remote.clients.serde.BinaryDeserializer;
|
||||||
|
import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||||
import org.nd4j.adapters.InferenceAdapter;
|
import org.nd4j.adapters.InferenceAdapter;
|
||||||
|
@ -51,6 +53,8 @@ public class SameDiffJsonModelServer<I, O> {
|
||||||
protected SameDiff sdModel;
|
protected SameDiff sdModel;
|
||||||
protected final JsonSerializer<O> serializer;
|
protected final JsonSerializer<O> serializer;
|
||||||
protected final JsonDeserializer<I> deserializer;
|
protected final JsonDeserializer<I> deserializer;
|
||||||
|
protected final BinarySerializer<O> binarySerializer;
|
||||||
|
protected final BinaryDeserializer<I> binaryDeserializer;
|
||||||
protected final InferenceAdapter<I, O> inferenceAdapter;
|
protected final InferenceAdapter<I, O> inferenceAdapter;
|
||||||
protected final int port;
|
protected final int port;
|
||||||
|
|
||||||
|
@ -64,9 +68,18 @@ public class SameDiffJsonModelServer<I, O> {
|
||||||
protected String[] orderedInputNodes;
|
protected String[] orderedInputNodes;
|
||||||
protected String[] orderedOutputNodes;
|
protected String[] orderedOutputNodes;
|
||||||
|
|
||||||
protected SameDiffJsonModelServer(@NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer, int port) {
|
protected SameDiffJsonModelServer(@NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||||
|
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||||
|
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||||
|
int port) {
|
||||||
Preconditions.checkArgument(port > 0 && port < 65535, "TCP port must be in range of 0..65535");
|
Preconditions.checkArgument(port > 0 && port < 65535, "TCP port must be in range of 0..65535");
|
||||||
|
Preconditions.checkArgument(serializer == null && binarySerializer == null ||
|
||||||
|
serializer != null && binarySerializer == null ||
|
||||||
|
serializer == null && binarySerializer != null,
|
||||||
|
"JSON and binary serializers/deserializers are mutually exclusive and mandatory.");
|
||||||
|
|
||||||
|
this.binarySerializer = binarySerializer;
|
||||||
|
this.binaryDeserializer = binaryDeserializer;
|
||||||
this.inferenceAdapter = inferenceAdapter;
|
this.inferenceAdapter = inferenceAdapter;
|
||||||
this.serializer = serializer;
|
this.serializer = serializer;
|
||||||
this.deserializer = deserializer;
|
this.deserializer = deserializer;
|
||||||
|
@ -74,8 +87,11 @@ public class SameDiffJsonModelServer<I, O> {
|
||||||
}
|
}
|
||||||
|
|
||||||
//@Builder
|
//@Builder
|
||||||
public SameDiffJsonModelServer(SameDiff sdModel, @NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, @NonNull String[] orderedOutputNodes) {
|
public SameDiffJsonModelServer(SameDiff sdModel, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||||
this(inferenceAdapter, serializer, deserializer, port);
|
JsonSerializer<O> serializer, JsonDeserializer<I> deserializer,
|
||||||
|
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer,
|
||||||
|
int port, String[] orderedInputNodes, @NonNull String[] orderedOutputNodes) {
|
||||||
|
this(inferenceAdapter, serializer, deserializer, binarySerializer, binaryDeserializer, port);
|
||||||
this.sdModel = sdModel;
|
this.sdModel = sdModel;
|
||||||
this.orderedInputNodes = orderedInputNodes;
|
this.orderedInputNodes = orderedInputNodes;
|
||||||
this.orderedOutputNodes = orderedOutputNodes;
|
this.orderedOutputNodes = orderedOutputNodes;
|
||||||
|
@ -282,7 +298,7 @@ public class SameDiffJsonModelServer<I, O> {
|
||||||
} else
|
} else
|
||||||
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
|
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
|
||||||
}
|
}
|
||||||
return new SameDiffJsonModelServer<I,O>(sameDiff, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
|
return new SameDiffJsonModelServer<I,O>(sameDiff, inferenceAdapter, serializer, deserializer, null, null, port, orderedInputNodes, orderedOutputNodes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,8 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.remote.clients.serde.BinaryDeserializer;
|
||||||
|
import org.nd4j.remote.clients.serde.BinarySerializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||||
import org.nd4j.adapters.InferenceAdapter;
|
import org.nd4j.adapters.InferenceAdapter;
|
||||||
|
@ -35,6 +37,7 @@ import java.io.InputStreamReader;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
|
|
||||||
import static javax.ws.rs.core.MediaType.APPLICATION_JSON;
|
import static javax.ws.rs.core.MediaType.APPLICATION_JSON;
|
||||||
|
import static javax.ws.rs.core.MediaType.APPLICATION_OCTET_STREAM;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This servlet provides SameDiff model serving capabilities
|
* This servlet provides SameDiff model serving capabilities
|
||||||
|
@ -50,9 +53,14 @@ import static javax.ws.rs.core.MediaType.APPLICATION_JSON;
|
||||||
@Builder
|
@Builder
|
||||||
public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
||||||
|
|
||||||
|
protected static final String typeJson = APPLICATION_JSON;
|
||||||
|
protected static final String typeBinary = APPLICATION_OCTET_STREAM;
|
||||||
|
|
||||||
protected SameDiff sdModel;
|
protected SameDiff sdModel;
|
||||||
protected JsonSerializer<O> serializer;
|
protected JsonSerializer<O> serializer;
|
||||||
protected JsonDeserializer<I> deserializer;
|
protected JsonDeserializer<I> deserializer;
|
||||||
|
protected BinarySerializer<O> binarySerializer;
|
||||||
|
protected BinaryDeserializer<I> binaryDeserializer;
|
||||||
protected InferenceAdapter<I, O> inferenceAdapter;
|
protected InferenceAdapter<I, O> inferenceAdapter;
|
||||||
|
|
||||||
protected String[] orderedInputNodes;
|
protected String[] orderedInputNodes;
|
||||||
|
@ -60,14 +68,36 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
||||||
|
|
||||||
protected final static String SERVING_ENDPOINT = "/v1/serving";
|
protected final static String SERVING_ENDPOINT = "/v1/serving";
|
||||||
protected final static String LISTING_ENDPOINT = "/v1";
|
protected final static String LISTING_ENDPOINT = "/v1";
|
||||||
protected final static int PAYLOAD_SIZE_LIMIT = 10 * 1024; // TODO: should be customizable
|
protected final static int PAYLOAD_SIZE_LIMIT = 10 * 1024 * 1024; // TODO: should be customizable
|
||||||
|
|
||||||
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer){
|
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer){
|
||||||
this.serializer = serializer;
|
this.serializer = serializer;
|
||||||
this.deserializer = deserializer;
|
this.deserializer = deserializer;
|
||||||
this.inferenceAdapter = inferenceAdapter;
|
this.inferenceAdapter = inferenceAdapter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||||
|
BinarySerializer<O> serializer, BinaryDeserializer<I> deserializer){
|
||||||
|
this.binarySerializer = serializer;
|
||||||
|
this.binaryDeserializer = deserializer;
|
||||||
|
this.inferenceAdapter = inferenceAdapter;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||||
|
JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer,
|
||||||
|
BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer){
|
||||||
|
|
||||||
|
this.serializer = jsonSerializer;
|
||||||
|
this.deserializer = jsonDeserializer;
|
||||||
|
this.binarySerializer = binarySerializer;
|
||||||
|
this.binaryDeserializer = binaryDeserializer;
|
||||||
|
this.inferenceAdapter = inferenceAdapter;
|
||||||
|
|
||||||
|
if (serializer != null && binarySerializer != null || serializer == null && binarySerializer == null)
|
||||||
|
throw new IllegalStateException("Binary and JSON serializers/deserializers are mutually exclusive and mandatory.");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void init(ServletConfig servletConfig) throws ServletException {
|
public void init(ServletConfig servletConfig) throws ServletException {
|
||||||
//
|
//
|
||||||
|
@ -108,7 +138,7 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
||||||
protected boolean validateRequest(HttpServletRequest request, HttpServletResponse response)
|
protected boolean validateRequest(HttpServletRequest request, HttpServletResponse response)
|
||||||
throws IOException{
|
throws IOException{
|
||||||
val contentType = request.getContentType();
|
val contentType = request.getContentType();
|
||||||
if (!StringUtils.equals(contentType, APPLICATION_JSON)) {
|
if (!StringUtils.equals(contentType, typeJson)) {
|
||||||
sendBadContentType(contentType, response);
|
sendBadContentType(contentType, response);
|
||||||
int contentLength = request.getContentLength();
|
int contentLength = request.getContentLength();
|
||||||
if (contentLength > PAYLOAD_SIZE_LIMIT) {
|
if (contentLength > PAYLOAD_SIZE_LIMIT) {
|
||||||
|
@ -125,7 +155,7 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
||||||
String path = request.getPathInfo();
|
String path = request.getPathInfo();
|
||||||
if (path.equals(LISTING_ENDPOINT)) {
|
if (path.equals(LISTING_ENDPOINT)) {
|
||||||
val contentType = request.getContentType();
|
val contentType = request.getContentType();
|
||||||
if (!StringUtils.equals(contentType, APPLICATION_JSON)) {
|
if (!StringUtils.equals(contentType, typeJson)) {
|
||||||
sendBadContentType(contentType, response);
|
sendBadContentType(contentType, response);
|
||||||
}
|
}
|
||||||
processorReturned = processor.listEndpoints();
|
processorReturned = processor.listEndpoints();
|
||||||
|
@ -147,7 +177,7 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
||||||
String path = request.getPathInfo();
|
String path = request.getPathInfo();
|
||||||
if (path.equals(SERVING_ENDPOINT)) {
|
if (path.equals(SERVING_ENDPOINT)) {
|
||||||
val contentType = request.getContentType();
|
val contentType = request.getContentType();
|
||||||
/*Preconditions.checkArgument(StringUtils.equals(contentType, APPLICATION_JSON),
|
/*Preconditions.checkArgument(StringUtils.equals(contentType, typeJson),
|
||||||
"Content type is " + contentType);*/
|
"Content type is " + contentType);*/
|
||||||
if (validateRequest(request,response)) {
|
if (validateRequest(request,response)) {
|
||||||
val stream = request.getInputStream();
|
val stream = request.getInputStream();
|
||||||
|
|
Loading…
Reference in New Issue