diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java index cf7d11838..e84488f9d 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java @@ -16,6 +16,7 @@ package org.datavec.spark.transform.client; + import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.Unirest; import com.mashape.unirest.http.exceptions.UnirestException; diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml index 1e525fc3b..2cc01e288 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml @@ -51,11 +51,13 @@ datavec-spark-inference-model ${datavec.version} + org.datavec datavec-spark_2.11 ${project.version} + org.datavec datavec-data-image @@ -67,61 +69,73 @@ akka-cluster_2.11 ${akka.version} + joda-time joda-time ${jodatime.version} + org.apache.commons commons-lang3 ${commons-lang3.version} + org.hibernate hibernate-validator ${hibernate.version} + org.scala-lang scala-library ${scala.version} + org.scala-lang scala-reflect ${scala.version} + org.yaml snakeyaml ${snakeyaml.version} + com.fasterxml.jackson.core jackson-core ${jackson.version} + com.fasterxml.jackson.core jackson-databind ${jackson.version} + com.fasterxml.jackson.core jackson-annotations ${jackson.version} + com.fasterxml.jackson.datatype jackson-datatype-jdk8 ${jackson.version} + com.fasterxml.jackson.datatype jackson-datatype-jsr310 ${jackson.version} + com.typesafe.play play-java_2.11 @@ -137,39 +151,44 @@ + net.jodah typetools ${jodah.typetools.version} + com.typesafe.play play-json_2.11 ${play.version} + com.typesafe.play play-server_2.11 ${play.version} + com.typesafe.play play_2.11 ${play.version} + com.typesafe.play play-netty-server_2.11 ${play.version} - com.mashape.unirest unirest-java ${unirest.version} test + com.beust jcommander diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java index d3d4aee30..e524f00a7 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java @@ -52,6 +52,7 @@ public class CSVSparkTransformServerNoJsonTest { public static void before() throws Exception { server = new CSVSparkTransformServer(); FileUtils.write(fileSave, transformProcess.toJson()); + // Only one time Unirest.setObjectMapper(new ObjectMapper() { private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = @@ -73,6 +74,7 @@ public class CSVSparkTransformServerNoJsonTest { } } }); + server.runMain(new String[] {"-dp", "9050"}); } diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java index 78021d3e5..a9e24d78e 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java @@ -16,6 +16,7 @@ package org.datavec.spark.transform; + import com.mashape.unirest.http.JsonNode; import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.Unirest; @@ -49,6 +50,7 @@ public class CSVSparkTransformServerTest { server = new CSVSparkTransformServer(); FileUtils.write(fileSave, transformProcess.toJson()); // Only one time + Unirest.setObjectMapper(new ObjectMapper() { private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = new org.nd4j.shade.jackson.databind.ObjectMapper(); @@ -69,6 +71,7 @@ public class CSVSparkTransformServerTest { } } }); + server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"}); } diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java index 62dec47c1..bfae23358 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java @@ -16,6 +16,7 @@ package org.datavec.spark.transform; + import com.mashape.unirest.http.JsonNode; import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.Unirest; diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java index a0561563b..e08d881ef 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java @@ -16,6 +16,7 @@ package org.datavec.spark.transform; + import com.mashape.unirest.http.JsonNode; import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.Unirest; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java index eaf967325..4c14f4c3d 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java @@ -19,10 +19,10 @@ package org.deeplearning4j.nearestneighbor.client; import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.Unirest; import com.mashape.unirest.request.HttpRequest; -import com.mashape.unirest.request.HttpRequestWithBody; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.Setter; +import lombok.val; import org.deeplearning4j.nearestneighbor.model.*; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.serde.base64.Nd4jBase64; @@ -51,6 +51,7 @@ public class NearestNeighborsClient { static { // Only one time + Unirest.setObjectMapper(new ObjectMapper() { private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = new org.nd4j.shade.jackson.databind.ObjectMapper(); @@ -89,7 +90,7 @@ public class NearestNeighborsClient { NearestNeighborRequest request = new NearestNeighborRequest(); request.setInputIndex(index); request.setK(k); - HttpRequestWithBody req = Unirest.post(url + "/knn"); + val req = Unirest.post(url + "/knn"); req.header("accept", "application/json") .header("Content-Type", "application/json").body(request); addAuthHeader(req); @@ -112,7 +113,7 @@ public class NearestNeighborsClient { Base64NDArrayBody base64NDArrayBody = Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build(); - HttpRequestWithBody req = Unirest.post(url + "/knnnew"); + val req = Unirest.post(url + "/knnnew"); req.header("accept", "application/json") .header("Content-Type", "application/json").body(base64NDArrayBody); addAuthHeader(req); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java index 9fcd76981..5313edd22 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java @@ -19,7 +19,6 @@ package org.deeplearning4j.models.word2vec; import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.gson.JsonParser; -import jdk.nashorn.internal.objects.annotations.Property; import lombok.Getter; import lombok.NonNull; import org.apache.commons.compress.compressors.gzip.GzipUtils; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java index b1e53e90a..74ffd3fe8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.adapters; import lombok.val; -import org.deeplearning4j.nn.api.OutputAdapter; +import org.nd4j.adapters.OutputAdapter; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java index b67767050..f0c83669b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.adapters; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.deeplearning4j.nn.api.OutputAdapter; +import org.nd4j.adapters.OutputAdapter; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java index ae7e4b5c1..6b99a92d4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.api; +import org.nd4j.adapters.OutputAdapter; import org.nd4j.linalg.api.ndarray.INDArray; /** diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index cf3fec70c..99f8aeff0 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -24,6 +24,7 @@ import lombok.val; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.bytedeco.javacpp.Pointer; +import org.nd4j.adapters.OutputAdapter; import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.exception.DL4JException; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 719a727f2..731ca398b 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -25,6 +25,7 @@ import lombok.val; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.bytedeco.javacpp.Pointer; +import org.nd4j.adapters.OutputAdapter; import org.nd4j.linalg.dataset.AsyncDataSetIterator;; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.eval.RegressionEvaluation; diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml new file mode 100644 index 000000000..e2d5f901f --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml @@ -0,0 +1,110 @@ + + + + 4.0.0 + jar + + + org.deeplearning4j + deeplearning4j-remote + 1.0.0-SNAPSHOT + + + deeplearning4j-json-server + 1.0.0-SNAPSHOT + deeplearning4j-json-server + + + + junit + junit + ${junit.version} + test + + + + org.projectlombok + lombok + ${lombok.version} + provided + + + + org.nd4j + nd4j-api + ${project.version} + + + + org.nd4j + nd4j-json-client + ${project.version} + + + + org.nd4j + nd4j-json-server + ${project.version} + + + + org.deeplearning4j + deeplearning4j-parallel-wrapper + ${project.version} + + + + org.slf4j + slf4j-api + ${slf4j.version} + + + + ch.qos.logback + logback-core + ${logback.version} + test + + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + + + + test-nd4j-native + + true + + + + org.nd4j + nd4j-native + ${project.version} + test + + + + + + test-nd4j-cuda-10.1 + + false + + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + + + + diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java new file mode 100644 index 000000000..796a51ad0 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java @@ -0,0 +1,205 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote; + +import lombok.*; +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.Model; +import org.deeplearning4j.nn.api.NeuralNetwork; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.parallelism.ParallelInference; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.remote.serving.SameDiffServlet; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; + +/** + * + * @author astoyakin + */ +@Slf4j +@NoArgsConstructor +public class DL4jServlet extends SameDiffServlet { + + protected ParallelInference parallelInference; + protected Model model; + protected boolean parallelEnabled = true; + + public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter inferenceAdapter, + @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer) { + super(inferenceAdapter, serializer, deserializer); + this.parallelInference = parallelInference; + this.model = null; + this.parallelEnabled = true; + } + + public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter inferenceAdapter, + @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer) { + super(inferenceAdapter, serializer, deserializer); + this.model = model; + this.parallelInference = null; + this.parallelEnabled = false; + } + + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException { + String processorReturned = ""; + String path = request.getPathInfo(); + if (path.equals(SERVING_ENDPOINT)) { + val contentType = request.getContentType(); + if (validateRequest(request,response)) { + val stream = request.getInputStream(); + val bufferedReader = new BufferedReader(new InputStreamReader(stream)); + char[] charBuffer = new char[128]; + int bytesRead = -1; + val buffer = new StringBuilder(); + while ((bytesRead = bufferedReader.read(charBuffer)) > 0) { + buffer.append(charBuffer, 0, bytesRead); + } + val requestString = buffer.toString(); + + val mds = inferenceAdapter.apply(deserializer.deserialize(requestString)); + + O result = null; + if (parallelEnabled) { + // process result + result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays())); + } + else { + synchronized(this) { + if (model instanceof ComputationGraph) + result = inferenceAdapter.apply(((ComputationGraph)model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays())); + else if (model instanceof MultiLayerNetwork) { + Preconditions.checkArgument(mds.getFeatures().length > 1 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 1), + "Input data for MultilayerNetwork is invalid!"); + result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false, + mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null)); + } + } + } + processorReturned = serializer.serialize(result); + } + } else { + // we return error otherwise + sendError(request.getRequestURI(), response); + } + try { + val out = response.getWriter(); + out.write(processorReturned); + } catch (IOException e) { + log.error(e.getMessage()); + } + } + + /** + * Creates servlet to serve models + * + * @param type of Input class + * @param type of Output class + * + * @author raver119@gmail.com + * @author astoyakin + */ + public static class Builder { + + private ParallelInference pi; + private Model model; + + private InferenceAdapter inferenceAdapter; + private JsonSerializer serializer; + private JsonDeserializer deserializer; + private int port; + private boolean parallelEnabled = true; + + public Builder(@NonNull ParallelInference pi) { + this.pi = pi; + } + + public Builder(@NonNull Model model) { + this.model = model; + } + + public Builder inferenceAdapter(@NonNull InferenceAdapter inferenceAdapter) { + this.inferenceAdapter = inferenceAdapter; + return this; + } + + /** + * This method is required to specify serializer + * + * @param serializer + * @return + */ + public Builder serializer(@NonNull JsonSerializer serializer) { + this.serializer = serializer; + return this; + } + + /** + * This method allows to specify deserializer + * + * @param deserializer + * @return + */ + public Builder deserializer(@NonNull JsonDeserializer deserializer) { + this.deserializer = deserializer; + return this; + } + + /** + * This method allows to specify port + * + * @param port + * @return + */ + public Builder port(int port) { + this.port = port; + return this; + } + + /** + * This method activates parallel inference + * + * @param parallelEnabled + * @return + */ + public Builder parallelEnabled(boolean parallelEnabled) { + this.parallelEnabled = parallelEnabled; + return this; + } + + public DL4jServlet build() { + return parallelEnabled ? new DL4jServlet(pi, inferenceAdapter, serializer, deserializer) : + new DL4jServlet(model, inferenceAdapter, serializer, deserializer); + } + } +} + + + + diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java new file mode 100644 index 000000000..030231f6e --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java @@ -0,0 +1,392 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote; + +import lombok.NonNull; +import lombok.val; +import org.deeplearning4j.nn.api.Model; +import org.deeplearning4j.nn.api.ModelAdapter; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.parallelism.ParallelInference; +import org.deeplearning4j.parallelism.inference.InferenceMode; +import org.deeplearning4j.parallelism.inference.LoadBalanceMode; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.adapters.InputAdapter; +import org.nd4j.adapters.OutputAdapter; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.remote.SameDiffJsonModelServer; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + + +import java.util.List; + +/** + * This class provides JSON-based model serving ability for Deeplearning4j/SameDiff models + * + * Server url will be http://0.0.0.0:{port}>/v1/serving + * Server only accepts POST requests + * + * @param type of the input class, i.e. String + * @param type of the output class, i.e. Sentiment + * + * @author raver119@gmail.com + * @author astoyakin + */ +public class JsonModelServer extends SameDiffJsonModelServer { + + // all serving goes through ParallelInference + protected ParallelInference parallelInference; + + + protected ModelAdapter modelAdapter; + + // actual models + protected ComputationGraph cgModel; + protected MultiLayerNetwork mlnModel; + + // service stuff + protected InferenceMode inferenceMode; + protected int numWorkers; + + protected boolean enabledParallel = true; + + protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port, String[] orderedInputNodes, String[] orderedOutputNodes) { + super(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes); + } + + protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) { + super(inferenceAdapter, serializer, deserializer, port); + + this.cgModel = cgModel; + this.inferenceMode = inferenceMode; + this.numWorkers = numWorkers; + } + + protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) { + super(inferenceAdapter, serializer, deserializer, port); + + this.mlnModel = mlnModel; + this.inferenceMode = inferenceMode; + this.numWorkers = numWorkers; + } + + protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port) { + super(inferenceAdapter, serializer, deserializer, port); + + this.parallelInference = pi; + } + + /** + * This method stops server + * + * @throws Exception + */ + @Override + public void stop() throws Exception { + if (parallelInference != null) + parallelInference.shutdown(); + super.stop(); + } + + /** + * This method starts server + * @throws Exception + */ + @Override + public void start() throws Exception { + // if we're just serving sdModel - we'll just call super. no dl4j functionality required in this case + if (sdModel != null) { + super.start(); + return; + } + Preconditions.checkArgument(cgModel != null || mlnModel != null, "Model serving requires either MultilayerNetwork or ComputationGraph defined"); + + val model = cgModel != null ? (Model) cgModel : (Model) mlnModel; + // PI construction is optional, since we can have it defined + if (enabledParallel) { + if (parallelInference == null) { + Preconditions.checkArgument(numWorkers >= 1, "Number of workers should be >= 1, got " + numWorkers + " instead"); + + parallelInference = new ParallelInference.Builder(model) + .inferenceMode(inferenceMode) + .workers(numWorkers) + .loadBalanceMode(LoadBalanceMode.FIFO) + .batchLimit(16) + .queueLimit(128) + .build(); + } + servingServlet = new DL4jServlet.Builder(parallelInference) + .parallelEnabled(true) + .serializer(serializer) + .deserializer(deserializer) + .inferenceAdapter(inferenceAdapter) + .build(); + } + else { + servingServlet = new DL4jServlet.Builder(model) + .parallelEnabled(false) + .serializer(serializer) + .deserializer(deserializer) + .inferenceAdapter(inferenceAdapter) + .build(); + } + start(port, servingServlet); + } + + /** + * Creates servlet to serve different types of models + * + * @param type of Input class + * @param type of Output class + * + * @author raver119@gmail.com + * @author astoyakin + */ + public static class Builder { + + private SameDiff sdModel; + private ComputationGraph cgModel; + private MultiLayerNetwork mlnModel; + private ParallelInference pi; + + private String[] orderedInputNodes; + private String[] orderedOutputNodes; + + private InferenceAdapter inferenceAdapter; + private JsonSerializer serializer; + private JsonDeserializer deserializer; + + private InputAdapter inputAdapter; + private OutputAdapter outputAdapter; + + private int port; + + private boolean parallelMode = true; + + // these fields actually require defaults + private InferenceMode inferenceMode = InferenceMode.BATCHED; + private int numWorkers = Nd4j.getAffinityManager().getNumberOfDevices(); + + public Builder(@NonNull SameDiff sdModel) { + this.sdModel = sdModel; + } + + public Builder(@NonNull MultiLayerNetwork mlnModel) { + this.mlnModel = mlnModel; + } + + public Builder(@NonNull ComputationGraph cgModel) { + this.cgModel = cgModel; + } + + public Builder(@NonNull ParallelInference pi) { + this.pi = pi; + } + + /** + * This method defines InferenceAdapter implementation, which will be used to convert object of Input type to the set of INDArray(s), and for conversion of resulting INDArray(s) into object of Output type + * @param inferenceAdapter + * @return + */ + public Builder inferenceAdapter(@NonNull InferenceAdapter inferenceAdapter) { + this.inferenceAdapter = inferenceAdapter; + return this; + } + + /** + * This method allows you to specify InputAdapter to be used for inference + * + * PLEASE NOTE: This method is optional, and will require OutputAdapter defined + * @param inputAdapter + * @return + */ + public Builder inputAdapter(@NonNull InputAdapter inputAdapter) { + this.inputAdapter = inputAdapter; + return this; + } + + /** + * This method allows you to specify OutputtAdapter to be used for inference + * + * PLEASE NOTE: This method is optional, and will require InputAdapter defined + * @param outputAdapter + * @return + */ + public Builder outputAdapter(@NonNull OutputAdapter outputAdapter) { + this.outputAdapter = outputAdapter; + return this; + } + + /** + * This method allows you to specify serializer + * + * @param serializer + * @return + */ + public Builder outputSerializer(@NonNull JsonSerializer serializer) { + this.serializer = serializer; + return this; + } + + /** + * This method allows you to specify deserializer + * + * @param deserializer + * @return + */ + public Builder inputDeserializer(@NonNull JsonDeserializer deserializer) { + this.deserializer = deserializer; + return this; + } + + /** + * This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details + * + * @param inferenceMode + * @return + */ + public Builder inferenceMode(@NonNull InferenceMode inferenceMode) { + this.inferenceMode = inferenceMode; + return this; + } + + /** + * This method allows you to specify number of worker threads for ParallelInference + * + * @param numWorkers + * @return + */ + public Builder numWorkers(int numWorkers) { + this.numWorkers = numWorkers; + return this; + } + + /** + * This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models + * + * PLEASE NOTE: this argument only used for SameDiff models + * @param args + * @return + */ + public Builder orderedInputNodes(String... args) { + orderedInputNodes = args; + return this; + } + + /** + * This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models + * + * PLEASE NOTE: this argument only used for SameDiff models + * @param args + * @return + */ + public Builder orderedInputNodes(@NonNull List args) { + orderedInputNodes = args.toArray(new String[args.size()]); + return this; + } + + /** + * This method allows you to specify output nodes + * + * PLEASE NOTE: this argument only used for SameDiff models + * @param args + * @return + */ + public Builder orderedOutputNodes(String... args) { + Preconditions.checkArgument(args != null && args.length > 0, "OutputNodes should contain at least 1 element"); + orderedOutputNodes = args; + return this; + } + + /** + * This method allows you to specify output nodes + * + * PLEASE NOTE: this argument only used for SameDiff models + * @param args + * @return + */ + public Builder orderedOutputNodes(@NonNull List args) { + Preconditions.checkArgument(args.size() > 0, "OutputNodes should contain at least 1 element"); + orderedOutputNodes = args.toArray(new String[args.size()]); + return this; + } + + /** + * This method allows you to specify http port + * + * PLEASE NOTE: port must be free and be in range regular TCP/IP ports range + * @param port + * @return + */ + public Builder port(int port) { + this.port = port; + return this; + } + + /** + * This method switches on ParallelInference usage + * @param - true - to use ParallelInference, false - to use ComputationGraph or + * MultiLayerNetwork directly + * + * PLEASE NOTE: this doesn't apply to SameDiff models + * + * @throws Exception + */ + public Builder parallelMode(boolean enable) { + this.parallelMode = enable; + return this; + } + + public JsonModelServer build() { + if (inferenceAdapter == null) { + if (inputAdapter != null && outputAdapter != null) { + inferenceAdapter = new InferenceAdapter() { + @Override + public MultiDataSet apply(I input) { + return inputAdapter.apply(input); + } + + @Override + public O apply(INDArray... outputs) { + return outputAdapter.apply(outputs); + } + }; + } else + throw new IllegalArgumentException("Either InferenceAdapter or InputAdapter + OutputAdapter should be configured"); + } + + if (sdModel != null) { + Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "For SameDiff model serving OutputNodes should be defined"); + return new JsonModelServer(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes); + } else if (cgModel != null) + return new JsonModelServer(cgModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers); + else if (mlnModel != null) + return new JsonModelServer(mlnModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers); + else if (pi != null) + return new JsonModelServer(pi, inferenceAdapter, serializer, deserializer, port); + else + throw new IllegalStateException("No models were defined for JsonModelServer"); + } + } + +} diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java new file mode 100644 index 000000000..aa353f307 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java @@ -0,0 +1,749 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.graph.MergeVertex; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.parallelism.inference.InferenceMode; +import org.deeplearning4j.remote.helpers.House; +import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter; +import org.deeplearning4j.remote.helpers.PredictedPrice; +import org.junit.After; +import org.junit.Test; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.Sgd; +import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.nd4j.remote.clients.JsonRemoteInference; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.shade.jackson.databind.ObjectMapper; + + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE; +import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL; +import static org.junit.Assert.*; + +@Slf4j +public class JsonModelServerTest { + private static final MultiLayerNetwork model; + private final int PORT = 18080; + + static { + val conf = new NeuralNetConfiguration.Builder() + .seed(119) + .updater(new Adam(0.119f)) + .weightInit(WeightInit.XAVIER) + .list() + .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(10).build()) + .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.SIGMOID).nIn(10).nOut(1).build()) + .build(); + + model = new MultiLayerNetwork(conf); + model.init(); + } + + @After + public void pause() throws Exception { + // TODO: the same port was used in previous test and not accessible immediately. Might be better solution. + TimeUnit.SECONDS.sleep(2); + } + + + @Test + public void testStartStopParallel() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 1,4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val serverDL = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .numWorkers(1) + .inferenceMode(SEQUENTIAL) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .port(PORT) + .build(); + + val serverSD = new JsonModelServer.Builder(sd) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .port(PORT+1) + .build(); + try { + serverDL.start(); + serverSD.start(); + + val clientDL = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + PredictedPrice price = clientDL.predict(house); + long timeStart = System.currentTimeMillis(); + price = clientDL.predict(house); + long timeStop = System.currentTimeMillis(); + log.info("Time spent: {} ms", timeStop - timeStart); + assertNotNull(price); + assertEquals((float) 0.421444, price.getPrice(), 1e-5); + + val clientSD = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving") + .build(); + + PredictedPrice price2 = clientSD.predict(house); + timeStart = System.currentTimeMillis(); + price = clientSD.predict(house); + timeStop = System.currentTimeMillis(); + log.info("Time spent: {} ms", timeStop - timeStart); + assertNotNull(price); + assertEquals((float) 3.0, price.getPrice(), 1e-5); + + } + finally { + serverSD.stop(); + serverDL.stop(); + } + } + + @Test + public void testStartStopSequential() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 1,4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val serverDL = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .numWorkers(1) + .inferenceMode(SEQUENTIAL) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .port(PORT) + .build(); + + val serverSD = new JsonModelServer.Builder(sd) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .port(PORT+1) + .build(); + + serverDL.start(); + serverDL.stop(); + + serverSD.start(); + serverSD.stop(); + } + + @Test + public void basicServingTestForSD() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 1,4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new JsonModelServer.Builder(sd) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .port(PORT) + .build(); + + try { + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house); + + val timeStart = System.currentTimeMillis(); + price = client.predict(house); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + + assertNotNull(price); + assertEquals((float) district + 1.0f, price.getPrice(), 1e-5); + } + finally { + server.stop(); + } + } + + @Test + public void basicServingTestForDLSynchronized() throws Exception { + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .numWorkers(1) + .inferenceMode(INPLACE) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .port(PORT) + .build(); + + try { + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house1 = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + House house2 = House.builder().area(50).bathrooms(1).bedrooms(2).district(district).build(); + House house3 = House.builder().area(80).bathrooms(1).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house1); + + val timeStart = System.currentTimeMillis(); + PredictedPrice price1 = client.predict(house1); + PredictedPrice price2 = client.predict(house2); + PredictedPrice price3 = client.predict(house3); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + + assertNotNull(price); + assertEquals((float) 0.421444, price.getPrice(), 1e-5); + + } finally { + server.stop(); + } + } + + @Test + public void basicServingTestForDL() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .numWorkers(1) + .inferenceMode(SEQUENTIAL) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .port(PORT) + .parallelMode(false) + .build(); + + try { + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house); + + val timeStart = System.currentTimeMillis(); + price = client.predict(house); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + + assertNotNull(price); + assertEquals((float) 0.421444, price.getPrice(), 1e-5); + + } finally { + server.stop(); + } + } + + @Test + public void testDeserialization_1() { + String request = "{\"bedrooms\":3,\"area\":100,\"district\":2,\"bathrooms\":2}"; + val deserializer = new House.HouseDeserializer(); + val result = deserializer.deserialize(request); + assertEquals(2, result.getDistrict()); + assertEquals(100, result.getArea()); + assertEquals(2, result.getBathrooms()); + assertEquals(3, result.getBedrooms()); + + } + + @Test + public void testDeserialization_2() { + String request = "{\"price\":1}"; + val deserializer = new PredictedPrice.PredictedPriceDeserializer(); + val result = deserializer.deserialize(request); + assertEquals(1.0, result.getPrice(), 1e-4); + } + + @Test(expected = NullPointerException.class) + public void negativeServingTest_1() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(null) + .port(18080) + .build(); + } + + @Test //(expected = NullPointerException.class) + public void negativeServingTest_2() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .port(PORT) + .build(); + + } + + @Test(expected = IOException.class) + public void negativeServingTest_3() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .inferenceMode(SEQUENTIAL) + .numWorkers(1) + .port(PORT) + .build(); + + try { + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new JsonDeserializer() { + @Override + public PredictedPrice deserialize(String json) { + return null; + } + }) + .endpointAddress("http://localhost:18080/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house); + } finally { + server.stop(); + } + } + + @Test + public void asyncServingTest() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .inferenceMode(SEQUENTIAL) + .numWorkers(1) + .port(PORT) + .build(); + + try { + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + val timeStart = System.currentTimeMillis(); + Future price = client.predictAsync(house); + assertNotNull(price); + assertEquals((float) 0.421444, price.get().getPrice(), 1e-5); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + } + finally { + server.stop(); + } + } + + @Test + public void negativeAsyncTest() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .inferenceMode(InferenceMode.BATCHED) + .numWorkers(1) + .port(PORT) + .build(); + + try { + server.start(); + + // Fake deserializer to test failure + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new JsonDeserializer() { + @Override + public PredictedPrice deserialize(String json) { + return null; + } + }) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + val timeStart = System.currentTimeMillis(); + try { + Future price = client.predictAsync(house); + assertNotNull(price); + assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + } catch (ExecutionException e) { + assertTrue(e.getMessage().contains("Deserialization failed")); + } + } finally { + server.stop(); + } + } + + + @Test + public void testSameDiffMnist() throws Exception { + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10)); + SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b)); + + val server = new JsonModelServer.Builder(sd) + .outputSerializer( new IntSerde()) + .inputDeserializer(new FloatSerde()) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(float[] input) { + return new MultiDataSet(Nd4j.create(input, 1, input.length), null); + } + + @Override + public Integer apply(INDArray... nnOutput) { + return nnOutput[0].argMax().getInt(0); + } + }) + .orderedInputNodes("in") + .orderedOutputNodes("softmax") + .port(PORT+1) + .build(); + + val client = JsonRemoteInference.builder() + .endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving") + .outputDeserializer(new IntSerde()) + .inputSerializer( new FloatSerde()) + .build(); + + try{ + server.start(); + for( int i=0; i<10; i++ ){ + INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28*28); + INDArray exp = sd.output(Collections.singletonMap("in", f), "softmax").get("softmax"); + float[] fArr = f.toFloatVector(); + int out = client.predict(fArr); + assertEquals(exp.argMax().getInt(0), out); + } + } finally { + server.stop(); + } + } + + @Test + public void testMlnMnist() throws Exception { + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .list() + .layer(new DenseLayer.Builder().nIn(784).nOut(10).build()) + .layer(new LossLayer.Builder().activation(Activation.SOFTMAX).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + val server = new JsonModelServer.Builder(net) + .outputSerializer( new IntSerde()) + .inputDeserializer(new FloatSerde()) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(float[] input) { + return new MultiDataSet(Nd4j.create(input, 1, input.length), null); + } + + @Override + public Integer apply(INDArray... nnOutput) { + return nnOutput[0].argMax().getInt(0); + } + }) + .orderedInputNodes("in") + .orderedOutputNodes("softmax") + .port(PORT + 1) + .inferenceMode(SEQUENTIAL) + .numWorkers(2) + .build(); + + val client = JsonRemoteInference.builder() + .endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving") + .outputDeserializer(new IntSerde()) + .inputSerializer( new FloatSerde()) + .build(); + + try { + server.start(); + for (int i = 0; i < 10; i++) { + INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28 * 28); + INDArray exp = net.output(f); + float[] fArr = f.toFloatVector(); + int out = client.predict(fArr); + assertEquals(exp.argMax().getInt(0), out); + } + } catch (Exception e){ + e.printStackTrace(); + throw e; + } finally { + server.stop(); + } + } + + @Test + public void testCompGraph() throws Exception { + + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .graphBuilder() + .addInputs("input1", "input2") + .addLayer("L1", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input1") + .addLayer("L2", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input2") + .addVertex("merge", new MergeVertex(), "L1", "L2") + .addLayer("out", new OutputLayer.Builder().nIn(4+4).nOut(3).build(), "merge") + .setOutputs("out") + .build(); + + ComputationGraph net = new ComputationGraph(conf); + net.init(); + + val server = new JsonModelServer.Builder(net) + .outputSerializer( new IntSerde()) + .inputDeserializer(new FloatSerde()) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(float[] input) { + return new MultiDataSet(Nd4j.create(input, 1, input.length), null); + } + + @Override + public Integer apply(INDArray... nnOutput) { + return nnOutput[0].argMax().getInt(0); + } + }) + .orderedInputNodes("in") + .orderedOutputNodes("softmax") + .port(PORT + 1) + .inferenceMode(SEQUENTIAL) + .numWorkers(2) + .parallelMode(false) + .build(); + + val client = JsonRemoteInference.builder() + .endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving") + .outputDeserializer(new IntSerde()) + .inputSerializer( new FloatSerde()) + .build(); + + try { + server.start(); + //client.predict(new float[]{0.0f, 1.0f, 2.0f}); + } catch (Exception e){ + e.printStackTrace(); + throw e; + } finally { + server.stop(); + } + } + + @Test + public void testCompGraph_1() throws Exception { + + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .updater(new Sgd(0.01)) + .graphBuilder() + .addInputs("input") + .addLayer("L1", new DenseLayer.Builder().nIn(8).nOut(4).build(), "input") + .addLayer("out1", new OutputLayer.Builder() + .lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) + .nIn(4).nOut(3).build(), "L1") + .addLayer("out2", new OutputLayer.Builder() + .lossFunction(LossFunctions.LossFunction.MSE) + .nIn(4).nOut(2).build(), "L1") + .setOutputs("out1","out2") + .build(); + + final ComputationGraph net = new ComputationGraph(conf); + net.init(); + + val server = new JsonModelServer.Builder(net) + .outputSerializer( new IntSerde()) + .inputDeserializer(new FloatSerde()) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(float[] input) { + return new MultiDataSet(Nd4j.create(input, 1, input.length), null); + } + + @Override + public Integer apply(INDArray... nnOutput) { + return nnOutput[0].argMax().getInt(0); + } + }) + .orderedInputNodes("input") + .orderedOutputNodes("out") + .port(PORT + 1) + .inferenceMode(SEQUENTIAL) + .numWorkers(2) + .parallelMode(false) + .build(); + + val client = JsonRemoteInference.builder() + .endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving") + .outputDeserializer(new IntSerde()) + .inputSerializer( new FloatSerde()) + .build(); + + try { + server.start(); + val result = client.predict(new float[]{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + assertNotNull(result); + } catch (Exception e){ + e.printStackTrace(); + throw e; + } finally { + server.stop(); + } + } + + private static class FloatSerde implements JsonSerializer, JsonDeserializer{ + private final ObjectMapper om = new ObjectMapper(); + + @Override + public float[] deserialize(String json) { + try { + return om.readValue(json, FloatHolder.class).getFloats(); + } catch (IOException e){ + throw new RuntimeException(e); + } + } + + @Override + public String serialize(float[] o) { + try{ + return om.writeValueAsString(new FloatHolder(o)); + } catch (IOException e){ + throw new RuntimeException(e); + } + } + + //Use float holder so Jackson does ser/de properly (no "{}" otherwise) + @AllArgsConstructor @NoArgsConstructor @Data + private static class FloatHolder { + private float[] floats; + } + } + + private static class IntSerde implements JsonSerializer, JsonDeserializer { + private final ObjectMapper om = new ObjectMapper(); + + @Override + public Integer deserialize(String json) { + try { + return om.readValue(json, Integer.class); + } catch (IOException e){ + throw new RuntimeException(e); + } + } + + @Override + public String serialize(Integer o) { + try{ + return om.writeValueAsString(o); + } catch (IOException e){ + throw new RuntimeException(e); + } + } + } +} \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java new file mode 100644 index 000000000..1b347d112 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java @@ -0,0 +1,133 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote; + +import lombok.val; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.HttpClientBuilder; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +public class ServletTest { + + private JsonModelServer server; + + @Before + public void setUp() throws Exception { + val sd = SameDiff.create(); + server = new JsonModelServer.Builder(sd) + .port(8080) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(String input) { + return null; + } + + @Override + public String apply(INDArray... nnOutput) { + return null; + } + }) + .outputSerializer(new JsonSerializer() { + @Override + public String serialize(String o) { + return ""; + } + }) + .inputDeserializer(new JsonDeserializer() { + @Override + public String deserialize(String json) { + return ""; + } + }) + .orderedInputNodes("input") + .orderedOutputNodes("output") + .build(); + + server.start(); + //server.join(); + } + + @After + public void tearDown() throws Exception { + server.stop(); + } + + @Test + public void getEndpoints() throws IOException { + val request = new HttpGet( "http://localhost:8080/v1" ); + request.setHeader("Content-type", "application/json"); + + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + @Test + public void testContentTypeGet() throws IOException { + val request = new HttpGet( "http://localhost:8080/v1" ); + request.setHeader("Content-type", "text/plain"); + + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(415, response.getStatusLine().getStatusCode()); + } + + @Test + public void testContentTypePost() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving"); + request.setHeader("Content-type", "text/plain"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(415, response.getStatusLine().getStatusCode()); + } + + @Test + public void postForServing() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving"); + request.setHeader("Content-type", "application/json"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(500, response.getStatusLine().getStatusCode()); + } + + @Test + public void testNotFoundPost() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving/some"); + request.setHeader("Content-type", "application/json"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(404, response.getStatusLine().getStatusCode()); + } + + @Test + public void testNotFoundGet() throws Exception { + val requestGet = new HttpGet( "http://localhost:8080/v1/not_found" ); + requestGet.setHeader("Content-type", "application/json"); + + val responseGet = HttpClientBuilder.create().build().execute( requestGet ); + assertEquals(404, responseGet.getStatusLine().getStatusCode()); + } + +} diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java new file mode 100644 index 000000000..d66c8bae5 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote.helpers; + +import com.google.gson.Gson; +import lombok.*; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +@Data +@Builder +@AllArgsConstructor +@NoArgsConstructor +public class House { + private int district; + private int bedrooms; + private int bathrooms; + private int area; + + + public static class HouseSerializer implements JsonSerializer { + @Override + public String serialize(@NonNull House o) { + return new Gson().toJson(o); + } + } + + public static class HouseDeserializer implements JsonDeserializer { + @Override + public House deserialize(@NonNull String json) { + return new Gson().fromJson(json, House.class); + } + } +} diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java new file mode 100644 index 000000000..82976a3da --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote.helpers; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; + +@Slf4j +public class HouseToPredictedPriceAdapter implements InferenceAdapter { + + @Override + public MultiDataSet apply(@NonNull House input) { + // we just create vector array with shape[4] and assign it's value to the district value + return new MultiDataSet(Nd4j.create(DataType.FLOAT, 1, 4).assign(input.getDistrict()), null); + } + + @Override + public PredictedPrice apply(INDArray... nnOutput) { + return new PredictedPrice(nnOutput[0].getFloat(0)); + } +} diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java new file mode 100644 index 000000000..c4024bb1d --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote.helpers; + +import com.google.gson.Gson; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class PredictedPrice { + private float price; + + public static class PredictedPriceSerializer implements JsonSerializer { + @Override + public String serialize(@NonNull PredictedPrice o) { + return new Gson().toJson(o); + } + } + + public static class PredictedPriceDeserializer implements JsonDeserializer { + @Override + public PredictedPrice deserialize(@NonNull String json) { + return new Gson().fromJson(json, PredictedPrice.class); + } + } +} diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml new file mode 100644 index 000000000..59b35644e --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml @@ -0,0 +1,48 @@ + + + + + + + + logs/application.log + + %date - [%level] - from %logger in %thread + %n%message%n%xException%n + + + + + + %logger{15} - %message%n%xException{5} + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-remote/pom.xml b/deeplearning4j/deeplearning4j-remote/pom.xml new file mode 100644 index 000000000..1fc937e10 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/pom.xml @@ -0,0 +1,30 @@ + + + + 4.0.0 + pom + + + deeplearning4j-json-server + + + + org.deeplearning4j + deeplearning4j + 1.0.0-SNAPSHOT + + + deeplearning4j-remote + 1.0.0-SNAPSHOT + deeplearning4j-remote + + + + test-nd4j-native + + + test-nd4j-cuda-10.1 + + + diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java index 48966a57d..149a122f2 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java @@ -21,7 +21,6 @@ import lombok.*; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.ModelAdapter; -import org.deeplearning4j.nn.api.OutputAdapter; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java index 022b28076..0e2fd339a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java @@ -22,7 +22,6 @@ import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.ModelAdapter; -import org.deeplearning4j.nn.api.OutputAdapter; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index e13b62cda..b8b70befb 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -144,6 +144,7 @@ dl4j-perf dl4j-integration-tests deeplearning4j-common + deeplearning4j-remote diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InferenceAdapter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InferenceAdapter.java new file mode 100644 index 000000000..671ef613d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InferenceAdapter.java @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.adapters; + +/** + * This interface describes methods needed to convert custom JVM objects to INDArrays, suitable for feeding neural networks + * + * @param type of the Input for the model. I.e. String for raw text + * @param type of the Output for the model, I.e. Sentiment, for Text->Sentiment extraction + * + * @author raver119@gmail.com + */ +public interface InferenceAdapter extends InputAdapter, OutputAdapter { +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InputAdapter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InputAdapter.java new file mode 100644 index 000000000..09ca4eaa6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InputAdapter.java @@ -0,0 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.adapters; + +import org.nd4j.linalg.dataset.MultiDataSet; + +/** + * This interface describes method for transformation from object of type I to MultiDataSet. + * + */ +public interface InputAdapter { + /** + * This method converts input object to MultiDataSet + * @param input + * @return + */ + MultiDataSet apply(I input); +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/OutputAdapter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/OutputAdapter.java similarity index 88% rename from deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/OutputAdapter.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/OutputAdapter.java index 4b32e83c2..ba1ff40d4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/OutputAdapter.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/OutputAdapter.java @@ -14,15 +14,15 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.nn.api; +package org.nd4j.adapters; import org.nd4j.linalg.api.ndarray.INDArray; import java.io.Serializable; /** - * This interface describes entity used to conver neural network output to specified class. - * I.e. INDArray -> int[] on the fly. + * This interface describes entity used to convert neural network output to specified class. + * I.e. INDArray -> int[] or INDArray -> Sentiment on the fly * * PLEASE NOTE: Implementation will be used in workspace environment to avoid additional allocations during inference. * This means you shouldn't store or return the INDArrays passed to OutputAdapter.apply(INDArray...) directly. diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index bc33b9e55..5b26a81ea 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -4036,6 +4036,11 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p public native void printBuffer(); public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/); + /** + * print element by element consequently in a way they (elements) are stored in physical memory + */ + public native void printLinearBuffer(); + /** * prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status */ @@ -7047,9 +7052,9 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder); @@ -7894,10 +7899,6 @@ public static final int PREALLOC_SIZE = 33554432; * Returns the prod of the data * up to the given length */ - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongPointer data, int length); - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongBuffer data, int length); - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") long[] data, int length); - @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length); @@ -8745,10 +8746,6 @@ public static final int PREALLOC_SIZE = 33554432; * @param originalTadNum the tad number for the reduced version of the problem */ -/** - * Returns the prod of the data - * up to the given length - */ /** * Returns the prod of the data diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 5e2a4e296..8c0bebfc1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -4036,6 +4036,11 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p public native void printBuffer(); public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/); + /** + * print element by element consequently in a way they (elements) are stored in physical memory + */ + public native void printLinearBuffer(); + /** * prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status */ @@ -7047,9 +7052,9 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder); @@ -7894,10 +7899,6 @@ public static final int PREALLOC_SIZE = 33554432; * Returns the prod of the data * up to the given length */ - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongPointer data, int length); - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongBuffer data, int length); - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") long[] data, int length); - @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length); @@ -8745,10 +8746,6 @@ public static final int PREALLOC_SIZE = 33554432; * @param originalTadNum the tad number for the reduced version of the problem */ -/** - * Returns the prod of the data - * up to the given length - */ /** * Returns the prod of the data diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml index 7d69be7b6..734b1b738 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml @@ -50,6 +50,8 @@ httpmime ${httpmime.version} + + com.mashape.unirest unirest-java diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml index c895a9cc6..1122f90d7 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml @@ -54,11 +54,13 @@ httpmime ${httpmime.version} + com.mashape.unirest unirest-java ${unirest.version} + org.nd4j nd4j-jackson diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java index dd698b96c..f69b9c24b 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java @@ -21,13 +21,15 @@ import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import com.beust.jcommander.Parameters; import com.google.common.primitives.Ints; -import com.mashape.unirest.http.HttpResponse; + +import org.nd4j.shade.jackson.databind.ObjectMapper; import com.mashape.unirest.http.Unirest; import io.aeron.Aeron; import io.aeron.driver.MediaDriver; import io.aeron.driver.ThreadingMode; import lombok.Data; import lombok.NoArgsConstructor; +import lombok.val; import org.agrona.CloseHelper; import org.agrona.concurrent.BusySpinIdleStrategy; import org.json.JSONObject; @@ -49,7 +51,6 @@ import org.nd4j.parameterserver.updater.SoftSyncParameterUpdater; import org.nd4j.parameterserver.updater.SynchronousParameterUpdater; import org.nd4j.parameterserver.updater.storage.InMemoryUpdateStorage; import org.nd4j.parameterserver.util.CheckSocket; -import org.nd4j.shade.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -342,7 +343,7 @@ public class ParameterServerSubscriber implements AutoCloseable { JSONObject jsonObject = new JSONObject(objectMapper.writeValueAsString(subscriberState)); String url = String.format("http://%s:%d/updatestatus/%d", statusServerHost, statusServerPort, streamId); - HttpResponse entity = Unirest.post(url).header("Content-Type", "application/json") + val entity = Unirest.post(url).header("Content-Type", "application/json") .body(jsonObject).asString(); } catch (Exception e) { failCount.incrementAndGet(); diff --git a/nd4j/nd4j-remote/README.md b/nd4j/nd4j-remote/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/nd4j/nd4j-serde/nd4j-grpc/pom.xml b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml similarity index 97% rename from nd4j/nd4j-serde/nd4j-grpc/pom.xml rename to nd4j/nd4j-remote/nd4j-grpc-client/pom.xml index 9105e333b..1f668eb66 100644 --- a/nd4j/nd4j-serde/nd4j-grpc/pom.xml +++ b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml @@ -19,13 +19,13 @@ - nd4j-serde + nd4j-remote org.nd4j 1.0.0-SNAPSHOT 4.0.0 - nd4j-grpc + nd4j-grpc-client nd4j-grpc diff --git a/nd4j/nd4j-serde/nd4j-grpc/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-grpc/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java rename to nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java diff --git a/nd4j/nd4j-serde/nd4j-grpc/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-grpc/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java rename to nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java diff --git a/nd4j/nd4j-serde/nd4j-grpc/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-grpc/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java rename to nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/pom.xml b/nd4j/nd4j-remote/nd4j-json-client/pom.xml new file mode 100644 index 000000000..d1ceeeda9 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/pom.xml @@ -0,0 +1,70 @@ + + + + 4.0.0 + jar + + + org.nd4j + nd4j-remote + 1.0.0-SNAPSHOT + + + nd4j-json-client + + nd4j-json-client + + + UTF-8 + 1.7 + 1.7 + + + + + junit + junit + test + + + + com.mashape.unirest + unirest-java + ${unirest.version} + + + + org.slf4j + slf4j-api + + + + org.nd4j + jackson + ${project.version} + + + + + + testresources + + + + org.nd4j + nd4j-native + ${project.version} + test + + + org.nd4j + nd4j-native + ${project.version} + ${javacpp.platform} + test + + + + + diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/JsonRemoteInference.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/JsonRemoteInference.java new file mode 100644 index 000000000..bbd6ac49e --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/JsonRemoteInference.java @@ -0,0 +1,153 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients; + +import com.mashape.unirest.http.HttpResponse; +import com.mashape.unirest.http.Unirest; +import com.mashape.unirest.http.exceptions.UnirestException; +import lombok.Builder; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.json.JSONObject; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * This class provides remote inference functionality via JSON-powered REST APIs. + * + * Basically we assume that there's remote JSON server available (on bare metal or in k8s/swarm/whatever cluster), and with proper serializers/deserializers provided we can issue REST requests and get back responses. + * So, in this way application logic can be separated from DL logic. + * + * You just need to provide serializer/deserializer and address of the REST server, i.e. "http://model:8080/v1/serving" + * + * @param type of the input class, i.e. String + * @param type of the output class, i.e. Sentiment + * + * @author raver119@gmail.com + */ +@Slf4j +public class JsonRemoteInference { + private String endpointAddress; + private JsonSerializer serializer; + private JsonDeserializer deserializer; + + @Builder + public JsonRemoteInference(@NonNull String endpointAddress, @NonNull JsonSerializer inputSerializer, @NonNull JsonDeserializer outputDeserializer) { + this.endpointAddress = endpointAddress; + this.serializer = inputSerializer; + this.deserializer = outputDeserializer; + } + + private O processResponse(HttpResponse response) throws IOException { + if (response.getStatus() != 200) + throw new IOException("Inference request returned bad error code: " + response.getStatus()); + + O result = deserializer.deserialize(response.getBody()); + if (result == null) { + throw new IOException("Deserialization failed!"); + } + return result; + } + + /** + * This method does remote inference in a blocking way + * + * @param input + * @return + * @throws IOException + */ + public O predict(I input) throws IOException { + try { + val stringResult = Unirest.post(endpointAddress) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .body(new JSONObject(serializer.serialize(input))).asString(); + + return processResponse(stringResult); + } catch (UnirestException e) { + throw new IOException(e); + } + } + + /** + * This method does remote inference in asynchronous way, returning Future instead + * @param input + * @return + */ + public Future predictAsync(I input) { + val stringResult = Unirest.post(endpointAddress) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .body(new JSONObject(serializer.serialize(input))).asStringAsync(); + return new InferenceFuture(stringResult); + } + + /** + * This class holds a Future of the object returned by remote inference server + */ + private class InferenceFuture implements Future { + private Future> unirestFuture; + + private InferenceFuture(@NonNull Future> future) { + this.unirestFuture = future; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return unirestFuture.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return unirestFuture.isCancelled(); + } + + @Override + public boolean isDone() { + return unirestFuture.isDone(); + } + + @Override + public O get() throws InterruptedException, ExecutionException { + val stringResult = unirestFuture.get(); + + try { + return processResponse(stringResult); + } catch (IOException e) { + throw new ExecutionException(e); + } + } + + @Override + public O get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + val stringResult = unirestFuture.get(timeout, unit); + + try { + return processResponse(stringResult); + } catch (IOException e) { + throw new ExecutionException(e); + } + } + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonDeserializer.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonDeserializer.java new file mode 100644 index 000000000..77f8939b2 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonDeserializer.java @@ -0,0 +1,33 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde; + +/** + * This interface describes basic JSON deserializer interface used for JsonRemoteInference + * @param type of the deserializable class + * + * @author raver119@gmail.com + */ +public interface JsonDeserializer { + + /** + * This method serializes given object into JSON-string + * @param json string containing JSON representation of the object + * @return + */ + T deserialize(String json); +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonSerializer.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonSerializer.java new file mode 100644 index 000000000..4258b98cc --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonSerializer.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde; + +/** + * This interface describes basic JSON serializer interface used for JsonRemoteInference + * @param type of the serializable class + * + * @author raver119@gmail.com + */ +public interface JsonSerializer { + + /** + * This method serializes given object into JSON-string + * + * @param o object to be serialized + * @return + */ + String serialize(T o); +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/AbstractSerDe.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/AbstractSerDe.java new file mode 100644 index 000000000..fa02496bf --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/AbstractSerDe.java @@ -0,0 +1,46 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.IOException; + +public abstract class AbstractSerDe implements JsonDeserializer, JsonSerializer { + protected ObjectMapper objectMapper = new ObjectMapper(); + + + protected String serializeClass(@NonNull T obj) { + try { + return objectMapper.writeValueAsString(obj); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + protected T deserializeClass(@NonNull String json, @NonNull Class cls) { + try { + return objectMapper.readValue(json, cls); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/BooleanSerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/BooleanSerde.java new file mode 100644 index 000000000..db92a4346 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/BooleanSerde.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; + +/** + * This class provides JSON ser/de for Java Boolean. Single value only. + */ +public class BooleanSerde extends AbstractSerDe { + @Override + public Boolean deserialize(@NonNull String json) { + return deserializeClass(json, Boolean.class); + } + + @Override + public String serialize(@NonNull Boolean o) { + return serializeClass(o); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java new file mode 100644 index 000000000..fbe44a101 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.*; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.IOException; + +/** + * This class provides JSON ser/de for Java double[] + */ +public class DoubleArraySerde extends AbstractSerDe { + + @Override + public String serialize(@NonNull double[] data) { + return serializeClass(data); + } + + @Override + public double[] deserialize(@NonNull String json) { + return deserializeClass(json, double[].class); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleSerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleSerde.java new file mode 100644 index 000000000..d12a1dc66 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleSerde.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; + +/** + * This class provides JSON ser/de for Java Double. Single value only. + */ +public class DoubleSerde extends AbstractSerDe { + @Override + public Double deserialize(@NonNull String json) { + return deserializeClass(json, Double.class); + } + + @Override + public String serialize(@NonNull Double o) { + return serializeClass(o); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatArraySerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatArraySerde.java new file mode 100644 index 000000000..24c78ecfd --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatArraySerde.java @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.remote.clients.serde.impl; + + +import lombok.*; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.IOException; + + +/** + * This class provides JSON ser/de for Java float[] + */ +public class FloatArraySerde extends AbstractSerDe { + + @Override + public String serialize(@NonNull float[] data) { + return serializeClass(data); + } + + @Override + public float[] deserialize(@NonNull String json) { + return deserializeClass(json, float[].class); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatSerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatSerde.java new file mode 100644 index 000000000..14b822ba7 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatSerde.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; + +/** + * This class provides JSON ser/de for Java Float. Single value only. + */ +public class FloatSerde extends AbstractSerDe { + @Override + public Float deserialize(@NonNull String json) { + return deserializeClass(json, Float.class); + } + + @Override + public String serialize(@NonNull Float o) { + return serializeClass(o); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/IntegerSerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/IntegerSerde.java new file mode 100644 index 000000000..0d8b40272 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/IntegerSerde.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; + +/** + * This class provides JSON ser/de for Java Integer. Single value only. + */ +public class IntegerSerde extends AbstractSerDe { + @Override + public Integer deserialize(@NonNull String json) { + return deserializeClass(json, Integer.class); + } + + @Override + public String serialize(@NonNull Integer o) { + return serializeClass(o); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/StringSerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/StringSerde.java new file mode 100644 index 000000000..20e1d8fd4 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/StringSerde.java @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.IOException; + +/** + * This class provides fake JSON serializer/deserializer functionality for String. + * It doesn't put any JSON-specific bits into actual string + */ +public class StringSerde extends AbstractSerDe { + + @Override + public String serialize(@NonNull String data) { + return serializeClass(data); + } + + @Override + public String deserialize(@NonNull String json) { + return deserializeClass(json, String.class); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/README.md b/nd4j/nd4j-remote/nd4j-json-server/README.md new file mode 100644 index 000000000..963890a7b --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/README.md @@ -0,0 +1,35 @@ +## SameDiff model serving + +This modules provides JSON-based serving of SameDiff models + +## Example + +First of all we'll create server instance. Most probably you'll do it in application that will be running in container +```java +val server = SameDiffJsonModelServer.builder() + .adapter(new StringToSentimentAdapter()) + .model(mySameDiffModel) + .port(8080) + .serializer(new SentimentSerializer()) + .deserializer(new StringDeserializer()) + .build(); + +server.start(); +server.join(); +``` + +Now, presumably in some other container, we'll set up remote inference client: +```java +val client = JsonRemoteInference.builder() + .endpointAddress("http://youraddress:8080/v1/serving") + .serializer(new StringSerializer()) + .deserializer(new SentimentDeserializer()) + .build(); + +Sentiment result = client.predict(myText); +``` + On top of that, there's async call available, for cases when you need to chain multiple requests to one or multiple remote model servers. + +```java +Future result = client.predictAsync(myText); +``` \ No newline at end of file diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/nd4j/nd4j-remote/nd4j-json-server/pom.xml new file mode 100644 index 000000000..d5e506839 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/pom.xml @@ -0,0 +1,179 @@ + + + + 4.0.0 + jar + + + org.nd4j + nd4j-remote + 1.0.0-SNAPSHOT + + + nd4j-json-server + nd4j-json-server + + + UTF-8 + 1.7 + 1.7 + 2.26 + + + + + junit + junit + test + + + + org.nd4j + nd4j-json-client + ${project.version} + + + + org.slf4j + slf4j-api + + + + org.nd4j + nd4j-api + ${project.version} + + + + org.glassfish.jersey.core + jersey-client + ${jersey.version} + + + + org.glassfish.jersey.core + jersey-server + ${jersey.version} + + + + org.eclipse.jetty + jetty-server + 9.4.19.v20190610 + + + + org.eclipse.jetty + jetty-servlet + 9.4.19.v20190610 + + + + org.glassfish.jersey.inject + jersey-hk2 + ${jersey.version} + + + + org.glassfish.jersey.media + jersey-media-json-processing + ${jersey.version} + + + + org.glassfish.jersey.containers + jersey-container-servlet-core + ${jersey.version} + + + + ch.qos.logback + logback-core + ${logback.version} + test + + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + javax.xml.bind + jaxb-api + 2.3.0 + + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + + + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + + javax.activation + activation + 1.1 + + + + + + + nd4j-tests-cpu + + true + + + + org.nd4j + nd4j-native + ${project.version} + test + + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + + + + + testresources + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + ${maven.compiler.source} + ${maven.compiler.target} + + + + + diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/SameDiffJsonModelServer.java b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/SameDiffJsonModelServer.java new file mode 100644 index 000000000..93f640b75 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/SameDiffJsonModelServer.java @@ -0,0 +1,288 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote; + +import lombok.*; +import lombok.extern.slf4j.Slf4j; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.nd4j.adapters.InputAdapter; +import org.nd4j.adapters.OutputAdapter; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.remote.serving.ModelServingServlet; +import org.nd4j.remote.serving.SameDiffServlet; + +import java.util.List; + +/** + * This class provides JSON-powered model serving functionality for SameDiff graphs. + * Server url will be http://0.0.0.0:{port}>/v1/serving + * Server only accepts POST requests + * + * @param type of the input class, i.e. String + * @param type of the output class, i.e. Sentiment + * + * @author raver119@gmail.com + */ +@Slf4j +public class SameDiffJsonModelServer { + + + protected SameDiff sdModel; + protected final JsonSerializer serializer; + protected final JsonDeserializer deserializer; + protected final InferenceAdapter inferenceAdapter; + protected final int port; + + // this servlet will be used to serve models + protected ModelServingServlet servingServlet; + + // HTTP server instance + protected Server server; + + // for SameDiff only + protected String[] orderedInputNodes; + protected String[] orderedOutputNodes; + + protected SameDiffJsonModelServer(@NonNull InferenceAdapter inferenceAdapter, @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer, int port) { + Preconditions.checkArgument(port > 0 && port < 65535, "TCP port must be in range of 0..65535"); + + this.inferenceAdapter = inferenceAdapter; + this.serializer = serializer; + this.deserializer = deserializer; + this.port = port; + } + + //@Builder + public SameDiffJsonModelServer(SameDiff sdModel, @NonNull InferenceAdapter inferenceAdapter, @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer, int port, String[] orderedInputNodes, @NonNull String[] orderedOutputNodes) { + this(inferenceAdapter, serializer, deserializer, port); + this.sdModel = sdModel; + this.orderedInputNodes = orderedInputNodes; + this.orderedOutputNodes = orderedOutputNodes; + + // TODO: both lists of nodes should be validated, to make sure nodes specified here exist in actual model + if (orderedInputNodes != null) { + // input nodes list might be null. strange but ok + } + + Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "SameDiff serving requires at least 1 output node"); + } + + protected void start(int port, @NonNull ModelServingServlet servlet) throws Exception { + val context = new ServletContextHandler(ServletContextHandler.SESSIONS); + context.setContextPath("/"); + + server = new Server(port); + server.setHandler(context); + + val jerseyServlet = context.addServlet(org.glassfish.jersey.servlet.ServletContainer.class, "/*"); + jerseyServlet.setInitOrder(0); + jerseyServlet.setServlet(servlet); + + server.start(); + } + + public void start() throws Exception { + Preconditions.checkArgument(sdModel != null, "SameDiff model wasn't defined"); + + servingServlet = SameDiffServlet.builder() + .sdModel(sdModel) + .serializer(serializer) + .deserializer(deserializer) + .inferenceAdapter(inferenceAdapter) + .orderedInputNodes(orderedInputNodes) + .orderedOutputNodes(orderedOutputNodes) + .build(); + + start(port, servingServlet); + } + + public void join() throws InterruptedException { + Preconditions.checkArgument(server != null, "Model server wasn't started yet"); + + server.join(); + } + + public void stop() throws Exception { + //Preconditions.checkArgument(server != null, "Model server wasn't started yet"); + + server.stop(); + } + + + public static class Builder { + private SameDiff sameDiff; + private String[] orderedInputNodes; + private String[] orderedOutputNodes; + private InferenceAdapter inferenceAdapter; + private JsonSerializer serializer; + private JsonDeserializer deserializer; + private int port; + + private InputAdapter inputAdapter; + private OutputAdapter outputAdapter; + + public Builder() {} + + public Builder sdModel(@NonNull SameDiff sameDiff) { + this.sameDiff = sameDiff; + return this; + } + + /** + * This method defines InferenceAdapter implementation, which will be used to convert object of Input type to the set of INDArray(s), and for conversion of resulting INDArray(s) into object of Output type + * @param inferenceAdapter + * @return + */ + public Builder inferenceAdapter(InferenceAdapter inferenceAdapter) { + this.inferenceAdapter = inferenceAdapter; + return this; + } + + /** + * This method allows you to specify InputAdapter to be used for inference + * + * PLEASE NOTE: This method is optional, and will require OutputAdapter defined + * @param inputAdapter + * @return + */ + public Builder inputAdapter(@NonNull InputAdapter inputAdapter) { + this.inputAdapter = inputAdapter; + return this; + } + + /** + * This method allows you to specify OutputAdapter to be used for inference + * + * PLEASE NOTE: This method is optional, and will require InputAdapter defined + * @param outputAdapter + * @return + */ + public Builder outputAdapter(@NonNull OutputAdapter outputAdapter) { + this.outputAdapter = outputAdapter; + return this; + } + + /** + * This method defines JsonSerializer instance to be used to convert object of output type into JSON format, so it could be sent over the wire + * + * @param serializer + * @return + */ + public Builder outputSerializer(@NonNull JsonSerializer serializer) { + this.serializer = serializer; + return this; + } + + /** + * This method defines JsonDeserializer instance to be used to convert JSON passed through HTTP into actual object of input type, that will be fed into SameDiff model + * + * @param deserializer + * @return + */ + public Builder inputDeserializer(@NonNull JsonDeserializer deserializer) { + this.deserializer = deserializer; + return this; + } + + /** + * This method defines the order of placeholders to be filled with INDArrays provided by Deserializer + * + * @param args + * @return + */ + public Builder orderedInputNodes(String... args) { + orderedInputNodes = args; + return this; + } + + /** + * This method defines the order of placeholders to be filled with INDArrays provided by Deserializer + * + * @param args + * @return + */ + public Builder orderedInputNodes(@NonNull List args) { + orderedInputNodes = args.toArray(new String[args.size()]); + return this; + } + + /** + * This method defines list of graph nodes to be extracted after feed-forward pass and used as OutputAdapter input + * @param args + * @return + */ + public Builder orderedOutputNodes(String... args) { + Preconditions.checkArgument(args != null && args.length > 0, "OutputNodes should contain at least 1 element"); + orderedOutputNodes = args; + return this; + } + + /** + * This method defines list of graph nodes to be extracted after feed-forward pass and used as OutputAdapter input + * @param args + * @return + */ + public Builder orderedOutputNodes(@NonNull List args) { + Preconditions.checkArgument(args.size() > 0, "OutputNodes should contain at least 1 element"); + orderedOutputNodes = args.toArray(new String[args.size()]); + return this; + } + + /** + * This method allows to configure HTTP port used for serving + * + * PLEASE NOTE: port must be free and be in range regular TCP/IP ports range + * @param port + * @return + */ + public Builder port(int port) { + this.port = port; + return this; + } + + /** + * This method builds SameDiffJsonModelServer instance + * @return + */ + public SameDiffJsonModelServer build() { + if (inferenceAdapter == null) { + if (inputAdapter != null && outputAdapter != null) { + inferenceAdapter = new InferenceAdapter() { + @Override + public MultiDataSet apply(I input) { + return inputAdapter.apply(input); + } + + @Override + public O apply(INDArray... outputs) { + return outputAdapter.apply(outputs); + } + }; + } else + throw new IllegalArgumentException("Either InferenceAdapter or InputAdapter + OutputAdapter should be configured"); + } + return new SameDiffJsonModelServer(sameDiff, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes); + } + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ModelServingServlet.java b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ModelServingServlet.java new file mode 100644 index 000000000..dc240b4f6 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ModelServingServlet.java @@ -0,0 +1,30 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.serving; + +import javax.servlet.Servlet; + +/** + * This interface describes Servlet interface extension, suited for ND4J/DL4J model serving + * @param + * @param + * + * @author raver119@gmail.com + */ +public interface ModelServingServlet extends Servlet { + // +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java new file mode 100644 index 000000000..0be7b8757 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java @@ -0,0 +1,206 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.serving; + +import lombok.*; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.adapters.InferenceAdapter; + +import javax.servlet.*; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.HttpMethod; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.LinkedHashMap; + +import static javax.ws.rs.core.MediaType.APPLICATION_JSON; + +/** + * This servlet provides SameDiff model serving capabilities + * + * @param + * @param + * + * @author raver119@gmail.com + */ +@NoArgsConstructor +@AllArgsConstructor +@Slf4j +@Builder +public class SameDiffServlet implements ModelServingServlet { + + protected SameDiff sdModel; + protected JsonSerializer serializer; + protected JsonDeserializer deserializer; + protected InferenceAdapter inferenceAdapter; + + protected String[] orderedInputNodes; + protected String[] orderedOutputNodes; + + protected final static String SERVING_ENDPOINT = "/v1/serving"; + protected final static String LISTING_ENDPOINT = "/v1"; + protected final static int PAYLOAD_SIZE_LIMIT = 10 * 1024; // TODO: should be customizable + + protected SameDiffServlet(@NonNull InferenceAdapter inferenceAdapter, @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer){ + this.serializer = serializer; + this.deserializer = deserializer; + this.inferenceAdapter = inferenceAdapter; + } + + @Override + public void init(ServletConfig servletConfig) throws ServletException { + // + } + + @Override + public ServletConfig getServletConfig() { + return null; + } + + @Override + public void service(ServletRequest servletRequest, ServletResponse servletResponse) throws ServletException, IOException { + // we'll parse request here, and do model serving + val httpRequest = (HttpServletRequest) servletRequest; + val httpResponse = (HttpServletResponse) servletResponse; + + if (httpRequest.getMethod().equals(HttpMethod.GET)) { + doGet(httpRequest, httpResponse); + } + else if (httpRequest.getMethod().equals(HttpMethod.POST)) { + doPost(httpRequest, httpResponse); + } + + } + + protected void sendError(String uri, HttpServletResponse response) throws IOException { + val msg = "Requested endpoint [" + uri + "] not found"; + response.setStatus(404, msg); + response.sendError(404, msg); + } + + protected void sendBadContentType(String actualContentType, HttpServletResponse response) throws IOException { + val msg = "Content type [" + actualContentType + "] not supported"; + response.setStatus(415, msg); + response.sendError(415, msg); + } + + protected boolean validateRequest(HttpServletRequest request, HttpServletResponse response) + throws IOException{ + val contentType = request.getContentType(); + if (!StringUtils.equals(contentType, APPLICATION_JSON)) { + sendBadContentType(contentType, response); + int contentLength = request.getContentLength(); + if (contentLength > PAYLOAD_SIZE_LIMIT) { + response.sendError(500, "Payload size limit violated!"); + } + return false; + } + return true; + } + + protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException { + val processor = new ServingProcessor(); + String processorReturned = ""; + String path = request.getPathInfo(); + if (path.equals(LISTING_ENDPOINT)) { + val contentType = request.getContentType(); + if (!StringUtils.equals(contentType, APPLICATION_JSON)) { + sendBadContentType(contentType, response); + } + processorReturned = processor.listEndpoints(); + } + else { + sendError(request.getRequestURI(), response); + } + try { + val out = response.getWriter(); + out.write(processorReturned); + } catch (IOException e) { + log.error(e.getMessage()); + } + } + + protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException { + val processor = new ServingProcessor(); + String processorReturned = ""; + String path = request.getPathInfo(); + if (path.equals(SERVING_ENDPOINT)) { + val contentType = request.getContentType(); + /*Preconditions.checkArgument(StringUtils.equals(contentType, APPLICATION_JSON), + "Content type is " + contentType);*/ + if (validateRequest(request,response)) { + val stream = request.getInputStream(); + val bufferedReader = new BufferedReader(new InputStreamReader(stream)); + char[] charBuffer = new char[128]; + int bytesRead = -1; + val buffer = new StringBuilder(); + while ((bytesRead = bufferedReader.read(charBuffer)) > 0) { + buffer.append(charBuffer, 0, bytesRead); + } + val requestString = buffer.toString(); + + val mds = inferenceAdapter.apply(deserializer.deserialize(requestString)); + val map = new LinkedHashMap(); + + // optionally define placeholders with names provided in server constructor + if (orderedInputNodes != null && orderedInputNodes.length > 0) { + int cnt = 0; + for (val n : orderedInputNodes) + map.put(n, mds.getFeatures(cnt++)); + } + + val output = sdModel.exec(map, orderedOutputNodes); + val arrays = new INDArray[output.size()]; + + // now we need to get ordered output arrays, as specified in server constructor + int cnt = 0; + for (val n : orderedOutputNodes) + arrays[cnt++] = output.get(n); + + // process result + val result = inferenceAdapter.apply(arrays); + processorReturned = serializer.serialize(result); + } + } else { + // we return error otherwise + sendError(request.getRequestURI(), response); + } + try { + val out = response.getWriter(); + out.write(processorReturned); + } catch (IOException e) { + log.error(e.getMessage()); + } + } + + @Override + public String getServletInfo() { + return null; + } + + @Override + public void destroy() { + // + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ServingProcessor.java b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ServingProcessor.java new file mode 100644 index 000000000..6f77215bd --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ServingProcessor.java @@ -0,0 +1,14 @@ +package org.nd4j.remote.serving; + +public class ServingProcessor { + + public String listEndpoints() { + String retVal = "/v1/ \n/v1/serving/"; + return retVal; + } + + public String processModel(String body) { + String response = null; //"Not implemented"; + return response; + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java new file mode 100644 index 000000000..290d1c4a1 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java @@ -0,0 +1,271 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.junit.Test; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.remote.clients.JsonRemoteInference; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.helpers.House; +import org.nd4j.remote.helpers.HouseToPredictedPriceAdapter; +import org.nd4j.remote.helpers.PredictedPrice; +import org.nd4j.remote.clients.serde.impl.FloatArraySerde; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +import static org.junit.Assert.*; + +@Slf4j +public class SameDiffJsonModelServerTest { + + @Test + public void basicServingTest_1() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .sdModel(sd) + .port(18080) + .build(); + + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:18080/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house); + + val timeStart = System.currentTimeMillis(); + price = client.predict(house); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + + assertNotNull(price); + assertEquals((float) district + 1.0f, price.getPrice(), 1e-5); + + server.stop(); + } + + @Test + public void testDeserialization_1() { + String request = "{\"bedrooms\":3,\"area\":100,\"district\":2,\"bathrooms\":2}"; + val deserializer = new House.HouseDeserializer(); + val result = deserializer.deserialize(request); + assertEquals(2, result.getDistrict()); + assertEquals(100, result.getArea()); + assertEquals(2, result.getBathrooms()); + assertEquals(3, result.getBedrooms()); + + } + + @Test + public void testDeserialization_2() { + String request = "{\"price\":1}"; + val deserializer = new PredictedPrice.PredictedPriceDeserializer(); + val result = deserializer.deserialize(request); + assertEquals(1.0, result.getPrice(), 1e-4); + } + + @Test + public void testDeserialization_3() { + float[] data = {0.0f, 0.1f, 0.2f}; + val serialized = new FloatArraySerde().serialize(data); + val deserialized = new FloatArraySerde().deserialize(serialized); + assertArrayEquals(data, deserialized, 1e-5f); + } + + @Test(expected = NullPointerException.class) + public void negativeServingTest_1() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(null) + .sdModel(sd) + .port(18080) + .build(); + } + + @Test(expected = NullPointerException.class) + public void negativeServingTest_2() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .sdModel(sd) + .port(18080) + .build(); + + } + + @Test(expected = IOException.class) + public void negativeServingTest_3() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .sdModel(sd) + .port(18080) + .build(); + + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new JsonDeserializer() { + @Override + public PredictedPrice deserialize(String json) { + return null; + } + }) + .endpointAddress("http://localhost:18080/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house); + + server.stop(); + } + + @Test + public void asyncServingTest() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .sdModel(sd) + .port(18080) + .build(); + + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:18080/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + val timeStart = System.currentTimeMillis(); + Future price = client.predictAsync(house); + assertNotNull(price); + assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + + + server.stop(); + } + + @Test + public void negativeAsyncTest() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .sdModel(sd) + .port(18080) + .build(); + + server.start(); + + // Fake deserializer to test failure + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new JsonDeserializer() { + @Override + public PredictedPrice deserialize(String json) { + return null; + } + }) + .endpointAddress("http://localhost:18080/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + val timeStart = System.currentTimeMillis(); + try { + Future price = client.predictAsync(house); + assertNotNull(price); + assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + } catch (ExecutionException e) { + assertTrue(e.getMessage().contains("Deserialization failed")); + } + + server.stop(); + } + +} \ No newline at end of file diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java new file mode 100644 index 000000000..a26809efc --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java @@ -0,0 +1,116 @@ +package org.nd4j.remote; + +import lombok.val; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.HttpClientBuilder; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.adapters.InferenceAdapter; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +public class SameDiffServletTest { + + private SameDiffJsonModelServer server; + + @Before + public void setUp() throws Exception { + server = new SameDiffJsonModelServer.Builder() + .sdModel(SameDiff.create()) + .port(8080) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(String input) { + return null; + } + + @Override + public String apply(INDArray... nnOutput) { + return null; + } + }) + .outputSerializer(new JsonSerializer() { + @Override + public String serialize(String o) { + return ""; + } + }) + .inputDeserializer(new JsonDeserializer() { + @Override + public String deserialize(String json) { + return ""; + } + }) + .orderedOutputNodes(new String[]{"output"}) + .build(); + + server.start(); + //server.join(); + } + + @After + public void tearDown() throws Exception { + server.stop(); + } + + @Test + public void getEndpoints() throws IOException { + val request = new HttpGet( "http://localhost:8080/v1" ); + request.setHeader("Content-type", "application/json"); + + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + @Test + public void testContentTypeGet() throws IOException { + val request = new HttpGet( "http://localhost:8080/v1" ); + request.setHeader("Content-type", "text/plain"); + + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(415, response.getStatusLine().getStatusCode()); + } + + @Test + public void testContentTypePost() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving"); + request.setHeader("Content-type", "text/plain"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(415, response.getStatusLine().getStatusCode()); + } + + @Test + public void postForServing() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving"); + request.setHeader("Content-type", "application/json"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(500, response.getStatusLine().getStatusCode()); + } + + @Test + public void testNotFoundPost() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving/some"); + request.setHeader("Content-type", "application/json"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(404, response.getStatusLine().getStatusCode()); + } + + @Test + public void testNotFoundGet() throws Exception { + val requestGet = new HttpGet( "http://localhost:8080/v1/not_found" ); + requestGet.setHeader("Content-type", "application/json"); + + val responseGet = HttpClientBuilder.create().build().execute( requestGet ); + assertEquals(404, responseGet.getStatusLine().getStatusCode()); + } + +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/House.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/House.java new file mode 100644 index 000000000..e5089e585 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/House.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.helpers; + +import com.google.gson.Gson; +import lombok.*; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +@Data +@Builder +@AllArgsConstructor +@NoArgsConstructor +public class House { + private int district; + private int bedrooms; + private int bathrooms; + private int area; + + + public static class HouseSerializer implements JsonSerializer { + @Override + public String serialize(@NonNull House o) { + return new Gson().toJson(o); + } + } + + public static class HouseDeserializer implements JsonDeserializer { + @Override + public House deserialize(@NonNull String json) { + return new Gson().fromJson(json, House.class); + } + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/HouseToPredictedPriceAdapter.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/HouseToPredictedPriceAdapter.java new file mode 100644 index 000000000..fe07623d3 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/HouseToPredictedPriceAdapter.java @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.helpers; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.adapters.InferenceAdapter; + +@Slf4j +public class HouseToPredictedPriceAdapter implements InferenceAdapter { + + @Override + public MultiDataSet apply(@NonNull House input) { + // we just create vector array with shape[4] and assign it's value to the district value + return new MultiDataSet(Nd4j.create(DataType.FLOAT, 4).assign(input.getDistrict()), null); + } + + @Override + public PredictedPrice apply(INDArray... nnOutput) { + return new PredictedPrice(nnOutput[0].getFloat(0)); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/PredictedPrice.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/PredictedPrice.java new file mode 100644 index 000000000..41d2bd253 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/PredictedPrice.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.helpers; + +import com.google.gson.Gson; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class PredictedPrice { + private float price; + + public static class PredictedPriceSerializer implements JsonSerializer { + @Override + public String serialize(@NonNull PredictedPrice o) { + return new Gson().toJson(o); + } + } + + public static class PredictedPriceDeserializer implements JsonDeserializer { + @Override + public PredictedPrice deserialize(@NonNull String json) { + return new Gson().fromJson(json, PredictedPrice.class); + } + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java new file mode 100644 index 000000000..3aca94b8a --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java @@ -0,0 +1,90 @@ +package org.nd4j.remote.serde; + +import lombok.val; +import org.junit.Test; +import org.nd4j.remote.clients.serde.impl.*; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class BasicSerdeTests { + private final static DoubleArraySerde doubleArraySerde = new DoubleArraySerde(); + private final static FloatArraySerde floatArraySerde = new FloatArraySerde(); + private final static StringSerde stringSerde = new StringSerde(); + private final static IntegerSerde integerSerde = new IntegerSerde(); + private final static FloatSerde floatSerde = new FloatSerde(); + private final static DoubleSerde doubleSerde = new DoubleSerde(); + private final static BooleanSerde booleanSerde = new BooleanSerde(); + + @Test + public void testStringSerde_1() { + val jvmString = "String with { strange } elements"; + + val serialized = stringSerde.serialize(jvmString); + val deserialized = stringSerde.deserialize(serialized); + + assertEquals(jvmString, deserialized); + } + + @Test + public void testFloatArraySerDe_1() { + val jvmArray = new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + + val serialized = floatArraySerde.serialize(jvmArray); + val deserialized = floatArraySerde.deserialize(serialized); + + assertArrayEquals(jvmArray, deserialized, 1e-5f); + } + + @Test + public void testDoubleArraySerDe_1() { + val jvmArray = new double[] {1.0, 2.0, 3.0, 4.0, 5.0}; + + val serialized = doubleArraySerde.serialize(jvmArray); + val deserialized = doubleArraySerde.deserialize(serialized); + + assertArrayEquals(jvmArray, deserialized, 1e-5); + } + + @Test + public void testFloatSerde_1() { + val f = 119.f; + + val serialized = floatSerde.serialize(f); + val deserialized = floatSerde.deserialize(serialized); + + assertEquals(f, deserialized, 1e-5f); + } + + @Test + public void testDoubleSerde_1() { + val d = 119.; + + val serialized = doubleSerde.serialize(d); + val deserialized = doubleSerde.deserialize(serialized); + + assertEquals(d, deserialized, 1e-5); + } + + @Test + public void testIntegerSerde_1() { + val f = 119; + + val serialized = integerSerde.serialize(f); + val deserialized = integerSerde.deserialize(serialized); + + + assertEquals(f, deserialized.intValue()); + } + + @Test + public void testBooleanSerde_1() { + val f = true; + + val serialized = booleanSerde.serialize(f); + val deserialized = booleanSerde.deserialize(serialized); + + + assertEquals(f, deserialized); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml b/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml new file mode 100644 index 000000000..59b35644e --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml @@ -0,0 +1,48 @@ + + + + + + + + logs/application.log + + %date - [%level] - from %logger in %thread + %n%message%n%xException%n + + + + + + %logger{15} - %message%n%xException{5} + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/nd4j/nd4j-remote/pom.xml b/nd4j/nd4j-remote/pom.xml new file mode 100644 index 000000000..c6a0cdb86 --- /dev/null +++ b/nd4j/nd4j-remote/pom.xml @@ -0,0 +1,35 @@ + + + + 4.0.0 + pom + + + nd4j-json-client + nd4j-grpc-client + nd4j-json-server + + + + org.nd4j + nd4j + 1.0.0-SNAPSHOT + + + nd4j-remote + 1.0.0-SNAPSHOT + nd4j-remote + + + UTF-8 + 1.7 + 1.7 + + + + + testresources + + + diff --git a/nd4j/nd4j-serde/pom.xml b/nd4j/nd4j-serde/pom.xml index 3234912cb..d4fc4ff05 100644 --- a/nd4j/nd4j-serde/pom.xml +++ b/nd4j/nd4j-serde/pom.xml @@ -33,7 +33,6 @@ nd4j-camel-routes nd4j-gson nd4j-arrow - nd4j-grpc diff --git a/nd4j/pom.xml b/nd4j/pom.xml index 255859171..6c294d7e7 100644 --- a/nd4j/pom.xml +++ b/nd4j/pom.xml @@ -62,6 +62,7 @@ nd4j-parameter-server-parent nd4j-uberjar nd4j-tensorflow + nd4j-remote