diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java
index cf7d11838..e84488f9d 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java
@@ -16,6 +16,7 @@
package org.datavec.spark.transform.client;
+
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
import com.mashape.unirest.http.exceptions.UnirestException;
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
index 1e525fc3b..2cc01e288 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
@@ -51,11 +51,13 @@
datavec-spark-inference-model
${datavec.version}
+
org.datavec
datavec-spark_2.11
${project.version}
+
org.datavec
datavec-data-image
@@ -67,61 +69,73 @@
akka-cluster_2.11
${akka.version}
+
joda-time
joda-time
${jodatime.version}
+
org.apache.commons
commons-lang3
${commons-lang3.version}
+
org.hibernate
hibernate-validator
${hibernate.version}
+
org.scala-lang
scala-library
${scala.version}
+
org.scala-lang
scala-reflect
${scala.version}
+
org.yaml
snakeyaml
${snakeyaml.version}
+
com.fasterxml.jackson.core
jackson-core
${jackson.version}
+
com.fasterxml.jackson.core
jackson-databind
${jackson.version}
+
com.fasterxml.jackson.core
jackson-annotations
${jackson.version}
+
com.fasterxml.jackson.datatype
jackson-datatype-jdk8
${jackson.version}
+
com.fasterxml.jackson.datatype
jackson-datatype-jsr310
${jackson.version}
+
com.typesafe.play
play-java_2.11
@@ -137,39 +151,44 @@
+
net.jodah
typetools
${jodah.typetools.version}
+
com.typesafe.play
play-json_2.11
${play.version}
+
com.typesafe.play
play-server_2.11
${play.version}
+
com.typesafe.play
play_2.11
${play.version}
+
com.typesafe.play
play-netty-server_2.11
${play.version}
-
com.mashape.unirest
unirest-java
${unirest.version}
test
+
com.beust
jcommander
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java
index d3d4aee30..e524f00a7 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java
@@ -52,6 +52,7 @@ public class CSVSparkTransformServerNoJsonTest {
public static void before() throws Exception {
server = new CSVSparkTransformServer();
FileUtils.write(fileSave, transformProcess.toJson());
+
// Only one time
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
@@ -73,6 +74,7 @@ public class CSVSparkTransformServerNoJsonTest {
}
}
});
+
server.runMain(new String[] {"-dp", "9050"});
}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java
index 78021d3e5..a9e24d78e 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java
@@ -16,6 +16,7 @@
package org.datavec.spark.transform;
+
import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
@@ -49,6 +50,7 @@ public class CSVSparkTransformServerTest {
server = new CSVSparkTransformServer();
FileUtils.write(fileSave, transformProcess.toJson());
// Only one time
+
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper();
@@ -69,6 +71,7 @@ public class CSVSparkTransformServerTest {
}
}
});
+
server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"});
}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java
index 62dec47c1..bfae23358 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java
@@ -16,6 +16,7 @@
package org.datavec.spark.transform;
+
import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java
index a0561563b..e08d881ef 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java
@@ -16,6 +16,7 @@
package org.datavec.spark.transform;
+
import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java
index eaf967325..4c14f4c3d 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java
@@ -19,10 +19,10 @@ package org.deeplearning4j.nearestneighbor.client;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
import com.mashape.unirest.request.HttpRequest;
-import com.mashape.unirest.request.HttpRequestWithBody;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
+import lombok.val;
import org.deeplearning4j.nearestneighbor.model.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.serde.base64.Nd4jBase64;
@@ -51,6 +51,7 @@ public class NearestNeighborsClient {
static {
// Only one time
+
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper();
@@ -89,7 +90,7 @@ public class NearestNeighborsClient {
NearestNeighborRequest request = new NearestNeighborRequest();
request.setInputIndex(index);
request.setK(k);
- HttpRequestWithBody req = Unirest.post(url + "/knn");
+ val req = Unirest.post(url + "/knn");
req.header("accept", "application/json")
.header("Content-Type", "application/json").body(request);
addAuthHeader(req);
@@ -112,7 +113,7 @@ public class NearestNeighborsClient {
Base64NDArrayBody base64NDArrayBody =
Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build();
- HttpRequestWithBody req = Unirest.post(url + "/knnnew");
+ val req = Unirest.post(url + "/knnnew");
req.header("accept", "application/json")
.header("Content-Type", "application/json").body(base64NDArrayBody);
addAuthHeader(req);
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java
index 9fcd76981..5313edd22 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java
@@ -19,7 +19,6 @@ package org.deeplearning4j.models.word2vec;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
-import jdk.nashorn.internal.objects.annotations.Property;
import lombok.Getter;
import lombok.NonNull;
import org.apache.commons.compress.compressors.gzip.GzipUtils;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java
index b1e53e90a..74ffd3fe8 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java
@@ -17,7 +17,7 @@
package org.deeplearning4j.nn.adapters;
import lombok.val;
-import org.deeplearning4j.nn.api.OutputAdapter;
+import org.nd4j.adapters.OutputAdapter;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java
index b67767050..f0c83669b 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java
@@ -18,7 +18,7 @@ package org.deeplearning4j.nn.adapters;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
-import org.deeplearning4j.nn.api.OutputAdapter;
+import org.nd4j.adapters.OutputAdapter;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java
index ae7e4b5c1..6b99a92d4 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.api;
+import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java
index cf3fec70c..99f8aeff0 100755
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java
@@ -24,6 +24,7 @@ import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer;
+import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.exception.DL4JException;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java
index 719a727f2..731ca398b 100755
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java
@@ -25,6 +25,7 @@ import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer;
+import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.RegressionEvaluation;
diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml
new file mode 100644
index 000000000..e2d5f901f
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml
@@ -0,0 +1,110 @@
+
+
+
+ 4.0.0
+ jar
+
+
+ org.deeplearning4j
+ deeplearning4j-remote
+ 1.0.0-SNAPSHOT
+
+
+ deeplearning4j-json-server
+ 1.0.0-SNAPSHOT
+ deeplearning4j-json-server
+
+
+
+ junit
+ junit
+ ${junit.version}
+ test
+
+
+
+ org.projectlombok
+ lombok
+ ${lombok.version}
+ provided
+
+
+
+ org.nd4j
+ nd4j-api
+ ${project.version}
+
+
+
+ org.nd4j
+ nd4j-json-client
+ ${project.version}
+
+
+
+ org.nd4j
+ nd4j-json-server
+ ${project.version}
+
+
+
+ org.deeplearning4j
+ deeplearning4j-parallel-wrapper
+ ${project.version}
+
+
+
+ org.slf4j
+ slf4j-api
+ ${slf4j.version}
+
+
+
+ ch.qos.logback
+ logback-core
+ ${logback.version}
+ test
+
+
+
+ ch.qos.logback
+ logback-classic
+ ${logback.version}
+ test
+
+
+
+
+
+
+ test-nd4j-native
+
+ true
+
+
+
+ org.nd4j
+ nd4j-native
+ ${project.version}
+ test
+
+
+
+
+
+ test-nd4j-cuda-10.1
+
+ false
+
+
+
+ org.nd4j
+ nd4j-cuda-10.1
+ ${project.version}
+ test
+
+
+
+
+
diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java
new file mode 100644
index 000000000..796a51ad0
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java
@@ -0,0 +1,205 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.remote;
+
+import lombok.*;
+import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.nn.api.Layer;
+import org.deeplearning4j.nn.api.Model;
+import org.deeplearning4j.nn.api.NeuralNetwork;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.parallelism.ParallelInference;
+import org.nd4j.adapters.InferenceAdapter;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.remote.clients.serde.JsonDeserializer;
+import org.nd4j.remote.clients.serde.JsonSerializer;
+import org.nd4j.adapters.InferenceAdapter;
+import org.nd4j.remote.serving.SameDiffServlet;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+
+/**
+ *
+ * @author astoyakin
+ */
+@Slf4j
+@NoArgsConstructor
+public class DL4jServlet extends SameDiffServlet {
+
+ protected ParallelInference parallelInference;
+ protected Model model;
+ protected boolean parallelEnabled = true;
+
+ public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter inferenceAdapter,
+ @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer) {
+ super(inferenceAdapter, serializer, deserializer);
+ this.parallelInference = parallelInference;
+ this.model = null;
+ this.parallelEnabled = true;
+ }
+
+ public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter inferenceAdapter,
+ @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer) {
+ super(inferenceAdapter, serializer, deserializer);
+ this.model = model;
+ this.parallelInference = null;
+ this.parallelEnabled = false;
+ }
+
+ @Override
+ protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
+ String processorReturned = "";
+ String path = request.getPathInfo();
+ if (path.equals(SERVING_ENDPOINT)) {
+ val contentType = request.getContentType();
+ if (validateRequest(request,response)) {
+ val stream = request.getInputStream();
+ val bufferedReader = new BufferedReader(new InputStreamReader(stream));
+ char[] charBuffer = new char[128];
+ int bytesRead = -1;
+ val buffer = new StringBuilder();
+ while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
+ buffer.append(charBuffer, 0, bytesRead);
+ }
+ val requestString = buffer.toString();
+
+ val mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
+
+ O result = null;
+ if (parallelEnabled) {
+ // process result
+ result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays()));
+ }
+ else {
+ synchronized(this) {
+ if (model instanceof ComputationGraph)
+ result = inferenceAdapter.apply(((ComputationGraph)model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
+ else if (model instanceof MultiLayerNetwork) {
+ Preconditions.checkArgument(mds.getFeatures().length > 1 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 1),
+ "Input data for MultilayerNetwork is invalid!");
+ result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
+ mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
+ }
+ }
+ }
+ processorReturned = serializer.serialize(result);
+ }
+ } else {
+ // we return error otherwise
+ sendError(request.getRequestURI(), response);
+ }
+ try {
+ val out = response.getWriter();
+ out.write(processorReturned);
+ } catch (IOException e) {
+ log.error(e.getMessage());
+ }
+ }
+
+ /**
+ * Creates servlet to serve models
+ *
+ * @param type of Input class
+ * @param type of Output class
+ *
+ * @author raver119@gmail.com
+ * @author astoyakin
+ */
+ public static class Builder {
+
+ private ParallelInference pi;
+ private Model model;
+
+ private InferenceAdapter inferenceAdapter;
+ private JsonSerializer serializer;
+ private JsonDeserializer deserializer;
+ private int port;
+ private boolean parallelEnabled = true;
+
+ public Builder(@NonNull ParallelInference pi) {
+ this.pi = pi;
+ }
+
+ public Builder(@NonNull Model model) {
+ this.model = model;
+ }
+
+ public Builder inferenceAdapter(@NonNull InferenceAdapter inferenceAdapter) {
+ this.inferenceAdapter = inferenceAdapter;
+ return this;
+ }
+
+ /**
+ * This method is required to specify serializer
+ *
+ * @param serializer
+ * @return
+ */
+ public Builder serializer(@NonNull JsonSerializer serializer) {
+ this.serializer = serializer;
+ return this;
+ }
+
+ /**
+ * This method allows to specify deserializer
+ *
+ * @param deserializer
+ * @return
+ */
+ public Builder deserializer(@NonNull JsonDeserializer deserializer) {
+ this.deserializer = deserializer;
+ return this;
+ }
+
+ /**
+ * This method allows to specify port
+ *
+ * @param port
+ * @return
+ */
+ public Builder port(int port) {
+ this.port = port;
+ return this;
+ }
+
+ /**
+ * This method activates parallel inference
+ *
+ * @param parallelEnabled
+ * @return
+ */
+ public Builder parallelEnabled(boolean parallelEnabled) {
+ this.parallelEnabled = parallelEnabled;
+ return this;
+ }
+
+ public DL4jServlet build() {
+ return parallelEnabled ? new DL4jServlet(pi, inferenceAdapter, serializer, deserializer) :
+ new DL4jServlet(model, inferenceAdapter, serializer, deserializer);
+ }
+ }
+}
+
+
+
+
diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java
new file mode 100644
index 000000000..030231f6e
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java
@@ -0,0 +1,392 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.remote;
+
+import lombok.NonNull;
+import lombok.val;
+import org.deeplearning4j.nn.api.Model;
+import org.deeplearning4j.nn.api.ModelAdapter;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.parallelism.ParallelInference;
+import org.deeplearning4j.parallelism.inference.InferenceMode;
+import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
+import org.nd4j.adapters.InferenceAdapter;
+import org.nd4j.adapters.InputAdapter;
+import org.nd4j.adapters.OutputAdapter;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.MultiDataSet;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.remote.SameDiffJsonModelServer;
+import org.nd4j.remote.clients.serde.JsonDeserializer;
+import org.nd4j.remote.clients.serde.JsonSerializer;
+
+
+import java.util.List;
+
+/**
+ * This class provides JSON-based model serving ability for Deeplearning4j/SameDiff models
+ *
+ * Server url will be http://0.0.0.0:{port}>/v1/serving
+ * Server only accepts POST requests
+ *
+ * @param type of the input class, i.e. String
+ * @param type of the output class, i.e. Sentiment
+ *
+ * @author raver119@gmail.com
+ * @author astoyakin
+ */
+public class JsonModelServer extends SameDiffJsonModelServer {
+
+ // all serving goes through ParallelInference
+ protected ParallelInference parallelInference;
+
+
+ protected ModelAdapter modelAdapter;
+
+ // actual models
+ protected ComputationGraph cgModel;
+ protected MultiLayerNetwork mlnModel;
+
+ // service stuff
+ protected InferenceMode inferenceMode;
+ protected int numWorkers;
+
+ protected boolean enabledParallel = true;
+
+ protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port, String[] orderedInputNodes, String[] orderedOutputNodes) {
+ super(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
+ }
+
+ protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
+ super(inferenceAdapter, serializer, deserializer, port);
+
+ this.cgModel = cgModel;
+ this.inferenceMode = inferenceMode;
+ this.numWorkers = numWorkers;
+ }
+
+ protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
+ super(inferenceAdapter, serializer, deserializer, port);
+
+ this.mlnModel = mlnModel;
+ this.inferenceMode = inferenceMode;
+ this.numWorkers = numWorkers;
+ }
+
+ protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port) {
+ super(inferenceAdapter, serializer, deserializer, port);
+
+ this.parallelInference = pi;
+ }
+
+ /**
+ * This method stops server
+ *
+ * @throws Exception
+ */
+ @Override
+ public void stop() throws Exception {
+ if (parallelInference != null)
+ parallelInference.shutdown();
+ super.stop();
+ }
+
+ /**
+ * This method starts server
+ * @throws Exception
+ */
+ @Override
+ public void start() throws Exception {
+ // if we're just serving sdModel - we'll just call super. no dl4j functionality required in this case
+ if (sdModel != null) {
+ super.start();
+ return;
+ }
+ Preconditions.checkArgument(cgModel != null || mlnModel != null, "Model serving requires either MultilayerNetwork or ComputationGraph defined");
+
+ val model = cgModel != null ? (Model) cgModel : (Model) mlnModel;
+ // PI construction is optional, since we can have it defined
+ if (enabledParallel) {
+ if (parallelInference == null) {
+ Preconditions.checkArgument(numWorkers >= 1, "Number of workers should be >= 1, got " + numWorkers + " instead");
+
+ parallelInference = new ParallelInference.Builder(model)
+ .inferenceMode(inferenceMode)
+ .workers(numWorkers)
+ .loadBalanceMode(LoadBalanceMode.FIFO)
+ .batchLimit(16)
+ .queueLimit(128)
+ .build();
+ }
+ servingServlet = new DL4jServlet.Builder(parallelInference)
+ .parallelEnabled(true)
+ .serializer(serializer)
+ .deserializer(deserializer)
+ .inferenceAdapter(inferenceAdapter)
+ .build();
+ }
+ else {
+ servingServlet = new DL4jServlet.Builder(model)
+ .parallelEnabled(false)
+ .serializer(serializer)
+ .deserializer(deserializer)
+ .inferenceAdapter(inferenceAdapter)
+ .build();
+ }
+ start(port, servingServlet);
+ }
+
+ /**
+ * Creates servlet to serve different types of models
+ *
+ * @param type of Input class
+ * @param type of Output class
+ *
+ * @author raver119@gmail.com
+ * @author astoyakin
+ */
+ public static class Builder {
+
+ private SameDiff sdModel;
+ private ComputationGraph cgModel;
+ private MultiLayerNetwork mlnModel;
+ private ParallelInference pi;
+
+ private String[] orderedInputNodes;
+ private String[] orderedOutputNodes;
+
+ private InferenceAdapter inferenceAdapter;
+ private JsonSerializer serializer;
+ private JsonDeserializer deserializer;
+
+ private InputAdapter inputAdapter;
+ private OutputAdapter outputAdapter;
+
+ private int port;
+
+ private boolean parallelMode = true;
+
+ // these fields actually require defaults
+ private InferenceMode inferenceMode = InferenceMode.BATCHED;
+ private int numWorkers = Nd4j.getAffinityManager().getNumberOfDevices();
+
+ public Builder(@NonNull SameDiff sdModel) {
+ this.sdModel = sdModel;
+ }
+
+ public Builder(@NonNull MultiLayerNetwork mlnModel) {
+ this.mlnModel = mlnModel;
+ }
+
+ public Builder(@NonNull ComputationGraph cgModel) {
+ this.cgModel = cgModel;
+ }
+
+ public Builder(@NonNull ParallelInference pi) {
+ this.pi = pi;
+ }
+
+ /**
+ * This method defines InferenceAdapter implementation, which will be used to convert object of Input type to the set of INDArray(s), and for conversion of resulting INDArray(s) into object of Output type
+ * @param inferenceAdapter
+ * @return
+ */
+ public Builder inferenceAdapter(@NonNull InferenceAdapter inferenceAdapter) {
+ this.inferenceAdapter = inferenceAdapter;
+ return this;
+ }
+
+ /**
+ * This method allows you to specify InputAdapter to be used for inference
+ *
+ * PLEASE NOTE: This method is optional, and will require OutputAdapter defined
+ * @param inputAdapter
+ * @return
+ */
+ public Builder inputAdapter(@NonNull InputAdapter inputAdapter) {
+ this.inputAdapter = inputAdapter;
+ return this;
+ }
+
+ /**
+ * This method allows you to specify OutputtAdapter to be used for inference
+ *
+ * PLEASE NOTE: This method is optional, and will require InputAdapter defined
+ * @param outputAdapter
+ * @return
+ */
+ public Builder outputAdapter(@NonNull OutputAdapter outputAdapter) {
+ this.outputAdapter = outputAdapter;
+ return this;
+ }
+
+ /**
+ * This method allows you to specify serializer
+ *
+ * @param serializer
+ * @return
+ */
+ public Builder outputSerializer(@NonNull JsonSerializer serializer) {
+ this.serializer = serializer;
+ return this;
+ }
+
+ /**
+ * This method allows you to specify deserializer
+ *
+ * @param deserializer
+ * @return
+ */
+ public Builder inputDeserializer(@NonNull JsonDeserializer deserializer) {
+ this.deserializer = deserializer;
+ return this;
+ }
+
+ /**
+ * This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details
+ *
+ * @param inferenceMode
+ * @return
+ */
+ public Builder inferenceMode(@NonNull InferenceMode inferenceMode) {
+ this.inferenceMode = inferenceMode;
+ return this;
+ }
+
+ /**
+ * This method allows you to specify number of worker threads for ParallelInference
+ *
+ * @param numWorkers
+ * @return
+ */
+ public Builder numWorkers(int numWorkers) {
+ this.numWorkers = numWorkers;
+ return this;
+ }
+
+ /**
+ * This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models
+ *
+ * PLEASE NOTE: this argument only used for SameDiff models
+ * @param args
+ * @return
+ */
+ public Builder orderedInputNodes(String... args) {
+ orderedInputNodes = args;
+ return this;
+ }
+
+ /**
+ * This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models
+ *
+ * PLEASE NOTE: this argument only used for SameDiff models
+ * @param args
+ * @return
+ */
+ public Builder orderedInputNodes(@NonNull List args) {
+ orderedInputNodes = args.toArray(new String[args.size()]);
+ return this;
+ }
+
+ /**
+ * This method allows you to specify output nodes
+ *
+ * PLEASE NOTE: this argument only used for SameDiff models
+ * @param args
+ * @return
+ */
+ public Builder orderedOutputNodes(String... args) {
+ Preconditions.checkArgument(args != null && args.length > 0, "OutputNodes should contain at least 1 element");
+ orderedOutputNodes = args;
+ return this;
+ }
+
+ /**
+ * This method allows you to specify output nodes
+ *
+ * PLEASE NOTE: this argument only used for SameDiff models
+ * @param args
+ * @return
+ */
+ public Builder orderedOutputNodes(@NonNull List args) {
+ Preconditions.checkArgument(args.size() > 0, "OutputNodes should contain at least 1 element");
+ orderedOutputNodes = args.toArray(new String[args.size()]);
+ return this;
+ }
+
+ /**
+ * This method allows you to specify http port
+ *
+ * PLEASE NOTE: port must be free and be in range regular TCP/IP ports range
+ * @param port
+ * @return
+ */
+ public Builder port(int port) {
+ this.port = port;
+ return this;
+ }
+
+ /**
+ * This method switches on ParallelInference usage
+ * @param - true - to use ParallelInference, false - to use ComputationGraph or
+ * MultiLayerNetwork directly
+ *
+ * PLEASE NOTE: this doesn't apply to SameDiff models
+ *
+ * @throws Exception
+ */
+ public Builder parallelMode(boolean enable) {
+ this.parallelMode = enable;
+ return this;
+ }
+
+ public JsonModelServer build() {
+ if (inferenceAdapter == null) {
+ if (inputAdapter != null && outputAdapter != null) {
+ inferenceAdapter = new InferenceAdapter() {
+ @Override
+ public MultiDataSet apply(I input) {
+ return inputAdapter.apply(input);
+ }
+
+ @Override
+ public O apply(INDArray... outputs) {
+ return outputAdapter.apply(outputs);
+ }
+ };
+ } else
+ throw new IllegalArgumentException("Either InferenceAdapter or InputAdapter + OutputAdapter should be configured");
+ }
+
+ if (sdModel != null) {
+ Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "For SameDiff model serving OutputNodes should be defined");
+ return new JsonModelServer(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
+ } else if (cgModel != null)
+ return new JsonModelServer(cgModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
+ else if (mlnModel != null)
+ return new JsonModelServer(mlnModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
+ else if (pi != null)
+ return new JsonModelServer(pi, inferenceAdapter, serializer, deserializer, port);
+ else
+ throw new IllegalStateException("No models were defined for JsonModelServer");
+ }
+ }
+
+}
diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java
new file mode 100644
index 000000000..aa353f307
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java
@@ -0,0 +1,749 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.remote;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import lombok.val;
+import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.graph.MergeVertex;
+import org.deeplearning4j.nn.conf.layers.*;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.parallelism.inference.InferenceMode;
+import org.deeplearning4j.remote.helpers.House;
+import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter;
+import org.deeplearning4j.remote.helpers.PredictedPrice;
+import org.junit.After;
+import org.junit.Test;
+import org.nd4j.adapters.InferenceAdapter;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.MultiDataSet;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.learning.config.Adam;
+import org.nd4j.linalg.learning.config.Sgd;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+import org.nd4j.remote.clients.JsonRemoteInference;
+import org.nd4j.remote.clients.serde.JsonDeserializer;
+import org.nd4j.remote.clients.serde.JsonSerializer;
+import org.nd4j.shade.jackson.databind.ObjectMapper;
+
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE;
+import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
+import static org.junit.Assert.*;
+
+@Slf4j
+public class JsonModelServerTest {
+ private static final MultiLayerNetwork model;
+ private final int PORT = 18080;
+
+ static {
+ val conf = new NeuralNetConfiguration.Builder()
+ .seed(119)
+ .updater(new Adam(0.119f))
+ .weightInit(WeightInit.XAVIER)
+ .list()
+ .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(10).build())
+ .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.SIGMOID).nIn(10).nOut(1).build())
+ .build();
+
+ model = new MultiLayerNetwork(conf);
+ model.init();
+ }
+
+ @After
+ public void pause() throws Exception {
+ // TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
+ TimeUnit.SECONDS.sleep(2);
+ }
+
+
+ @Test
+ public void testStartStopParallel() throws Exception {
+ val sd = SameDiff.create();
+ val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
+ val result = sdVariable.add(1.0);
+ val total = result.mean("total", Integer.MAX_VALUE);
+
+ val serverDL = new JsonModelServer.Builder(model)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(new House.HouseDeserializer())
+ .numWorkers(1)
+ .inferenceMode(SEQUENTIAL)
+ .inferenceAdapter(new HouseToPredictedPriceAdapter())
+ .port(PORT)
+ .build();
+
+ val serverSD = new JsonModelServer.Builder(sd)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(new House.HouseDeserializer())
+ .inferenceAdapter(new HouseToPredictedPriceAdapter())
+ .orderedInputNodes(new String[]{"input"})
+ .orderedOutputNodes(new String[]{"total"})
+ .port(PORT+1)
+ .build();
+ try {
+ serverDL.start();
+ serverSD.start();
+
+ val clientDL = JsonRemoteInference.builder()
+ .inputSerializer(new House.HouseSerializer())
+ .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
+ .endpointAddress("http://localhost:" + PORT + "/v1/serving")
+ .build();
+
+ int district = 2;
+ House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
+ PredictedPrice price = clientDL.predict(house);
+ long timeStart = System.currentTimeMillis();
+ price = clientDL.predict(house);
+ long timeStop = System.currentTimeMillis();
+ log.info("Time spent: {} ms", timeStop - timeStart);
+ assertNotNull(price);
+ assertEquals((float) 0.421444, price.getPrice(), 1e-5);
+
+ val clientSD = JsonRemoteInference.builder()
+ .inputSerializer(new House.HouseSerializer())
+ .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
+ .endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving")
+ .build();
+
+ PredictedPrice price2 = clientSD.predict(house);
+ timeStart = System.currentTimeMillis();
+ price = clientSD.predict(house);
+ timeStop = System.currentTimeMillis();
+ log.info("Time spent: {} ms", timeStop - timeStart);
+ assertNotNull(price);
+ assertEquals((float) 3.0, price.getPrice(), 1e-5);
+
+ }
+ finally {
+ serverSD.stop();
+ serverDL.stop();
+ }
+ }
+
+ @Test
+ public void testStartStopSequential() throws Exception {
+ val sd = SameDiff.create();
+ val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
+ val result = sdVariable.add(1.0);
+ val total = result.mean("total", Integer.MAX_VALUE);
+
+ val serverDL = new JsonModelServer.Builder(model)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(new House.HouseDeserializer())
+ .numWorkers(1)
+ .inferenceMode(SEQUENTIAL)
+ .inferenceAdapter(new HouseToPredictedPriceAdapter())
+ .port(PORT)
+ .build();
+
+ val serverSD = new JsonModelServer.Builder(sd)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(new House.HouseDeserializer())
+ .inferenceAdapter(new HouseToPredictedPriceAdapter())
+ .orderedInputNodes(new String[]{"input"})
+ .orderedOutputNodes(new String[]{"total"})
+ .port(PORT+1)
+ .build();
+
+ serverDL.start();
+ serverDL.stop();
+
+ serverSD.start();
+ serverSD.stop();
+ }
+
+ @Test
+ public void basicServingTestForSD() throws Exception {
+ val sd = SameDiff.create();
+ val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
+ val result = sdVariable.add(1.0);
+ val total = result.mean("total", Integer.MAX_VALUE);
+
+ val server = new JsonModelServer.Builder(sd)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(new House.HouseDeserializer())
+ .inferenceAdapter(new HouseToPredictedPriceAdapter())
+ .orderedInputNodes(new String[]{"input"})
+ .orderedOutputNodes(new String[]{"total"})
+ .port(PORT)
+ .build();
+
+ try {
+ server.start();
+
+ val client = JsonRemoteInference.builder()
+ .inputSerializer(new House.HouseSerializer())
+ .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
+ .endpointAddress("http://localhost:" + PORT + "/v1/serving")
+ .build();
+
+ int district = 2;
+ House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
+
+ // warmup
+ PredictedPrice price = client.predict(house);
+
+ val timeStart = System.currentTimeMillis();
+ price = client.predict(house);
+ val timeStop = System.currentTimeMillis();
+
+ log.info("Time spent: {} ms", timeStop - timeStart);
+
+ assertNotNull(price);
+ assertEquals((float) district + 1.0f, price.getPrice(), 1e-5);
+ }
+ finally {
+ server.stop();
+ }
+ }
+
+ @Test
+ public void basicServingTestForDLSynchronized() throws Exception {
+ val server = new JsonModelServer.Builder(model)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(new House.HouseDeserializer())
+ .numWorkers(1)
+ .inferenceMode(INPLACE)
+ .inferenceAdapter(new HouseToPredictedPriceAdapter())
+ .port(PORT)
+ .build();
+
+ try {
+ server.start();
+
+ val client = JsonRemoteInference.builder()
+ .inputSerializer(new House.HouseSerializer())
+ .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
+ .endpointAddress("http://localhost:" + PORT + "/v1/serving")
+ .build();
+
+ int district = 2;
+ House house1 = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
+ House house2 = House.builder().area(50).bathrooms(1).bedrooms(2).district(district).build();
+ House house3 = House.builder().area(80).bathrooms(1).bedrooms(3).district(district).build();
+
+ // warmup
+ PredictedPrice price = client.predict(house1);
+
+ val timeStart = System.currentTimeMillis();
+ PredictedPrice price1 = client.predict(house1);
+ PredictedPrice price2 = client.predict(house2);
+ PredictedPrice price3 = client.predict(house3);
+ val timeStop = System.currentTimeMillis();
+
+ log.info("Time spent: {} ms", timeStop - timeStart);
+
+ assertNotNull(price);
+ assertEquals((float) 0.421444, price.getPrice(), 1e-5);
+
+ } finally {
+ server.stop();
+ }
+ }
+
+ @Test
+ public void basicServingTestForDL() throws Exception {
+
+ val server = new JsonModelServer.Builder(model)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(new House.HouseDeserializer())
+ .numWorkers(1)
+ .inferenceMode(SEQUENTIAL)
+ .inferenceAdapter(new HouseToPredictedPriceAdapter())
+ .port(PORT)
+ .parallelMode(false)
+ .build();
+
+ try {
+ server.start();
+
+ val client = JsonRemoteInference.builder()
+ .inputSerializer(new House.HouseSerializer())
+ .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
+ .endpointAddress("http://localhost:" + PORT + "/v1/serving")
+ .build();
+
+ int district = 2;
+ House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
+
+ // warmup
+ PredictedPrice price = client.predict(house);
+
+ val timeStart = System.currentTimeMillis();
+ price = client.predict(house);
+ val timeStop = System.currentTimeMillis();
+
+ log.info("Time spent: {} ms", timeStop - timeStart);
+
+ assertNotNull(price);
+ assertEquals((float) 0.421444, price.getPrice(), 1e-5);
+
+ } finally {
+ server.stop();
+ }
+ }
+
+ @Test
+ public void testDeserialization_1() {
+ String request = "{\"bedrooms\":3,\"area\":100,\"district\":2,\"bathrooms\":2}";
+ val deserializer = new House.HouseDeserializer();
+ val result = deserializer.deserialize(request);
+ assertEquals(2, result.getDistrict());
+ assertEquals(100, result.getArea());
+ assertEquals(2, result.getBathrooms());
+ assertEquals(3, result.getBedrooms());
+
+ }
+
+ @Test
+ public void testDeserialization_2() {
+ String request = "{\"price\":1}";
+ val deserializer = new PredictedPrice.PredictedPriceDeserializer();
+ val result = deserializer.deserialize(request);
+ assertEquals(1.0, result.getPrice(), 1e-4);
+ }
+
+ @Test(expected = NullPointerException.class)
+ public void negativeServingTest_1() throws Exception {
+
+ val server = new JsonModelServer.Builder(model)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(null)
+ .port(18080)
+ .build();
+ }
+
+ @Test //(expected = NullPointerException.class)
+ public void negativeServingTest_2() throws Exception {
+
+ val server = new JsonModelServer.Builder(model)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(new House.HouseDeserializer())
+ .inferenceAdapter(new HouseToPredictedPriceAdapter())
+ .port(PORT)
+ .build();
+
+ }
+
+ @Test(expected = IOException.class)
+ public void negativeServingTest_3() throws Exception {
+
+ val server = new JsonModelServer.Builder(model)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(new House.HouseDeserializer())
+ .inferenceAdapter(new HouseToPredictedPriceAdapter())
+ .inferenceMode(SEQUENTIAL)
+ .numWorkers(1)
+ .port(PORT)
+ .build();
+
+ try {
+ server.start();
+
+ val client = JsonRemoteInference.builder()
+ .inputSerializer(new House.HouseSerializer())
+ .outputDeserializer(new JsonDeserializer() {
+ @Override
+ public PredictedPrice deserialize(String json) {
+ return null;
+ }
+ })
+ .endpointAddress("http://localhost:18080/v1/serving")
+ .build();
+
+ int district = 2;
+ House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
+
+ // warmup
+ PredictedPrice price = client.predict(house);
+ } finally {
+ server.stop();
+ }
+ }
+
+ @Test
+ public void asyncServingTest() throws Exception {
+
+ val server = new JsonModelServer.Builder(model)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(new House.HouseDeserializer())
+ .inferenceAdapter(new HouseToPredictedPriceAdapter())
+ .inferenceMode(SEQUENTIAL)
+ .numWorkers(1)
+ .port(PORT)
+ .build();
+
+ try {
+ server.start();
+
+ val client = JsonRemoteInference.builder()
+ .inputSerializer(new House.HouseSerializer())
+ .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
+ .endpointAddress("http://localhost:" + PORT + "/v1/serving")
+ .build();
+
+ int district = 2;
+ House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
+
+ val timeStart = System.currentTimeMillis();
+ Future price = client.predictAsync(house);
+ assertNotNull(price);
+ assertEquals((float) 0.421444, price.get().getPrice(), 1e-5);
+ val timeStop = System.currentTimeMillis();
+
+ log.info("Time spent: {} ms", timeStop - timeStart);
+ }
+ finally {
+ server.stop();
+ }
+ }
+
+ @Test
+ public void negativeAsyncTest() throws Exception {
+
+ val server = new JsonModelServer.Builder(model)
+ .outputSerializer(new PredictedPrice.PredictedPriceSerializer())
+ .inputDeserializer(new House.HouseDeserializer())
+ .inferenceAdapter(new HouseToPredictedPriceAdapter())
+ .inferenceMode(InferenceMode.BATCHED)
+ .numWorkers(1)
+ .port(PORT)
+ .build();
+
+ try {
+ server.start();
+
+ // Fake deserializer to test failure
+ val client = JsonRemoteInference.builder()
+ .inputSerializer(new House.HouseSerializer())
+ .outputDeserializer(new JsonDeserializer() {
+ @Override
+ public PredictedPrice deserialize(String json) {
+ return null;
+ }
+ })
+ .endpointAddress("http://localhost:" + PORT + "/v1/serving")
+ .build();
+
+ int district = 2;
+ House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
+
+ val timeStart = System.currentTimeMillis();
+ try {
+ Future price = client.predictAsync(house);
+ assertNotNull(price);
+ assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5);
+ val timeStop = System.currentTimeMillis();
+
+ log.info("Time spent: {} ms", timeStop - timeStart);
+ } catch (ExecutionException e) {
+ assertTrue(e.getMessage().contains("Deserialization failed"));
+ }
+ } finally {
+ server.stop();
+ }
+ }
+
+
+ @Test
+ public void testSameDiffMnist() throws Exception {
+
+ SameDiff sd = SameDiff.create();
+ SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28);
+ SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10));
+ SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10));
+ SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b));
+
+ val server = new JsonModelServer.Builder(sd)
+ .outputSerializer( new IntSerde())
+ .inputDeserializer(new FloatSerde())
+ .inferenceAdapter(new InferenceAdapter() {
+ @Override
+ public MultiDataSet apply(float[] input) {
+ return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
+ }
+
+ @Override
+ public Integer apply(INDArray... nnOutput) {
+ return nnOutput[0].argMax().getInt(0);
+ }
+ })
+ .orderedInputNodes("in")
+ .orderedOutputNodes("softmax")
+ .port(PORT+1)
+ .build();
+
+ val client = JsonRemoteInference.builder()
+ .endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving")
+ .outputDeserializer(new IntSerde())
+ .inputSerializer( new FloatSerde())
+ .build();
+
+ try{
+ server.start();
+ for( int i=0; i<10; i++ ){
+ INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28*28);
+ INDArray exp = sd.output(Collections.singletonMap("in", f), "softmax").get("softmax");
+ float[] fArr = f.toFloatVector();
+ int out = client.predict(fArr);
+ assertEquals(exp.argMax().getInt(0), out);
+ }
+ } finally {
+ server.stop();
+ }
+ }
+
+ @Test
+ public void testMlnMnist() throws Exception {
+
+ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+ .list()
+ .layer(new DenseLayer.Builder().nIn(784).nOut(10).build())
+ .layer(new LossLayer.Builder().activation(Activation.SOFTMAX).build())
+ .build();
+
+ MultiLayerNetwork net = new MultiLayerNetwork(conf);
+ net.init();
+
+ val server = new JsonModelServer.Builder(net)
+ .outputSerializer( new IntSerde())
+ .inputDeserializer(new FloatSerde())
+ .inferenceAdapter(new InferenceAdapter() {
+ @Override
+ public MultiDataSet apply(float[] input) {
+ return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
+ }
+
+ @Override
+ public Integer apply(INDArray... nnOutput) {
+ return nnOutput[0].argMax().getInt(0);
+ }
+ })
+ .orderedInputNodes("in")
+ .orderedOutputNodes("softmax")
+ .port(PORT + 1)
+ .inferenceMode(SEQUENTIAL)
+ .numWorkers(2)
+ .build();
+
+ val client = JsonRemoteInference.builder()
+ .endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
+ .outputDeserializer(new IntSerde())
+ .inputSerializer( new FloatSerde())
+ .build();
+
+ try {
+ server.start();
+ for (int i = 0; i < 10; i++) {
+ INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28 * 28);
+ INDArray exp = net.output(f);
+ float[] fArr = f.toFloatVector();
+ int out = client.predict(fArr);
+ assertEquals(exp.argMax().getInt(0), out);
+ }
+ } catch (Exception e){
+ e.printStackTrace();
+ throw e;
+ } finally {
+ server.stop();
+ }
+ }
+
+ @Test
+ public void testCompGraph() throws Exception {
+
+ ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
+ .graphBuilder()
+ .addInputs("input1", "input2")
+ .addLayer("L1", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input1")
+ .addLayer("L2", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input2")
+ .addVertex("merge", new MergeVertex(), "L1", "L2")
+ .addLayer("out", new OutputLayer.Builder().nIn(4+4).nOut(3).build(), "merge")
+ .setOutputs("out")
+ .build();
+
+ ComputationGraph net = new ComputationGraph(conf);
+ net.init();
+
+ val server = new JsonModelServer.Builder(net)
+ .outputSerializer( new IntSerde())
+ .inputDeserializer(new FloatSerde())
+ .inferenceAdapter(new InferenceAdapter() {
+ @Override
+ public MultiDataSet apply(float[] input) {
+ return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
+ }
+
+ @Override
+ public Integer apply(INDArray... nnOutput) {
+ return nnOutput[0].argMax().getInt(0);
+ }
+ })
+ .orderedInputNodes("in")
+ .orderedOutputNodes("softmax")
+ .port(PORT + 1)
+ .inferenceMode(SEQUENTIAL)
+ .numWorkers(2)
+ .parallelMode(false)
+ .build();
+
+ val client = JsonRemoteInference.builder()
+ .endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
+ .outputDeserializer(new IntSerde())
+ .inputSerializer( new FloatSerde())
+ .build();
+
+ try {
+ server.start();
+ //client.predict(new float[]{0.0f, 1.0f, 2.0f});
+ } catch (Exception e){
+ e.printStackTrace();
+ throw e;
+ } finally {
+ server.stop();
+ }
+ }
+
+ @Test
+ public void testCompGraph_1() throws Exception {
+
+ ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
+ .updater(new Sgd(0.01))
+ .graphBuilder()
+ .addInputs("input")
+ .addLayer("L1", new DenseLayer.Builder().nIn(8).nOut(4).build(), "input")
+ .addLayer("out1", new OutputLayer.Builder()
+ .lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
+ .nIn(4).nOut(3).build(), "L1")
+ .addLayer("out2", new OutputLayer.Builder()
+ .lossFunction(LossFunctions.LossFunction.MSE)
+ .nIn(4).nOut(2).build(), "L1")
+ .setOutputs("out1","out2")
+ .build();
+
+ final ComputationGraph net = new ComputationGraph(conf);
+ net.init();
+
+ val server = new JsonModelServer.Builder(net)
+ .outputSerializer( new IntSerde())
+ .inputDeserializer(new FloatSerde())
+ .inferenceAdapter(new InferenceAdapter() {
+ @Override
+ public MultiDataSet apply(float[] input) {
+ return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
+ }
+
+ @Override
+ public Integer apply(INDArray... nnOutput) {
+ return nnOutput[0].argMax().getInt(0);
+ }
+ })
+ .orderedInputNodes("input")
+ .orderedOutputNodes("out")
+ .port(PORT + 1)
+ .inferenceMode(SEQUENTIAL)
+ .numWorkers(2)
+ .parallelMode(false)
+ .build();
+
+ val client = JsonRemoteInference.builder()
+ .endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
+ .outputDeserializer(new IntSerde())
+ .inputSerializer( new FloatSerde())
+ .build();
+
+ try {
+ server.start();
+ val result = client.predict(new float[]{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
+ assertNotNull(result);
+ } catch (Exception e){
+ e.printStackTrace();
+ throw e;
+ } finally {
+ server.stop();
+ }
+ }
+
+ private static class FloatSerde implements JsonSerializer, JsonDeserializer{
+ private final ObjectMapper om = new ObjectMapper();
+
+ @Override
+ public float[] deserialize(String json) {
+ try {
+ return om.readValue(json, FloatHolder.class).getFloats();
+ } catch (IOException e){
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public String serialize(float[] o) {
+ try{
+ return om.writeValueAsString(new FloatHolder(o));
+ } catch (IOException e){
+ throw new RuntimeException(e);
+ }
+ }
+
+ //Use float holder so Jackson does ser/de properly (no "{}" otherwise)
+ @AllArgsConstructor @NoArgsConstructor @Data
+ private static class FloatHolder {
+ private float[] floats;
+ }
+ }
+
+ private static class IntSerde implements JsonSerializer, JsonDeserializer {
+ private final ObjectMapper om = new ObjectMapper();
+
+ @Override
+ public Integer deserialize(String json) {
+ try {
+ return om.readValue(json, Integer.class);
+ } catch (IOException e){
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public String serialize(Integer o) {
+ try{
+ return om.writeValueAsString(o);
+ } catch (IOException e){
+ throw new RuntimeException(e);
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java
new file mode 100644
index 000000000..1b347d112
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java
@@ -0,0 +1,133 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.remote;
+
+import lombok.val;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.impl.client.HttpClientBuilder;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.nd4j.adapters.InferenceAdapter;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.MultiDataSet;
+import org.nd4j.remote.clients.serde.JsonDeserializer;
+import org.nd4j.remote.clients.serde.JsonSerializer;
+
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+
+public class ServletTest {
+
+ private JsonModelServer server;
+
+ @Before
+ public void setUp() throws Exception {
+ val sd = SameDiff.create();
+ server = new JsonModelServer.Builder(sd)
+ .port(8080)
+ .inferenceAdapter(new InferenceAdapter() {
+ @Override
+ public MultiDataSet apply(String input) {
+ return null;
+ }
+
+ @Override
+ public String apply(INDArray... nnOutput) {
+ return null;
+ }
+ })
+ .outputSerializer(new JsonSerializer() {
+ @Override
+ public String serialize(String o) {
+ return "";
+ }
+ })
+ .inputDeserializer(new JsonDeserializer() {
+ @Override
+ public String deserialize(String json) {
+ return "";
+ }
+ })
+ .orderedInputNodes("input")
+ .orderedOutputNodes("output")
+ .build();
+
+ server.start();
+ //server.join();
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ server.stop();
+ }
+
+ @Test
+ public void getEndpoints() throws IOException {
+ val request = new HttpGet( "http://localhost:8080/v1" );
+ request.setHeader("Content-type", "application/json");
+
+ val response = HttpClientBuilder.create().build().execute( request );
+ assertEquals(200, response.getStatusLine().getStatusCode());
+ }
+
+ @Test
+ public void testContentTypeGet() throws IOException {
+ val request = new HttpGet( "http://localhost:8080/v1" );
+ request.setHeader("Content-type", "text/plain");
+
+ val response = HttpClientBuilder.create().build().execute( request );
+ assertEquals(415, response.getStatusLine().getStatusCode());
+ }
+
+ @Test
+ public void testContentTypePost() throws Exception {
+ val request = new HttpPost("http://localhost:8080/v1/serving");
+ request.setHeader("Content-type", "text/plain");
+ val response = HttpClientBuilder.create().build().execute( request );
+ assertEquals(415, response.getStatusLine().getStatusCode());
+ }
+
+ @Test
+ public void postForServing() throws Exception {
+ val request = new HttpPost("http://localhost:8080/v1/serving");
+ request.setHeader("Content-type", "application/json");
+ val response = HttpClientBuilder.create().build().execute( request );
+ assertEquals(500, response.getStatusLine().getStatusCode());
+ }
+
+ @Test
+ public void testNotFoundPost() throws Exception {
+ val request = new HttpPost("http://localhost:8080/v1/serving/some");
+ request.setHeader("Content-type", "application/json");
+ val response = HttpClientBuilder.create().build().execute( request );
+ assertEquals(404, response.getStatusLine().getStatusCode());
+ }
+
+ @Test
+ public void testNotFoundGet() throws Exception {
+ val requestGet = new HttpGet( "http://localhost:8080/v1/not_found" );
+ requestGet.setHeader("Content-type", "application/json");
+
+ val responseGet = HttpClientBuilder.create().build().execute( requestGet );
+ assertEquals(404, responseGet.getStatusLine().getStatusCode());
+ }
+
+}
diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java
new file mode 100644
index 000000000..d66c8bae5
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java
@@ -0,0 +1,48 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.remote.helpers;
+
+import com.google.gson.Gson;
+import lombok.*;
+import org.nd4j.remote.clients.serde.JsonDeserializer;
+import org.nd4j.remote.clients.serde.JsonSerializer;
+
+@Data
+@Builder
+@AllArgsConstructor
+@NoArgsConstructor
+public class House {
+ private int district;
+ private int bedrooms;
+ private int bathrooms;
+ private int area;
+
+
+ public static class HouseSerializer implements JsonSerializer {
+ @Override
+ public String serialize(@NonNull House o) {
+ return new Gson().toJson(o);
+ }
+ }
+
+ public static class HouseDeserializer implements JsonDeserializer {
+ @Override
+ public House deserialize(@NonNull String json) {
+ return new Gson().fromJson(json, House.class);
+ }
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java
new file mode 100644
index 000000000..82976a3da
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java
@@ -0,0 +1,40 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.remote.helpers;
+
+import lombok.NonNull;
+import lombok.extern.slf4j.Slf4j;
+import org.nd4j.adapters.InferenceAdapter;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.MultiDataSet;
+import org.nd4j.linalg.factory.Nd4j;
+
+@Slf4j
+public class HouseToPredictedPriceAdapter implements InferenceAdapter {
+
+ @Override
+ public MultiDataSet apply(@NonNull House input) {
+ // we just create vector array with shape[4] and assign it's value to the district value
+ return new MultiDataSet(Nd4j.create(DataType.FLOAT, 1, 4).assign(input.getDistrict()), null);
+ }
+
+ @Override
+ public PredictedPrice apply(INDArray... nnOutput) {
+ return new PredictedPrice(nnOutput[0].getFloat(0));
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java
new file mode 100644
index 000000000..c4024bb1d
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java
@@ -0,0 +1,47 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.remote.helpers;
+
+import com.google.gson.Gson;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.NonNull;
+import lombok.extern.slf4j.Slf4j;
+import org.nd4j.remote.clients.serde.JsonDeserializer;
+import org.nd4j.remote.clients.serde.JsonSerializer;
+
+@Data
+@AllArgsConstructor
+@NoArgsConstructor
+public class PredictedPrice {
+ private float price;
+
+ public static class PredictedPriceSerializer implements JsonSerializer {
+ @Override
+ public String serialize(@NonNull PredictedPrice o) {
+ return new Gson().toJson(o);
+ }
+ }
+
+ public static class PredictedPriceDeserializer implements JsonDeserializer {
+ @Override
+ public PredictedPrice deserialize(@NonNull String json) {
+ return new Gson().fromJson(json, PredictedPrice.class);
+ }
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml
new file mode 100644
index 000000000..59b35644e
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+ logs/application.log
+
+ %date - [%level] - from %logger in %thread
+ %n%message%n%xException%n
+
+
+
+
+
+ %logger{15} - %message%n%xException{5}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/deeplearning4j/deeplearning4j-remote/pom.xml b/deeplearning4j/deeplearning4j-remote/pom.xml
new file mode 100644
index 000000000..1fc937e10
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-remote/pom.xml
@@ -0,0 +1,30 @@
+
+
+
+ 4.0.0
+ pom
+
+
+ deeplearning4j-json-server
+
+
+
+ org.deeplearning4j
+ deeplearning4j
+ 1.0.0-SNAPSHOT
+
+
+ deeplearning4j-remote
+ 1.0.0-SNAPSHOT
+ deeplearning4j-remote
+
+
+
+ test-nd4j-native
+
+
+ test-nd4j-cuda-10.1
+
+
+
diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java
index 48966a57d..149a122f2 100644
--- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java
+++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java
@@ -21,7 +21,6 @@ import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
-import org.deeplearning4j.nn.api.OutputAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java
index 022b28076..0e2fd339a 100644
--- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java
+++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java
@@ -22,7 +22,6 @@ import lombok.val;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
-import org.deeplearning4j.nn.api.OutputAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml
index e13b62cda..b8b70befb 100644
--- a/deeplearning4j/pom.xml
+++ b/deeplearning4j/pom.xml
@@ -144,6 +144,7 @@
dl4j-perf
dl4j-integration-tests
deeplearning4j-common
+ deeplearning4j-remote
diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt
index 2e3c51091..8e940bedb 100755
--- a/libnd4j/blas/CMakeLists.txt
+++ b/libnd4j/blas/CMakeLists.txt
@@ -163,9 +163,9 @@ if(CUDA_BLAS)
if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 10
if ("${COMPUTE}" STREQUAL "all")
if (APPLE)
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
else()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
endif()
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
@@ -173,24 +173,24 @@ if(CUDA_BLAS)
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
if ("${COMPUTE}" STREQUAL "all")
if (APPLE)
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
else()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
endif()
else()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8.0
if ("${COMPUTE}" STREQUAL "all")
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
else()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
else()
if ("${COMPUTE}" STREQUAL "all")
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 )
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 )
else()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
endif()
@@ -205,34 +205,34 @@ if(CUDA_BLAS)
message("CUDA 10 Debug build")
if ("${COMPUTE}" STREQUAL "all")
if (APPLE)
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
elseif()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
endif()
else()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=compute_${COMPUTE})
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
if ("${COMPUTE}" STREQUAL "all")
if (APPLE)
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
elseif()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
endif()
else()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8
if ("${COMPUTE}" STREQUAL "all")
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
else()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
else()
if ("${COMPUTE}" STREQUAL "all")
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53)
else()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
endif()
endif()
@@ -249,7 +249,7 @@ if(CUDA_BLAS)
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
- file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
+ file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/impl/*.cpp ../include/loops/*.h)
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
if (NOT BUILD_TESTS)
diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp
index a0529d106..726549415 100644
--- a/libnd4j/blas/NDArray.hpp
+++ b/libnd4j/blas/NDArray.hpp
@@ -344,7 +344,7 @@ bool NDArray::isS() const {
//////////////////////////////////////////////////////////////////////////
bool NDArray::isR() const {
auto xType = ArrayOptions::dataType(this->_shapeInfo);
- return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8;
+ return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8 || xType == BFLOAT16;
}
//////////////////////////////////////////////////////////////////////////
diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h
index 87555a303..9ce90176f 100755
--- a/libnd4j/blas/NativeOps.h
+++ b/libnd4j/blas/NativeOps.h
@@ -1769,6 +1769,17 @@ ND4J_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr);
ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut);
ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut);
+typedef nd4j::LaunchContext OpaqueLaunchContext;
+
+ND4J_EXPORT OpaqueLaunchContext* defaultLaunchContext();
+ND4J_EXPORT Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc);
+ND4J_EXPORT Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc);
+ND4J_EXPORT Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc);
+ND4J_EXPORT Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc);
+ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc);
+ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc);
+ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc);
+
}
#endif //NATIVEOPERATIONS_NATIVEOPS_H
diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp
index 74bd072c8..f5d4996e4 100644
--- a/libnd4j/blas/cpu/NativeOps.cpp
+++ b/libnd4j/blas/cpu/NativeOps.cpp
@@ -2985,6 +2985,38 @@ const char* runFullBenchmarkSuit(bool printOut) {
return chars;
}
+nd4j::LaunchContext* defaultLaunchContext() {
+ return LaunchContext::defaultContext();
+}
+
+Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) {
+ return nullptr;
+}
+
+Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) {
+ return nullptr;
+}
+
+Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) {
+ return nullptr;
+}
+
+Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) {
+ return nullptr;
+}
+
+Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) {
+ return nullptr;
+}
+
+Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) {
+ return nullptr;
+}
+
+Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) {
+ return nullptr;
+}
+
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
diff --git a/libnd4j/blas/cuda/NDArray.cu b/libnd4j/blas/cuda/NDArray.cu
index 67173c971..126837ad9 100644
--- a/libnd4j/blas/cuda/NDArray.cu
+++ b/libnd4j/blas/cuda/NDArray.cu
@@ -356,7 +356,7 @@ void NDArray::tile(const std::vector& reps, NDArray& target) const {
auto stream = getContext()->getCudaStream();
prepareSpecialUse({&target}, {this});
- BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES, LIBND4J_TYPES);
+ BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
registerSpecialUse({&target}, {this});
}
@@ -375,7 +375,7 @@ void NDArray::tile(NDArray& target) const {
auto stream = getContext()->getCudaStream();
prepareSpecialUse({&target}, {this});
- BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES, LIBND4J_TYPES);
+ BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
registerSpecialUse({&target}, {this});
}
@@ -434,7 +434,7 @@ void NDArray::repeat(int dimension, NDArray& target) const {
NDArray::prepareSpecialUse({&target}, {this});
auto stream = getContext()->getCudaStream();
- BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), repeatKernelHH, (getSpecialBuffer(), target.getSpecialBuffer(), numTads, lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES, LIBND4J_TYPES);
+ BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), repeatKernelHH, (getSpecialBuffer(), target.getSpecialBuffer(), numTads, lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES);
NDArray::registerSpecialUse({&target}, {this});
}
diff --git a/libnd4j/blas/cuda/NDArrayLambda.hpp b/libnd4j/blas/cuda/NDArrayLambda.hpp
index f7846e121..bf9848981 100644
--- a/libnd4j/blas/cuda/NDArrayLambda.hpp
+++ b/libnd4j/blas/cuda/NDArrayLambda.hpp
@@ -23,6 +23,14 @@
#include
#include
+static Nd4jLong __device__ __noinline__ __getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) {
+ return shape::getIndexOffset(index, shapeInfo, length);
+}
+
+static Nd4jLong __device__ __noinline__ __length(Nd4jLong *shapeInfo) {
+ return shape::length(shapeInfo);
+}
+
template static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
template static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
template static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
@@ -86,7 +94,7 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
auto xOrder = shape::order(xShapeInfo);
auto zOrder = shape::order(zShapeInfo);
- auto zLength = shape::length(zShapeInfo);
+ auto zLength = __length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@@ -95,8 +103,8 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
z[e * zEws] = lambda(x[e * xEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
- auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
- auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
+ auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
+ auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
z[zOffset] = lambda(x[xOffset]);
}
@@ -115,7 +123,7 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
auto xOrder = shape::order(xShapeInfo);
auto zOrder = shape::order(zShapeInfo);
- auto zLength = shape::length(zShapeInfo);
+ auto zLength = __length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@@ -124,8 +132,8 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
z[e * zEws] = lambda(e, x[e * xEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
- auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
- auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
+ auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
+ auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
z[zOffset] = lambda(e, x[xOffset]);
}
@@ -147,7 +155,7 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
- auto zLength = shape::length(zShapeInfo);
+ auto zLength = __length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@@ -156,9 +164,9 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
- auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
- auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength);
- auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
+ auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
+ auto yOffset = __getIndexOffset(e, yShapeInfo, zLength);
+ auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
z[zOffset] = lambda(e, x[xOffset], y[yOffset]);
}
@@ -180,7 +188,7 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
- auto zLength = shape::length(zShapeInfo);
+ auto zLength = __length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@@ -189,9 +197,9 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
z[e * zEws] = lambda(x[e * xEws], y[e * yEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
- auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
- auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength);
- auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
+ auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
+ auto yOffset = __getIndexOffset(e, yShapeInfo, zLength);
+ auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
z[zOffset] = lambda(x[xOffset], y[yOffset]);
}
@@ -216,7 +224,7 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
- auto zLength = shape::length(zShapeInfo);
+ auto zLength = __length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@@ -225,10 +233,10 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
- auto wOffset = shape::getIndexOffset(e, wShapeInfo, zLength);
- auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
- auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength);
- auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
+ auto wOffset = __getIndexOffset(e, wShapeInfo, zLength);
+ auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
+ auto yOffset = __getIndexOffset(e, yShapeInfo, zLength);
+ auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]);
}
diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu
index 0d441fe5e..af9fc6776 100755
--- a/libnd4j/blas/cuda/NativeOps.cu
+++ b/libnd4j/blas/cuda/NativeOps.cu
@@ -28,6 +28,7 @@
#include
#include
#include
+#include
#include
#include
@@ -1691,11 +1692,7 @@ void setOmpMinThreads(int threads) {
}
int getDevice() {
- int curDevice = -1;
-
- cudaGetDevice(&curDevice);
-
- return curDevice;
+ return nd4j::AffinityManager::currentDeviceId();
}
void setElementThreshold(int num) {
@@ -2391,8 +2388,8 @@ void sortByValue(Nd4jPointer *extraPointers,
auto xLength = shape::length(xShapeInfo);
auto xEWS = shape::elementWiseStride(xShapeInfo);
- auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
- auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
+ auto xType = nd4j::ArrayOptions::dataType(yShapeInfo);
+ auto yType = nd4j::ArrayOptions::dataType(xShapeInfo);
// check if xLength is a power of 2, and use bitonic sort, if that's the case
@@ -2406,7 +2403,7 @@ void sortByValue(Nd4jPointer *extraPointers,
for (int k = 2; k <= xLength; k = 2*k) {
for (int j = k >> 1; j > 0; j = j >> 1) {
- BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
+ BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
} else {
@@ -2430,7 +2427,7 @@ void sortByValue(Nd4jPointer *extraPointers,
int rev = 0;
do{
int half = n >> 1;
- BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
+ BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
n>>=1;
rev = 1;
} while(n > 1);
@@ -3342,6 +3339,7 @@ Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) {
nd4j::graph::Context* createGraphContext(int nodeId) {
return new nd4j::graph::Context(nodeId);
}
+
nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) {
return &ptr->randomGenerator();
}
@@ -3460,3 +3458,35 @@ const char* runFullBenchmarkSuit(bool printOut) {
Nd4jLong getCachedMemory(int deviceId) {
return nd4j::ConstantHelper::getInstance()->getCachedAmount(deviceId);
}
+
+nd4j::LaunchContext* defaultLaunchContext() {
+ return LaunchContext::defaultContext();
+}
+
+Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) {
+ return lc->getScalarPointer();
+}
+
+Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) {
+ return lc->getReductionPointer();
+}
+
+Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) {
+ return lc->getAllocationPointer();
+}
+
+Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) {
+ return lc->getCudaStream();
+}
+
+Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) {
+ return lc->getCudaSpecialStream();
+}
+
+Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) {
+ return lc->getCublasHandle();
+}
+
+Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) {
+ return lc->getCusolverHandle();
+}
\ No newline at end of file
diff --git a/libnd4j/include/execution/AffinityManager.h b/libnd4j/include/execution/AffinityManager.h
new file mode 100644
index 000000000..463d6942e
--- /dev/null
+++ b/libnd4j/include/execution/AffinityManager.h
@@ -0,0 +1,46 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+//
+// @author raver119@gmail.com
+//
+
+#ifndef LIBND4J_AFFINITYMANAGER_H
+#define LIBND4J_AFFINITYMANAGER_H
+
+#include
+#include
+#include
+#include
+
+namespace nd4j {
+ class ND4J_EXPORT AffinityManager {
+ private:
+ static std::atomic _lastDevice;
+ static int _numberOfDevices;
+ static std::mutex _currentMutex;
+ static std::mutex _numberMutex;
+
+ public:
+ static int currentNativeDeviceId();
+ static int currentDeviceId();
+ static int numberOfDevices();
+ static void setCurrentDevice(int deviceId);
+ static void setCurrentNativeDevice(int deviceId);
+ };
+}
+
+#endif //DEV_TESTS_AFFINITYMANAGER_H
diff --git a/libnd4j/include/execution/ContextBuffers.h b/libnd4j/include/execution/ContextBuffers.h
new file mode 100644
index 000000000..77b8d4ca3
--- /dev/null
+++ b/libnd4j/include/execution/ContextBuffers.h
@@ -0,0 +1,58 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+//
+// @author raver119@gmail.com
+//
+
+#ifndef LIBND4J_CONTEXTBUFFERS_H
+#define LIBND4J_CONTEXTBUFFERS_H
+
+#include
+#include
+
+namespace nd4j {
+ class ND4J_EXPORT ContextBuffers {
+ private:
+ void* _reductionPointer;
+ void* _scalarPointer;
+ void* _allocationPointer;
+ bool _allocated = true;
+
+ int _deviceId = -1;
+
+ void initialize();
+ public:
+ ContextBuffers();
+ ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner = false);
+ ~ContextBuffers();
+
+ void* reductionBuffer();
+ void* scalarBuffer();
+ void* allocationBuffer();
+
+ void setReductionBuffer(void* pointer);
+ void setScalarBuffer(void* pointer);
+ void setAllocationBuffer(void* pointer);
+
+ void triggerOwnership(bool isOwner);
+
+ int deviceId();
+ };
+}
+
+
+#endif //DEV_TESTS_CONTEXTBUFFERS_H
diff --git a/libnd4j/include/execution/LaunchContext.h b/libnd4j/include/execution/LaunchContext.h
index 853a970d2..02b772415 100644
--- a/libnd4j/include/execution/LaunchContext.h
+++ b/libnd4j/include/execution/LaunchContext.h
@@ -35,6 +35,8 @@
#include
#include
#include
+#include
+#include
@@ -44,49 +46,44 @@ class ND4J_EXPORT LaunchContext {
private:
static std::vector> _contexts;
+ static std::mutex _mutex;
#ifdef __CUDABLAS__
#ifndef __JAVACPP_HACK__
- void* _reductionPointer;
- void* _scalarPointer;
- int* _allocationPointer;
- cudaStream_t *_cudaStream = nullptr;
- cudaStream_t *_cudaSpecialStream = nullptr;
- void *_cublasHandle = nullptr;
+ cudaStream_t* _cudaStream = nullptr;
+ cudaStream_t* _cudaSpecialStream = nullptr;
+ void* _cublasHandle = nullptr;
+ void* _cusolverHandle = nullptr;
#endif // JCPP
bool _isAllocated = false;
#endif // CUDA
- nd4j::memory::Workspace* _workspace = nullptr;
- int _deviceID = 0;
+ nd4j::memory::Workspace* _workspace = nullptr;
+ int _deviceID = 0;
+
public:
#ifdef __CUDABLAS__
#ifndef __JAVACPP_HACK__
LaunchContext(cudaStream_t* cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer = nullptr, void* scalarPointer = nullptr, int* allocationPointer = nullptr);
- FORCEINLINE void* getReductionPointer () const {return _reductionPointer;};
+ void* getReductionPointer () const;
+ void* getScalarPointer() const;
+ int* getAllocationPointer() const;
+ void* getCublasHandle() const;
+ void* getCusolverHandle() const;
+ cudaStream_t* getCudaStream() const;
+ cudaStream_t* getCudaSpecialStream() const;
- FORCEINLINE void* getScalarPointer() const {return _scalarPointer;};
-
- FORCEINLINE int* getAllocationPointer() const {return _allocationPointer;};
-
- FORCEINLINE void* getCublasHandle() const {return _cublasHandle;};
- FORCEINLINE cudaStream_t* getCudaStream() const {return _cudaStream;};
- FORCEINLINE cudaStream_t* getCudaSpecialStream() const {return _cudaSpecialStream;};
-
- FORCEINLINE void setReductionPointer (void* reductionPointer) {_reductionPointer = reductionPointer;};
-
- FORCEINLINE void setScalarPointer(void* scalarPointer) {_scalarPointer = scalarPointer;};
-
- FORCEINLINE void setAllocationPointer(int* allocationPointer) {_allocationPointer = allocationPointer;};
-
- FORCEINLINE void setCudaStream(cudaStream_t* cudaStream) {_cudaStream = cudaStream;};
- FORCEINLINE void setCudaSpecialStream(cudaStream_t* cudaStream) {_cudaSpecialStream = cudaStream;};
- FORCEINLINE void setCublasHandle(void *handle) {_cublasHandle = handle; };
+ void setReductionPointer (void* reductionPointer);
+ void setScalarPointer(void* scalarPointer);
+ void setAllocationPointer(int* allocationPointer);
+ void setCudaStream(cudaStream_t* cudaStream);
+ void setCudaSpecialStream(cudaStream_t* cudaStream);
+ void setCublasHandle(void *handle);
#endif // JCPP
diff --git a/libnd4j/include/execution/cpu/AffinityManager.cpp b/libnd4j/include/execution/cpu/AffinityManager.cpp
new file mode 100644
index 000000000..7927982a6
--- /dev/null
+++ b/libnd4j/include/execution/cpu/AffinityManager.cpp
@@ -0,0 +1,43 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+//
+// @author raver119@gmail.com
+//
+
+#include
+
+namespace nd4j {
+ int AffinityManager::currentDeviceId() {
+ return 0;
+ }
+
+ int AffinityManager::currentNativeDeviceId() {
+ return 0;
+ }
+
+ int AffinityManager::numberOfDevices() {
+ return 1;
+ }
+
+ void AffinityManager::setCurrentDevice(int deviceId) {
+ // no-op
+ }
+
+ void AffinityManager::setCurrentNativeDevice(int deviceId) {
+ // no-op
+ }
+}
\ No newline at end of file
diff --git a/libnd4j/include/execution/cpu/ContextBuffers.cpp b/libnd4j/include/execution/cpu/ContextBuffers.cpp
new file mode 100644
index 000000000..d385548d0
--- /dev/null
+++ b/libnd4j/include/execution/cpu/ContextBuffers.cpp
@@ -0,0 +1,74 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+//
+// @author raver119@gmail.com
+//
+#include
+#include
+
+namespace nd4j {
+ ContextBuffers::ContextBuffers() {
+ _deviceId = AffinityManager::currentDeviceId();
+ }
+
+ ContextBuffers::~ContextBuffers() {
+ // no-op
+ }
+
+ ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) {
+ _reductionPointer = rPointer;
+ _scalarPointer = sPointer;
+ _allocationPointer = aPointer;
+ _allocated = isOwner;
+ }
+
+ void ContextBuffers::initialize() {
+ // no-op
+ }
+
+ void* ContextBuffers::reductionBuffer() {
+ return _reductionPointer;
+ }
+
+ void* ContextBuffers::scalarBuffer() {
+ return _scalarPointer;
+ }
+
+ void* ContextBuffers::allocationBuffer() {
+ return _allocationPointer;
+ }
+
+ void ContextBuffers::setReductionBuffer(void* pointer) {
+ _reductionPointer = pointer;
+ }
+
+ void ContextBuffers::setScalarBuffer(void* pointer) {
+ _scalarPointer = pointer;
+ }
+
+ void ContextBuffers::setAllocationBuffer(void* pointer) {
+ _allocationPointer = pointer;
+ }
+
+ void ContextBuffers::triggerOwnership(bool isOwner) {
+ _allocated = isOwner;
+ }
+
+ int ContextBuffers::deviceId() {
+ return _deviceId;
+ }
+}
diff --git a/libnd4j/include/execution/cpu/LaunchContext.cpp b/libnd4j/include/execution/cpu/LaunchContext.cpp
new file mode 100644
index 000000000..47207719f
--- /dev/null
+++ b/libnd4j/include/execution/cpu/LaunchContext.cpp
@@ -0,0 +1,56 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+//
+// Created by raver119 on 30.11.17.
+//
+
+#include
+#include
+#include
+#include
+
+thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
+
+namespace nd4j {
+
+ LaunchContext::~LaunchContext() {
+
+ }
+
+ std::vector> LaunchContext::_contexts = std::vector>();
+
+////////////////////////////////////////////////////////////////////////
+ LaunchContext::LaunchContext() {
+ // default constructor, just to make clang/ranlib happy
+ _workspace = nullptr;
+ _deviceID = 0;
+ }
+
+ LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
+
+ }
+
+ LaunchContext* LaunchContext::defaultContext() {
+ // TODO: we need it to be device-aware, but only once we add NUMA support for cpu
+ if (LaunchContext::_contexts.empty()) {
+ LaunchContext::_contexts.emplace_back(std::make_shared());
+ }
+
+ // return context for current device
+ return LaunchContext::_contexts[0].get();
+ }
+}
\ No newline at end of file
diff --git a/libnd4j/include/execution/cuda/AffinityManager.cu b/libnd4j/include/execution/cuda/AffinityManager.cu
new file mode 100644
index 000000000..811dc267a
--- /dev/null
+++ b/libnd4j/include/execution/cuda/AffinityManager.cu
@@ -0,0 +1,108 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+//
+// @author raver119@gmail.com
+//
+
+#include
+#include
+#include
+
+thread_local int globalThreadToDevice = -1;
+
+namespace nd4j {
+ std::mutex AffinityManager::_currentMutex;
+ std::mutex AffinityManager::_numberMutex;
+ int AffinityManager::_numberOfDevices = -1;
+
+ int AffinityManager::currentDeviceId() {
+ // if there's no affinity set - set it now
+ if (globalThreadToDevice < 0) {
+
+ // this block must be thread-local
+ _currentMutex.lock();
+
+ globalThreadToDevice = _lastDevice++;
+
+ // we need to check if we've got deviceId >= number of actual devices, and reset to zero otherwise
+ if (globalThreadToDevice >= numberOfDevices()) {
+ globalThreadToDevice = 0;
+ _lastDevice = numberOfDevices() > 1 ? 1 : 0;
+ }
+
+ _currentMutex.unlock();
+
+ setCurrentDevice(globalThreadToDevice);
+ }
+
+ // if we already know affinity - just return it
+ if (globalThreadToDevice >= 0)
+ return globalThreadToDevice;
+
+ int dev = 0;
+ auto res = cudaGetDevice(&dev);
+
+ if (res != 0)
+ throw cuda_exception::build("cudaGetDevice failed", res);
+
+ return dev;
+ }
+
+ int AffinityManager::currentNativeDeviceId() {
+ int dev = 0;
+ auto res = cudaGetDevice(&dev);
+
+ if (res != 0)
+ throw cuda_exception::build("cudaGetDevice failed", res);
+
+ return dev;
+ }
+
+ int AffinityManager::numberOfDevices() {
+ _numberMutex.lock();
+ // we want to cache number of devices
+ if (_numberOfDevices <= 0) {
+ int dev = 0;
+ auto res = cudaGetDeviceCount(&dev);
+
+ if (res != 0)
+ throw cuda_exception::build("cudaGetDeviceCount failed", res);
+
+ _numberOfDevices = dev;
+ }
+ _numberMutex.unlock();
+
+ return _numberOfDevices;
+ }
+
+ void AffinityManager::setCurrentNativeDevice(int deviceId) {
+ auto res = cudaSetDevice(deviceId);
+ }
+
+ void AffinityManager::setCurrentDevice(int deviceId) {
+ auto res = cudaSetDevice(deviceId);
+ if (res != 0)
+ throw cuda_exception::build("cudaSetDevice failed", res);
+
+ // update thread-device affinity
+ globalThreadToDevice = deviceId;
+
+ // TODO: update context buffers?
+ }
+
+ std::atomic AffinityManager::_lastDevice;// = std::atomic(initialV);
+}
\ No newline at end of file
diff --git a/libnd4j/include/execution/cuda/ContextBuffers.cu b/libnd4j/include/execution/cuda/ContextBuffers.cu
new file mode 100644
index 000000000..f82747b91
--- /dev/null
+++ b/libnd4j/include/execution/cuda/ContextBuffers.cu
@@ -0,0 +1,116 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+//
+// @author raver119@gmail.com
+//
+
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+namespace nd4j {
+ ContextBuffers::ContextBuffers() {
+ nd4j_printf("Creating ContextBuffers for device [%i]\n", AffinityManager::currentDeviceId());
+ _deviceId = AffinityManager::currentDeviceId();
+ }
+
+ ContextBuffers::~ContextBuffers() {
+ if (_allocated) {
+ nd4j_printf("Releasing ContextBuffers\n","");
+
+ if (_allocationPointer != nullptr)
+ cudaFree(_allocationPointer);
+
+ if (_scalarPointer != nullptr)
+ cudaFree(_scalarPointer);
+
+ if (_allocationPointer != nullptr)
+ cudaFree(_reductionPointer);
+ }
+ }
+
+ ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) {
+ _reductionPointer = rPointer;
+ _scalarPointer = sPointer;
+ _allocationPointer = aPointer;
+ _allocated = isOwner;
+ }
+
+ void ContextBuffers::initialize() {
+ nd4j_printf("Initializing buffers on deviceId [%i]\n", AffinityManager::currentNativeDeviceId());
+
+ auto res = cudaMalloc(reinterpret_cast(&_reductionPointer), 1024 * 1024 * 8);
+ if (res != 0)
+ throw std::runtime_error("_reductionPointer allocation failed");
+
+ res = cudaMalloc(reinterpret_cast(&_scalarPointer), 16);
+ if (res != 0)
+ throw std::runtime_error("_scalarPointer allocation failed");
+
+ res = cudaMalloc(reinterpret_cast(&_allocationPointer), 1024 * 1024 * 8);
+ if (res != 0)
+ throw std::runtime_error("_allocationPointer allocation failed");
+
+ _allocated = true;
+ }
+
+ void* ContextBuffers::reductionBuffer() {
+ if (_reductionPointer == nullptr)
+ initialize();
+
+ return _reductionPointer;
+ }
+
+ void* ContextBuffers::scalarBuffer() {
+ if (_scalarPointer == nullptr)
+ initialize();
+
+ return _scalarPointer;
+ }
+
+ void* ContextBuffers::allocationBuffer() {
+ if (_allocationPointer == nullptr)
+ initialize();
+
+ return _allocationPointer;
+ }
+
+ void ContextBuffers::setReductionBuffer(void* pointer) {
+ _reductionPointer = pointer;
+ }
+
+ void ContextBuffers::setScalarBuffer(void* pointer) {
+ _scalarPointer = pointer;
+ }
+
+ void ContextBuffers::setAllocationBuffer(void* pointer) {
+ _allocationPointer = pointer;
+ }
+
+ void ContextBuffers::triggerOwnership(bool isOwner) {
+ _allocated = isOwner;
+ }
+
+ int ContextBuffers::deviceId() {
+ return _deviceId;
+ }
+}
diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu
new file mode 100644
index 000000000..004ed2cac
--- /dev/null
+++ b/libnd4j/include/execution/cuda/LaunchContext.cu
@@ -0,0 +1,182 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+//
+// Created by raver119 on 30.11.17.
+//
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
+
+namespace nd4j {
+
+ std::vector> LaunchContext::_contexts = std::vector>();
+ std::mutex LaunchContext::_mutex;
+
+////////////////////////////////////////////////////////////////////////
+LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) {
+
+ _cudaStream = cudaStream;
+ _cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream;
+ //_reductionPointer = reductionPointer;
+ //_scalarPointer = scalarPointer;
+ //_allocationPointer = allocationPointer;
+ _workspace = nullptr;
+ _isAllocated = false;
+}
+
+LaunchContext::~LaunchContext() {
+ if (_isAllocated) {
+ cudaStreamSynchronize(*_cudaStream);
+ cudaStreamSynchronize(*_cudaSpecialStream);
+
+ cudaStreamDestroy(*_cudaStream);
+ cudaStreamDestroy(*_cudaSpecialStream);
+
+ delete _cudaStream;
+ delete _cudaSpecialStream;
+ }
+}
+
+////////////////////////////////////////////////////////////////////////
+LaunchContext::LaunchContext() {
+ // default constructor, just to make clang/ranlib happy
+ _workspace = nullptr;
+ _deviceID = 0;
+
+ _isAllocated = true;
+ _cudaStream = new cudaStream_t();
+ _cudaSpecialStream = new cudaStream_t();
+ if (nullptr == _cudaStream || nullptr == _cudaSpecialStream)
+ throw std::runtime_error("Failed to allocate memory for new CUDA stream");
+
+ cudaError_t err = cudaStreamCreate(_cudaStream);
+ if (err != 0)
+ throw cuda_exception::build("Failed to create default CUDA stream with launch context", err);
+
+ err = cudaStreamCreate(_cudaSpecialStream);
+ if (err != 0)
+ throw cuda_exception::build("Failed to create special CUDA stream with launch context", err);
+
+ _cublasHandle = CublasHelper::getInstance()->handle();
+
+ _cusolverHandle = CublasHelper::getInstance()->solver();
+
+ auto res = cudaStreamSynchronize(*_cudaStream);
+ if (res != 0)
+ throw cuda_exception::build("Initial sync failed", res);
+}
+
+ LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
+ _isAllocated = false;
+ _cudaStream = reinterpret_cast(cudaStream);
+ _cudaSpecialStream = reinterpret_cast(cudaStream);
+ //_reductionPointer = reductionPointer;
+ //_scalarPointer = scalarPointer;
+ //_allocationPointer = reinterpret_cast(allocationPointer);
+ }
+
+ LaunchContext* LaunchContext::defaultContext() {
+ /**
+ * This method returns LaunchContext, that has multiple entities within:
+ * 1) temporary buffers. they must be per-thread
+ * 2) CUDA stream. it must be either per-thread or per-device
+ * 3) cuBLAS handle. it must be per-device
+ */
+ auto deviceId = AffinityManager::currentDeviceId();
+
+ // we need this block synchronous, to avoid double initialization etc
+ _mutex.lock();
+ if (LaunchContext::_contexts.empty()) {
+ // create one context per device
+ auto numDevices = AffinityManager::numberOfDevices();
+
+ _contexts.resize(numDevices);
+ for (int e = 0; e < numDevices; e++) {
+ AffinityManager::setCurrentDevice(e);
+
+ LaunchContext::_contexts[e] = std::make_shared();
+ }
+
+ // don't forget to restore device back again
+ AffinityManager::setCurrentDevice(deviceId);
+ }
+ _mutex.unlock();
+
+ // return context for current device
+ return LaunchContext::_contexts[deviceId].get();
+ }
+
+
+ void* LaunchContext::getReductionPointer () const {
+ return contextBuffers.reductionBuffer();
+ };
+
+ void* LaunchContext::getScalarPointer() const {
+ return contextBuffers.scalarBuffer();
+ };
+
+ int* LaunchContext::getAllocationPointer() const {
+ return reinterpret_cast(contextBuffers.allocationBuffer());
+ };
+
+ void* LaunchContext::getCublasHandle() const {
+ return _cublasHandle;
+ };
+
+ void* LaunchContext::getCusolverHandle() const {
+ return _cusolverHandle;
+ };
+
+ cudaStream_t* LaunchContext::getCudaStream() const {
+ return _cudaStream;
+ };
+
+ cudaStream_t* LaunchContext::getCudaSpecialStream() const {
+ return _cudaSpecialStream;
+ };
+
+
+ void LaunchContext::setReductionPointer (void* reductionPointer) {
+ contextBuffers.setReductionBuffer(reductionPointer);
+ };
+
+ void LaunchContext::setScalarPointer(void* scalarPointer) {
+ contextBuffers.setScalarBuffer(scalarPointer);
+ };
+
+ void LaunchContext::setAllocationPointer(int* allocationPointer) {
+ contextBuffers.setAllocationBuffer(allocationPointer);
+ };
+
+ void LaunchContext::setCudaStream(cudaStream_t* cudaStream) {
+ _cudaStream = cudaStream;
+ };
+
+ void LaunchContext::setCudaSpecialStream(cudaStream_t* cudaStream) {
+ _cudaSpecialStream = cudaStream;
+ };
+
+ void LaunchContext::setCublasHandle(void *handle) {
+ _cublasHandle = handle;
+ };
+}
\ No newline at end of file
diff --git a/libnd4j/include/execution/impl/LaunchContext.cpp b/libnd4j/include/execution/impl/LaunchContext.cpp
deleted file mode 100644
index edc95dabc..000000000
--- a/libnd4j/include/execution/impl/LaunchContext.cpp
+++ /dev/null
@@ -1,130 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-//
-// Created by raver119 on 30.11.17.
-//
-
-#include
-#include
-#include
-#include
-
-namespace nd4j {
-
-#ifdef __CUDABLAS__
-
-////////////////////////////////////////////////////////////////////////
-LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) {
-
- _cudaStream = cudaStream;
- _cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream;
- _reductionPointer = reductionPointer;
- _scalarPointer = scalarPointer;
- _allocationPointer = allocationPointer;
- _workspace = nullptr;
- _isAllocated = false;
-}
-#endif
-
-LaunchContext::~LaunchContext() {
-#ifdef __CUDABLAS__
- if (_isAllocated) {
- cudaStreamSynchronize(*_cudaStream);
- cudaStreamSynchronize(*_cudaSpecialStream);
-
- cudaStreamDestroy(*_cudaStream);
- cudaStreamDestroy(*_cudaSpecialStream);
-
- delete _cudaStream;
- delete _cudaSpecialStream;
-
- cudaFree(_reductionPointer);
- cudaFree(_allocationPointer);
- cudaFree(_scalarPointer);
-
- cublas::destroyHandle(_cublasHandle);
- }
-#endif
-}
-
- std::vector> LaunchContext::_contexts = std::vector>();
-
-////////////////////////////////////////////////////////////////////////
-LaunchContext::LaunchContext() {
- // default constructor, just to make clang/ranlib happy
- _workspace = nullptr;
- _deviceID = 0;
-
-#ifdef __CUDABLAS__
- _isAllocated = true;
- _cudaStream = new cudaStream_t();
- _cudaSpecialStream = new cudaStream_t();
- if (nullptr == _cudaStream || nullptr == _cudaSpecialStream)
- throw std::runtime_error("Failed to allocate memory for new CUDA stream");
-
- cudaError_t err = cudaStreamCreate(_cudaStream);
- if (err != 0)
- throw cuda_exception::build("Failed to create default CUDA stream with launch context", err);
-
- err = cudaStreamCreate(_cudaSpecialStream);
- if (err != 0)
- throw cuda_exception::build("Failed to create special CUDA stream with launch context", err);
-
- _cublasHandle = cublas::handle();
-
- auto res = cudaStreamSynchronize(*_cudaStream);
- if (res != 0)
- throw cuda_exception::build("Initial sync failed", res);
-
- res = cudaMalloc(reinterpret_cast(&_reductionPointer), 1024 * 1024 * 8);
- if (res != 0)
- throw std::runtime_error("_reductionPointer allocation failed");
-
- res = cudaMalloc(reinterpret_cast(&_scalarPointer), 8);
- if (res != 0)
- throw std::runtime_error("_scalarPointer allocation failed");
-
- res = cudaMalloc(reinterpret_cast(&_allocationPointer), 1024 * 1024 * 8);
- if (res != 0)
- throw std::runtime_error("_allocationPointer allocation failed");
-#else
- //
-#endif
-}
-
- LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
-#ifdef __CUDABLAS__
- _isAllocated = false;
- _cudaStream = reinterpret_cast(cudaStream);
- _cudaSpecialStream = reinterpret_cast(cudaStream);
- _reductionPointer = reductionPointer;
- _scalarPointer = scalarPointer;
- _allocationPointer = reinterpret_cast(allocationPointer);
-#else
- // no-op
-#endif
- }
-
-LaunchContext* LaunchContext::defaultContext() {
- // TODO: we need it to be device-aware
- if (LaunchContext::_contexts.empty()) {
- LaunchContext::_contexts.emplace_back(std::make_shared());
- }
- return LaunchContext::_contexts[0].get();
-}
-
-}
\ No newline at end of file
diff --git a/libnd4j/include/helpers/cpu/ConstantHelper.cpp b/libnd4j/include/helpers/cpu/ConstantHelper.cpp
index f74bd5637..43a4f97c1 100644
--- a/libnd4j/include/helpers/cpu/ConstantHelper.cpp
+++ b/libnd4j/include/helpers/cpu/ConstantHelper.cpp
@@ -21,6 +21,7 @@
#ifndef __CUDABLAS__
#include
+#include
#include
#include
#include
@@ -59,11 +60,11 @@ namespace nd4j {
}
int ConstantHelper::getCurrentDevice() {
- return 0L;
+ return AffinityManager::currentDeviceId();
}
int ConstantHelper::getNumberOfDevices() {
- return 1;
+ return AffinityManager::numberOfDevices();
}
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp
index 293360a25..d17d2c021 100644
--- a/libnd4j/include/helpers/cpu/MmulHelper.cpp
+++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp
@@ -21,6 +21,7 @@
#include "../MmulHelper.h"
#include
#include
+#include
namespace nd4j {
@@ -147,7 +148,12 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX,
//////////////////////////////////////////////////////////////////////////////
// MXK x KxN = MxN
-NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) {
+NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) {
+ if (A->dataType() != B->dataType())
+ throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), B->dataType());
+
+ if (C != nullptr && A->dataType() != C->dataType())
+ throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), C->dataType());
if(A->rankOf() != 2)
throw std::runtime_error("MmulHelper::mmulMxM: rank of A array is not equal 2 !");
@@ -212,7 +218,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, reinterpret_cast(pA->getBuffer()), lda, reinterpret_cast(pB->getBuffer()), ldb, (double) beta, reinterpret_cast(pC->getBuffer()), ldc);
}
else {
- BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
+ BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), NUMERIC_TYPES);
+ //BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
}
if(pC != C) {
@@ -230,6 +237,11 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
////////////////////////////////////////////////////////////////////////////
// MXN x N = M
NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* Y, const double alpha, const double beta, const char outOrder) {
+ if (X->dataType() != A->dataType())
+ throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), X->dataType());
+
+ if (Y != nullptr && X->dataType() != Y->dataType())
+ throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), Y->dataType());
int xLenDim, yLenDim(0);
@@ -279,7 +291,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->getBuffer(), lda, (float*)X->getBuffer(), incx, (float)beta, (float*)Y->getBuffer(), incy);
}
else {
- BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
+ BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), NUMERIC_TYPES);
+ //BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
}
if(pA != A)
@@ -291,6 +304,11 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
////////////////////////////////////////////////////////////////////////////
// (X * Y) = Z[0]
NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, const double alpha, const double beta) {
+ if (X->dataType() != Y->dataType())
+ throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Y->dataType());
+
+ if (Z != nullptr && X->dataType() != Z->dataType())
+ throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Z->dataType());
int xLenDim(0), yLenDim(0);
@@ -316,13 +334,14 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
const auto yType = Y->dataType();
const auto zType = Z->dataType();
- BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
+ BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), NUMERIC_TYPES);
+ //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
return Z;
}
-BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
-BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
-BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
+//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
+//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
+//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
}
diff --git a/libnd4j/include/helpers/cpu/cublasHelper.cpp b/libnd4j/include/helpers/cpu/cublasHelper.cpp
index cc2a4029a..3dba2d31e 100644
--- a/libnd4j/include/helpers/cpu/cublasHelper.cpp
+++ b/libnd4j/include/helpers/cpu/cublasHelper.cpp
@@ -21,13 +21,41 @@
#include "../cublasHelper.h"
namespace nd4j {
- namespace cublas {
- void* handle() {
- return nullptr;
- }
-
- void destroyHandle(void* handle) {
- //
- }
+ static void* handle_() {
+ return nullptr;
}
+
+ static void destroyHandle_(void* handle) {
+
+ }
+
+ CublasHelper::CublasHelper() {
+
+ }
+
+ CublasHelper::~CublasHelper() {
+
+ }
+
+ CublasHelper* CublasHelper::getInstance() {
+ if (!_INSTANCE)
+ _INSTANCE = new nd4j::CublasHelper();
+
+ return _INSTANCE;
+ }
+
+ void* CublasHelper::handle() {
+ return nullptr;
+ }
+
+ void* CublasHelper::solver() {
+ return nullptr;
+ }
+
+ void* CublasHelper::handle(int deviceId) {
+ return nullptr;
+ }
+
+
+ nd4j::CublasHelper* nd4j::CublasHelper::_INSTANCE = 0;
}
\ No newline at end of file
diff --git a/libnd4j/include/helpers/impl/loops/IndexReductionLoops.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/IndexReductionLoops.cpp
rename to libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp
diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_0.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_0.cpp
rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp
diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_1.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_1.cpp
rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp
diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_2.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_2.cpp
rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp
diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_3.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_3.cpp
rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp
diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops.hpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops.hpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/ReductionLoops.hpp
rename to libnd4j/include/helpers/cpu/loops/ReductionLoops.hpp
diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_bool.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/ReductionLoops_bool.cpp
rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp
diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_0.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_0.cpp
rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp
diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_1.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_1.cpp
rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp
diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_2.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_2.cpp
rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp
diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_3.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_3.cpp
rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp
diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_long.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/ReductionLoops_long.cpp
rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp
diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_same.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp
similarity index 100%
rename from libnd4j/include/helpers/impl/loops/ReductionLoops_same.cpp
rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp
diff --git a/libnd4j/include/helpers/cublasHelper.h b/libnd4j/include/helpers/cublasHelper.h
index bff16b2d4..d4f92881e 100644
--- a/libnd4j/include/helpers/cublasHelper.h
+++ b/libnd4j/include/helpers/cublasHelper.h
@@ -21,12 +21,28 @@
#ifndef DEV_TESTS_CUBLASHELPER_H
#define DEV_TESTS_CUBLASHELPER_H
-namespace nd4j {
- namespace cublas {
- void* handle();
+#include
+#include
+#include
- void destroyHandle(void* handle);
- }
+namespace nd4j {
+ class CublasHelper {
+ private:
+ static CublasHelper *_INSTANCE;
+
+ std::vector _cache;
+ std::vector _solvers;
+
+ CublasHelper();
+ ~CublasHelper();
+ public:
+ static CublasHelper* getInstance();
+
+ void* solver();
+
+ void* handle();
+ void* handle(int deviceId);
+ };
}
#endif //DEV_TESTS_CUBLASHELPER_H
diff --git a/libnd4j/include/helpers/cuda/ConstantHelper.cu b/libnd4j/include/helpers/cuda/ConstantHelper.cu
index d0579b66d..0c7f2cbc1 100644
--- a/libnd4j/include/helpers/cuda/ConstantHelper.cu
+++ b/libnd4j/include/helpers/cuda/ConstantHelper.cu
@@ -26,6 +26,7 @@
#include
#include
#include
+#include
#define CONSTANT_LIMIT 49152
@@ -43,23 +44,11 @@ namespace nd4j {
}
int ConstantHelper::getCurrentDevice() {
- int dev = 0;
- auto res = cudaGetDevice(&dev);
-
- if (res != 0)
- throw cuda_exception::build("cudaGetDevice failed", res);
-
- return dev;
+ return AffinityManager::currentDeviceId();
}
int ConstantHelper::getNumberOfDevices() {
- int dev = 0;
- auto res = cudaGetDeviceCount(&dev);
-
- if (res != 0)
- throw cuda_exception::build("cudaGetDeviceCount failed", res);
-
- return dev;
+ return AffinityManager::numberOfDevices();
}
diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu
index 56e726004..ac5eb4176 100644
--- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu
+++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu
@@ -250,8 +250,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
blocksPerGrid.y = math::nd4j_ceil(static_cast(M) / threadsPerBlock.y); // rows
}
- BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
- // BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES)
+ //BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
+ BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES)
}
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
@@ -339,8 +339,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
threadsPerBlock.x = 512;
blocksPerGrid.x = math::nd4j_ceil(static_cast(M) / threadsPerBlock.x); // rows
}
- BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
- // BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES)
+ //BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
+ BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES)
}
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
@@ -397,8 +397,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
NDArray::prepareSpecialUse({Z}, {X, Y});
- BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
- // BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES)
+ //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
+ BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES)
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
@@ -408,8 +408,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
return Z;
}
-BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
-BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
-BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
+//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
+//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
+//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
}
\ No newline at end of file
diff --git a/libnd4j/include/helpers/cuda_off/cublasHelper.cu b/libnd4j/include/helpers/cuda_off/cublasHelper.cu
index f80bf1f87..6f2cf2084 100644
--- a/libnd4j/include/helpers/cuda_off/cublasHelper.cu
+++ b/libnd4j/include/helpers/cuda_off/cublasHelper.cu
@@ -20,12 +20,15 @@
#include
+#include
#include "../cublasHelper.h"
#include
#include
+#include
namespace nd4j {
- void* cublas::handle() {
+
+ static void* handle_() {
auto _handle = new cublasHandle_t();
auto status = cublasCreate_v2(_handle); // initialize CUBLAS context
if (status != CUBLAS_STATUS_SUCCESS)
@@ -34,7 +37,16 @@ namespace nd4j {
return reinterpret_cast(_handle);
}
- void cublas::destroyHandle(void* handle) {
+ static void* solver_() {
+ auto cusolverH = new cusolverDnHandle_t();
+ auto status = cusolverDnCreate(cusolverH);
+ if (status != CUSOLVER_STATUS_SUCCESS)
+ throw cuda_exception::build("cuSolver handle creation failed !", status);
+
+ return cusolverH;
+ }
+
+ static void destroyHandle_(void* handle) {
auto ch = reinterpret_cast(handle);
auto status = cublasDestroy_v2(*ch);
if (status != CUBLAS_STATUS_SUCCESS)
@@ -42,4 +54,57 @@ namespace nd4j {
delete ch;
}
+
+ CublasHelper::CublasHelper() {
+ auto numDevices = AffinityManager::numberOfDevices();
+ auto currentDevice = AffinityManager::currentDeviceId();
+ _cache.resize(numDevices);
+ _solvers.resize(numDevices);
+ for (int e = 0; e < numDevices; e++) {
+ AffinityManager::setCurrentDevice(e);
+
+ _cache[e] = handle_();
+ _solvers[e] = solver_();
+ }
+
+ // don't forget to restore back original device
+ AffinityManager::setCurrentDevice(currentDevice);
+ }
+
+ CublasHelper::~CublasHelper() {
+ auto numDevices = AffinityManager::numberOfDevices();
+
+ for (int e = 0; e < numDevices; e++)
+ destroyHandle_(_cache[e]);
+ }
+
+ CublasHelper* CublasHelper::getInstance() {
+ if (!_INSTANCE)
+ _INSTANCE = new nd4j::CublasHelper();
+
+ return _INSTANCE;
+ }
+
+ void* CublasHelper::handle() {
+ auto deviceId = AffinityManager::currentDeviceId();
+ return handle(deviceId);
+ }
+
+ void* CublasHelper::solver() {
+ auto deviceId = AffinityManager::currentDeviceId();
+ if (deviceId < 0 || deviceId > _solvers.size())
+ throw cuda_exception::build("requested deviceId doesn't look valid", deviceId);
+
+ return _solvers[deviceId];
+ }
+
+ void* CublasHelper::handle(int deviceId) {
+ if (deviceId < 0 || deviceId > _cache.size())
+ throw cuda_exception::build("requested deviceId doesn't look valid", deviceId);
+
+ return _cache[deviceId];
+ }
+
+
+ nd4j::CublasHelper* nd4j::CublasHelper::_INSTANCE = 0;
}
\ No newline at end of file
diff --git a/libnd4j/include/loops/cpu/random.cpp b/libnd4j/include/loops/cpu/random.cpp
index aeeedc007..889e48181 100644
--- a/libnd4j/include/loops/cpu/random.cpp
+++ b/libnd4j/include/loops/cpu/random.cpp
@@ -276,23 +276,6 @@ namespace functions {
DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS)
}
- // FIXME: eventually we might want to get rid of that
-#ifndef __CLION_IDE__
-/*
- BUILD_CALL_1(template void RandomFunction::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *y, Nd4jLong *yShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
- BUILD_CALL_1(template void RandomFunction::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *y, Nd4jLong *yShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
- BUILD_CALL_1(template void RandomFunction::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *y, Nd4jLong *yShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
-
- BUILD_CALL_1(template void RandomFunction::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
- BUILD_CALL_1(template void RandomFunction::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
- BUILD_CALL_1(template void RandomFunction::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
-
- BUILD_CALL_1(template void RandomFunction::execTransform, float, (Nd4jPointer state, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
- BUILD_CALL_1(template void RandomFunction::execTransform, float16, (Nd4jPointer state, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
- BUILD_CALL_1(template void RandomFunction::execTransform, double, (Nd4jPointer state, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
-*/
-#endif
-
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES);
}
diff --git a/libnd4j/include/loops/cuda/broadcasting.chpp b/libnd4j/include/loops/cuda/broadcasting.chpp
index e673f4eae..dc8a3eeb1 100644
--- a/libnd4j/include/loops/cuda/broadcasting.chpp
+++ b/libnd4j/include/loops/cuda/broadcasting.chpp
@@ -60,9 +60,18 @@ static __global__ void broadcastInverseSimple(
functions::broadcast::Broadcast::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
+
namespace functions {
namespace broadcast {
+ static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) {
+ return shape::getIndexOffset(index, shapeInfo, length);
+ }
+
+ static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) {
+ return shape::length(shapeInfo);
+ }
+
template
template
__host__ void Broadcast