diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java index cf7d11838..e84488f9d 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java @@ -16,6 +16,7 @@ package org.datavec.spark.transform.client; + import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.Unirest; import com.mashape.unirest.http.exceptions.UnirestException; diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml index 1e525fc3b..2cc01e288 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml @@ -51,11 +51,13 @@ datavec-spark-inference-model ${datavec.version} + org.datavec datavec-spark_2.11 ${project.version} + org.datavec datavec-data-image @@ -67,61 +69,73 @@ akka-cluster_2.11 ${akka.version} + joda-time joda-time ${jodatime.version} + org.apache.commons commons-lang3 ${commons-lang3.version} + org.hibernate hibernate-validator ${hibernate.version} + org.scala-lang scala-library ${scala.version} + org.scala-lang scala-reflect ${scala.version} + org.yaml snakeyaml ${snakeyaml.version} + com.fasterxml.jackson.core jackson-core ${jackson.version} + com.fasterxml.jackson.core jackson-databind ${jackson.version} + com.fasterxml.jackson.core jackson-annotations ${jackson.version} + com.fasterxml.jackson.datatype jackson-datatype-jdk8 ${jackson.version} + com.fasterxml.jackson.datatype jackson-datatype-jsr310 ${jackson.version} + com.typesafe.play play-java_2.11 @@ -137,39 +151,44 @@ + net.jodah typetools ${jodah.typetools.version} + com.typesafe.play play-json_2.11 ${play.version} + com.typesafe.play play-server_2.11 ${play.version} + com.typesafe.play play_2.11 ${play.version} + com.typesafe.play play-netty-server_2.11 ${play.version} - com.mashape.unirest unirest-java ${unirest.version} test + com.beust jcommander diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java index d3d4aee30..e524f00a7 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java @@ -52,6 +52,7 @@ public class CSVSparkTransformServerNoJsonTest { public static void before() throws Exception { server = new CSVSparkTransformServer(); FileUtils.write(fileSave, transformProcess.toJson()); + // Only one time Unirest.setObjectMapper(new ObjectMapper() { private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = @@ -73,6 +74,7 @@ public class CSVSparkTransformServerNoJsonTest { } } }); + server.runMain(new String[] {"-dp", "9050"}); } diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java index 78021d3e5..a9e24d78e 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java @@ -16,6 +16,7 @@ package org.datavec.spark.transform; + import com.mashape.unirest.http.JsonNode; import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.Unirest; @@ -49,6 +50,7 @@ public class CSVSparkTransformServerTest { server = new CSVSparkTransformServer(); FileUtils.write(fileSave, transformProcess.toJson()); // Only one time + Unirest.setObjectMapper(new ObjectMapper() { private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = new org.nd4j.shade.jackson.databind.ObjectMapper(); @@ -69,6 +71,7 @@ public class CSVSparkTransformServerTest { } } }); + server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"}); } diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java index 62dec47c1..bfae23358 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java @@ -16,6 +16,7 @@ package org.datavec.spark.transform; + import com.mashape.unirest.http.JsonNode; import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.Unirest; diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java index a0561563b..e08d881ef 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java @@ -16,6 +16,7 @@ package org.datavec.spark.transform; + import com.mashape.unirest.http.JsonNode; import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.Unirest; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java index eaf967325..4c14f4c3d 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java @@ -19,10 +19,10 @@ package org.deeplearning4j.nearestneighbor.client; import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.Unirest; import com.mashape.unirest.request.HttpRequest; -import com.mashape.unirest.request.HttpRequestWithBody; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.Setter; +import lombok.val; import org.deeplearning4j.nearestneighbor.model.*; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.serde.base64.Nd4jBase64; @@ -51,6 +51,7 @@ public class NearestNeighborsClient { static { // Only one time + Unirest.setObjectMapper(new ObjectMapper() { private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = new org.nd4j.shade.jackson.databind.ObjectMapper(); @@ -89,7 +90,7 @@ public class NearestNeighborsClient { NearestNeighborRequest request = new NearestNeighborRequest(); request.setInputIndex(index); request.setK(k); - HttpRequestWithBody req = Unirest.post(url + "/knn"); + val req = Unirest.post(url + "/knn"); req.header("accept", "application/json") .header("Content-Type", "application/json").body(request); addAuthHeader(req); @@ -112,7 +113,7 @@ public class NearestNeighborsClient { Base64NDArrayBody base64NDArrayBody = Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build(); - HttpRequestWithBody req = Unirest.post(url + "/knnnew"); + val req = Unirest.post(url + "/knnnew"); req.header("accept", "application/json") .header("Content-Type", "application/json").body(base64NDArrayBody); addAuthHeader(req); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java index 9fcd76981..5313edd22 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java @@ -19,7 +19,6 @@ package org.deeplearning4j.models.word2vec; import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.gson.JsonParser; -import jdk.nashorn.internal.objects.annotations.Property; import lombok.Getter; import lombok.NonNull; import org.apache.commons.compress.compressors.gzip.GzipUtils; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java index b1e53e90a..74ffd3fe8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/ArgmaxAdapter.java @@ -17,7 +17,7 @@ package org.deeplearning4j.nn.adapters; import lombok.val; -import org.deeplearning4j.nn.api.OutputAdapter; +import org.nd4j.adapters.OutputAdapter; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java index b67767050..f0c83669b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/adapters/Regression2dAdapter.java @@ -18,7 +18,7 @@ package org.deeplearning4j.nn.adapters; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.deeplearning4j.nn.api.OutputAdapter; +import org.nd4j.adapters.OutputAdapter; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java index ae7e4b5c1..6b99a92d4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.api; +import org.nd4j.adapters.OutputAdapter; import org.nd4j.linalg.api.ndarray.INDArray; /** diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index cf3fec70c..99f8aeff0 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -24,6 +24,7 @@ import lombok.val; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.bytedeco.javacpp.Pointer; +import org.nd4j.adapters.OutputAdapter; import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.exception.DL4JException; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 719a727f2..731ca398b 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -25,6 +25,7 @@ import lombok.val; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.bytedeco.javacpp.Pointer; +import org.nd4j.adapters.OutputAdapter; import org.nd4j.linalg.dataset.AsyncDataSetIterator;; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.eval.RegressionEvaluation; diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml new file mode 100644 index 000000000..e2d5f901f --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml @@ -0,0 +1,110 @@ + + + + 4.0.0 + jar + + + org.deeplearning4j + deeplearning4j-remote + 1.0.0-SNAPSHOT + + + deeplearning4j-json-server + 1.0.0-SNAPSHOT + deeplearning4j-json-server + + + + junit + junit + ${junit.version} + test + + + + org.projectlombok + lombok + ${lombok.version} + provided + + + + org.nd4j + nd4j-api + ${project.version} + + + + org.nd4j + nd4j-json-client + ${project.version} + + + + org.nd4j + nd4j-json-server + ${project.version} + + + + org.deeplearning4j + deeplearning4j-parallel-wrapper + ${project.version} + + + + org.slf4j + slf4j-api + ${slf4j.version} + + + + ch.qos.logback + logback-core + ${logback.version} + test + + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + + + + test-nd4j-native + + true + + + + org.nd4j + nd4j-native + ${project.version} + test + + + + + + test-nd4j-cuda-10.1 + + false + + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + + + + diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java new file mode 100644 index 000000000..796a51ad0 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java @@ -0,0 +1,205 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote; + +import lombok.*; +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.Model; +import org.deeplearning4j.nn.api.NeuralNetwork; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.parallelism.ParallelInference; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.remote.serving.SameDiffServlet; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; + +/** + * + * @author astoyakin + */ +@Slf4j +@NoArgsConstructor +public class DL4jServlet extends SameDiffServlet { + + protected ParallelInference parallelInference; + protected Model model; + protected boolean parallelEnabled = true; + + public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter inferenceAdapter, + @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer) { + super(inferenceAdapter, serializer, deserializer); + this.parallelInference = parallelInference; + this.model = null; + this.parallelEnabled = true; + } + + public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter inferenceAdapter, + @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer) { + super(inferenceAdapter, serializer, deserializer); + this.model = model; + this.parallelInference = null; + this.parallelEnabled = false; + } + + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException { + String processorReturned = ""; + String path = request.getPathInfo(); + if (path.equals(SERVING_ENDPOINT)) { + val contentType = request.getContentType(); + if (validateRequest(request,response)) { + val stream = request.getInputStream(); + val bufferedReader = new BufferedReader(new InputStreamReader(stream)); + char[] charBuffer = new char[128]; + int bytesRead = -1; + val buffer = new StringBuilder(); + while ((bytesRead = bufferedReader.read(charBuffer)) > 0) { + buffer.append(charBuffer, 0, bytesRead); + } + val requestString = buffer.toString(); + + val mds = inferenceAdapter.apply(deserializer.deserialize(requestString)); + + O result = null; + if (parallelEnabled) { + // process result + result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays())); + } + else { + synchronized(this) { + if (model instanceof ComputationGraph) + result = inferenceAdapter.apply(((ComputationGraph)model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays())); + else if (model instanceof MultiLayerNetwork) { + Preconditions.checkArgument(mds.getFeatures().length > 1 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 1), + "Input data for MultilayerNetwork is invalid!"); + result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false, + mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null)); + } + } + } + processorReturned = serializer.serialize(result); + } + } else { + // we return error otherwise + sendError(request.getRequestURI(), response); + } + try { + val out = response.getWriter(); + out.write(processorReturned); + } catch (IOException e) { + log.error(e.getMessage()); + } + } + + /** + * Creates servlet to serve models + * + * @param type of Input class + * @param type of Output class + * + * @author raver119@gmail.com + * @author astoyakin + */ + public static class Builder { + + private ParallelInference pi; + private Model model; + + private InferenceAdapter inferenceAdapter; + private JsonSerializer serializer; + private JsonDeserializer deserializer; + private int port; + private boolean parallelEnabled = true; + + public Builder(@NonNull ParallelInference pi) { + this.pi = pi; + } + + public Builder(@NonNull Model model) { + this.model = model; + } + + public Builder inferenceAdapter(@NonNull InferenceAdapter inferenceAdapter) { + this.inferenceAdapter = inferenceAdapter; + return this; + } + + /** + * This method is required to specify serializer + * + * @param serializer + * @return + */ + public Builder serializer(@NonNull JsonSerializer serializer) { + this.serializer = serializer; + return this; + } + + /** + * This method allows to specify deserializer + * + * @param deserializer + * @return + */ + public Builder deserializer(@NonNull JsonDeserializer deserializer) { + this.deserializer = deserializer; + return this; + } + + /** + * This method allows to specify port + * + * @param port + * @return + */ + public Builder port(int port) { + this.port = port; + return this; + } + + /** + * This method activates parallel inference + * + * @param parallelEnabled + * @return + */ + public Builder parallelEnabled(boolean parallelEnabled) { + this.parallelEnabled = parallelEnabled; + return this; + } + + public DL4jServlet build() { + return parallelEnabled ? new DL4jServlet(pi, inferenceAdapter, serializer, deserializer) : + new DL4jServlet(model, inferenceAdapter, serializer, deserializer); + } + } +} + + + + diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java new file mode 100644 index 000000000..030231f6e --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java @@ -0,0 +1,392 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote; + +import lombok.NonNull; +import lombok.val; +import org.deeplearning4j.nn.api.Model; +import org.deeplearning4j.nn.api.ModelAdapter; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.parallelism.ParallelInference; +import org.deeplearning4j.parallelism.inference.InferenceMode; +import org.deeplearning4j.parallelism.inference.LoadBalanceMode; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.adapters.InputAdapter; +import org.nd4j.adapters.OutputAdapter; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.remote.SameDiffJsonModelServer; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + + +import java.util.List; + +/** + * This class provides JSON-based model serving ability for Deeplearning4j/SameDiff models + * + * Server url will be http://0.0.0.0:{port}>/v1/serving + * Server only accepts POST requests + * + * @param type of the input class, i.e. String + * @param type of the output class, i.e. Sentiment + * + * @author raver119@gmail.com + * @author astoyakin + */ +public class JsonModelServer extends SameDiffJsonModelServer { + + // all serving goes through ParallelInference + protected ParallelInference parallelInference; + + + protected ModelAdapter modelAdapter; + + // actual models + protected ComputationGraph cgModel; + protected MultiLayerNetwork mlnModel; + + // service stuff + protected InferenceMode inferenceMode; + protected int numWorkers; + + protected boolean enabledParallel = true; + + protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port, String[] orderedInputNodes, String[] orderedOutputNodes) { + super(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes); + } + + protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) { + super(inferenceAdapter, serializer, deserializer, port); + + this.cgModel = cgModel; + this.inferenceMode = inferenceMode; + this.numWorkers = numWorkers; + } + + protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) { + super(inferenceAdapter, serializer, deserializer, port); + + this.mlnModel = mlnModel; + this.inferenceMode = inferenceMode; + this.numWorkers = numWorkers; + } + + protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter inferenceAdapter, JsonSerializer serializer, JsonDeserializer deserializer, int port) { + super(inferenceAdapter, serializer, deserializer, port); + + this.parallelInference = pi; + } + + /** + * This method stops server + * + * @throws Exception + */ + @Override + public void stop() throws Exception { + if (parallelInference != null) + parallelInference.shutdown(); + super.stop(); + } + + /** + * This method starts server + * @throws Exception + */ + @Override + public void start() throws Exception { + // if we're just serving sdModel - we'll just call super. no dl4j functionality required in this case + if (sdModel != null) { + super.start(); + return; + } + Preconditions.checkArgument(cgModel != null || mlnModel != null, "Model serving requires either MultilayerNetwork or ComputationGraph defined"); + + val model = cgModel != null ? (Model) cgModel : (Model) mlnModel; + // PI construction is optional, since we can have it defined + if (enabledParallel) { + if (parallelInference == null) { + Preconditions.checkArgument(numWorkers >= 1, "Number of workers should be >= 1, got " + numWorkers + " instead"); + + parallelInference = new ParallelInference.Builder(model) + .inferenceMode(inferenceMode) + .workers(numWorkers) + .loadBalanceMode(LoadBalanceMode.FIFO) + .batchLimit(16) + .queueLimit(128) + .build(); + } + servingServlet = new DL4jServlet.Builder(parallelInference) + .parallelEnabled(true) + .serializer(serializer) + .deserializer(deserializer) + .inferenceAdapter(inferenceAdapter) + .build(); + } + else { + servingServlet = new DL4jServlet.Builder(model) + .parallelEnabled(false) + .serializer(serializer) + .deserializer(deserializer) + .inferenceAdapter(inferenceAdapter) + .build(); + } + start(port, servingServlet); + } + + /** + * Creates servlet to serve different types of models + * + * @param type of Input class + * @param type of Output class + * + * @author raver119@gmail.com + * @author astoyakin + */ + public static class Builder { + + private SameDiff sdModel; + private ComputationGraph cgModel; + private MultiLayerNetwork mlnModel; + private ParallelInference pi; + + private String[] orderedInputNodes; + private String[] orderedOutputNodes; + + private InferenceAdapter inferenceAdapter; + private JsonSerializer serializer; + private JsonDeserializer deserializer; + + private InputAdapter inputAdapter; + private OutputAdapter outputAdapter; + + private int port; + + private boolean parallelMode = true; + + // these fields actually require defaults + private InferenceMode inferenceMode = InferenceMode.BATCHED; + private int numWorkers = Nd4j.getAffinityManager().getNumberOfDevices(); + + public Builder(@NonNull SameDiff sdModel) { + this.sdModel = sdModel; + } + + public Builder(@NonNull MultiLayerNetwork mlnModel) { + this.mlnModel = mlnModel; + } + + public Builder(@NonNull ComputationGraph cgModel) { + this.cgModel = cgModel; + } + + public Builder(@NonNull ParallelInference pi) { + this.pi = pi; + } + + /** + * This method defines InferenceAdapter implementation, which will be used to convert object of Input type to the set of INDArray(s), and for conversion of resulting INDArray(s) into object of Output type + * @param inferenceAdapter + * @return + */ + public Builder inferenceAdapter(@NonNull InferenceAdapter inferenceAdapter) { + this.inferenceAdapter = inferenceAdapter; + return this; + } + + /** + * This method allows you to specify InputAdapter to be used for inference + * + * PLEASE NOTE: This method is optional, and will require OutputAdapter defined + * @param inputAdapter + * @return + */ + public Builder inputAdapter(@NonNull InputAdapter inputAdapter) { + this.inputAdapter = inputAdapter; + return this; + } + + /** + * This method allows you to specify OutputtAdapter to be used for inference + * + * PLEASE NOTE: This method is optional, and will require InputAdapter defined + * @param outputAdapter + * @return + */ + public Builder outputAdapter(@NonNull OutputAdapter outputAdapter) { + this.outputAdapter = outputAdapter; + return this; + } + + /** + * This method allows you to specify serializer + * + * @param serializer + * @return + */ + public Builder outputSerializer(@NonNull JsonSerializer serializer) { + this.serializer = serializer; + return this; + } + + /** + * This method allows you to specify deserializer + * + * @param deserializer + * @return + */ + public Builder inputDeserializer(@NonNull JsonDeserializer deserializer) { + this.deserializer = deserializer; + return this; + } + + /** + * This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details + * + * @param inferenceMode + * @return + */ + public Builder inferenceMode(@NonNull InferenceMode inferenceMode) { + this.inferenceMode = inferenceMode; + return this; + } + + /** + * This method allows you to specify number of worker threads for ParallelInference + * + * @param numWorkers + * @return + */ + public Builder numWorkers(int numWorkers) { + this.numWorkers = numWorkers; + return this; + } + + /** + * This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models + * + * PLEASE NOTE: this argument only used for SameDiff models + * @param args + * @return + */ + public Builder orderedInputNodes(String... args) { + orderedInputNodes = args; + return this; + } + + /** + * This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models + * + * PLEASE NOTE: this argument only used for SameDiff models + * @param args + * @return + */ + public Builder orderedInputNodes(@NonNull List args) { + orderedInputNodes = args.toArray(new String[args.size()]); + return this; + } + + /** + * This method allows you to specify output nodes + * + * PLEASE NOTE: this argument only used for SameDiff models + * @param args + * @return + */ + public Builder orderedOutputNodes(String... args) { + Preconditions.checkArgument(args != null && args.length > 0, "OutputNodes should contain at least 1 element"); + orderedOutputNodes = args; + return this; + } + + /** + * This method allows you to specify output nodes + * + * PLEASE NOTE: this argument only used for SameDiff models + * @param args + * @return + */ + public Builder orderedOutputNodes(@NonNull List args) { + Preconditions.checkArgument(args.size() > 0, "OutputNodes should contain at least 1 element"); + orderedOutputNodes = args.toArray(new String[args.size()]); + return this; + } + + /** + * This method allows you to specify http port + * + * PLEASE NOTE: port must be free and be in range regular TCP/IP ports range + * @param port + * @return + */ + public Builder port(int port) { + this.port = port; + return this; + } + + /** + * This method switches on ParallelInference usage + * @param - true - to use ParallelInference, false - to use ComputationGraph or + * MultiLayerNetwork directly + * + * PLEASE NOTE: this doesn't apply to SameDiff models + * + * @throws Exception + */ + public Builder parallelMode(boolean enable) { + this.parallelMode = enable; + return this; + } + + public JsonModelServer build() { + if (inferenceAdapter == null) { + if (inputAdapter != null && outputAdapter != null) { + inferenceAdapter = new InferenceAdapter() { + @Override + public MultiDataSet apply(I input) { + return inputAdapter.apply(input); + } + + @Override + public O apply(INDArray... outputs) { + return outputAdapter.apply(outputs); + } + }; + } else + throw new IllegalArgumentException("Either InferenceAdapter or InputAdapter + OutputAdapter should be configured"); + } + + if (sdModel != null) { + Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "For SameDiff model serving OutputNodes should be defined"); + return new JsonModelServer(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes); + } else if (cgModel != null) + return new JsonModelServer(cgModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers); + else if (mlnModel != null) + return new JsonModelServer(mlnModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers); + else if (pi != null) + return new JsonModelServer(pi, inferenceAdapter, serializer, deserializer, port); + else + throw new IllegalStateException("No models were defined for JsonModelServer"); + } + } + +} diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java new file mode 100644 index 000000000..aa353f307 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java @@ -0,0 +1,749 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.graph.MergeVertex; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.parallelism.inference.InferenceMode; +import org.deeplearning4j.remote.helpers.House; +import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter; +import org.deeplearning4j.remote.helpers.PredictedPrice; +import org.junit.After; +import org.junit.Test; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.Sgd; +import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.nd4j.remote.clients.JsonRemoteInference; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.shade.jackson.databind.ObjectMapper; + + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE; +import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL; +import static org.junit.Assert.*; + +@Slf4j +public class JsonModelServerTest { + private static final MultiLayerNetwork model; + private final int PORT = 18080; + + static { + val conf = new NeuralNetConfiguration.Builder() + .seed(119) + .updater(new Adam(0.119f)) + .weightInit(WeightInit.XAVIER) + .list() + .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(10).build()) + .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.SIGMOID).nIn(10).nOut(1).build()) + .build(); + + model = new MultiLayerNetwork(conf); + model.init(); + } + + @After + public void pause() throws Exception { + // TODO: the same port was used in previous test and not accessible immediately. Might be better solution. + TimeUnit.SECONDS.sleep(2); + } + + + @Test + public void testStartStopParallel() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 1,4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val serverDL = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .numWorkers(1) + .inferenceMode(SEQUENTIAL) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .port(PORT) + .build(); + + val serverSD = new JsonModelServer.Builder(sd) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .port(PORT+1) + .build(); + try { + serverDL.start(); + serverSD.start(); + + val clientDL = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + PredictedPrice price = clientDL.predict(house); + long timeStart = System.currentTimeMillis(); + price = clientDL.predict(house); + long timeStop = System.currentTimeMillis(); + log.info("Time spent: {} ms", timeStop - timeStart); + assertNotNull(price); + assertEquals((float) 0.421444, price.getPrice(), 1e-5); + + val clientSD = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving") + .build(); + + PredictedPrice price2 = clientSD.predict(house); + timeStart = System.currentTimeMillis(); + price = clientSD.predict(house); + timeStop = System.currentTimeMillis(); + log.info("Time spent: {} ms", timeStop - timeStart); + assertNotNull(price); + assertEquals((float) 3.0, price.getPrice(), 1e-5); + + } + finally { + serverSD.stop(); + serverDL.stop(); + } + } + + @Test + public void testStartStopSequential() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 1,4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val serverDL = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .numWorkers(1) + .inferenceMode(SEQUENTIAL) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .port(PORT) + .build(); + + val serverSD = new JsonModelServer.Builder(sd) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .port(PORT+1) + .build(); + + serverDL.start(); + serverDL.stop(); + + serverSD.start(); + serverSD.stop(); + } + + @Test + public void basicServingTestForSD() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 1,4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new JsonModelServer.Builder(sd) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .port(PORT) + .build(); + + try { + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house); + + val timeStart = System.currentTimeMillis(); + price = client.predict(house); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + + assertNotNull(price); + assertEquals((float) district + 1.0f, price.getPrice(), 1e-5); + } + finally { + server.stop(); + } + } + + @Test + public void basicServingTestForDLSynchronized() throws Exception { + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .numWorkers(1) + .inferenceMode(INPLACE) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .port(PORT) + .build(); + + try { + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house1 = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + House house2 = House.builder().area(50).bathrooms(1).bedrooms(2).district(district).build(); + House house3 = House.builder().area(80).bathrooms(1).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house1); + + val timeStart = System.currentTimeMillis(); + PredictedPrice price1 = client.predict(house1); + PredictedPrice price2 = client.predict(house2); + PredictedPrice price3 = client.predict(house3); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + + assertNotNull(price); + assertEquals((float) 0.421444, price.getPrice(), 1e-5); + + } finally { + server.stop(); + } + } + + @Test + public void basicServingTestForDL() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .numWorkers(1) + .inferenceMode(SEQUENTIAL) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .port(PORT) + .parallelMode(false) + .build(); + + try { + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house); + + val timeStart = System.currentTimeMillis(); + price = client.predict(house); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + + assertNotNull(price); + assertEquals((float) 0.421444, price.getPrice(), 1e-5); + + } finally { + server.stop(); + } + } + + @Test + public void testDeserialization_1() { + String request = "{\"bedrooms\":3,\"area\":100,\"district\":2,\"bathrooms\":2}"; + val deserializer = new House.HouseDeserializer(); + val result = deserializer.deserialize(request); + assertEquals(2, result.getDistrict()); + assertEquals(100, result.getArea()); + assertEquals(2, result.getBathrooms()); + assertEquals(3, result.getBedrooms()); + + } + + @Test + public void testDeserialization_2() { + String request = "{\"price\":1}"; + val deserializer = new PredictedPrice.PredictedPriceDeserializer(); + val result = deserializer.deserialize(request); + assertEquals(1.0, result.getPrice(), 1e-4); + } + + @Test(expected = NullPointerException.class) + public void negativeServingTest_1() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(null) + .port(18080) + .build(); + } + + @Test //(expected = NullPointerException.class) + public void negativeServingTest_2() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .port(PORT) + .build(); + + } + + @Test(expected = IOException.class) + public void negativeServingTest_3() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .inferenceMode(SEQUENTIAL) + .numWorkers(1) + .port(PORT) + .build(); + + try { + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new JsonDeserializer() { + @Override + public PredictedPrice deserialize(String json) { + return null; + } + }) + .endpointAddress("http://localhost:18080/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house); + } finally { + server.stop(); + } + } + + @Test + public void asyncServingTest() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .inferenceMode(SEQUENTIAL) + .numWorkers(1) + .port(PORT) + .build(); + + try { + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + val timeStart = System.currentTimeMillis(); + Future price = client.predictAsync(house); + assertNotNull(price); + assertEquals((float) 0.421444, price.get().getPrice(), 1e-5); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + } + finally { + server.stop(); + } + } + + @Test + public void negativeAsyncTest() throws Exception { + + val server = new JsonModelServer.Builder(model) + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .inferenceMode(InferenceMode.BATCHED) + .numWorkers(1) + .port(PORT) + .build(); + + try { + server.start(); + + // Fake deserializer to test failure + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new JsonDeserializer() { + @Override + public PredictedPrice deserialize(String json) { + return null; + } + }) + .endpointAddress("http://localhost:" + PORT + "/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + val timeStart = System.currentTimeMillis(); + try { + Future price = client.predictAsync(house); + assertNotNull(price); + assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + } catch (ExecutionException e) { + assertTrue(e.getMessage().contains("Deserialization failed")); + } + } finally { + server.stop(); + } + } + + + @Test + public void testSameDiffMnist() throws Exception { + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10)); + SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b)); + + val server = new JsonModelServer.Builder(sd) + .outputSerializer( new IntSerde()) + .inputDeserializer(new FloatSerde()) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(float[] input) { + return new MultiDataSet(Nd4j.create(input, 1, input.length), null); + } + + @Override + public Integer apply(INDArray... nnOutput) { + return nnOutput[0].argMax().getInt(0); + } + }) + .orderedInputNodes("in") + .orderedOutputNodes("softmax") + .port(PORT+1) + .build(); + + val client = JsonRemoteInference.builder() + .endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving") + .outputDeserializer(new IntSerde()) + .inputSerializer( new FloatSerde()) + .build(); + + try{ + server.start(); + for( int i=0; i<10; i++ ){ + INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28*28); + INDArray exp = sd.output(Collections.singletonMap("in", f), "softmax").get("softmax"); + float[] fArr = f.toFloatVector(); + int out = client.predict(fArr); + assertEquals(exp.argMax().getInt(0), out); + } + } finally { + server.stop(); + } + } + + @Test + public void testMlnMnist() throws Exception { + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .list() + .layer(new DenseLayer.Builder().nIn(784).nOut(10).build()) + .layer(new LossLayer.Builder().activation(Activation.SOFTMAX).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + val server = new JsonModelServer.Builder(net) + .outputSerializer( new IntSerde()) + .inputDeserializer(new FloatSerde()) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(float[] input) { + return new MultiDataSet(Nd4j.create(input, 1, input.length), null); + } + + @Override + public Integer apply(INDArray... nnOutput) { + return nnOutput[0].argMax().getInt(0); + } + }) + .orderedInputNodes("in") + .orderedOutputNodes("softmax") + .port(PORT + 1) + .inferenceMode(SEQUENTIAL) + .numWorkers(2) + .build(); + + val client = JsonRemoteInference.builder() + .endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving") + .outputDeserializer(new IntSerde()) + .inputSerializer( new FloatSerde()) + .build(); + + try { + server.start(); + for (int i = 0; i < 10; i++) { + INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28 * 28); + INDArray exp = net.output(f); + float[] fArr = f.toFloatVector(); + int out = client.predict(fArr); + assertEquals(exp.argMax().getInt(0), out); + } + } catch (Exception e){ + e.printStackTrace(); + throw e; + } finally { + server.stop(); + } + } + + @Test + public void testCompGraph() throws Exception { + + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .graphBuilder() + .addInputs("input1", "input2") + .addLayer("L1", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input1") + .addLayer("L2", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input2") + .addVertex("merge", new MergeVertex(), "L1", "L2") + .addLayer("out", new OutputLayer.Builder().nIn(4+4).nOut(3).build(), "merge") + .setOutputs("out") + .build(); + + ComputationGraph net = new ComputationGraph(conf); + net.init(); + + val server = new JsonModelServer.Builder(net) + .outputSerializer( new IntSerde()) + .inputDeserializer(new FloatSerde()) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(float[] input) { + return new MultiDataSet(Nd4j.create(input, 1, input.length), null); + } + + @Override + public Integer apply(INDArray... nnOutput) { + return nnOutput[0].argMax().getInt(0); + } + }) + .orderedInputNodes("in") + .orderedOutputNodes("softmax") + .port(PORT + 1) + .inferenceMode(SEQUENTIAL) + .numWorkers(2) + .parallelMode(false) + .build(); + + val client = JsonRemoteInference.builder() + .endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving") + .outputDeserializer(new IntSerde()) + .inputSerializer( new FloatSerde()) + .build(); + + try { + server.start(); + //client.predict(new float[]{0.0f, 1.0f, 2.0f}); + } catch (Exception e){ + e.printStackTrace(); + throw e; + } finally { + server.stop(); + } + } + + @Test + public void testCompGraph_1() throws Exception { + + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .updater(new Sgd(0.01)) + .graphBuilder() + .addInputs("input") + .addLayer("L1", new DenseLayer.Builder().nIn(8).nOut(4).build(), "input") + .addLayer("out1", new OutputLayer.Builder() + .lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) + .nIn(4).nOut(3).build(), "L1") + .addLayer("out2", new OutputLayer.Builder() + .lossFunction(LossFunctions.LossFunction.MSE) + .nIn(4).nOut(2).build(), "L1") + .setOutputs("out1","out2") + .build(); + + final ComputationGraph net = new ComputationGraph(conf); + net.init(); + + val server = new JsonModelServer.Builder(net) + .outputSerializer( new IntSerde()) + .inputDeserializer(new FloatSerde()) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(float[] input) { + return new MultiDataSet(Nd4j.create(input, 1, input.length), null); + } + + @Override + public Integer apply(INDArray... nnOutput) { + return nnOutput[0].argMax().getInt(0); + } + }) + .orderedInputNodes("input") + .orderedOutputNodes("out") + .port(PORT + 1) + .inferenceMode(SEQUENTIAL) + .numWorkers(2) + .parallelMode(false) + .build(); + + val client = JsonRemoteInference.builder() + .endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving") + .outputDeserializer(new IntSerde()) + .inputSerializer( new FloatSerde()) + .build(); + + try { + server.start(); + val result = client.predict(new float[]{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + assertNotNull(result); + } catch (Exception e){ + e.printStackTrace(); + throw e; + } finally { + server.stop(); + } + } + + private static class FloatSerde implements JsonSerializer, JsonDeserializer{ + private final ObjectMapper om = new ObjectMapper(); + + @Override + public float[] deserialize(String json) { + try { + return om.readValue(json, FloatHolder.class).getFloats(); + } catch (IOException e){ + throw new RuntimeException(e); + } + } + + @Override + public String serialize(float[] o) { + try{ + return om.writeValueAsString(new FloatHolder(o)); + } catch (IOException e){ + throw new RuntimeException(e); + } + } + + //Use float holder so Jackson does ser/de properly (no "{}" otherwise) + @AllArgsConstructor @NoArgsConstructor @Data + private static class FloatHolder { + private float[] floats; + } + } + + private static class IntSerde implements JsonSerializer, JsonDeserializer { + private final ObjectMapper om = new ObjectMapper(); + + @Override + public Integer deserialize(String json) { + try { + return om.readValue(json, Integer.class); + } catch (IOException e){ + throw new RuntimeException(e); + } + } + + @Override + public String serialize(Integer o) { + try{ + return om.writeValueAsString(o); + } catch (IOException e){ + throw new RuntimeException(e); + } + } + } +} \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java new file mode 100644 index 000000000..1b347d112 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java @@ -0,0 +1,133 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote; + +import lombok.val; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.HttpClientBuilder; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +public class ServletTest { + + private JsonModelServer server; + + @Before + public void setUp() throws Exception { + val sd = SameDiff.create(); + server = new JsonModelServer.Builder(sd) + .port(8080) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(String input) { + return null; + } + + @Override + public String apply(INDArray... nnOutput) { + return null; + } + }) + .outputSerializer(new JsonSerializer() { + @Override + public String serialize(String o) { + return ""; + } + }) + .inputDeserializer(new JsonDeserializer() { + @Override + public String deserialize(String json) { + return ""; + } + }) + .orderedInputNodes("input") + .orderedOutputNodes("output") + .build(); + + server.start(); + //server.join(); + } + + @After + public void tearDown() throws Exception { + server.stop(); + } + + @Test + public void getEndpoints() throws IOException { + val request = new HttpGet( "http://localhost:8080/v1" ); + request.setHeader("Content-type", "application/json"); + + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + @Test + public void testContentTypeGet() throws IOException { + val request = new HttpGet( "http://localhost:8080/v1" ); + request.setHeader("Content-type", "text/plain"); + + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(415, response.getStatusLine().getStatusCode()); + } + + @Test + public void testContentTypePost() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving"); + request.setHeader("Content-type", "text/plain"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(415, response.getStatusLine().getStatusCode()); + } + + @Test + public void postForServing() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving"); + request.setHeader("Content-type", "application/json"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(500, response.getStatusLine().getStatusCode()); + } + + @Test + public void testNotFoundPost() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving/some"); + request.setHeader("Content-type", "application/json"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(404, response.getStatusLine().getStatusCode()); + } + + @Test + public void testNotFoundGet() throws Exception { + val requestGet = new HttpGet( "http://localhost:8080/v1/not_found" ); + requestGet.setHeader("Content-type", "application/json"); + + val responseGet = HttpClientBuilder.create().build().execute( requestGet ); + assertEquals(404, responseGet.getStatusLine().getStatusCode()); + } + +} diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java new file mode 100644 index 000000000..d66c8bae5 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote.helpers; + +import com.google.gson.Gson; +import lombok.*; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +@Data +@Builder +@AllArgsConstructor +@NoArgsConstructor +public class House { + private int district; + private int bedrooms; + private int bathrooms; + private int area; + + + public static class HouseSerializer implements JsonSerializer { + @Override + public String serialize(@NonNull House o) { + return new Gson().toJson(o); + } + } + + public static class HouseDeserializer implements JsonDeserializer { + @Override + public House deserialize(@NonNull String json) { + return new Gson().fromJson(json, House.class); + } + } +} diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java new file mode 100644 index 000000000..82976a3da --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote.helpers; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; + +@Slf4j +public class HouseToPredictedPriceAdapter implements InferenceAdapter { + + @Override + public MultiDataSet apply(@NonNull House input) { + // we just create vector array with shape[4] and assign it's value to the district value + return new MultiDataSet(Nd4j.create(DataType.FLOAT, 1, 4).assign(input.getDistrict()), null); + } + + @Override + public PredictedPrice apply(INDArray... nnOutput) { + return new PredictedPrice(nnOutput[0].getFloat(0)); + } +} diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java new file mode 100644 index 000000000..c4024bb1d --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.remote.helpers; + +import com.google.gson.Gson; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class PredictedPrice { + private float price; + + public static class PredictedPriceSerializer implements JsonSerializer { + @Override + public String serialize(@NonNull PredictedPrice o) { + return new Gson().toJson(o); + } + } + + public static class PredictedPriceDeserializer implements JsonDeserializer { + @Override + public PredictedPrice deserialize(@NonNull String json) { + return new Gson().fromJson(json, PredictedPrice.class); + } + } +} diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml new file mode 100644 index 000000000..59b35644e --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml @@ -0,0 +1,48 @@ + + + + + + + + logs/application.log + + %date - [%level] - from %logger in %thread + %n%message%n%xException%n + + + + + + %logger{15} - %message%n%xException{5} + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-remote/pom.xml b/deeplearning4j/deeplearning4j-remote/pom.xml new file mode 100644 index 000000000..1fc937e10 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/pom.xml @@ -0,0 +1,30 @@ + + + + 4.0.0 + pom + + + deeplearning4j-json-server + + + + org.deeplearning4j + deeplearning4j + 1.0.0-SNAPSHOT + + + deeplearning4j-remote + 1.0.0-SNAPSHOT + deeplearning4j-remote + + + + test-nd4j-native + + + test-nd4j-cuda-10.1 + + + diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java index 48966a57d..149a122f2 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java @@ -21,7 +21,6 @@ import lombok.*; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.ModelAdapter; -import org.deeplearning4j.nn.api.OutputAdapter; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java index 022b28076..0e2fd339a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java @@ -22,7 +22,6 @@ import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.ModelAdapter; -import org.deeplearning4j.nn.api.OutputAdapter; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index e13b62cda..b8b70befb 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -144,6 +144,7 @@ dl4j-perf dl4j-integration-tests deeplearning4j-common + deeplearning4j-remote diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index 2e3c51091..8e940bedb 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -163,9 +163,9 @@ if(CUDA_BLAS) if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 10 if ("${COMPUTE}" STREQUAL "all") if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) endif() else() list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) @@ -173,24 +173,24 @@ if(CUDA_BLAS) elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9 if ("${COMPUTE}" STREQUAL "all") if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) endif() else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8.0 if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() else() if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 ) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 ) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() endif() @@ -205,34 +205,34 @@ if(CUDA_BLAS) message("CUDA 10 Debug build") if ("${COMPUTE}" STREQUAL "all") if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) elseif() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70) endif() else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=compute_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9 if ("${COMPUTE}" STREQUAL "all") if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) elseif() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70) endif() else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8 if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() else() if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() endif() endif() @@ -249,7 +249,7 @@ if(CUDA_BLAS) file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h) file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h) file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h) - file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h) + file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/impl/*.cpp ../include/loops/*.h) file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu) if (NOT BUILD_TESTS) diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index a0529d106..726549415 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -344,7 +344,7 @@ bool NDArray::isS() const { ////////////////////////////////////////////////////////////////////////// bool NDArray::isR() const { auto xType = ArrayOptions::dataType(this->_shapeInfo); - return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8; + return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8 || xType == BFLOAT16; } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 87555a303..9ce90176f 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -1769,6 +1769,17 @@ ND4J_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr); ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut); ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut); +typedef nd4j::LaunchContext OpaqueLaunchContext; + +ND4J_EXPORT OpaqueLaunchContext* defaultLaunchContext(); +ND4J_EXPORT Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); + } #endif //NATIVEOPERATIONS_NATIVEOPS_H diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 74bd072c8..f5d4996e4 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -2985,6 +2985,38 @@ const char* runFullBenchmarkSuit(bool printOut) { return chars; } +nd4j::LaunchContext* defaultLaunchContext() { + return LaunchContext::defaultContext(); +} + +Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) { + return nullptr; +} + BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); diff --git a/libnd4j/blas/cuda/NDArray.cu b/libnd4j/blas/cuda/NDArray.cu index 67173c971..126837ad9 100644 --- a/libnd4j/blas/cuda/NDArray.cu +++ b/libnd4j/blas/cuda/NDArray.cu @@ -356,7 +356,7 @@ void NDArray::tile(const std::vector& reps, NDArray& target) const { auto stream = getContext()->getCudaStream(); prepareSpecialUse({&target}, {this}); - BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES); registerSpecialUse({&target}, {this}); } @@ -375,7 +375,7 @@ void NDArray::tile(NDArray& target) const { auto stream = getContext()->getCudaStream(); prepareSpecialUse({&target}, {this}); - BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES); registerSpecialUse({&target}, {this}); } @@ -434,7 +434,7 @@ void NDArray::repeat(int dimension, NDArray& target) const { NDArray::prepareSpecialUse({&target}, {this}); auto stream = getContext()->getCudaStream(); - BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), repeatKernelHH, (getSpecialBuffer(), target.getSpecialBuffer(), numTads, lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), repeatKernelHH, (getSpecialBuffer(), target.getSpecialBuffer(), numTads, lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES); NDArray::registerSpecialUse({&target}, {this}); } diff --git a/libnd4j/blas/cuda/NDArrayLambda.hpp b/libnd4j/blas/cuda/NDArrayLambda.hpp index f7846e121..bf9848981 100644 --- a/libnd4j/blas/cuda/NDArrayLambda.hpp +++ b/libnd4j/blas/cuda/NDArrayLambda.hpp @@ -23,6 +23,14 @@ #include #include +static Nd4jLong __device__ __noinline__ __getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) { + return shape::getIndexOffset(index, shapeInfo, length); +} + +static Nd4jLong __device__ __noinline__ __length(Nd4jLong *shapeInfo) { + return shape::length(shapeInfo); +} + template static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda); template static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda); template static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda); @@ -86,7 +94,7 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL auto xOrder = shape::order(xShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = shape::length(zShapeInfo); + auto zLength = __length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -95,8 +103,8 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL z[e * zEws] = lambda(x[e * xEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength); - auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength); + auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); + auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); z[zOffset] = lambda(x[xOffset]); } @@ -115,7 +123,7 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz auto xOrder = shape::order(xShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = shape::length(zShapeInfo); + auto zLength = __length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -124,8 +132,8 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz z[e * zEws] = lambda(e, x[e * xEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength); - auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength); + auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); + auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); z[zOffset] = lambda(e, x[xOffset]); } @@ -147,7 +155,7 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = shape::length(zShapeInfo); + auto zLength = __length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -156,9 +164,9 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength); - auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength); - auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength); + auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); + auto yOffset = __getIndexOffset(e, yShapeInfo, zLength); + auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); z[zOffset] = lambda(e, x[xOffset], y[yOffset]); } @@ -180,7 +188,7 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = shape::length(zShapeInfo); + auto zLength = __length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -189,9 +197,9 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v z[e * zEws] = lambda(x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength); - auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength); - auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength); + auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); + auto yOffset = __getIndexOffset(e, yShapeInfo, zLength); + auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); z[zOffset] = lambda(x[xOffset], y[yOffset]); } @@ -216,7 +224,7 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = shape::length(zShapeInfo); + auto zLength = __length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -225,10 +233,10 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto wOffset = shape::getIndexOffset(e, wShapeInfo, zLength); - auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength); - auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength); - auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength); + auto wOffset = __getIndexOffset(e, wShapeInfo, zLength); + auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); + auto yOffset = __getIndexOffset(e, yShapeInfo, zLength); + auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]); } diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 0d441fe5e..af9fc6776 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -1691,11 +1692,7 @@ void setOmpMinThreads(int threads) { } int getDevice() { - int curDevice = -1; - - cudaGetDevice(&curDevice); - - return curDevice; + return nd4j::AffinityManager::currentDeviceId(); } void setElementThreshold(int num) { @@ -2391,8 +2388,8 @@ void sortByValue(Nd4jPointer *extraPointers, auto xLength = shape::length(xShapeInfo); auto xEWS = shape::elementWiseStride(xShapeInfo); - auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); - auto yType = nd4j::ArrayOptions::dataType(yShapeInfo); + auto xType = nd4j::ArrayOptions::dataType(yShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(xShapeInfo); // check if xLength is a power of 2, and use bitonic sort, if that's the case @@ -2406,7 +2403,7 @@ void sortByValue(Nd4jPointer *extraPointers, for (int k = 2; k <= xLength; k = 2*k) { for (int j = k >> 1; j > 0; j = j >> 1) { - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); } } } else { @@ -2430,7 +2427,7 @@ void sortByValue(Nd4jPointer *extraPointers, int rev = 0; do{ int half = n >> 1; - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES); n>>=1; rev = 1; } while(n > 1); @@ -3342,6 +3339,7 @@ Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) { nd4j::graph::Context* createGraphContext(int nodeId) { return new nd4j::graph::Context(nodeId); } + nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) { return &ptr->randomGenerator(); } @@ -3460,3 +3458,35 @@ const char* runFullBenchmarkSuit(bool printOut) { Nd4jLong getCachedMemory(int deviceId) { return nd4j::ConstantHelper::getInstance()->getCachedAmount(deviceId); } + +nd4j::LaunchContext* defaultLaunchContext() { + return LaunchContext::defaultContext(); +} + +Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) { + return lc->getScalarPointer(); +} + +Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) { + return lc->getReductionPointer(); +} + +Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) { + return lc->getAllocationPointer(); +} + +Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) { + return lc->getCudaStream(); +} + +Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) { + return lc->getCudaSpecialStream(); +} + +Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) { + return lc->getCublasHandle(); +} + +Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) { + return lc->getCusolverHandle(); +} \ No newline at end of file diff --git a/libnd4j/include/execution/AffinityManager.h b/libnd4j/include/execution/AffinityManager.h new file mode 100644 index 000000000..463d6942e --- /dev/null +++ b/libnd4j/include/execution/AffinityManager.h @@ -0,0 +1,46 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_AFFINITYMANAGER_H +#define LIBND4J_AFFINITYMANAGER_H + +#include +#include +#include +#include + +namespace nd4j { + class ND4J_EXPORT AffinityManager { + private: + static std::atomic _lastDevice; + static int _numberOfDevices; + static std::mutex _currentMutex; + static std::mutex _numberMutex; + + public: + static int currentNativeDeviceId(); + static int currentDeviceId(); + static int numberOfDevices(); + static void setCurrentDevice(int deviceId); + static void setCurrentNativeDevice(int deviceId); + }; +} + +#endif //DEV_TESTS_AFFINITYMANAGER_H diff --git a/libnd4j/include/execution/ContextBuffers.h b/libnd4j/include/execution/ContextBuffers.h new file mode 100644 index 000000000..77b8d4ca3 --- /dev/null +++ b/libnd4j/include/execution/ContextBuffers.h @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_CONTEXTBUFFERS_H +#define LIBND4J_CONTEXTBUFFERS_H + +#include +#include + +namespace nd4j { + class ND4J_EXPORT ContextBuffers { + private: + void* _reductionPointer; + void* _scalarPointer; + void* _allocationPointer; + bool _allocated = true; + + int _deviceId = -1; + + void initialize(); + public: + ContextBuffers(); + ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner = false); + ~ContextBuffers(); + + void* reductionBuffer(); + void* scalarBuffer(); + void* allocationBuffer(); + + void setReductionBuffer(void* pointer); + void setScalarBuffer(void* pointer); + void setAllocationBuffer(void* pointer); + + void triggerOwnership(bool isOwner); + + int deviceId(); + }; +} + + +#endif //DEV_TESTS_CONTEXTBUFFERS_H diff --git a/libnd4j/include/execution/LaunchContext.h b/libnd4j/include/execution/LaunchContext.h index 853a970d2..02b772415 100644 --- a/libnd4j/include/execution/LaunchContext.h +++ b/libnd4j/include/execution/LaunchContext.h @@ -35,6 +35,8 @@ #include #include #include +#include +#include @@ -44,49 +46,44 @@ class ND4J_EXPORT LaunchContext { private: static std::vector> _contexts; + static std::mutex _mutex; #ifdef __CUDABLAS__ #ifndef __JAVACPP_HACK__ - void* _reductionPointer; - void* _scalarPointer; - int* _allocationPointer; - cudaStream_t *_cudaStream = nullptr; - cudaStream_t *_cudaSpecialStream = nullptr; - void *_cublasHandle = nullptr; + cudaStream_t* _cudaStream = nullptr; + cudaStream_t* _cudaSpecialStream = nullptr; + void* _cublasHandle = nullptr; + void* _cusolverHandle = nullptr; #endif // JCPP bool _isAllocated = false; #endif // CUDA - nd4j::memory::Workspace* _workspace = nullptr; - int _deviceID = 0; + nd4j::memory::Workspace* _workspace = nullptr; + int _deviceID = 0; + public: #ifdef __CUDABLAS__ #ifndef __JAVACPP_HACK__ LaunchContext(cudaStream_t* cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer = nullptr, void* scalarPointer = nullptr, int* allocationPointer = nullptr); - FORCEINLINE void* getReductionPointer () const {return _reductionPointer;}; + void* getReductionPointer () const; + void* getScalarPointer() const; + int* getAllocationPointer() const; + void* getCublasHandle() const; + void* getCusolverHandle() const; + cudaStream_t* getCudaStream() const; + cudaStream_t* getCudaSpecialStream() const; - FORCEINLINE void* getScalarPointer() const {return _scalarPointer;}; - - FORCEINLINE int* getAllocationPointer() const {return _allocationPointer;}; - - FORCEINLINE void* getCublasHandle() const {return _cublasHandle;}; - FORCEINLINE cudaStream_t* getCudaStream() const {return _cudaStream;}; - FORCEINLINE cudaStream_t* getCudaSpecialStream() const {return _cudaSpecialStream;}; - - FORCEINLINE void setReductionPointer (void* reductionPointer) {_reductionPointer = reductionPointer;}; - - FORCEINLINE void setScalarPointer(void* scalarPointer) {_scalarPointer = scalarPointer;}; - - FORCEINLINE void setAllocationPointer(int* allocationPointer) {_allocationPointer = allocationPointer;}; - - FORCEINLINE void setCudaStream(cudaStream_t* cudaStream) {_cudaStream = cudaStream;}; - FORCEINLINE void setCudaSpecialStream(cudaStream_t* cudaStream) {_cudaSpecialStream = cudaStream;}; - FORCEINLINE void setCublasHandle(void *handle) {_cublasHandle = handle; }; + void setReductionPointer (void* reductionPointer); + void setScalarPointer(void* scalarPointer); + void setAllocationPointer(int* allocationPointer); + void setCudaStream(cudaStream_t* cudaStream); + void setCudaSpecialStream(cudaStream_t* cudaStream); + void setCublasHandle(void *handle); #endif // JCPP diff --git a/libnd4j/include/execution/cpu/AffinityManager.cpp b/libnd4j/include/execution/cpu/AffinityManager.cpp new file mode 100644 index 000000000..7927982a6 --- /dev/null +++ b/libnd4j/include/execution/cpu/AffinityManager.cpp @@ -0,0 +1,43 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace nd4j { + int AffinityManager::currentDeviceId() { + return 0; + } + + int AffinityManager::currentNativeDeviceId() { + return 0; + } + + int AffinityManager::numberOfDevices() { + return 1; + } + + void AffinityManager::setCurrentDevice(int deviceId) { + // no-op + } + + void AffinityManager::setCurrentNativeDevice(int deviceId) { + // no-op + } +} \ No newline at end of file diff --git a/libnd4j/include/execution/cpu/ContextBuffers.cpp b/libnd4j/include/execution/cpu/ContextBuffers.cpp new file mode 100644 index 000000000..d385548d0 --- /dev/null +++ b/libnd4j/include/execution/cpu/ContextBuffers.cpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// +#include +#include + +namespace nd4j { + ContextBuffers::ContextBuffers() { + _deviceId = AffinityManager::currentDeviceId(); + } + + ContextBuffers::~ContextBuffers() { + // no-op + } + + ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) { + _reductionPointer = rPointer; + _scalarPointer = sPointer; + _allocationPointer = aPointer; + _allocated = isOwner; + } + + void ContextBuffers::initialize() { + // no-op + } + + void* ContextBuffers::reductionBuffer() { + return _reductionPointer; + } + + void* ContextBuffers::scalarBuffer() { + return _scalarPointer; + } + + void* ContextBuffers::allocationBuffer() { + return _allocationPointer; + } + + void ContextBuffers::setReductionBuffer(void* pointer) { + _reductionPointer = pointer; + } + + void ContextBuffers::setScalarBuffer(void* pointer) { + _scalarPointer = pointer; + } + + void ContextBuffers::setAllocationBuffer(void* pointer) { + _allocationPointer = pointer; + } + + void ContextBuffers::triggerOwnership(bool isOwner) { + _allocated = isOwner; + } + + int ContextBuffers::deviceId() { + return _deviceId; + } +} diff --git a/libnd4j/include/execution/cpu/LaunchContext.cpp b/libnd4j/include/execution/cpu/LaunchContext.cpp new file mode 100644 index 000000000..47207719f --- /dev/null +++ b/libnd4j/include/execution/cpu/LaunchContext.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 30.11.17. +// + +#include +#include +#include +#include + +thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers(); + +namespace nd4j { + + LaunchContext::~LaunchContext() { + + } + + std::vector> LaunchContext::_contexts = std::vector>(); + +//////////////////////////////////////////////////////////////////////// + LaunchContext::LaunchContext() { + // default constructor, just to make clang/ranlib happy + _workspace = nullptr; + _deviceID = 0; + } + + LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { + + } + + LaunchContext* LaunchContext::defaultContext() { + // TODO: we need it to be device-aware, but only once we add NUMA support for cpu + if (LaunchContext::_contexts.empty()) { + LaunchContext::_contexts.emplace_back(std::make_shared()); + } + + // return context for current device + return LaunchContext::_contexts[0].get(); + } +} \ No newline at end of file diff --git a/libnd4j/include/execution/cuda/AffinityManager.cu b/libnd4j/include/execution/cuda/AffinityManager.cu new file mode 100644 index 000000000..811dc267a --- /dev/null +++ b/libnd4j/include/execution/cuda/AffinityManager.cu @@ -0,0 +1,108 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include + +thread_local int globalThreadToDevice = -1; + +namespace nd4j { + std::mutex AffinityManager::_currentMutex; + std::mutex AffinityManager::_numberMutex; + int AffinityManager::_numberOfDevices = -1; + + int AffinityManager::currentDeviceId() { + // if there's no affinity set - set it now + if (globalThreadToDevice < 0) { + + // this block must be thread-local + _currentMutex.lock(); + + globalThreadToDevice = _lastDevice++; + + // we need to check if we've got deviceId >= number of actual devices, and reset to zero otherwise + if (globalThreadToDevice >= numberOfDevices()) { + globalThreadToDevice = 0; + _lastDevice = numberOfDevices() > 1 ? 1 : 0; + } + + _currentMutex.unlock(); + + setCurrentDevice(globalThreadToDevice); + } + + // if we already know affinity - just return it + if (globalThreadToDevice >= 0) + return globalThreadToDevice; + + int dev = 0; + auto res = cudaGetDevice(&dev); + + if (res != 0) + throw cuda_exception::build("cudaGetDevice failed", res); + + return dev; + } + + int AffinityManager::currentNativeDeviceId() { + int dev = 0; + auto res = cudaGetDevice(&dev); + + if (res != 0) + throw cuda_exception::build("cudaGetDevice failed", res); + + return dev; + } + + int AffinityManager::numberOfDevices() { + _numberMutex.lock(); + // we want to cache number of devices + if (_numberOfDevices <= 0) { + int dev = 0; + auto res = cudaGetDeviceCount(&dev); + + if (res != 0) + throw cuda_exception::build("cudaGetDeviceCount failed", res); + + _numberOfDevices = dev; + } + _numberMutex.unlock(); + + return _numberOfDevices; + } + + void AffinityManager::setCurrentNativeDevice(int deviceId) { + auto res = cudaSetDevice(deviceId); + } + + void AffinityManager::setCurrentDevice(int deviceId) { + auto res = cudaSetDevice(deviceId); + if (res != 0) + throw cuda_exception::build("cudaSetDevice failed", res); + + // update thread-device affinity + globalThreadToDevice = deviceId; + + // TODO: update context buffers? + } + + std::atomic AffinityManager::_lastDevice;// = std::atomic(initialV); +} \ No newline at end of file diff --git a/libnd4j/include/execution/cuda/ContextBuffers.cu b/libnd4j/include/execution/cuda/ContextBuffers.cu new file mode 100644 index 000000000..f82747b91 --- /dev/null +++ b/libnd4j/include/execution/cuda/ContextBuffers.cu @@ -0,0 +1,116 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include + +#include +#include +#include +#include + +namespace nd4j { + ContextBuffers::ContextBuffers() { + nd4j_printf("Creating ContextBuffers for device [%i]\n", AffinityManager::currentDeviceId()); + _deviceId = AffinityManager::currentDeviceId(); + } + + ContextBuffers::~ContextBuffers() { + if (_allocated) { + nd4j_printf("Releasing ContextBuffers\n",""); + + if (_allocationPointer != nullptr) + cudaFree(_allocationPointer); + + if (_scalarPointer != nullptr) + cudaFree(_scalarPointer); + + if (_allocationPointer != nullptr) + cudaFree(_reductionPointer); + } + } + + ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) { + _reductionPointer = rPointer; + _scalarPointer = sPointer; + _allocationPointer = aPointer; + _allocated = isOwner; + } + + void ContextBuffers::initialize() { + nd4j_printf("Initializing buffers on deviceId [%i]\n", AffinityManager::currentNativeDeviceId()); + + auto res = cudaMalloc(reinterpret_cast(&_reductionPointer), 1024 * 1024 * 8); + if (res != 0) + throw std::runtime_error("_reductionPointer allocation failed"); + + res = cudaMalloc(reinterpret_cast(&_scalarPointer), 16); + if (res != 0) + throw std::runtime_error("_scalarPointer allocation failed"); + + res = cudaMalloc(reinterpret_cast(&_allocationPointer), 1024 * 1024 * 8); + if (res != 0) + throw std::runtime_error("_allocationPointer allocation failed"); + + _allocated = true; + } + + void* ContextBuffers::reductionBuffer() { + if (_reductionPointer == nullptr) + initialize(); + + return _reductionPointer; + } + + void* ContextBuffers::scalarBuffer() { + if (_scalarPointer == nullptr) + initialize(); + + return _scalarPointer; + } + + void* ContextBuffers::allocationBuffer() { + if (_allocationPointer == nullptr) + initialize(); + + return _allocationPointer; + } + + void ContextBuffers::setReductionBuffer(void* pointer) { + _reductionPointer = pointer; + } + + void ContextBuffers::setScalarBuffer(void* pointer) { + _scalarPointer = pointer; + } + + void ContextBuffers::setAllocationBuffer(void* pointer) { + _allocationPointer = pointer; + } + + void ContextBuffers::triggerOwnership(bool isOwner) { + _allocated = isOwner; + } + + int ContextBuffers::deviceId() { + return _deviceId; + } +} diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu new file mode 100644 index 000000000..004ed2cac --- /dev/null +++ b/libnd4j/include/execution/cuda/LaunchContext.cu @@ -0,0 +1,182 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 30.11.17. +// + +#include +#include +#include +#include +#include +#include + +thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers(); + +namespace nd4j { + + std::vector> LaunchContext::_contexts = std::vector>(); + std::mutex LaunchContext::_mutex; + +//////////////////////////////////////////////////////////////////////// +LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) { + + _cudaStream = cudaStream; + _cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream; + //_reductionPointer = reductionPointer; + //_scalarPointer = scalarPointer; + //_allocationPointer = allocationPointer; + _workspace = nullptr; + _isAllocated = false; +} + +LaunchContext::~LaunchContext() { + if (_isAllocated) { + cudaStreamSynchronize(*_cudaStream); + cudaStreamSynchronize(*_cudaSpecialStream); + + cudaStreamDestroy(*_cudaStream); + cudaStreamDestroy(*_cudaSpecialStream); + + delete _cudaStream; + delete _cudaSpecialStream; + } +} + +//////////////////////////////////////////////////////////////////////// +LaunchContext::LaunchContext() { + // default constructor, just to make clang/ranlib happy + _workspace = nullptr; + _deviceID = 0; + + _isAllocated = true; + _cudaStream = new cudaStream_t(); + _cudaSpecialStream = new cudaStream_t(); + if (nullptr == _cudaStream || nullptr == _cudaSpecialStream) + throw std::runtime_error("Failed to allocate memory for new CUDA stream"); + + cudaError_t err = cudaStreamCreate(_cudaStream); + if (err != 0) + throw cuda_exception::build("Failed to create default CUDA stream with launch context", err); + + err = cudaStreamCreate(_cudaSpecialStream); + if (err != 0) + throw cuda_exception::build("Failed to create special CUDA stream with launch context", err); + + _cublasHandle = CublasHelper::getInstance()->handle(); + + _cusolverHandle = CublasHelper::getInstance()->solver(); + + auto res = cudaStreamSynchronize(*_cudaStream); + if (res != 0) + throw cuda_exception::build("Initial sync failed", res); +} + + LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { + _isAllocated = false; + _cudaStream = reinterpret_cast(cudaStream); + _cudaSpecialStream = reinterpret_cast(cudaStream); + //_reductionPointer = reductionPointer; + //_scalarPointer = scalarPointer; + //_allocationPointer = reinterpret_cast(allocationPointer); + } + + LaunchContext* LaunchContext::defaultContext() { + /** + * This method returns LaunchContext, that has multiple entities within: + * 1) temporary buffers. they must be per-thread + * 2) CUDA stream. it must be either per-thread or per-device + * 3) cuBLAS handle. it must be per-device + */ + auto deviceId = AffinityManager::currentDeviceId(); + + // we need this block synchronous, to avoid double initialization etc + _mutex.lock(); + if (LaunchContext::_contexts.empty()) { + // create one context per device + auto numDevices = AffinityManager::numberOfDevices(); + + _contexts.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + AffinityManager::setCurrentDevice(e); + + LaunchContext::_contexts[e] = std::make_shared(); + } + + // don't forget to restore device back again + AffinityManager::setCurrentDevice(deviceId); + } + _mutex.unlock(); + + // return context for current device + return LaunchContext::_contexts[deviceId].get(); + } + + + void* LaunchContext::getReductionPointer () const { + return contextBuffers.reductionBuffer(); + }; + + void* LaunchContext::getScalarPointer() const { + return contextBuffers.scalarBuffer(); + }; + + int* LaunchContext::getAllocationPointer() const { + return reinterpret_cast(contextBuffers.allocationBuffer()); + }; + + void* LaunchContext::getCublasHandle() const { + return _cublasHandle; + }; + + void* LaunchContext::getCusolverHandle() const { + return _cusolverHandle; + }; + + cudaStream_t* LaunchContext::getCudaStream() const { + return _cudaStream; + }; + + cudaStream_t* LaunchContext::getCudaSpecialStream() const { + return _cudaSpecialStream; + }; + + + void LaunchContext::setReductionPointer (void* reductionPointer) { + contextBuffers.setReductionBuffer(reductionPointer); + }; + + void LaunchContext::setScalarPointer(void* scalarPointer) { + contextBuffers.setScalarBuffer(scalarPointer); + }; + + void LaunchContext::setAllocationPointer(int* allocationPointer) { + contextBuffers.setAllocationBuffer(allocationPointer); + }; + + void LaunchContext::setCudaStream(cudaStream_t* cudaStream) { + _cudaStream = cudaStream; + }; + + void LaunchContext::setCudaSpecialStream(cudaStream_t* cudaStream) { + _cudaSpecialStream = cudaStream; + }; + + void LaunchContext::setCublasHandle(void *handle) { + _cublasHandle = handle; + }; +} \ No newline at end of file diff --git a/libnd4j/include/execution/impl/LaunchContext.cpp b/libnd4j/include/execution/impl/LaunchContext.cpp deleted file mode 100644 index edc95dabc..000000000 --- a/libnd4j/include/execution/impl/LaunchContext.cpp +++ /dev/null @@ -1,130 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 30.11.17. -// - -#include -#include -#include -#include - -namespace nd4j { - -#ifdef __CUDABLAS__ - -//////////////////////////////////////////////////////////////////////// -LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) { - - _cudaStream = cudaStream; - _cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream; - _reductionPointer = reductionPointer; - _scalarPointer = scalarPointer; - _allocationPointer = allocationPointer; - _workspace = nullptr; - _isAllocated = false; -} -#endif - -LaunchContext::~LaunchContext() { -#ifdef __CUDABLAS__ - if (_isAllocated) { - cudaStreamSynchronize(*_cudaStream); - cudaStreamSynchronize(*_cudaSpecialStream); - - cudaStreamDestroy(*_cudaStream); - cudaStreamDestroy(*_cudaSpecialStream); - - delete _cudaStream; - delete _cudaSpecialStream; - - cudaFree(_reductionPointer); - cudaFree(_allocationPointer); - cudaFree(_scalarPointer); - - cublas::destroyHandle(_cublasHandle); - } -#endif -} - - std::vector> LaunchContext::_contexts = std::vector>(); - -//////////////////////////////////////////////////////////////////////// -LaunchContext::LaunchContext() { - // default constructor, just to make clang/ranlib happy - _workspace = nullptr; - _deviceID = 0; - -#ifdef __CUDABLAS__ - _isAllocated = true; - _cudaStream = new cudaStream_t(); - _cudaSpecialStream = new cudaStream_t(); - if (nullptr == _cudaStream || nullptr == _cudaSpecialStream) - throw std::runtime_error("Failed to allocate memory for new CUDA stream"); - - cudaError_t err = cudaStreamCreate(_cudaStream); - if (err != 0) - throw cuda_exception::build("Failed to create default CUDA stream with launch context", err); - - err = cudaStreamCreate(_cudaSpecialStream); - if (err != 0) - throw cuda_exception::build("Failed to create special CUDA stream with launch context", err); - - _cublasHandle = cublas::handle(); - - auto res = cudaStreamSynchronize(*_cudaStream); - if (res != 0) - throw cuda_exception::build("Initial sync failed", res); - - res = cudaMalloc(reinterpret_cast(&_reductionPointer), 1024 * 1024 * 8); - if (res != 0) - throw std::runtime_error("_reductionPointer allocation failed"); - - res = cudaMalloc(reinterpret_cast(&_scalarPointer), 8); - if (res != 0) - throw std::runtime_error("_scalarPointer allocation failed"); - - res = cudaMalloc(reinterpret_cast(&_allocationPointer), 1024 * 1024 * 8); - if (res != 0) - throw std::runtime_error("_allocationPointer allocation failed"); -#else - // -#endif -} - - LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { -#ifdef __CUDABLAS__ - _isAllocated = false; - _cudaStream = reinterpret_cast(cudaStream); - _cudaSpecialStream = reinterpret_cast(cudaStream); - _reductionPointer = reductionPointer; - _scalarPointer = scalarPointer; - _allocationPointer = reinterpret_cast(allocationPointer); -#else - // no-op -#endif - } - -LaunchContext* LaunchContext::defaultContext() { - // TODO: we need it to be device-aware - if (LaunchContext::_contexts.empty()) { - LaunchContext::_contexts.emplace_back(std::make_shared()); - } - return LaunchContext::_contexts[0].get(); -} - -} \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/ConstantHelper.cpp b/libnd4j/include/helpers/cpu/ConstantHelper.cpp index f74bd5637..43a4f97c1 100644 --- a/libnd4j/include/helpers/cpu/ConstantHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantHelper.cpp @@ -21,6 +21,7 @@ #ifndef __CUDABLAS__ #include +#include #include #include #include @@ -59,11 +60,11 @@ namespace nd4j { } int ConstantHelper::getCurrentDevice() { - return 0L; + return AffinityManager::currentDeviceId(); } int ConstantHelper::getNumberOfDevices() { - return 1; + return AffinityManager::numberOfDevices(); } ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) { diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index 293360a25..d17d2c021 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -21,6 +21,7 @@ #include "../MmulHelper.h" #include #include +#include namespace nd4j { @@ -147,7 +148,12 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX, ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN -NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { +NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { + if (A->dataType() != B->dataType()) + throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), B->dataType()); + + if (C != nullptr && A->dataType() != C->dataType()) + throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), C->dataType()); if(A->rankOf() != 2) throw std::runtime_error("MmulHelper::mmulMxM: rank of A array is not equal 2 !"); @@ -212,7 +218,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, reinterpret_cast(pA->getBuffer()), lda, reinterpret_cast(pB->getBuffer()), ldb, (double) beta, reinterpret_cast(pC->getBuffer()), ldc); } else { - BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), NUMERIC_TYPES); + //BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } if(pC != C) { @@ -230,6 +237,11 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con //////////////////////////////////////////////////////////////////////////// // MXN x N = M NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* Y, const double alpha, const double beta, const char outOrder) { + if (X->dataType() != A->dataType()) + throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), X->dataType()); + + if (Y != nullptr && X->dataType() != Y->dataType()) + throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), Y->dataType()); int xLenDim, yLenDim(0); @@ -279,7 +291,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->getBuffer(), lda, (float*)X->getBuffer(), incx, (float)beta, (float*)Y->getBuffer(), incy); } else { - BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), NUMERIC_TYPES); + //BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } if(pA != A) @@ -291,6 +304,11 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* //////////////////////////////////////////////////////////////////////////// // (X * Y) = Z[0] NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, const double alpha, const double beta) { + if (X->dataType() != Y->dataType()) + throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Y->dataType()); + + if (Z != nullptr && X->dataType() != Z->dataType()) + throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Z->dataType()); int xLenDim(0), yLenDim(0); @@ -316,13 +334,14 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c const auto yType = Y->dataType(); const auto zType = Z->dataType(); - BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), NUMERIC_TYPES); + //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); return Z; } -BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } diff --git a/libnd4j/include/helpers/cpu/cublasHelper.cpp b/libnd4j/include/helpers/cpu/cublasHelper.cpp index cc2a4029a..3dba2d31e 100644 --- a/libnd4j/include/helpers/cpu/cublasHelper.cpp +++ b/libnd4j/include/helpers/cpu/cublasHelper.cpp @@ -21,13 +21,41 @@ #include "../cublasHelper.h" namespace nd4j { - namespace cublas { - void* handle() { - return nullptr; - } - - void destroyHandle(void* handle) { - // - } + static void* handle_() { + return nullptr; } + + static void destroyHandle_(void* handle) { + + } + + CublasHelper::CublasHelper() { + + } + + CublasHelper::~CublasHelper() { + + } + + CublasHelper* CublasHelper::getInstance() { + if (!_INSTANCE) + _INSTANCE = new nd4j::CublasHelper(); + + return _INSTANCE; + } + + void* CublasHelper::handle() { + return nullptr; + } + + void* CublasHelper::solver() { + return nullptr; + } + + void* CublasHelper::handle(int deviceId) { + return nullptr; + } + + + nd4j::CublasHelper* nd4j::CublasHelper::_INSTANCE = 0; } \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/loops/IndexReductionLoops.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/IndexReductionLoops.cpp rename to libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_0.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_0.cpp rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_1.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_1.cpp rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_2.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_2.cpp rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_3.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_3.cpp rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops.hpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops.hpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops.hpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops.hpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_bool.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_bool.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_0.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_0.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_1.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_1.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_2.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_2.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_3.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_3.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_long.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_long.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_same.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_same.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp diff --git a/libnd4j/include/helpers/cublasHelper.h b/libnd4j/include/helpers/cublasHelper.h index bff16b2d4..d4f92881e 100644 --- a/libnd4j/include/helpers/cublasHelper.h +++ b/libnd4j/include/helpers/cublasHelper.h @@ -21,12 +21,28 @@ #ifndef DEV_TESTS_CUBLASHELPER_H #define DEV_TESTS_CUBLASHELPER_H -namespace nd4j { - namespace cublas { - void* handle(); +#include +#include +#include - void destroyHandle(void* handle); - } +namespace nd4j { + class CublasHelper { + private: + static CublasHelper *_INSTANCE; + + std::vector _cache; + std::vector _solvers; + + CublasHelper(); + ~CublasHelper(); + public: + static CublasHelper* getInstance(); + + void* solver(); + + void* handle(); + void* handle(int deviceId); + }; } #endif //DEV_TESTS_CUBLASHELPER_H diff --git a/libnd4j/include/helpers/cuda/ConstantHelper.cu b/libnd4j/include/helpers/cuda/ConstantHelper.cu index d0579b66d..0c7f2cbc1 100644 --- a/libnd4j/include/helpers/cuda/ConstantHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantHelper.cu @@ -26,6 +26,7 @@ #include #include #include +#include #define CONSTANT_LIMIT 49152 @@ -43,23 +44,11 @@ namespace nd4j { } int ConstantHelper::getCurrentDevice() { - int dev = 0; - auto res = cudaGetDevice(&dev); - - if (res != 0) - throw cuda_exception::build("cudaGetDevice failed", res); - - return dev; + return AffinityManager::currentDeviceId(); } int ConstantHelper::getNumberOfDevices() { - int dev = 0; - auto res = cudaGetDeviceCount(&dev); - - if (res != 0) - throw cuda_exception::build("cudaGetDeviceCount failed", res); - - return dev; + return AffinityManager::numberOfDevices(); } diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index 56e726004..ac5eb4176 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -250,8 +250,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou blocksPerGrid.y = math::nd4j_ceil(static_cast(M) / threadsPerBlock.y); // rows } - BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - // BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES) + //BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES) } if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); @@ -339,8 +339,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* threadsPerBlock.x = 512; blocksPerGrid.x = math::nd4j_ceil(static_cast(M) / threadsPerBlock.x); // rows } - BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - // BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES) + //BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES) } if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); @@ -397,8 +397,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c NDArray::prepareSpecialUse({Z}, {X, Y}); - BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - // BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES) + //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES) auto cudaResult = cudaStreamSynchronize(*stream); if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); @@ -408,8 +408,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c return Z; } -BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); -BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); -BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda_off/cublasHelper.cu b/libnd4j/include/helpers/cuda_off/cublasHelper.cu index f80bf1f87..6f2cf2084 100644 --- a/libnd4j/include/helpers/cuda_off/cublasHelper.cu +++ b/libnd4j/include/helpers/cuda_off/cublasHelper.cu @@ -20,12 +20,15 @@ #include +#include #include "../cublasHelper.h" #include #include +#include namespace nd4j { - void* cublas::handle() { + + static void* handle_() { auto _handle = new cublasHandle_t(); auto status = cublasCreate_v2(_handle); // initialize CUBLAS context if (status != CUBLAS_STATUS_SUCCESS) @@ -34,7 +37,16 @@ namespace nd4j { return reinterpret_cast(_handle); } - void cublas::destroyHandle(void* handle) { + static void* solver_() { + auto cusolverH = new cusolverDnHandle_t(); + auto status = cusolverDnCreate(cusolverH); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("cuSolver handle creation failed !", status); + + return cusolverH; + } + + static void destroyHandle_(void* handle) { auto ch = reinterpret_cast(handle); auto status = cublasDestroy_v2(*ch); if (status != CUBLAS_STATUS_SUCCESS) @@ -42,4 +54,57 @@ namespace nd4j { delete ch; } + + CublasHelper::CublasHelper() { + auto numDevices = AffinityManager::numberOfDevices(); + auto currentDevice = AffinityManager::currentDeviceId(); + _cache.resize(numDevices); + _solvers.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + AffinityManager::setCurrentDevice(e); + + _cache[e] = handle_(); + _solvers[e] = solver_(); + } + + // don't forget to restore back original device + AffinityManager::setCurrentDevice(currentDevice); + } + + CublasHelper::~CublasHelper() { + auto numDevices = AffinityManager::numberOfDevices(); + + for (int e = 0; e < numDevices; e++) + destroyHandle_(_cache[e]); + } + + CublasHelper* CublasHelper::getInstance() { + if (!_INSTANCE) + _INSTANCE = new nd4j::CublasHelper(); + + return _INSTANCE; + } + + void* CublasHelper::handle() { + auto deviceId = AffinityManager::currentDeviceId(); + return handle(deviceId); + } + + void* CublasHelper::solver() { + auto deviceId = AffinityManager::currentDeviceId(); + if (deviceId < 0 || deviceId > _solvers.size()) + throw cuda_exception::build("requested deviceId doesn't look valid", deviceId); + + return _solvers[deviceId]; + } + + void* CublasHelper::handle(int deviceId) { + if (deviceId < 0 || deviceId > _cache.size()) + throw cuda_exception::build("requested deviceId doesn't look valid", deviceId); + + return _cache[deviceId]; + } + + + nd4j::CublasHelper* nd4j::CublasHelper::_INSTANCE = 0; } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/random.cpp b/libnd4j/include/loops/cpu/random.cpp index aeeedc007..889e48181 100644 --- a/libnd4j/include/loops/cpu/random.cpp +++ b/libnd4j/include/loops/cpu/random.cpp @@ -276,23 +276,6 @@ namespace functions { DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS) } - // FIXME: eventually we might want to get rid of that -#ifndef __CLION_IDE__ -/* - BUILD_CALL_1(template void RandomFunction::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *y, Nd4jLong *yShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *y, Nd4jLong *yShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *y, Nd4jLong *yShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS) - - BUILD_CALL_1(template void RandomFunction::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS) - - BUILD_CALL_1(template void RandomFunction::execTransform, float, (Nd4jPointer state, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, float16, (Nd4jPointer state, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, double, (Nd4jPointer state, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS) -*/ -#endif - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES); } diff --git a/libnd4j/include/loops/cuda/broadcasting.chpp b/libnd4j/include/loops/cuda/broadcasting.chpp index e673f4eae..dc8a3eeb1 100644 --- a/libnd4j/include/loops/cuda/broadcasting.chpp +++ b/libnd4j/include/loops/cuda/broadcasting.chpp @@ -60,9 +60,18 @@ static __global__ void broadcastInverseSimple( functions::broadcast::Broadcast::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); } + namespace functions { namespace broadcast { + static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) { + return shape::getIndexOffset(index, shapeInfo, length); + } + + static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) { + return shape::length(shapeInfo); + } + template template __host__ void Broadcast::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { @@ -120,9 +129,9 @@ namespace functions { if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo); + tadLength = _length(tadOnlyShapeInfo); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(yShapeInfo) / tadLength; + numTads = _length(yShapeInfo) / tadLength; xEWS = shape::elementWiseStride(xShapeInfo); zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); } @@ -146,9 +155,9 @@ namespace functions { else { // it is expected that x and z tads and y array all have the same length for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo, tadLength); - auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); + auto xOffset = _getIndexOffset(i, xShapeInfo, tadLength); + auto yOffset = _getIndexOffset(i, tadOnlyShapeInfo, tadLength); + auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); } } @@ -186,9 +195,9 @@ namespace functions { if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo); + tadLength = _length(tadOnlyShapeInfo); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; + numTads = _length(xShapeInfo) / tadLength; yEWS = shape::elementWiseStride(yShapeInfo); zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); } @@ -212,14 +221,15 @@ namespace functions { // it is expected that x and z tads and y array all have the same length for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); - auto yOffset = shape::getIndexOffset(i, yShapeInfo, tadLength); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); + auto xOffset = _getIndexOffset(i, tadOnlyShapeInfo, tadLength); + auto yOffset = _getIndexOffset(i, yShapeInfo, tadLength); + auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); } } } } + /* BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1); diff --git a/libnd4j/include/loops/cuda/broadcasting.cu b/libnd4j/include/loops/cuda/broadcasting.cu new file mode 100644 index 000000000..8028db2ba --- /dev/null +++ b/libnd4j/include/loops/cuda/broadcasting.cu @@ -0,0 +1,115 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace functions { + namespace broadcast { + template + void Broadcast::execInverse(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + // + } + + template + void Broadcast::exec(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + /** + * CPU execution + * @param x the input + * @param xShapeInfo the x shape information + * @param y the y data + * @param yShapeInfo the y shape information + * @param result the result + * @param resultShapeInfo the result shape information + * @param dimension the dimension to broadcast along long + * @param dimensionLength the length of the dimension buffer + */ + template + template + void Broadcast::exec(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + // + } + + + template + template + void Broadcast::execInverse(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting_bool.cu b/libnd4j/include/loops/cuda/broadcasting_bool.cu index 6acf71356..6cc3f3cbb 100644 --- a/libnd4j/include/loops/cuda/broadcasting_bool.cu +++ b/libnd4j/include/loops/cuda/broadcasting_bool.cu @@ -224,6 +224,77 @@ namespace functions { } } + + template + void BroadcastBool::exec(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + void BroadcastBool::execInverse(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + template + void BroadcastBool::exec(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + template + void BroadcastBool::execInverse(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/indexreduce.cu b/libnd4j/include/loops/cuda/indexreduce.cu index 7c17538fa..94793f8e8 100644 --- a/libnd4j/include/loops/cuda/indexreduce.cu +++ b/libnd4j/include/loops/cuda/indexreduce.cu @@ -361,6 +361,32 @@ namespace functions { } } + + + + template + Nd4jLong IndexReduce::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) { + return 0; + } + + template + void IndexReduce::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { + + } + + template + template + Nd4jLong IndexReduce:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) { + return 0; + } + + template + template + _CUDA_H void IndexReduce::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { + + } + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES); } } diff --git a/libnd4j/include/loops/cuda/pairwise.cu b/libnd4j/include/loops/cuda/pairwise.cu new file mode 100644 index 000000000..17f8537e5 --- /dev/null +++ b/libnd4j/include/loops/cuda/pairwise.cu @@ -0,0 +1,79 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../pairwise_transform.h" + +namespace functions { + namespace pairwise_transforms { + template + void PairWiseTransform::exec( + const int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + void *extraParams) { + + } + + template + void PairWiseTransform::exec( + const int opNum, + void *x, + Nd4jLong xStride, + void *y, + Nd4jLong yStride, + void *z, + Nd4jLong resultStride, + void *extraParams, + Nd4jLong len) { + + } + + + template + template + void PairWiseTransform:: exec( + void *vx, + Nd4jLong* xShapeInfo, + void *vy, + Nd4jLong* yShapeInfo, + void *vresult, + Nd4jLong* zShapeInfo, + void *vextraParams) { + + } + + template + template + void PairWiseTransform::exec(void *vx, + Nd4jLong xStride, + void *vy, + Nd4jLong yStride, + void *vresult, + Nd4jLong resultStride, + void *vextraParams, + const Nd4jLong len) { + + } + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/pairwise_bool.cu b/libnd4j/include/loops/cuda/pairwise_bool.cu index 41bca38cb..0834386f2 100644 --- a/libnd4j/include/loops/cuda/pairwise_bool.cu +++ b/libnd4j/include/loops/cuda/pairwise_bool.cu @@ -110,6 +110,63 @@ void PairWiseBoolTransform::executeCudaShaped(dim3& launchDims, cudaStream_ DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS); } + + + template + void PairWiseBoolTransform::exec( + const int opNum, + void *dx, + Nd4jLong *xShapeBuffer, + void *y, + Nd4jLong *yShapeBuffer, + void *result, + Nd4jLong *resultShapeBuffer, + void *extraParams) { + + } + + template + void PairWiseBoolTransform::exec( + const int opNum, + void *dx, + Nd4jLong xStride, + void *y, + Nd4jLong yStride, + void *result, + Nd4jLong resultStride, + void *extraParams, + Nd4jLong n) { + + } + + + template + template + void PairWiseBoolTransform::exec( + void *vx, + Nd4jLong* xShapeBuffer, + void *vy, + Nd4jLong* yShapeBuffer, + void *vresult, + Nd4jLong* resultShapeBuffer, + void *vextraParams) { + + } + + template + template + void PairWiseBoolTransform::exec(void *vx, + Nd4jLong xStride, + void *vy, + Nd4jLong yStride, + void *vresult, + Nd4jLong resultStride, + void *vextraParams, + const Nd4jLong n) { + + } + + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); } diff --git a/libnd4j/include/loops/cuda/random.cu b/libnd4j/include/loops/cuda/random.cu index 4cc1c6565..727f0868f 100644 --- a/libnd4j/include/loops/cuda/random.cu +++ b/libnd4j/include/loops/cuda/random.cu @@ -442,6 +442,39 @@ namespace functions { DEBUG_KERNEL(stream, opNum); } + template + template + void RandomFunction::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + + template + template + void RandomFunction::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + + template + template + void RandomFunction::execTransform(Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + + template + void RandomFunction::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + + template + void RandomFunction::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + + template + void RandomFunction::execTransform(int opNum, Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES); } } diff --git a/libnd4j/include/loops/cuda/reduce3.cu b/libnd4j/include/loops/cuda/reduce3.cu new file mode 100644 index 000000000..1ad94beee --- /dev/null +++ b/libnd4j/include/loops/cuda/reduce3.cu @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#include +#include +#include +#include +#include + +namespace functions { + namespace reduce3 { + template + template + void Reduce3::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo) { + + } + + + template + void Reduce3::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParamsVals, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo) { + + } + + + template + template + void Reduce3::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) { + + } + + + template + template + void Reduce3::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + + template + template + void Reduce3::execAll(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { + + } + + + template + void Reduce3::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) { + + } + + + template + void Reduce3::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + + template + void Reduce3::execAll(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { + + } + + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar.cu b/libnd4j/include/loops/cuda/scalar.cu new file mode 100644 index 000000000..67cbc7a98 --- /dev/null +++ b/libnd4j/include/loops/cuda/scalar.cu @@ -0,0 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "loops/scalar.h" +#include +#include +#include +#include +#include + +namespace functions { + namespace scalar { + + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar_bool.cu b/libnd4j/include/loops/cuda/scalar_bool.cu index a5a26d7e7..c6563c9ef 100644 --- a/libnd4j/include/loops/cuda/scalar_bool.cu +++ b/libnd4j/include/loops/cuda/scalar_bool.cu @@ -231,6 +231,41 @@ void ScalarBoolTransform::executeCudaAlongDimension(dim3& launchDims, cudaS } BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); + + + template + template + void ScalarBoolTransform::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + } + + template + void ScalarBoolTransform::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + } + + template + void ScalarBoolTransform::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) { + + } + + template + void ScalarBoolTransform::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) { + + } + + template + template + void ScalarBoolTransform::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) { + + } + + + template + template + void ScalarBoolTransform::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) { + + } } } diff --git a/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu b/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu index 7584949cc..8ee950c25 100644 --- a/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu +++ b/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu @@ -21,84 +21,6 @@ #include -////////////////////////////////////////////////////////////////////////// -template -__global__ void bitonicArbitraryStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) { - auto x = static_cast(vx); - auto y = static_cast(vy); - - int tid = threadIdx.x + blockDim.x * blockIdx.x; - int half = window>>1; - - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) { - xLength = shape::length(xShapeInfo); - } - __syncthreads(); - - //for (int i = 0; i < length; i+= window) - /* - if window == 4; - iterations will be: 0; 4; 8; 12; 16; 20 - if gridDim = 3; - on first iteration we'll have: 0; 4; 8; - on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20 - */ - int firstPosition; - int firstStep; - int secondPosition; - int secondStep; - - int WARP_SIZE = 32; - int numWarps = (gridDim.x * blockDim.x) / 32; - int warpId = tid / WARP_SIZE; - int warpIdx = tid % WARP_SIZE; - - if (half >= 128) { - firstPosition = blockIdx.x * window; - firstStep = gridDim.x * window; - - secondPosition = threadIdx.x; - secondStep = blockDim.x; - } else if (half >= 32) { - firstPosition = warpId * window; - firstStep = numWarps * window; - - secondPosition = warpIdx; - secondStep = WARP_SIZE; - } else { - firstPosition = tid * window; - firstStep = blockDim.x * gridDim.x * window; - - secondPosition = 0; - secondStep = 1; - } - - - for (int i = firstPosition; i < length; i += firstStep) { - for (int j = secondPosition; j < half; j += secondStep) { - int it = (reverse) ? i + j + half : i + window - j - 1; - int ij = i+j; - if (it < length && ij < length ) { - int posIT = shape::getIndexOffset(it, yShapeInfo, xLength); - int posIJ = shape::getIndexOffset(ij, yShapeInfo, xLength); - - Y v0 = y[posIJ]; - Y v1 = y[posIT]; - - if(!descending == (v0 > v1)) { - y[posIJ] = v1; - y[posIT] = v0; - - X xtemp = x[posIJ]; - x[posIJ] = x[posIT]; - x[posIT] = xtemp; - } - } - } - } -} - ////////////////////////////////////////////////////////////////////////// template __global__ void bitonicArbitraryStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) { @@ -264,11 +186,5 @@ __host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *str bitonicArbitraryStepKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending); } -template -__host__ void bitonicArbitraryStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) { - bitonicArbitraryStepKernelValue<<>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending); -} - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES); BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu b/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu index 3e1a0edc5..d9b2ec74c 100644 --- a/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu +++ b/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu @@ -21,60 +21,6 @@ #include -////////////////////////////////////////////////////////////////////////// -template -__global__ void bitonicSortStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) { - - auto x = static_cast(vx); - auto y = static_cast(vy); - - unsigned int i, ixj; /* Sorting partners: i and ixj */ - i = threadIdx.x + blockDim.x * blockIdx.x; - - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) - xLength = shape::length(xShapeInfo); - - __syncthreads(); - - - if (i >= length) - return; - - ixj = i^j; - - /* The threads with the lowest ids sort the array. */ - if ((ixj)>i) { - int posI = shape::getIndexOffset(i, yShapeInfo, xLength); - int posIXJ = shape::getIndexOffset(ixj, yShapeInfo, xLength); - - if ((i&k)==0) { - /* Sort ascending */ - if (!descending == (y[posI]>y[posIXJ])) { - /* exchange(i,ixj); */ - X temp = x[posI]; - x[posI] = x[posIXJ]; - x[posIXJ] = temp; - - Y ytemp = y[posI]; - y[posI] = y[posIXJ]; - y[posIXJ] = ytemp; - } - } else if ((i&k)!=0) { - /* Sort descending */ - if (!descending == (y[posI] @@ -189,13 +135,6 @@ __host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, bitonicSortStepKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending); } -////////////////////////////////////////////////////////////////////////// -template -__host__ void bitonicSortStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) { - bitonicSortStepKernelValue<<>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending); -} - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES); BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/repeatKernel.cu b/libnd4j/include/loops/cuda/specials/repeatKernel.cu index 5193aca2a..c3177049f 100644 --- a/libnd4j/include/loops/cuda/specials/repeatKernel.cu +++ b/libnd4j/include/loops/cuda/specials/repeatKernel.cu @@ -62,9 +62,9 @@ namespace nd4j { } } } - BUILD_DOUBLE_TEMPLATE(template __global__ void repeatKernelDouble, (void const* inputBuffer, void* outputBuffer, + BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void repeatKernelDouble, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets, - Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets), LIBND4J_TYPES, LIBND4J_TYPES); + Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets), LIBND4J_TYPES); template void repeatKernelH(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong outputLength, @@ -88,10 +88,10 @@ namespace nd4j { dim3 launchDims(256, 512, 8192); repeatKernelDouble<<>>(inputBuffer, outputBuffer, numTads, inputLength, tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets); } - BUILD_DOUBLE_TEMPLATE(template void repeatKernelHH, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, + BUILD_SINGLE_TEMPLATE_TWICE(template void repeatKernelHH, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets, Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets, - cudaStream_t stream), LIBND4J_TYPES, LIBND4J_TYPES); + cudaStream_t stream), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/tileKernel.cu b/libnd4j/include/loops/cuda/specials/tileKernel.cu index d2c62ced7..7d2e87e2d 100644 --- a/libnd4j/include/loops/cuda/specials/tileKernel.cu +++ b/libnd4j/include/loops/cuda/specials/tileKernel.cu @@ -21,6 +21,17 @@ #include namespace nd4j { + static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) { + return shape::getIndexOffset(index, shapeInfo, length); + } + + static Nd4jLong __device__ __noinline__ _subArrayOffset(Nd4jLong index, Nd4jLong *shapeInfoA, Nd4jLong *shapeInfoB) { + return shape::subArrayOffset(index, shapeInfoA, shapeInfoB); + } + + static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) { + return shape::length(shapeInfo); + } //////////////////////////////////////////////////////////////////////// template @@ -34,31 +45,20 @@ namespace nd4j { //const auto resultLength = shape::length(outputShape); if (shape::order(outputShape) == 'c') { // ews == 1 always here for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = shape::subArrayOffset(i, outputShape, inputShape); + auto yOffset = _subArrayOffset(i, outputShape, inputShape); *(reinterpret_cast(outputBuffer) + i) = *(reinterpret_cast(inputBuffer) + yOffset); } -// for(Nd4jLong i=0; itemplate templatedAssign, (newBuff, i, this->_buffer, yOffset), LIBND4J_TYPES); -// -// } } else { -// - //auto inputLength = shape::lenght(inputShape); for (int i = tid; i < resultLength; i += totalThreads) { - auto xOffset = shape::getIndexOffset(i, outputShape, resultLength); - auto yOffset = shape::subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + xOffset) = *(reinterpret_cast(inputBuffer) + - yOffset); -// BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, (newBuff, xOffset, this->_buffer, yOffset), LIBND4J_TYPES); + auto xOffset = _getIndexOffset(i, outputShape, resultLength); + auto yOffset = _subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + xOffset) = *(reinterpret_cast(inputBuffer) + yOffset); } } } - BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel, - (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength), - LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel,(void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength), LIBND4J_TYPES); template void tileKernelH(void const *inputBuffer, Nd4jLong *inputShape, void *outputBuffer, Nd4jLong *outputShape, Nd4jLong resultLength, cudaStream_t *stream) { @@ -77,29 +77,26 @@ namespace nd4j { if (ordering == 'c' && ews == 1) { // ews == 1 always here for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = shape::subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + i) = static_cast(*(reinterpret_cast(inputBuffer) + - yOffset)); + auto yOffset = _subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + i) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } } else if (ordering == 'c' && ews > 1) { for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = shape::subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + i * ews) = static_cast(*( - reinterpret_cast(inputBuffer) + yOffset)); + auto yOffset = _subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + i * ews) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } } else { for (int i = tid; i < resultLength; i += totalThreads) { - auto xOffset = shape::getIndexOffset(i, outputShape, resultLength); - auto yOffset = shape::subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + xOffset) = static_cast(*( - reinterpret_cast(inputBuffer) + yOffset)); + auto xOffset = _getIndexOffset(i, outputShape, resultLength); + auto yOffset = _subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + xOffset) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } } } - BUILD_DOUBLE_TEMPLATE(template __global__ void tileKernelDouble, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void tileKernelDouble, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews), LIBND4J_TYPES); template void tileKernelHH(void const *inputBuffer, Nd4jLong *inputShape, void *outputBuffer, Nd4jLong *outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream) { @@ -107,5 +104,5 @@ namespace nd4j { tileKernelDouble<<>>(inputBuffer, inputShape, outputBuffer, outputShape, resultLength, ews); } - BUILD_DOUBLE_TEMPLATE(template void tileKernelHH, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream),LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE_TWICE(template void tileKernelHH, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream),LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/summarystatsreduce.cu b/libnd4j/include/loops/cuda/summarystatsreduce.cu index cb3d06a4b..1e2f3ce4f 100644 --- a/libnd4j/include/loops/cuda/summarystatsreduce.cu +++ b/libnd4j/include/loops/cuda/summarystatsreduce.cu @@ -413,6 +413,74 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa DEBUG_KERNEL(stream, opNum); } + + template + Y SummaryStatsReduce::execScalar(int opNum, + bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams) { + return 0; + } + + template + void SummaryStatsReduce::execScalar(int opNum, + bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams, + void *vz, + Nd4jLong *resultShapeInfoBuffer) { + + } + + template + void SummaryStatsReduce::exec(int opNum, + bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams, + void *vz, + Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength) { + + } + + template + template + Y SummaryStatsReduce::execScalar(bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams) { + return 0; + } + + template + template + void SummaryStatsReduce::execScalar(bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams, + void *vz, + Nd4jLong *resultShapeInfoBuffer) { + // + } + + + template + template + void SummaryStatsReduce::exec(bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams, + void *vz, + Nd4jLong *resultShapeInfoBuffer, + int *dimension, + int dimensionLength) { + + } + + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/transform/transform_any.cu b/libnd4j/include/loops/cuda/transform/transform_any.cu index a217167a6..34f56380a 100644 --- a/libnd4j/include/loops/cuda/transform/transform_any.cu +++ b/libnd4j/include/loops/cuda/transform/transform_any.cu @@ -114,6 +114,17 @@ namespace functions { nd4j::DebugHelper::checkErrorCode(stream, "transformAny(...) failed"); } + template + void TransformAny::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) { + + } + + template + template + void TransformAny::exec(void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) { + + } + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES); } } diff --git a/libnd4j/include/loops/cuda/transform/transform_bool.cu b/libnd4j/include/loops/cuda/transform/transform_bool.cu index bff361fcb..a01221cfa 100644 --- a/libnd4j/include/loops/cuda/transform/transform_bool.cu +++ b/libnd4j/include/loops/cuda/transform/transform_bool.cu @@ -120,6 +120,17 @@ namespace functions { nd4j::DebugHelper::checkErrorCode(stream, "transformBool(...) failed"); } + template + void TransformBool::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + template + template + void TransformBool::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES); } } diff --git a/libnd4j/include/loops/cuda/transform/transform_float.cu b/libnd4j/include/loops/cuda/transform/transform_float.cu index 05d4c9999..e1cd36256 100644 --- a/libnd4j/include/loops/cuda/transform/transform_float.cu +++ b/libnd4j/include/loops/cuda/transform/transform_float.cu @@ -142,6 +142,17 @@ namespace functions { nd4j::DebugHelper::checkErrorCode(stream, "transformFloat(...) failed"); } + template + void TransformFloat::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + template + template + void TransformFloat::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES); } diff --git a/libnd4j/include/loops/cuda/transform/transform_same.cu b/libnd4j/include/loops/cuda/transform/transform_same.cu index 1e9bf2d64..a0d137d64 100644 --- a/libnd4j/include/loops/cuda/transform/transform_same.cu +++ b/libnd4j/include/loops/cuda/transform/transform_same.cu @@ -118,6 +118,17 @@ namespace functions { nd4j::DebugHelper::checkErrorCode(stream, "transformSame(...) failed"); } + template + void TransformSame::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + template + template + void TransformSame::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES); } } diff --git a/libnd4j/include/loops/cuda/transform/transform_strict.cu b/libnd4j/include/loops/cuda/transform/transform_strict.cu index 8a5b65c04..10385812d 100644 --- a/libnd4j/include/loops/cuda/transform/transform_strict.cu +++ b/libnd4j/include/loops/cuda/transform/transform_strict.cu @@ -119,6 +119,17 @@ namespace functions { nd4j::DebugHelper::checkErrorCode(stream, "transformStrict(...) failed"); } + template + void TransformStrict::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + template + template + void TransformStrict::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES); } } diff --git a/libnd4j/include/loops/cpu/type_conversions.cpp b/libnd4j/include/loops/impl/type_conversions.cpp similarity index 96% rename from libnd4j/include/loops/cpu/type_conversions.cpp rename to libnd4j/include/loops/impl/type_conversions.cpp index 3c923de39..dc85b9554 100644 --- a/libnd4j/include/loops/cpu/type_conversions.cpp +++ b/libnd4j/include/loops/impl/type_conversions.cpp @@ -209,15 +209,6 @@ PRAGMA_OMP_ATOMIC_ARGS(write) } }; - _CUDA_H Nd4jLong TypeCast::estimateQuantizedSize(Nd4jLong rawSize) { - if (rawSize <= 0) - throw std::runtime_error("Input size for quantization can't be <= 0"); - - // 2 fp32 values for max/min, and rawSize number of BYTES - return 8 + rawSize; - } - - template void TypeCast::convertFromThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); template void TypeCast::convertFromThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); template void TypeCast::convertFromThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); diff --git a/libnd4j/include/loops/type_conversions.h b/libnd4j/include/loops/type_conversions.h index 1c54f41d4..d6029d7af 100644 --- a/libnd4j/include/loops/type_conversions.h +++ b/libnd4j/include/loops/type_conversions.h @@ -69,7 +69,14 @@ namespace nd4j { template static _CUDA_H void convertFromThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize); + FORCEINLINE static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize) { + if (rawSize <= 0) + throw std::runtime_error("Input size for quantization can't be <= 0"); + + // 2 fp32 values for max/min, and rawSize number of BYTES + return 8 + rawSize; + } + template static _CUDA_H void convertToQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp index 781dea86a..5ae075c99 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp @@ -85,7 +85,7 @@ namespace nd4j { COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(shapeE, shapeG); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); return shapeList; } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp index d10c32435..9dea93699 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp @@ -86,7 +86,7 @@ namespace nd4j { COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(shapeE, shapeG); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); return shapeList; } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp index 887225f6a..af282fe7c 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp @@ -107,7 +107,7 @@ namespace nd4j { COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(shapeE, shapeG); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); return shapeList; } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp index f27b4fc61..76f2d6830 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp @@ -114,7 +114,7 @@ namespace nd4j { COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(shapeE, shapeG); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); return shapeList; } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp index 6806be664..3309c6104 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp @@ -112,7 +112,8 @@ namespace nd4j { COPY_SHAPE(input, epsShape); COPY_SHAPE(bias, gradShape); - return SHAPELIST(epsShape, gradShape); + return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape)); + } } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp index 7541ab841..1e0330294 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp @@ -75,7 +75,7 @@ namespace nd4j { DECLARE_TYPES(non_max_suppression) { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); + ->setAllowedOutputTypes({ALL_INDICES}); } } diff --git a/libnd4j/include/ops/declarable/generic/recurrent/gruCell.cpp b/libnd4j/include/ops/declarable/generic/recurrent/gruCell.cpp index 5ddd1654e..ddd18dc84 100644 --- a/libnd4j/include/ops/declarable/generic/recurrent/gruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/recurrent/gruCell.cpp @@ -253,7 +253,7 @@ DECLARE_SHAPE_FN(gruCell_bp) { Nd4jLong *dLdbcShapeInfo = nullptr; COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo); - return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWShapeInfo, dLdWcShapeInfo, dLdbShapeInfo, dLdbcShapeInfo); + return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdhiShapeInfo), CONSTANT(dLdWShapeInfo), CONSTANT(dLdWcShapeInfo), CONSTANT(dLdbShapeInfo), CONSTANT(dLdbcShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp index 1d2b25678..8047da41a 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp @@ -101,7 +101,7 @@ namespace ops { Nd4jLong *out; COPY_SHAPE(in, out); - return SHAPELIST(out); + return SHAPELIST(CONSTANT(out)); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp b/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp index fa9ab7b40..5484d822d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp @@ -87,8 +87,7 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { - - BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), FLOAT_TYPES); } /* diff --git a/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp b/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp index d01a8e2be..9a5141a82 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp @@ -89,7 +89,7 @@ static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarA void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), FLOAT_TYPES); } /* diff --git a/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp b/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp index b29a79504..b4a54ad7a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp @@ -119,11 +119,9 @@ void col2im_(nd4j::LaunchContext & context, const NDArray& input, NDArray& outp void col2im(nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(input.dataType(), col2im_, (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), col2im_, (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void col2im_, (nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW), LIBND4J_TYPES); - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index 6d319d993..033e0b5e5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -2445,71 +2445,52 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d( void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); } void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); } void ConvolutionUtils::upsampling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::upsampling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); } void ConvolutionUtils::upsampling2dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::upsampling3dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::pooling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); } void ConvolutionUtils::pooling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); } void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); } void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); } - - - BUILD_DOUBLE_TEMPLATE(template void conv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); - BUILD_DOUBLE_TEMPLATE(template void conv2dBP_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); - BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2d_, (const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); - BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2dBP_, (const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); - BUILD_DOUBLE_TEMPLATE(template void sconv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); - - BUILD_SINGLE_TEMPLATE(template void upsampling2d_, (const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void upsampling3d_, (const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void upsampling2dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void upsampling3dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void vol2col_, (const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void col2vol_, (const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void pooling2d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void pooling3d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void pooling2dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void pooling3dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); - } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp index c9a4e0fb5..c75bbf131 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp @@ -81,10 +81,8 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const } } -BUILD_DOUBLE_TEMPLATE(template void dilation2d_, (NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES, FLOAT_TYPES); - void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp index 49626168c..97cd2f84e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp @@ -76,7 +76,7 @@ namespace nd4j { double min_val = input.reduceNumber(reduce::SameOps::Min).e(0); double max_val = input.reduceNumber(reduce::SameOps::Max).e(0); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INDEXING_TYPES); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp b/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp index 131165117..002c68226 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp @@ -122,11 +122,9 @@ static void im2col_(nd4j::LaunchContext & context, const NDArray& input, NDArra void im2col(nd4j::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) { - BUILD_SINGLE_SELECTOR(im.dataType(), im2col_, (context, im, col, kH, kW, sH, sW, pH, pW, dH, dW, arrZeroPadVal), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(im.dataType(), im2col_, (context, im, col, kH, kW, sH, sW, pH, pW, dH, dW, arrZeroPadVal), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void im2col_, (nd4j::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal), LIBND4J_TYPES); - } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index 062db8d87..2ac679fc5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -334,10 +334,6 @@ namespace helpers { BUILD_TRIPLE_SELECTOR(images->dataType(), boxes->dataType(), indices->dataType(), cropAndResizeFunctor_, (images, boxes, indices, cropSize, method, extrapolationVal, crops), NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); } - - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, - (NDArray const* images, NDArray const* boxes, NDArray const* indices, NDArray const* cropSize, int method, double extrapolationVal, NDArray* crops), - NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp index 52a41a201..45024b5cb 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp @@ -32,7 +32,6 @@ namespace helpers { theFirst->applyPairwiseLambda(theSecond, functor, nullptr); } - BUILD_SINGLE_TEMPLATE(template void reluDerivative__, (NDArray* input, NDArray* epsilon), FLOAT_TYPES); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, (theFirst, theSecond), FLOAT_TYPES); @@ -46,7 +45,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void reluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); @@ -61,8 +59,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void relu6Derivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -76,8 +72,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void leakyReluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -91,8 +85,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void eluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -106,8 +98,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void seluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -121,8 +111,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void cubeDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -137,8 +125,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void reduceNorm1_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -153,8 +139,6 @@ namespace helpers { logits->applyPairwiseLambda(labels, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmCrossEntropy_, (NDArray* logits, NDArray* labels, NDArray* output);, FLOAT_TYPES); - void sigmCrossEntropy(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, (logits, labels, output), FLOAT_TYPES); } @@ -173,8 +157,6 @@ namespace helpers { logits->applyPairwiseLambda(labels, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmCrossEntropyGrad_, (NDArray* logits, NDArray* labels, NDArray*output);, FLOAT_TYPES); - void sigmCrossEntropyGrad(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, (logits, labels, output), FLOAT_TYPES); } @@ -190,8 +172,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void tanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void tanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -207,8 +187,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void hardTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void hardTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -222,8 +200,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void rationalTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void rationalTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -237,8 +213,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void rectifiedTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void rectifiedTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -256,8 +230,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void softSignDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void softSignDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -272,8 +244,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void softPlusDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void softPlusDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -291,8 +261,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmoidDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void sigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -306,8 +274,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void hardSigmoidDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void hardSigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -347,13 +313,10 @@ namespace helpers { void logSumExp(nd4j::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template void logSumExp_, (NDArray* input, NDArray* axis, NDArray*output);, FLOAT_TYPES); void logSumExp(nd4j::LaunchContext * context, NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, subtrah, axis, output), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template void logSumExp_, (NDArray* input, NDArray* subtrah, NDArray* axis, NDArray*output);, FLOAT_TYPES); - ////////////////////////////////////////////////////////////////////////// template @@ -393,7 +356,6 @@ static void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, NDArr void weightedCrossEntropyWithLogitsFunctor(nd4j::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output) { BUILD_SINGLE_SELECTOR(targets->dataType(), weightedCrossEntropyWithLogitsFunctor_, (targets, input, weights, output), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void weightedCrossEntropyWithLogitsFunctor_, (NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output), FLOAT_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp index 75b23c932..a02d5918c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp @@ -410,10 +410,9 @@ static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, c gradI *= gradO; } -BUILD_DOUBLE_TEMPLATE(template void lrnBP_, (const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta), LIBND4J_TYPES, FLOAT_TYPES); void lrnBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { - BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (input, gradO, gradI, depth, bias, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (input, gradO, gradI, depth, bias, alpha, beta), FLOAT_TYPES, FLOAT_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 29d9f463b..1fb1ef1df 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -345,8 +345,6 @@ template int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int cholesky_, (NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template int inverse_, (NDArray* input, NDArray* output), FLOAT_TYPES); template int logdetFunctor_(NDArray* input, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matmul.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matmul.cpp deleted file mode 100644 index 6990f2dc3..000000000 --- a/libnd4j/include/ops/declarable/helpers/cpu/matmul.cpp +++ /dev/null @@ -1,64 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.12.17. -// - -#include - -namespace nd4j { - namespace ops { - namespace helpers { - template - void __matmul(NDArray *vA, NDArray *vB, NDArray *vC, int transA, int transB, double alpha, double beta) { - CBLAS_TRANSPOSE tA = (CBLAS_TRANSPOSE) transA; - CBLAS_TRANSPOSE tB = (CBLAS_TRANSPOSE) transB; - - int M = vA->sizeAt(0); - int N = vB->sizeAt(1); - int K = vA->sizeAt(1); - - int ldA = transA == CblasNoTrans ? M : K; - int ldB = transB == CblasNoTrans ? K : N; - int ldC = M; - - auto A = reinterpret_cast(vA->buffer()); - auto B = reinterpret_cast(vB->buffer()); - auto C = reinterpret_cast(vC->buffer()); - - PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2) - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; ++n) { - Z c_mnp = 0; - - for (int k = 0; k < K; ++k) - c_mnp += (Z) A[tA == CblasNoTrans ? (m + k * ldA) : (m * ldA + k)] * (Z) B[tB == CblasNoTrans ? (k + n * ldB) : (k * ldB + n)]; - - C[m + n * ldC] = (Z) alpha * (Z) c_mnp + (Z) beta * (Z) C[m + n * ldC]; - } - } - } - - - void _matmul(nd4j::LaunchContext * context, NDArray *vA, NDArray *vB, NDArray *vC, int transA, int transB, double alpha, double beta) { - BUILD_TRIPLE_SELECTOR(vA->dataType(), vB->dataType(), vC->dataType(), __matmul, (vA, vB, vC, transA, transB, alpha, beta), LIBND4J_TYPES, LIBND4J_TYPES, LIBND4J_TYPES); - } - - BUILD_TRIPLE_TEMPLATE(template void __matmul, (NDArray *A, NDArray *B, NDArray *C, int transA, int transB, double alpha, double beta), LIBND4J_TYPES, LIBND4J_TYPES, LIBND4J_TYPES); - } - } -} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp index 06fe2eec2..6ebca9184 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp @@ -76,9 +76,6 @@ namespace helpers { BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES); } - - BUILD_SINGLE_TEMPLATE(template void maxPoolingFunctor_, (nd4j::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices), FLOAT_TYPES); - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp index c38df008f..0fc6eea0b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp @@ -32,7 +32,6 @@ namespace nd4j { in.applyLambda(lambda, &out); } - BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES); void __toggle_bits(nd4j::LaunchContext * context, NDArray& in, NDArray& out) { BUILD_SINGLE_SELECTOR(in.dataType(), toggle_bits__, (in, out), INTEGER_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp index 71641f215..3536f9f62 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp @@ -56,9 +56,6 @@ static void triuBP_(nd4j::LaunchContext * context, const NDArray& input, const N BUILD_SINGLE_SELECTOR(gradO.dataType(), triuBP_, (context, input, gradO, gradI, diagonal), LIBND4J_TYPES); } - -BUILD_SINGLE_TEMPLATE(template void triuBP_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal), LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// template static void trace_(const NDArray& input, NDArray& output) { @@ -78,8 +75,6 @@ static void trace_(const NDArray& input, NDArray& output) { BUILD_SINGLE_SELECTOR(input.dataType(), trace_, (input, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void trace_, (const NDArray& input, NDArray& output), LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// template void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) { @@ -173,14 +168,6 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerato BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES); - - - - - - - ////////////////////////////////////////////////////////////////////////// template @@ -387,8 +374,6 @@ void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, co BUILD_SINGLE_SELECTOR(input.dataType(), pad_, (mode, input, paddings, output, padValue), LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void pad_, (const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, NDArray const& padValue), LIBND4J_TYPES); - //////////////////////////////////////////////////////////////////////// /*// initial values of inIdx, outIdx, dim must be equal to zero template @@ -623,9 +608,8 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) { //////////////////////////////////////////////////////////////////////// void gatherND(nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) { - BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, (input, indices, output), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, (input, indices, output), LIBND4J_TYPES, INDEXING_TYPES); } -BUILD_DOUBLE_TEMPLATE(template void gatherND_, (NDArray& input, NDArray& indices, NDArray& output), LIBND4J_TYPES, INTEGER_TYPES); //////////////////////////////////////////////////////////////////////// @@ -705,8 +689,6 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, con BUILD_SINGLE_SELECTOR(input->dataType(), gather_, (input, indices, output, intArgs), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void gather_, (NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs), LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// void eye(nd4j::LaunchContext * context, NDArray& output) { @@ -826,7 +808,6 @@ static void mergeMaxIndex_(const std::vector& inArrs, NDArray& output) BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void mergeMaxIndex_, (const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -850,8 +831,6 @@ static void mergeMax_(const std::vector& inArrs, NDArray& output) { BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (inArrs, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void mergeMax_, (const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// template static void mergeAvg_(const std::vector& inArrs, NDArray& output) { @@ -874,7 +853,6 @@ static void mergeAvg_(const std::vector& inArrs, NDArray& output) { BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (inArrs, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void mergeAvg_, (const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -898,8 +876,6 @@ static void mergeAdd_(const std::vector& inArrs, NDArray& output) { BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// template static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { @@ -970,11 +946,6 @@ void clipByNorm(nd4j::LaunchContext * context, NDArray& input, NDArray& output, BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void clipByNorm_, (NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES); - - - - diff --git a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu index 33805e335..1397874f8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu @@ -99,7 +99,7 @@ void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& a const auto yType = alpha.dataType(); NDArray::prepareSpecialUse({&output}, {&input, &alpha}); - BUILD_DOUBLE_SELECTOR(xType, yType, preluCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), output.getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(xType, preluCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), output.getSpecialBuffer()), FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&input, &alpha}); manager.synchronize(); @@ -189,7 +189,7 @@ void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& const auto zType = alpha.dataType(); NDArray::prepareSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); - BUILD_DOUBLE_SELECTOR(xType, zType, preluBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), dLdO.getSpecialBuffer(), dLdO.getSpecialShapeInfo(), dLdI.getSpecialBuffer(), dLdI.getSpecialShapeInfo(), dLdA.getSpecialBuffer(), dLdA.getSpecialShapeInfo()), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(xType, preluBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), dLdO.getSpecialBuffer(), dLdO.getSpecialShapeInfo(), dLdI.getSpecialBuffer(), dLdI.getSpecialShapeInfo(), dLdA.getSpecialBuffer(), dLdA.getSpecialShapeInfo()), FLOAT_TYPES); NDArray::registerSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); manager.synchronize(); @@ -574,14 +574,6 @@ void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDAr BUILD_SINGLE_SELECTOR(input->dataType(), thresholdReluDerivative_, (input, threshold, dLdO, output), FLOAT_TYPES); } - -BUILD_SINGLE_TEMPLATE(template void thresholdReluDerivative_, (NDArray* input, double threshold, NDArray* dLdO, NDArray* output), FLOAT_TYPES); -BUILD_DOUBLE_TEMPLATE(template void preluCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz), LIBND4J_TYPES, FLOAT_TYPES); -BUILD_DOUBLE_TEMPLATE(template void preluBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo), LIBND4J_TYPES, FLOAT_TYPES); -BUILD_SINGLE_TEMPLATE(template void softMaxForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES); -BUILD_SINGLE_TEMPLATE(template void softMaxDerivForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES); - - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu b/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu index def7d316f..e8062e126 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu @@ -78,7 +78,6 @@ static _CUDA_H void adjustHueCudaLauncher(const int blocksPerGrid, const int thr adjustHueCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, deltaScalarArr->e(0), dimC); } -BUILD_SINGLE_TEMPLATE(template void adjustHueCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { @@ -94,7 +93,7 @@ void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray PointersManager manager(context, "adjustHue"); NDArray::prepareSpecialUse({output}, {input, deltaScalarArr}); - BUILD_SINGLE_SELECTOR(input->dataType(), adjustHueCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, deltaScalarArr, dimC), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), adjustHueCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, deltaScalarArr, dimC), FLOAT_TYPES); NDArray::registerSpecialUse({output}, {input, deltaScalarArr}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu b/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu index ce910a892..4ab8da304 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu @@ -80,7 +80,6 @@ static _CUDA_H void adjustSaturationCudaLauncher(const int blocksPerGrid, const adjustSaturationCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, factorScalarArr->e(0), dimC); } -BUILD_SINGLE_TEMPLATE(template void adjustSaturationCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { @@ -96,7 +95,7 @@ void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const PointersManager manager(context, "adjustSaturation"); NDArray::prepareSpecialUse({output}, {input, factorScalarArr}); - BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, factorScalarArr, dimC), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, factorScalarArr, dimC), FLOAT_TYPES); NDArray::registerSpecialUse({output}, {input, factorScalarArr}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu index 8a5dbd744..7678779ac 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu @@ -182,7 +182,6 @@ __host__ static void batchnormCudaLauncher(const int blocksPerGrid, const int th batchnormCuda<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, static_cast(epsilon)); } -BUILD_SINGLE_TEMPLATE(template void batchnormCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vMean, const Nd4jLong* meanShapeInfo, const void* vVariance, const Nd4jLong* varianceShapeInfo, const void* vGamma, const Nd4jLong* gammaShapeInfo, const void* vBeta, const Nd4jLong* betaShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const double epsilon), FLOAT_TYPES); /////////////////////////////////////////////////////////////////// template @@ -198,7 +197,6 @@ __host__ static void batchnormCudaLauncher2(const int blocksPerGrid, const int t batchnormCuda2<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, numDims, dims, static_cast(epsilon)); } -BUILD_SINGLE_TEMPLATE(template void batchnormCudaLauncher2, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vMean, const Nd4jLong* meanShapeInfo, const void* vVariance, const Nd4jLong* varianceShapeInfo, const void* vGamma, const Nd4jLong* gammaShapeInfo, const void* vBeta, const Nd4jLong* betaShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int numDims, const int* dims, const double epsilon), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void batchnorm(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/bds.cu b/libnd4j/include/ops/declarable/helpers/cuda/bds.cu index 6aef74adb..ef501eac0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/bds.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/bds.cu @@ -107,7 +107,6 @@ namespace helpers { return Status::OK(); return Status::OK(); } - BUILD_SINGLE_TEMPLATE(template void bdsLoopH, (cudaStream_t* stream, void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape), NUMERIC_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu b/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu index 2088e18fe..e02bce146 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu @@ -189,7 +189,6 @@ static void col2imCudaLauncher(const int blocksPerGrid, const int threadsPerBloc // col2imCuda2<<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW); col2imCuda<<>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW); } -BUILD_SINGLE_TEMPLATE(template void col2imCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, const void *col, const Nd4jLong *colShapeInfo, void *im, const Nd4jLong *imShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void col2im(nd4j::LaunchContext& context, const NDArray& col, NDArray& im, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) { @@ -201,7 +200,7 @@ void col2im(nd4j::LaunchContext& context, const NDArray& col, NDArray& im, const const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&im}, {&col}); - BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES); NDArray::registerSpecialUse({&im}, {&col}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index e993b370e..44a0156d7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -98,7 +98,6 @@ static void vol2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlo vol2colCuda<<>>(volume, volShapeInfo, columns, colShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); } -BUILD_SINGLE_TEMPLATE(template void vol2colCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, const void *vol, const Nd4jLong *volShapeInfo, void *col, const Nd4jLong *colShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { @@ -205,7 +204,6 @@ static void col2volCudaLauncher(const int blocksPerGrid, const int threadsPerBlo col2volCuda<<>>(columns, colShapeInfo, volume, volShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); } -BUILD_SINGLE_TEMPLATE(template void col2volCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, const void *col, const Nd4jLong *colShapeInfo, void *vol, const Nd4jLong *volShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { @@ -285,7 +283,7 @@ static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDA ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -345,7 +343,7 @@ static void depthwiseConv2d_(const NDArray* input, const NDArray* weights, const ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -390,7 +388,7 @@ static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const ND ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -488,7 +486,6 @@ template static void avgPooling2dCudaLauncher(nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { avgPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); } -BUILD_DOUBLE_TEMPLATE(template void avgPooling2dCudaLauncher, (nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0), LIBND4J_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -582,7 +579,6 @@ template static void pnormPooling2dCudaLauncher(nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { pnormPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); } -BUILD_DOUBLE_TEMPLATE(template void pnormPooling2dCudaLauncher, (nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0), LIBND4J_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -679,7 +675,6 @@ template static void maxPooling2dCudaLauncher(nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { maxPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); } -BUILD_DOUBLE_TEMPLATE(template void maxPooling2dCudaLauncher, (nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0), LIBND4J_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::pooling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { @@ -689,15 +684,15 @@ void ConvolutionUtils::pooling2d(nd4j::graph::Context& block, const NDArray& inp switch (poolingMode) { case MAX_POOL: { - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); } break; case AVG_POOL: { - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); } break; case PNORM_POOL: { - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), pnormPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), pnormPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); } break; default: @@ -845,7 +840,6 @@ static void pooling3dCudaLauncher(const int blocksPerGrid, const int threadsPerB pooling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); } -BUILD_SINGLE_TEMPLATE(template void pooling3dCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::pooling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { @@ -857,49 +851,12 @@ void ConvolutionUtils::pooling3d(nd4j::graph::Context& block, const NDArray& inp const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&input}); manager.synchronize(); } - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ////////////////////////////////////////////////////////////////////////// template __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { @@ -1032,7 +989,6 @@ static void pooling2dBPCudaLauncher(const int blocksPerGrid, const int threadsPe pooling2dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0); } -BUILD_SINGLE_TEMPLATE(template void pooling2dBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { @@ -1047,7 +1003,7 @@ void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& i const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); manager.synchronize(); @@ -1201,7 +1157,6 @@ static void pooling3dBPCudaLauncher(const int blocksPerGrid, const int threadsPe pooling3dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); } -BUILD_SINGLE_TEMPLATE(template void pooling3dBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { @@ -1216,7 +1171,7 @@ void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& i const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); manager.synchronize(); @@ -1292,11 +1247,10 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N delete gradI; } } -BUILD_DOUBLE_TEMPLATE(template void conv2dBP_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -1374,11 +1328,10 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con delete gradI; } } -BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2dBP_, (const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } @@ -1434,7 +1387,6 @@ static void upsampling2dCudaLauncher(const int blocksPerGrid, const int threadsP upsampling2dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorH, factorW, isNCHW); } -BUILD_SINGLE_TEMPLATE(template void upsampling2dCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorH, const int factorW, const bool isNCHW), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::upsampling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { @@ -1446,7 +1398,7 @@ void ConvolutionUtils::upsampling2d(nd4j::graph::Context& block, const NDArray& const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorH, factorW, isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorH, factorW, isNCHW), FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&input}); manager.synchronize(); @@ -1505,7 +1457,6 @@ static void upsampling3dCudaLauncher(const int blocksPerGrid, const int threadsP upsampling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorD, factorH, factorW, isNCDHW); } -BUILD_SINGLE_TEMPLATE(template void upsampling3dCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorD, const int factorH, const int factorW, const bool isNCDHW), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::upsampling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { @@ -1517,7 +1468,7 @@ void ConvolutionUtils::upsampling3d(nd4j::graph::Context& block, const NDArray& const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorD, factorH, factorW, isNCDHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&input}); manager.synchronize(); @@ -1579,7 +1530,6 @@ static void upsampling2dBPCudaLauncher(const int blocksPerGrid, const int thread upsampling2dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCHW); } -BUILD_SINGLE_TEMPLATE(template void upsampling2dBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCHW), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::upsampling2dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { @@ -1591,7 +1541,7 @@ void ConvolutionUtils::upsampling2dBP(nd4j::graph::Context& block, const NDArray const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCHW), FLOAT_TYPES); NDArray::registerSpecialUse({&gradI}, {&gradO}); manager.synchronize(); @@ -1656,7 +1606,6 @@ static void upsampling3dBPCudaLauncher(const int blocksPerGrid, const int thread upsampling3dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCDHW); } -BUILD_SINGLE_TEMPLATE(template void upsampling3dBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCDHW), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::upsampling3dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { @@ -1668,7 +1617,7 @@ void ConvolutionUtils::upsampling3dBP(nd4j::graph::Context& block, const NDArray const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCDHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCDHW), FLOAT_TYPES); NDArray::registerSpecialUse({&gradI}, {&gradO}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu index 423944e0f..f4dff2279 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu @@ -100,19 +100,12 @@ static __global__ void diagFunctorKernel(void* outputBuffer, Nd4jLong* outputSha input->syncToDevice(); diagPartFunctorKernel<<>>(output->specialBuffer(), output->specialShapeInfo(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), outLen, inLen); -// int i(0), j; -// for (j = 0;j < outLen; j++) { -// output->p(j, input->e(i)); -// i += outLen + 1; -// } - } - BUILD_SINGLE_TEMPLATE(template void _diagPartFunctor, (nd4j::LaunchContext * context, const NDArray* input, NDArray* output);, LIBND4J_TYPES); void diagPartFunctor(nd4j::LaunchContext * context, NDArray const* input, NDArray* output) { auto zType = output->dataType(); - BUILD_SINGLE_SELECTOR(zType, _diagPartFunctor, (context, input, output), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(zType, _diagPartFunctor, (context, input, output), NUMERIC_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu index e23c4c84f..a636af891 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu @@ -114,8 +114,6 @@ static void dilation2dCudaLauncher(const int blocksPerGrid, const int threadsPer dilation2dCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, sH, sW, pH, pW, dH, dW); } -BUILD_DOUBLE_TEMPLATE(template void dilation2dCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES, FLOAT_TYPES); - void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { PointersManager manager(context, "dilation2d"); @@ -125,7 +123,7 @@ void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, const int sharedMem = (weights->rankOf() + output->rankOf()) * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({output}, {input, weights}); - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), weights->getSpecialBuffer(), weights->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), sH, sW, pH, pW, dH, dW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), weights->getSpecialBuffer(), weights->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES); NDArray::registerSpecialUse({output}, {input, weights}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu index 952bf47c7..a01b4f555 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu @@ -73,8 +73,6 @@ namespace helpers { NDArray::registerSpecialUse({output}, {input}); } - BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (nd4j::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES); - template int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { @@ -124,8 +122,6 @@ namespace helpers { BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int _dropOutFunctor, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES); - /////////////////////////////////// backrpopagations /////////////////////////////////////////////// template static __global__ void dropoutBPKernel(void* outputBuf, Nd4jLong* outputShape, void* gradOutBuf, Nd4jLong* gradOutShape, double probValue) { @@ -260,17 +256,14 @@ namespace helpers { int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { BUILD_SINGLE_SELECTOR(context.dataType(), return dropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int dropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue), FLOAT_TYPES); int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctor_, (context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctor_, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu index d6a2d26bb..857ebed38 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu @@ -306,7 +306,7 @@ namespace nd4j { NDArray::prepareSpecialUse({}, {indices, input}); - BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicPartitionFunctor, (context, input, indices, outputList), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicPartitionFunctor, (context, input, indices, outputList), NUMERIC_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({}, {indices, input}); @@ -336,7 +336,7 @@ namespace nd4j { NDArray::prepareSpecialUse({output}, {}); - BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicStitchFunctor, (context, inputs, indices, output), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicStitchFunctor, (context, inputs, indices, output), NUMERIC_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {}); @@ -346,22 +346,15 @@ namespace nd4j { int dynamicStitchFunctorBP(nd4j::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList) { auto xType = inputs.at(0)->dataType(); - BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, (inputs, indices, gradInput, outputList), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, (inputs, indices, gradInput, outputList), NUMERIC_TYPES); } void dynamicPartitionFunctorBP(nd4j::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList) { auto xType = input->dataType(); - BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, (input, indices, inputGradientList, outputList), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, (input, indices, inputGradientList, outputList), NUMERIC_TYPES); } - BUILD_SINGLE_TEMPLATE(template void _dynamicPartitionFunctorBP, (NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList);, LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template int _dynamicStitchFunctorBP, (std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList);, LIBND4J_TYPES); - - BUILD_DOUBLE_TEMPLATE(template void _dynamicPartitionFunctor, (nd4j::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector& outputList);, LIBND4J_TYPES, INTEGER_TYPES); - BUILD_DOUBLE_TEMPLATE(template int _dynamicStitchFunctor, (nd4j::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray* output);, LIBND4J_TYPES, INTEGER_TYPES); - - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gather.cu b/libnd4j/include/ops/declarable/helpers/cuda/gather.cu index 5415ddab1..aabd9e949 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gather.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gather.cu @@ -164,13 +164,13 @@ void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* sizeof(Nd4jLong))); NDArray::prepareSpecialUse({output}, {input, pIndices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLauncher, (context->getCudaStream(), numOfSubArrs, input->getSpecialBuffer(), xShapeInfo, xOffsets, pIndices->getSpecialBuffer(), pIndices->getSpecialShapeInfo(), output->getSpecialBuffer(), zShapeInfo, zOffsets), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLauncher, (context->getCudaStream(), numOfSubArrs, input->getSpecialBuffer(), xShapeInfo, xOffsets, pIndices->getSpecialBuffer(), pIndices->getSpecialShapeInfo(), output->getSpecialBuffer(), zShapeInfo, zOffsets), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {input, pIndices}); manager.synchronize(); } else { NDArray::prepareSpecialUse({output}, {input, pIndices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLinear, (context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), pIndices->getSpecialBuffer(), pIndices->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLinear, (context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), pIndices->getSpecialBuffer(), pIndices->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {input, pIndices}); } @@ -181,12 +181,6 @@ void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* } } - -BUILD_DOUBLE_TEMPLATE(template void gatherCudaLauncher, (const cudaStream_t *stream, const int numOfSubArrs, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xOffsets, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zOffsets), NUMERIC_TYPES, INTEGER_TYPES); -BUILD_DOUBLE_TEMPLATE(template void gatherCudaLinear, (const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo), NUMERIC_TYPES, INTEGER_TYPES); - - - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu index 614ac95c1..71dc284a6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu @@ -120,7 +120,6 @@ namespace nd4j { gatherNDCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } - BUILD_DOUBLE_TEMPLATE(template void gatherNDCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo), LIBND4J_TYPES, INTEGER_TYPES); /////////////////////////////////////////////////////////////////// void gatherND(nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) { @@ -137,7 +136,7 @@ namespace nd4j { PointersManager manager(context, "gatherND"); NDArray::prepareSpecialUse({&output}, {&input, &indices}); - BUILD_DOUBLE_SELECTOR(xType, yType, gatherNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, gatherNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({&output}, {&input, &indices}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu index e04b1b57a..eda19ccd8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu @@ -125,7 +125,7 @@ namespace nd4j { double min_val = input.reduceNumber(reduce::SameOps::Min).e(0); double max_val = input.reduceNumber(reduce::SameOps::Max).e(0); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({&output}, {&input}); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu b/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu index 73cae9d80..3e8ec6836 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu @@ -85,7 +85,6 @@ template static void im2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext & context, const void *image, void *columns, const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, int sH, int sW, int pH, int pW, int dH, int dW, double zeroPadVal) { im2colCuda<<>>(image, columns, imShapeInfo, colShapeInfo, sH, sW, pH, pW, dH, dW, zeroPadVal); } -BUILD_SINGLE_TEMPLATE(template void im2colCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext& context, const void *image, void *columns, const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const double zeroPadVal), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void im2col(nd4j::LaunchContext& context, const NDArray& image, NDArray& columns, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) { @@ -96,7 +95,7 @@ void im2col(nd4j::LaunchContext& context, const NDArray& image, NDArray& columns const int blocksPerGrid = (columns.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; NDArray::prepareSpecialUse({&columns}, {&image}); - BUILD_SINGLE_SELECTOR(columns.dataType(), im2colCudaLauncher, (blocksPerGrid, threadsPerBlock, context, image.getSpecialBuffer(), columns.getSpecialBuffer(), image.getSpecialShapeInfo(), columns.getSpecialShapeInfo(), sH, sW, pH, pW, dH, dW, arrZeroPadVal.e(0)), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(columns.dataType(), im2colCudaLauncher, (blocksPerGrid, threadsPerBlock, context, image.getSpecialBuffer(), columns.getSpecialBuffer(), image.getSpecialShapeInfo(), columns.getSpecialShapeInfo(), sH, sW, pH, pW, dH, dW, arrZeroPadVal.e(0)), FLOAT_TYPES); NDArray::registerSpecialUse({&columns}, {&image}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu index cd6887bf0..2cec0a065 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu @@ -85,8 +85,8 @@ namespace helpers { *shouldSelect = shouldSelectShared; } } - template + template static __global__ void copyIndices(void* indices, void* indicesLong, Nd4jLong len) { __shared__ I* indexBuf; __shared__ Nd4jLong* srcBuf; @@ -115,15 +115,15 @@ namespace helpers { sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); // TO DO: sort indices using scales as value row //std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e(i) > scales->e(j);}); - I* indexBuf = reinterpret_cast(indices->specialBuffer()); + auto indexBuf = reinterpret_cast(indices->specialBuffer()); NDArray selectedIndices = NDArrayFactory::create('c', {output->lengthOf()}); int numSelected = 0; int numBoxes = boxes->sizeAt(0); - T* boxesBuf = reinterpret_cast(boxes->specialBuffer()); + auto boxesBuf = reinterpret_cast(boxes->specialBuffer()); - I* selectedIndicesData = reinterpret_cast(selectedIndices.specialBuffer()); - I* outputBuf = reinterpret_cast(output->specialBuffer()); + auto selectedIndicesData = reinterpret_cast(selectedIndices.specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()); bool* shouldSelectD; auto err = cudaMalloc(&shouldSelectD, sizeof(bool)); @@ -138,8 +138,7 @@ namespace helpers { throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to device", err); } - shouldSelectKernel <<< 128, 256, 1024, *stream >>> - (boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, threshold, numSelected, i, shouldSelectD); + shouldSelectKernel<<<128, 256, 1024, *stream>>>(boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, threshold, numSelected, i, shouldSelectD); err = cudaMemcpy(&shouldSelect, shouldSelectD, sizeof(bool), cudaMemcpyDeviceToHost); if (err) { throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to host", err); @@ -161,9 +160,8 @@ namespace helpers { } void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, (context, boxes, scales, maxSize, threshold, output), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, (context, boxes, scales, maxSize, threshold, output), FLOAT_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void nonMaxSuppressionV2_, (nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu index 46f972f44..a0f30a116 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu @@ -34,7 +34,6 @@ namespace nd4j { theFirst->applyPairwiseLambda(theSecond, functor, nullptr); } - BUILD_SINGLE_TEMPLATE(template void reluDerivative__, (NDArray* input, NDArray* epsilon), FLOAT_TYPES); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, (theFirst, theSecond), FLOAT_TYPES); @@ -48,7 +47,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void reluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); @@ -63,8 +61,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void relu6Derivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -78,8 +74,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void leakyReluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -93,8 +87,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void eluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -108,8 +100,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void seluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu index 9ad1ee0ad..017180b38 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu @@ -36,8 +36,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void tanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void tanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -53,8 +51,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void hardTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void hardTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -68,8 +64,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void rationalTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void rationalTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -83,8 +77,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void rectifiedTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void rectifiedTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu index 6d0788c64..defdfaf09 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu @@ -35,8 +35,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void cubeDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -51,8 +49,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void reduceNorm1_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -67,8 +63,6 @@ namespace helpers { logits->applyPairwiseLambda(labels, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmCrossEntropy_, (NDArray* logits, NDArray* labels, NDArray* output);, FLOAT_TYPES); - void sigmCrossEntropy(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, (logits, labels, output), FLOAT_TYPES); } @@ -87,8 +81,6 @@ namespace helpers { logits->applyPairwiseLambda(labels, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmCrossEntropyGrad_, (NDArray* logits, NDArray* labels, NDArray*output);, FLOAT_TYPES); - void sigmCrossEntropyGrad(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, (logits, labels, output), FLOAT_TYPES); } @@ -106,8 +98,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void softSignDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void softSignDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -122,8 +112,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void softPlusDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void softPlusDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -141,8 +129,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmoidDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void sigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -156,8 +142,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void hardSigmoidDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void hardSigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -197,12 +181,10 @@ namespace helpers { void logSumExp(nd4j::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template void logSumExp_, (NDArray* input, NDArray* axis, NDArray*output);, FLOAT_TYPES); void logSumExp(nd4j::LaunchContext * context, NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, subtrah, axis, output), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template void logSumExp_, (NDArray* input, NDArray* subtrah, NDArray* axis, NDArray*output);, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -246,7 +228,7 @@ void weightedCrossEntropyWithLogitsFunctor(nd4j::LaunchContext * context, NDArra NDArray::registerSpecialUse({output}, {targets, input, weights}); } -BUILD_SINGLE_TEMPLATE(template void weightedCrossEntropyWithLogitsFunctor_, (NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output), FLOAT_TYPES); + } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu index baabf6574..f27511b3a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu @@ -148,7 +148,7 @@ namespace helpers { input.syncToDevice(); gradO.syncToDevice(); - BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (block, input, gradO, gradI, depth, bias, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (block, input, gradO, gradI, depth, bias, alpha, beta), FLOAT_TYPES, FLOAT_TYPES); gradI.tickWriteDevice(); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 9c8dff3a5..ffd652ee7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -212,8 +212,6 @@ namespace helpers { invertLowKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); } - BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE); - void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); } @@ -232,8 +230,6 @@ namespace helpers { invertUpKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); } - BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE); - void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); } @@ -562,8 +558,6 @@ namespace helpers { return Status::OK(); } - BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE); - int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); } @@ -612,8 +606,6 @@ namespace helpers { return ND4J_STATUS_OK; } - BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE); - int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matmul.cu b/libnd4j/include/ops/declarable/helpers/cuda/matmul.cu deleted file mode 100644 index 322966836..000000000 --- a/libnd4j/include/ops/declarable/helpers/cuda/matmul.cu +++ /dev/null @@ -1,39 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.12.17. -// - -#include - -namespace nd4j { - namespace ops { - namespace helpers { - template - void __matmul(NDArray *vA, NDArray *vB, NDArray *vC, int transA, int transB, double alpha, double beta) { - - } - - - void _matmul(nd4j::LaunchContext * context, NDArray *vA, NDArray *vB, NDArray *vC, int transA, int transB, double alpha, double beta) { - BUILD_TRIPLE_SELECTOR(vA->dataType(), vB->dataType(), vC->dataType(), __matmul, (vA, vB, vC, transA, transB, alpha, beta), LIBND4J_TYPES, LIBND4J_TYPES, LIBND4J_TYPES); - } - - BUILD_TRIPLE_TEMPLATE(template void __matmul, (NDArray *A, NDArray *B, NDArray *C, int transA, int transB, double alpha, double beta), LIBND4J_TYPES, LIBND4J_TYPES, LIBND4J_TYPES); - } - } -} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu index d3aa58a9c..d5af6328a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu @@ -88,13 +88,10 @@ namespace helpers { void maxPoolingFunctor(nd4j::LaunchContext * context, nd4j::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { NDArray::prepareSpecialUse({values, indices}, {input}); auto yType = indices == nullptr ? nd4j::DataType::INT64 : indices->dataType(); - BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({values, indices}, {input}); } - - BUILD_DOUBLE_TEMPLATE(template void maxPoolingFunctor_, (nd4j::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices), FLOAT_TYPES, INTEGER_TYPES); - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu index 0af1f0eda..a2aec252e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu @@ -107,7 +107,6 @@ namespace nd4j { NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); } - BUILD_SINGLE_TEMPLATE(template void maximumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu index 3c8d159be..ceb748453 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu @@ -79,10 +79,9 @@ namespace nd4j { } void mergeMaxIndex(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void mergeMaxIndex_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), LIBND4J_TYPES, INTEGER_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -128,7 +127,6 @@ namespace nd4j { manager.synchronize(); } - BUILD_SINGLE_TEMPLATE(template void mergeMax_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); void mergeMax(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES); @@ -176,10 +174,9 @@ namespace nd4j { manager.synchronize(); } - BUILD_SINGLE_TEMPLATE(template void mergeAvg_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); void mergeAvg(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -224,10 +221,10 @@ namespace nd4j { manager.synchronize(); } - BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), NUMERIC_TYPES); void mergeAdd(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu b/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu index 2647a53df..ea4a1e146 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu @@ -136,7 +136,7 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////// void meshgrid(nd4j::LaunchContext * context, const std::vector& inArrs, const std::vector& outArrs, const bool swapFirst2Dims) { - BUILD_SINGLE_SELECTOR(inArrs.at(0)->dataType(), meshgrid_, (context, inArrs, outArrs, swapFirst2Dims), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(inArrs.at(0)->dataType(), meshgrid_, (context, inArrs, outArrs, swapFirst2Dims), NUMERIC_TYPES); for (auto v:outArrs) v->tickWriteDevice(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu index 12f888005..75c73f96b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu @@ -109,8 +109,6 @@ namespace nd4j { NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); } - BUILD_SINGLE_TEMPLATE(template void minimumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES); - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu index 80662a19b..aeddd3b97 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu @@ -88,8 +88,7 @@ namespace helpers { BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, (context, input, n, output, reverse), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void nthElementFunctor_, (nd4j::LaunchContext * context, NDArray* input, Nd4jLong n, NDArray* output, bool reverse), LIBND4J_TYPES); - + } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/pad.cu b/libnd4j/include/ops/declarable/helpers/cuda/pad.cu index b268e6366..ef74180c8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/pad.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/pad.cu @@ -128,7 +128,6 @@ namespace nd4j { padCuda<<>>(mode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, padVal); } - BUILD_DOUBLE_TEMPLATE(template void padCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int mode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const void* vPadVal), LIBND4J_TYPES, INTEGER_TYPES); /////////////////////////////////////////////////////////////////// void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) { @@ -144,7 +143,7 @@ namespace nd4j { const auto xType = input.dataType(); const auto yType = paddings.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), mode, input.getSpecialBuffer(), input.getSpecialShapeInfo(), paddings.getSpecialBuffer(), paddings.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), padValue.getSpecialBuffer()), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), mode, input.getSpecialBuffer(), input.getSpecialShapeInfo(), paddings.getSpecialBuffer(), paddings.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), padValue.getSpecialBuffer()), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({&output}, {&input, &paddings, &padValue}); manager.synchronize(); @@ -272,11 +271,9 @@ namespace nd4j { } void mirrorPad(nd4j::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) { - BUILD_DOUBLE_SELECTOR(input.dataType(), paddings.dataType(), mirrorPad_, (context, input, paddings, output, mode), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), paddings.dataType(), mirrorPad_, (context, input, paddings, output, mode), LIBND4J_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void mirrorPad_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode), LIBND4J_TYPES, INTEGER_TYPES); - } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu b/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu index 90b9e5d5f..53cfcc22d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu @@ -160,7 +160,7 @@ void prefix(nd4j::LaunchContext * context, scalar::Ops op, const NDArray* x, NDA PointersManager manager(context, "prefix"); NDArray::prepareSpecialUse({z}, {x}); - BUILD_SINGLE_SELECTOR(x->dataType(), prefixPerBlockCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, x->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), z->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLen, exclusive, reverse), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(x->dataType(), prefixPerBlockCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, x->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), z->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLen, exclusive, reverse), NUMERIC_TYPES); NDArray::registerSpecialUse({z}, {x}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/range.cu b/libnd4j/include/ops/declarable/helpers/cuda/range.cu index 323877e47..7e8ddb2a7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/range.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/range.cu @@ -46,7 +46,7 @@ namespace helpers { BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (context, start, delta, outVector), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void _range, (nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void _range, (nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector), NUMERIC_TYPES); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu index b6f0c215a..776d92c45 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu @@ -415,7 +415,7 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind const auto xType = indices.dataType(); const auto yType = updates.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterLockCudaLauncher, (blocksPerGrid, threadsPerBlock, 1024, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.getSpecialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), indices.lengthOf(), yTadLen, zTadLen), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterLockCudaLauncher, (blocksPerGrid, threadsPerBlock, 1024, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.getSpecialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), indices.lengthOf(), yTadLen, zTadLen), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); } else { @@ -426,7 +426,7 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind const auto xType = indices.dataType(); const auto yType = updates.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); } NDArray::registerSpecialUse({&output}, {&updates, &indices}); @@ -714,7 +714,7 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i const auto xType = indices.dataType(); const auto yType = updates.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDLockCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), updates.getSpecialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.getSpecialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), output.getSpecialShapeInfo(), packX.numberOfTads(), packZ.numberOfTads(), shape::length(packY.primaryShapeInfo())), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDLockCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), updates.getSpecialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.getSpecialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), output.getSpecialShapeInfo(), packX.numberOfTads(), packZ.numberOfTads(), shape::length(packY.primaryShapeInfo())), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); } else { @@ -725,7 +725,7 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i const auto xType = indices.dataType(); const auto yType = updates.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); } NDArray::registerSpecialUse({&output}, {&updates, &indices}); @@ -797,26 +797,18 @@ void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArra if(calcGrad) { NDArray::prepareSpecialUse({&updates}, {&indices}); - BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), INTEGER_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), INDEXING_TYPES, FLOAT_TYPES); NDArray::registerSpecialUse({&updates}, {&indices}); } else { NDArray::prepareSpecialUse({&output}, {&indices, &updates}); - BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), INTEGER_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), INDEXING_TYPES, FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&indices, &updates}); } manager.synchronize(); } - - - -BUILD_DOUBLE_TEMPLATE(template void scatterCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int opCode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); -BUILD_DOUBLE_TEMPLATE(template void scatterLockCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int opCode, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, const Nd4jLong xLen, const Nd4jLong yTadLen, const Nd4jLong zTadLen), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); -BUILD_DOUBLE_TEMPLATE(template void scatterNDCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int opCode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); -BUILD_DOUBLE_TEMPLATE(template void scatterNDLockCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int opCode, const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, const Nd4jLong *zShapeInfo, const Nd4jLong numOfXTads, const Nd4jLong numOfZTads, const Nd4jLong zTadLen), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu index 5d3c4eb52..f1eda6b01 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu @@ -70,7 +70,7 @@ namespace nd4j { NDArray::prepareSpecialUse({&input}, {&updates, &indices}); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterSimple_, (context, opId, input, updates, indices, dimensions), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterSimple_, (context, opId, input, updates, indices, dimensions), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({&input}, {&updates, &indices}); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment.cu index 67cb77b5c..4aa5c762d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment.cu @@ -40,12 +40,9 @@ namespace helpers { } bool segmentIndicesValidate(nd4j::LaunchContext* context , NDArray* indices, NDArray& expected, NDArray& output) { - BUILD_DOUBLE_SELECTOR(output.dataType(), indices->dataType(), return segmentIndicesValidate_, (indices, expected, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output.dataType(), indices->dataType(), return segmentIndicesValidate_, (indices, expected, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template bool segmentIndicesValidate_, (NDArray*, NDArray&, NDArray&), NUMERIC_TYPES, INTEGER_TYPES); - - // -------------------------------------------------------------------------------------------------------------- // // Unsorted segment ops functors implementation // -------------------------------------------------------------------------------------------------------------- // @@ -85,9 +82,9 @@ namespace helpers { } bool unsortedSegmentIndicesValidate(nd4j::LaunchContext* context , NDArray* indices, Nd4jLong expected, Nd4jLong& output) { - BUILD_SINGLE_SELECTOR(indices->dataType(), return unsortedSegmentIndicesValidate_, (context, indices, expected, output), INTEGER_TYPES); + BUILD_SINGLE_SELECTOR(indices->dataType(), return unsortedSegmentIndicesValidate_, (context, indices, expected, output), INDEXING_TYPES); } - BUILD_SINGLE_TEMPLATE(template bool unsortedSegmentIndicesValidate_, (nd4j::LaunchContext* context , NDArray* indices, Nd4jLong expected, Nd4jLong& output), INTEGER_TYPES); + // -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- // @@ -126,9 +123,8 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void fillUpSegments(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens) { - BUILD_SINGLE_SELECTOR(indices->dataType(), fillUpSegments_, (indices, numClasses, classesRangesBegs, classesRangesLens), INTEGER_TYPES); + BUILD_SINGLE_SELECTOR(indices->dataType(), fillUpSegments_, (indices, numClasses, classesRangesBegs, classesRangesLens), INDEXING_TYPES); } - BUILD_SINGLE_TEMPLATE(template void fillUpSegments_, (NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens), INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu index a1792750f..20796b1d1 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu @@ -201,9 +201,8 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void segmentMaxFunctor_, (LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -241,10 +240,9 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); + // -------------------------------------------------------------------------------------------------------------- // // segment max // -------------------------------------------------------------------------------------------------------------- // @@ -371,10 +369,8 @@ namespace nd4j { // -------------------------------------------------------------------------------------------------------------- // int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input, - indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); + indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int segmentMaxFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -418,10 +414,8 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu index 19c50728a..c60272188 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu @@ -186,9 +186,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void segmentMeanFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); + // -------------------------------------------------------------------------------------------------------------- // template static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { @@ -226,10 +226,8 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output), - FLOAT_TYPES, INTEGER_TYPES); + NUMERIC_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMeanFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -351,11 +349,9 @@ namespace helpers { // segmen mean bp main int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input, - indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); + indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); } // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int segmentMeanFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); - // -------------------------------------------------------------------------------------------------------------- // template static int unsortedSegmentMeanFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { @@ -402,12 +398,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMeanFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); - // -------------------------------------------------------------------------------------------------------------- // } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu index b5c76e18d..de602201b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu @@ -192,9 +192,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void segmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // @@ -235,11 +234,9 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INTEGER_TYPES); + NUMERIC_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); template static __global__ void segmentMinBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput, Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, @@ -366,10 +363,8 @@ namespace helpers { // segmen min int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input, - indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); + indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template int segmentMinFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); - // -------------------------------------------------------------------------------------------------------------- // template static int unsortedSegmentMinFunctorBP_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { @@ -412,12 +407,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); - // -------------------------------------------------------------------------------------------------------------- // - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu index 0a7c73040..7454756b5 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu @@ -192,9 +192,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void segmentProdFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -233,10 +232,8 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output), - FLOAT_TYPES, INTEGER_TYPES); + NUMERIC_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -360,11 +357,9 @@ namespace helpers { int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INTEGER_TYPES); + indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int segmentProdFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -407,10 +402,8 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentProdFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu index 6e3ab24d9..875f63e77 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu @@ -147,9 +147,8 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output), - FLOAT_TYPES, INTEGER_TYPES); + FLOAT_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSqrtNFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template static __global__ void segmentSqrtNBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, @@ -270,11 +269,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSqrtNFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu index 4f2cc92a1..1d9d983ef 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu @@ -190,9 +190,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void segmentSumFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); + // -------------------------------------------------------------------------------------------------------------- // template static void unsortedSegmentSumFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { @@ -230,11 +230,9 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INTEGER_TYPES); + NUMERIC_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSumFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // // Backpropagate ops @@ -344,10 +342,8 @@ namespace helpers { int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input, - indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); + indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template int segmentSumFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); - // -------------------------------------------------------------------------------------------------------------- // template static int unsortedSegmentSumFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { @@ -383,10 +379,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSumFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu index fd2f3db6c..150c616a6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu @@ -231,7 +231,6 @@ static void sruBICudaLauncher(const int blocksPerGrid, const int threadsPerBlock sruBICuda<<>>(vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, maskShapeInfo, vht, htShapeInfo, vct, ctShapeInfo); } -BUILD_SINGLE_TEMPLATE(template void sruBICudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, const void* vmask, const Nd4jLong* maskShapeInfo, void* vht, const Nd4jLong* htShapeInfo, void* vct, const Nd4jLong* ctShapeInfo), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void sruBI(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu index 36b369113..db6213dd3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu @@ -101,7 +101,7 @@ int inTopKFunctor(nd4j::LaunchContext * context, const NDArray* predictions, con const auto yType = targets->dataType(); NDArray::prepareSpecialUse({output}, {predictions, targets}); - BUILD_DOUBLE_SELECTOR(xType, yType, inTopKCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), predictions->getSpecialBuffer(), predictions->getSpecialShapeInfo(), targets->getSpecialBuffer(), targets->getSpecialShapeInfo(), output->getSpecialBuffer(), output->getSpecialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets(), k), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, inTopKCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), predictions->getSpecialBuffer(), predictions->getSpecialShapeInfo(), targets->getSpecialBuffer(), targets->getSpecialShapeInfo(), output->getSpecialBuffer(), output->getSpecialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets(), k), FLOAT_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {predictions, targets}); manager.synchronize(); @@ -269,7 +269,7 @@ int inTopKFunctor(nd4j::LaunchContext * context, const NDArray* predictions, con int topKFunctor(nd4j::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { input->syncToDevice(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), topKFunctor_, (context, input, values, indices, k, needSort), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), topKFunctor_, (context, input, values, indices, k, needSort), LIBND4J_TYPES, INDEXING_TYPES); values->tickWriteDevice(); indices->tickWriteDevice(); @@ -277,9 +277,6 @@ int inTopKFunctor(nd4j::LaunchContext * context, const NDArray* predictions, con return Status::OK(); } - - BUILD_DOUBLE_TEMPLATE(template int topKFunctor_, (nd4j::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort), LIBND4J_TYPES, INTEGER_TYPES); - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index 19c726581..bb311ed01 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -68,7 +68,6 @@ __host__ static void invertPermutationCudaLauncher(const int blocksPerGrid, cons invertPermutationCuda<<>>(vx, xShapeInfo, vz, zShapeInfo); } -BUILD_SINGLE_TEMPLATE(template void invertPermutationCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void invertPermutation(nd4j::LaunchContext* context, const NDArray& input, NDArray& output) { @@ -149,7 +148,7 @@ static void traceCudaLauncher(const int blocksPerGrid, const int threadsPerBlock traceCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, diagLen); } -BUILD_SINGLE_TEMPLATE(template void traceCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint diagLen), LIBND4J_TYPES); + /////////////////////////////////////////////////////////////////// void trace(nd4j::LaunchContext* context, const NDArray& input, NDArray& output) { @@ -214,7 +213,6 @@ static void triuBPCudaLauncher(const int blocksPerGrid, const int threadsPerBloc triuBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, diag); } -BUILD_SINGLE_TEMPLATE(template void triuBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag), LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// void triuBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) { @@ -280,7 +278,6 @@ static void tileBPCudaLauncher(const int blocksPerGrid, const int threadsPerBloc tileBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, globMem); } -BUILD_SINGLE_TEMPLATE(template void tileBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// @@ -526,7 +523,7 @@ static void clipByNormBPCudaLauncher(const int blocksPerGrid, const int threadsP else // means tads using clipByNormBPTadsCuda<<>>(vx, xShapeInfo, xTadOffsets, vy, yShapeInfo, yTadOffsets, vz, zShapeInfo, zTadOffsets, static_cast(clipNormVal)); } -BUILD_DOUBLE_TEMPLATE(template void clipByNormBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* xTadOffsets, const void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong* yTadOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal), LIBND4J_TYPES, FLOAT_TYPES); +BUILD_DOUBLE_TEMPLATE(template void clipByNormBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* xTadOffsets, const void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong* yTadOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal), FLOAT_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm) { @@ -547,7 +544,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr if(dimensions.empty() || dimensions.size() == input.rankOf()) { // means whole array const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), nullptr, gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), nullptr, gradI.getSpecialBuffer(), gradI.getSpecialShapeInfo(), nullptr, context->getReductionPointer(), clipNormVal), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), nullptr, gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), nullptr, gradI.getSpecialBuffer(), gradI.getSpecialShapeInfo(), nullptr, context->getReductionPointer(), clipNormVal), FLOAT_TYPES, FLOAT_TYPES); } else { // means tads using @@ -556,7 +553,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(gradI.getShapeInfo(), dimensions); const int blocksPerGrid = packX.numberOfTads(); - BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradO.getSpecialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), gradI.getSpecialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), nullptr, clipNormVal), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradO.getSpecialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), gradI.getSpecialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), nullptr, clipNormVal), FLOAT_TYPES, FLOAT_TYPES); } NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); diff --git a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp index 1ce00f44a..731c5a5f9 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp @@ -341,7 +341,7 @@ namespace nd4j { if (DataTypeUtils::isR(xType)) { COPY_SHAPE(inShape, newShape); - return SHAPELIST(newShape); + return SHAPELIST(CONSTANT(newShape)); } else if (DataTypeUtils::isZ(xType)) { auto zShapeArr = INPUT_VARIABLE(0); auto zShapeVector = zShapeArr->asVectorT(); diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp index d20bf3d04..3e35e2c11 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp @@ -47,7 +47,7 @@ namespace nd4j { Nd4jLong *newShape; COPY_SHAPE(inShape, newShape); - return SHAPELIST(newShape); + return SHAPELIST(CONSTANT(newShape)); } Nd4jStatus LegacyScalarBoolOp::validateAndExecute(Context &block) { diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp index d870f15d0..de8248d25 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp @@ -61,7 +61,7 @@ namespace nd4j { Nd4jLong *newShape; COPY_SHAPE(inShape, newShape); - return SHAPELIST(newShape); + return SHAPELIST(CONSTANT(newShape)); } } } diff --git a/libnd4j/include/play.h b/libnd4j/include/play.h index 1d4ad80dc..ecafe84ea 100644 --- a/libnd4j/include/play.h +++ b/libnd4j/include/play.h @@ -40,7 +40,7 @@ (float, long, long) -BUILD_SINGLE_SELECTOR_THRICE(xType, template class functionName, , DATA_TYPES); +BUILD_SINGLE_TEMPLATE_TWICE(template class functionName, , DATA_TYPES) //BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functionName, (signature), DATA_TYPES, Y_TYPES); diff --git a/libnd4j/include/type_boilerplate.h b/libnd4j/include/type_boilerplate.h index 69ad370b0..bd235726a 100644 --- a/libnd4j/include/type_boilerplate.h +++ b/libnd4j/include/type_boilerplate.h @@ -546,10 +546,13 @@ #ifndef __CLION_IDE__ #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) +#define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} +#define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} + #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, (TYPES_Z), (TYPES_Y), TYPES_X)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); } } @@ -559,8 +562,10 @@ #else #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) +#define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) +#define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) @@ -596,6 +601,12 @@ #define _SELECTOR_SINGLE_THRICE(A, B, C, D) case C: {AB; break;}; #define SELECTOR_SINGLE_THRICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) +#define _SELECTOR_SINGLE_TWICE(A, B, C, D) case C: {AB; break;}; +#define SELECTOR_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) + +#define _TEMPLATE_SINGLE_TWICE(A, B, C, D) AB; +#define TEMPLATE_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) + #define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) case C: {A D, UNPAREN2(B); break;}; #define SELECTOR_PARTIAL_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) @@ -624,6 +635,7 @@ #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) +#define ALL_INDICES nd4j::DataType::INT32, nd4j::DataType::INT64 #define ALL_INTS nd4j::DataType::INT8, nd4j::DataType::UINT8, nd4j::DataType::INT16, nd4j::DataType::UINT16, nd4j::DataType::INT32, nd4j::DataType::UINT32, nd4j::DataType::INT64, nd4j::DataType::UINT64 #define ALL_FLOATS nd4j::DataType::HALF, nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16 diff --git a/libnd4j/include/types/types.h b/libnd4j/include/types/types.h index b11f44c6e..9c8dcb273 100644 --- a/libnd4j/include/types/types.h +++ b/libnd4j/include/types/types.h @@ -76,6 +76,10 @@ (nd4j::DataType::FLOAT32, float), \ (nd4j::DataType::DOUBLE, double) +#define INDEXING_TYPES \ + (nd4j::DataType::INT32, int32_t), \ + (nd4j::DataType::INT64, Nd4jLong) + #define FLOAT_NATIVE \ (nd4j::DataType::FLOAT32, float), \ (nd4j::DataType::DOUBLE, double) diff --git a/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp b/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp index 5d3ce0ea7..0fa4d687d 100644 --- a/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp @@ -34,6 +34,8 @@ public: int dimensionLength = 2; }; +#ifndef __CUDABLAS__ + TEST_F(BroadcastMultiDimTest,MultimDimTest) { shape::TAD *tad = new shape::TAD(); tad->init(inputShapeBuffer,dimensions,dimensionLength); @@ -58,4 +60,6 @@ TEST_F(BroadcastMultiDimTest,MultimDimTest) { } delete tad; -} \ No newline at end of file +} + +#endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu index 65509a1d4..19f107ea4 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu @@ -452,7 +452,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_20) { ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); } - +/* ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_21) { @@ -600,6 +600,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_28) { ASSERT_TRUE(c.equalsTo(&exp)); } + */ ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_1) { @@ -918,6 +919,7 @@ TEST_F(CudaBasicsTests2, mmulMxV_18) { } ////////////////////////////////////////////////////////////////////////// +/* TEST_F(CudaBasicsTests2, mmulMxV_19) { const Nd4jLong M = 3; @@ -1150,4 +1152,5 @@ TEST_F(CudaBasicsTests2, mmulDot_4) { nd4j::MmulHelper::mmul(&x, &y, &z); ASSERT_TRUE(z.equalsTo(&exp)); -} \ No newline at end of file +} + */ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp index 5bf5f8013..501a29b8c 100644 --- a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -27,6 +27,7 @@ #include #include #include +#include using namespace nd4j; @@ -55,9 +56,9 @@ TEST_F(DataTypesValidationTests, Basic_Test_1) { } TEST_F(DataTypesValidationTests, Basic_Test_2) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); weights.assign(2.0); input.linspace(1); @@ -75,10 +76,10 @@ TEST_F(DataTypesValidationTests, Basic_Test_2) { TEST_F(DataTypesValidationTests, Basic_Test_3) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); - auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); weights.assign(2.0); input.linspace(1); @@ -104,6 +105,14 @@ TEST_F(DataTypesValidationTests, Basic_Test_4) { ASSERT_EQ(ND4J_STATUS_VALIDATION, result); } +TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) { + auto x = NDArrayFactory::create('c', {5, 10}); + RandomGenerator gen(119, 120); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), gen, &x, 1, 6); + + ASSERT_TRUE(x.sumNumber().e(0) > 0); +} + TEST_F(DataTypesValidationTests, cast_1) { float16 x = static_cast(1.f); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 8835ecc8e..5375f8fca 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); NDArray* result = results->at(0); -// result->printIndexedBuffer("OOOOUUUUTTT"); + //result->printIndexedBuffer("OOOOUUUUTTT"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1892,7 +1892,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); NDArray* result = results->at(0); -// result->printBuffer("NonMaxSuppression OUtput2"); + result->printBuffer("NonMaxSuppression OUtput2"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 7f66e9be3..6fe3dfac6 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -729,8 +729,8 @@ TEST_F(DeclarableOpsTests12, multiUnique_2) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, tensormmul_6) { - NDArray x('c', {1}, {2}); - NDArray y('c', {2,1,2}, {1,2,3,4}); + NDArray x('c', {1}, {2}, nd4j::DataType::FLOAT32); + NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32); NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32); nd4j::ops::tensormmul op; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 53c1e9a99..06d677b27 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -24,6 +24,7 @@ #include #include #include +#include using namespace nd4j; @@ -363,6 +364,77 @@ TEST_F(DeclarableOpsTests15, test_rank_2) { delete result; } +TEST_F(DeclarableOpsTests15, test_concat_column_1) { + auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); + auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); + auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); + auto z = NDArrayFactory::create('c', {2, 2}); + + nd4j::ops::concat op; + auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); + ASSERT_EQ(Status::OK(), status); + + z.printIndexedBuffer("z"); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests15, test_concat_large_1) { + std::array arrays; + Context context(1); + Nd4jLong axis = 0; + + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < arrays.size(); e++) { + auto array = NDArrayFactory::create_('c', {1, 300}); + array->assign(e); + context.setInputArray(e, array, true); + } + + auto z = NDArrayFactory::create('c', {2000, 300}); + context.setOutputArray(0, &z, false); + context.setIArguments(&axis, 1); + + nd4j::ops::concat op; + op.execute(&context); + + for (int e = 0; e < arrays.size(); e++) { + auto row = z.tensorAlongDimension(e, {1}); + + ASSERT_NEAR((float) e, row->e(0), 1e-5f); + + delete row; + } +} + +TEST_F(DeclarableOpsTests15, test_concat_large_2) { + std::array arrays; + Context context(1); + Nd4jLong axis = 0; + + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < arrays.size(); e++) { + auto array = NDArrayFactory::create_('c', {1, 5, 20}); + array->assign(e); + context.setInputArray(e, array, true); + } + + auto z = NDArrayFactory::create('c', {arrays.size(), 5, 20}); + context.setOutputArray(0, &z, false); + context.setIArguments(&axis, 1); + + nd4j::ops::concat op; + op.execute(&context); + + for (int e = 0; e < arrays.size(); e++) { + auto row = z.tensorAlongDimension(e, {1, 2}); + + ASSERT_NEAR((float) e, row->meanNumber().e(0), 1e-5f); + + delete row; + } +} + TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { auto x0 = NDArrayFactory::create(5); auto x1 = NDArrayFactory::create('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f}); diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index 9f6f3d787..b646bacab 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -1212,6 +1212,18 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) { ASSERT_EQ(e, z); } +TEST_F(JavaInteropTests, test_bfloat16_rng) { + if (!Environment::getInstance()->isCPU()) + return; + + auto z = NDArrayFactory::create('c', {10}); + RandomGenerator rng(119, 323841120L); + bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f}; + execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), args); + z.printIndexedBuffer("z"); + ASSERT_TRUE(z.sumNumber().e(0) > 0); +} + /* TEST_F(JavaInteropTests, Test_Results_Conversion_1) { auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); diff --git a/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu b/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu new file mode 100644 index 000000000..d7632ace5 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu @@ -0,0 +1,125 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace nd4j; +using namespace nd4j::ops; + +class LaunchContextCudaTests : public testing::Test { + // +}; + + +void acquireContext(int threadId, int &deviceId) { + deviceId = AffinityManager::currentDeviceId(); + + nd4j_printf("Creating thread: [%i]; assigned deviceId: [%i];\n", threadId, deviceId); + + auto lc = LaunchContext::defaultContext(); + nd4j_printf("LC: [%p]\n", lc); + + nd4j_printf("reductionPtr: [%p]; stream: [%p];\n", lc->getReductionPointer(), lc->getCudaStream()); +} + +TEST_F(LaunchContextCudaTests, basic_test_1) { + int deviceA, deviceB; + std::thread threadA(acquireContext, 0, std::ref(deviceA)); + std::thread threadB(acquireContext, 1, std::ref(deviceB)); + + threadA.join(); + threadB.join(); + nd4j_printf("All threads joined\n",""); + + if (AffinityManager::numberOfDevices() > 1) + ASSERT_NE(deviceA, deviceB); +} + +void fillArray(int tid, std::vector &arrays) { + auto array = NDArrayFactory::create_('c', {3, 10}); + nd4j_printf("Array created on device [%i]\n", AffinityManager::currentDeviceId()); + array->assign(tid); + arrays[tid] = array; +} + +TEST_F(LaunchContextCudaTests, basic_test_2) { + std::vector arrays(2); + + std::thread threadA(fillArray, 0, std::ref(arrays)); + std::thread threadB(fillArray, 1, std::ref(arrays)); + + threadA.join(); + threadB.join(); + + for (int e = 0; e < 2; e++) { + auto array = arrays[e]; + ASSERT_EQ(e, array->e(0)); + + delete array; + } +} + +void initAffinity(int tid, std::vector &aff) { + auto affinity = AffinityManager::currentDeviceId(); + aff[tid] = affinity; + nd4j_printf("Thread [%i] affined with device [%i]\n", tid, affinity); +} + +TEST_F(LaunchContextCudaTests, basic_test_3) { + auto totalThreads = AffinityManager::numberOfDevices() * 4; + nd4j_printf("Total threads: %i\n", totalThreads); + std::vector affinities(totalThreads); + + for (int e = 0; e < totalThreads; e++) { + std::thread thread(initAffinity, e, std::ref(affinities)); + + thread.join(); + } + + std::vector hits(AffinityManager::numberOfDevices()); + std::fill(hits.begin(), hits.end(), 0); + + // we need to make sure all threads were attached to "valid" devices + for (int e = 0; e < totalThreads; e++) { + auto aff = affinities[e]; + ASSERT_TRUE(aff >= 0 && aff < AffinityManager::numberOfDevices()); + + hits[aff]++; + } + + // now we check if all devices got some threads + for (int e = 0; e < AffinityManager::numberOfDevices(); e++) { + ASSERT_GT(hits[e], 0); + } +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index 9151b70bd..aeac06ccb 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -465,6 +465,7 @@ TEST_F(LegacyOpsTests, PowDerivative_1) { ASSERT_TRUE(exp.equalsTo(&x)); } +#ifndef __CUDABLAS__ TEST_F(LegacyOpsTests, reduce3_1) { Nd4jLong yShape[2] = {4,4}; @@ -494,6 +495,8 @@ TEST_F(LegacyOpsTests, reduce3_1) { delete[] xShapeBuffer; } +#endif + TEST_F(LegacyOpsTests, Reduce3_2) { auto x = NDArrayFactory::create('c', {5, 5}); diff --git a/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp b/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp index 2545bf919..e4c28e9ba 100644 --- a/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp @@ -32,6 +32,7 @@ public: int dimensionLength = 1; }; +#ifndef __CUDABLAS__ TEST_F(EqualsTest,Eps) { auto val = nd4j::NDArrayFactory::create(0.0f); @@ -45,3 +46,5 @@ TEST_F(EqualsTest,Eps) { val.shapeInfo()); ASSERT_TRUE(val.e(0) < 0.5); } + +#endif diff --git a/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp b/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp index c6bb7af5b..608ee443f 100644 --- a/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp @@ -31,13 +31,17 @@ class QuantizationTests : public testing::Test { }; TEST_F(QuantizationTests, Basic_Test_1) { +#ifndef __CUDABLAS__ auto s = TypeCast::estimateQuantizedSize(10); ASSERT_EQ(18, s); +#endif } TEST_F(QuantizationTests, Basic_Test_2) { +#ifndef __CUDABLAS__ auto s = TypeCast::estimateQuantizedSize(1); ASSERT_EQ(9, s); +#endif } TEST_F(QuantizationTests, Compression_Test_1) { diff --git a/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp b/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp index 0877632d5..4df0f3dc8 100644 --- a/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp @@ -62,7 +62,7 @@ public: std::vector dim = {1, 2, 3}; }; - +#ifndef __CUDABLAS__ TEST_F(EuclideanDistanceTest,Test1) { //int *tadShapeBuffer = shape::computeResultShape(shapeBuffer,dimension,dimensionLength); nd4j::ArrayOptions::setDataType(shapeBuffer, nd4j::DataType::FLOAT32); @@ -152,4 +152,6 @@ TEST_F(ReduceTest,MatrixTest) { delete tad; delete[] xShapeInfo; -} \ No newline at end of file +} + +#endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp b/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp index 7ea44aa17..1f352dd2e 100644 --- a/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp @@ -32,6 +32,7 @@ public: }; TEST_F(TypeCastTests, Test_Cast_1) { +#ifndef __CUDABLAS__ const int limit = 100; auto src = new double[limit]; auto z = new float[limit]; @@ -51,6 +52,7 @@ TEST_F(TypeCastTests, Test_Cast_1) { delete[] src; delete[] z; delete[] exp; +#endif } TEST_F(TypeCastTests, Test_ConvertDtype_1) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InferenceAdapter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InferenceAdapter.java new file mode 100644 index 000000000..671ef613d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InferenceAdapter.java @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.adapters; + +/** + * This interface describes methods needed to convert custom JVM objects to INDArrays, suitable for feeding neural networks + * + * @param type of the Input for the model. I.e. String for raw text + * @param type of the Output for the model, I.e. Sentiment, for Text->Sentiment extraction + * + * @author raver119@gmail.com + */ +public interface InferenceAdapter extends InputAdapter, OutputAdapter { +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InputAdapter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InputAdapter.java new file mode 100644 index 000000000..09ca4eaa6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/InputAdapter.java @@ -0,0 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.adapters; + +import org.nd4j.linalg.dataset.MultiDataSet; + +/** + * This interface describes method for transformation from object of type I to MultiDataSet. + * + */ +public interface InputAdapter { + /** + * This method converts input object to MultiDataSet + * @param input + * @return + */ + MultiDataSet apply(I input); +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/OutputAdapter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/OutputAdapter.java similarity index 88% rename from deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/OutputAdapter.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/OutputAdapter.java index 4b32e83c2..ba1ff40d4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/OutputAdapter.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/adapters/OutputAdapter.java @@ -14,15 +14,15 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.nn.api; +package org.nd4j.adapters; import org.nd4j.linalg.api.ndarray.INDArray; import java.io.Serializable; /** - * This interface describes entity used to conver neural network output to specified class. - * I.e. INDArray -> int[] on the fly. + * This interface describes entity used to convert neural network output to specified class. + * I.e. INDArray -> int[] or INDArray -> Sentiment on the fly * * PLEASE NOTE: Implementation will be used in workspace environment to avoid additional allocations during inference. * This means you shouldn't store or return the INDArrays passed to OutputAdapter.apply(INDArray...) directly. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java index 15ae181f2..5625db5a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java @@ -34,20 +34,6 @@ public interface AffinityManager { */ Integer getDeviceForCurrentThread(); - /** - * This method returns deviceId for specified thread - * @param thread - * @return - */ - Integer getDeviceForThread(Thread thread); - - /** - * This method returns deviceId for specified threadId - * - * @param threadId - * @return - */ - Integer getDeviceForThread(long threadId); /** * This method returns id of current device for a given INDArray @@ -57,23 +43,6 @@ public interface AffinityManager { */ Integer getDeviceForArray(INDArray array); - /** - * This method attaches specified thread to specified device - * - * @param thread - * @param deviceId - */ - void attachThreadToDevice(Thread thread, Integer deviceId); - - - /** - * This method attaches specified thread (by Id) to specified device - * - * @param threadId java ID of the thread - * @param deviceId - */ - void attachThreadToDevice(long threadId, Integer deviceId); - /** * This method returns number of available devices * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java index 40947e7fc..ad0320825 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java @@ -28,26 +28,6 @@ public abstract class BasicAffinityManager implements AffinityManager { return 0; } - @Override - public Integer getDeviceForThread(Thread thread) { - return 0; - } - - @Override - public Integer getDeviceForThread(long threadId) { - return 0; - } - - @Override - public void attachThreadToDevice(Thread thread, Integer deviceId) { - // no-op - } - - @Override - public void attachThreadToDevice(long threadId, Integer deviceId) { - // no-op - } - @Override public Integer getDeviceForArray(INDArray array) { return 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index 3b772dba9..b1c4b34ad 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -25,6 +25,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.impl.shape.bp.ConcatBp; @@ -43,6 +44,12 @@ public class Concat extends DynamicCustomOp { } + public Concat(int concatDimension, INDArray... arrays) { + super(null, arrays, new INDArray[0]); + this.concatDimension = concatDimension; + addIArgument(concatDimension); + } + public Concat(SameDiff sameDiff, int concatDimension, SDVariable... inputs){ super(null, sameDiff, inputs); addIArgument(concatDimension); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java index e8640bea6..500a1e123 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java @@ -129,13 +129,8 @@ public class AsyncDataSetIterator implements DataSetIterator { if (iterator.resetSupported() && !iterator.hasNext()) this.backedIterator.reset(); - this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, null); + this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, null, deviceId); - /** - * We want to ensure, that background thread will have the same thread->device affinity, as master thread - */ - - Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); thread.setDaemon(true); thread.start(); } @@ -229,12 +224,7 @@ public class AsyncDataSetIterator implements DataSetIterator { backedIterator.reset(); shouldWork.set(true); - this.thread = new AsyncPrefetchThread(buffer, backedIterator, terminator, null); - - /** - * We want to ensure, that background thread will have the same thread->device affinity, as master thread - */ - Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); + this.thread = new AsyncPrefetchThread(buffer, backedIterator, terminator, null, deviceId); thread.setDaemon(true); thread.start(); @@ -391,13 +381,15 @@ public class AsyncDataSetIterator implements DataSetIterator { .policySpill(SpillPolicy.REALLOCATE).build(); private MemoryWorkspace workspace; + private final int deviceId; protected AsyncPrefetchThread(@NonNull BlockingQueue queue, @NonNull DataSetIterator iterator, - @NonNull DataSet terminator, MemoryWorkspace workspace) { + @NonNull DataSet terminator, MemoryWorkspace workspace, int deviceId) { this.queue = queue; this.iterator = iterator; this.terminator = terminator; + this.deviceId = deviceId; this.setDaemon(true); this.setName("ADSI prefetch thread"); @@ -405,6 +397,7 @@ public class AsyncDataSetIterator implements DataSetIterator { @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); externalCall(); try { if (useWorkspace) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java index 078db549a..ac20ce66d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java @@ -116,12 +116,7 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { if (iterator.resetSupported() && !iterator.hasNext()) this.backedIterator.reset(); - this.thread = new AsyncPrefetchThread(buffer, iterator, terminator); - - /** - * We want to ensure, that background thread will have the same thread->device affinity, as master thread - */ - Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); + this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, deviceId); thread.setDaemon(true); thread.start(); @@ -207,12 +202,7 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { backedIterator.reset(); shouldWork.set(true); - this.thread = new AsyncPrefetchThread(buffer, backedIterator, terminator); - - /** - * We want to ensure, that background thread will have the same thread->device affinity, as master thread - */ - Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); + this.thread = new AsyncPrefetchThread(buffer, backedIterator, terminator, deviceId); thread.setDaemon(true); thread.start(); @@ -340,13 +330,15 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { private MemoryWorkspace workspace; + private final int deviceId; + protected AsyncPrefetchThread(@NonNull BlockingQueue queue, - @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator) { + @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator, int deviceId) { this.queue = queue; this.iterator = iterator; this.terminator = terminator; - + this.deviceId = deviceId; this.setDaemon(true); this.setName("AMDSI prefetch thread"); @@ -354,6 +346,7 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); externalCall(); try { if (useWorkspaces) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index b82ce008a..429782c3e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -58,6 +58,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.shape.Diag; import org.nd4j.linalg.api.ops.impl.shape.DiagPart; import org.nd4j.linalg.api.ops.impl.shape.Stack; +import org.nd4j.linalg.api.ops.impl.shape.Tile; import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse; import org.nd4j.linalg.api.ops.random.custom.RandomExponential; import org.nd4j.linalg.api.ops.random.impl.*; @@ -2222,7 +2223,6 @@ public class Nd4j { * Write an ndarray to a writer * @param writer the writer to write to * @param write the ndarray to write - * @throws IOException */ public static void write(OutputStream writer, INDArray write) throws IOException { DataOutputStream stream = new DataOutputStream(writer); @@ -2230,14 +2230,12 @@ public class Nd4j { stream.close(); } - /** * Convert an ndarray to a byte array * @param arr the array to convert * @return the converted byte array - * @throws IOException */ - public static byte[] toByteArray(INDArray arr) throws IOException { + public static byte[] toByteArray(@NonNull INDArray arr) throws IOException { if (arr.length() * arr.data().getElementSize() > Integer.MAX_VALUE) throw new ND4JIllegalStateException(""); @@ -2252,15 +2250,13 @@ public class Nd4j { * Read an ndarray from a byte array * @param arr the array to read from * @return the deserialized ndarray - * @throws IOException */ - public static INDArray fromByteArray(byte[] arr) throws IOException { + public static INDArray fromByteArray(@NonNull byte[] arr) throws IOException { ByteArrayInputStream bis = new ByteArrayInputStream(arr); INDArray ret = read(bis); return ret; } - /** * Read line via input streams * @@ -2280,7 +2276,6 @@ public class Nd4j { * @param split the split separator * @param charset the charset * @return the deserialized array. - * @throws IOException */ public static INDArray readNumpy(@NonNull DataType dataType, @NonNull InputStream filePath, @NonNull String split, @NonNull Charset charset) throws IOException { BufferedReader reader = new BufferedReader(new InputStreamReader(filePath, charset)); @@ -2449,7 +2444,7 @@ public class Nd4j { } //parse data if (lineNum > 5) { - String[] entries = line.replace("\\],", "").replaceAll("\\]", "").replaceAll("\\[", "").split(sep); + String[] entries = line.replace("\\],", "").replaceAll("]", "").replaceAll("\\[", "").split(sep); if (rank == 0) { try { newArr.addi((format.parse(entries[0])).doubleValue()); @@ -2496,7 +2491,6 @@ public class Nd4j { * @return NDArray */ public static INDArray readTxt(String filePath) { - String sep = ","; File file = new File(filePath); InputStream is = null; try { @@ -2505,13 +2499,7 @@ public class Nd4j { } catch (FileNotFoundException e) { throw new RuntimeException(e); } finally { - if (is != null) { - try { - is.close(); - } catch (IOException e) { - e.printStackTrace(); - } - } + IOUtils.closeQuietly(is); } } @@ -2552,9 +2540,9 @@ public class Nd4j { */ public static INDArray createArrayFromShapeBuffer(DataBuffer data, Pair shapeInfo) { int rank = Shape.rank(shapeInfo.getFirst()); - long offset = Shape.offset(shapeInfo.getFirst()); + // removed offset parameter that called a deprecated method which always returns 0. INDArray result = Nd4j.create(data, toIntArray(rank, Shape.shapeOf(shapeInfo.getFirst())), - toIntArray(rank, Shape.stride(shapeInfo.getFirst())), offset, Shape.order(shapeInfo.getFirst())); + toIntArray(rank, Shape.stride(shapeInfo.getFirst())), 0, Shape.order(shapeInfo.getFirst())); if (data instanceof CompressedDataBuffer) result.markAsCompressed(true); @@ -2566,9 +2554,8 @@ public class Nd4j { * * @param dis the data input stream to read from * @return the ndarray - * @throws IOException */ - public static INDArray read(DataInputStream dis) throws IOException { + public static INDArray read(DataInputStream dis) { val headerShape = BaseDataBuffer.readHeader(dis); var shapeInformation = Nd4j.createBufferDetached(new long[]{headerShape.getMiddle().longValue()}, headerShape.getRight()); @@ -2597,7 +2584,6 @@ public class Nd4j { * * @param arr the array to write * @param dataOutputStream the data output stream to write to - * @throws IOException */ public static void write(INDArray arr, DataOutputStream dataOutputStream) throws IOException { //BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here @@ -2614,7 +2600,6 @@ public class Nd4j { * Save an ndarray to the given file * @param arr the array to save * @param saveTo the file to save to - * @throws IOException */ public static void saveBinary(INDArray arr, File saveTo) throws IOException { BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(saveTo)); @@ -2629,7 +2614,6 @@ public class Nd4j { * Read a binary ndarray from the given file * @param read the nd array to read * @return the loaded ndarray - * @throws IOException */ public static INDArray readBinary(File read) throws IOException { BufferedInputStream bis = new BufferedInputStream(new FileInputStream(read)); @@ -2645,21 +2629,9 @@ public class Nd4j { * @param arr the array to clear */ public static void clearNans(INDArray arr) { - //BooleanIndexing.applyWhere(arr, Conditions.isNan(), new Value(Nd4j.EPS_THRESHOLD)); getExecutioner().exec(new ReplaceNans(arr, Nd4j.EPS_THRESHOLD)); } - - /** - * Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc - * - * @param reverse the matrix to reverse - * @return the reversed matrix - */ - public static INDArray rot(INDArray reverse) { - INDArray ret = INSTANCE.rot(reverse); - return ret; - } - + /** * Reverses the passed in matrix such that m[0] becomes m[m.length - 1] etc * @@ -2667,8 +2639,6 @@ public class Nd4j { * @return the reversed matrix */ public static INDArray reverse(INDArray reverse) { - //INDArray ret = INSTANCE.reverse(reverse); - //logCreationIfNecessary(ret); return Nd4j.getExecutioner().exec(new OldReverse(reverse)); } @@ -2682,8 +2652,7 @@ public class Nd4j { * @return the 1D range vector */ public static INDArray arange(double begin, double end, double step) { - INDArray ret = INSTANCE.arange(begin, end, step); - return ret; + return INSTANCE.arange(begin, end, step); } /** @@ -2693,8 +2662,7 @@ public class Nd4j { * See {@link #arange(double, double, double)} with step size 1. */ public static INDArray arange(double begin, double end) { - INDArray ret = INSTANCE.arange(begin, end, 1); - return ret; + return INSTANCE.arange(begin, end, 1); } /** @@ -2717,27 +2685,6 @@ public class Nd4j { INSTANCE.copy(a, b); } - /** - * Creates a new matrix where the values of the given vector are the diagonal values of - * the matrix if a vector is passed in, if a matrix is returns the kth diagonal - * in the matrix - * - * @param x the diagonal values - * @param k the kth diagonal to get - * @return new matrix - */ - public static INDArray diag(INDArray x, int k) { - INDArray ret; - if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) { - ret = Nd4j.create(new long[]{x.length(), x.length()}); - Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret})); - } else { - ret = Nd4j.createUninitialized(new long[]{Math.min(x.size(0), x.size(1))}); - Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret)); - } - return ret; - } - /** * Creates a new matrix where the values of the given vector are the diagonal values of * the matrix if a vector is passed in, if a matrix is returns the kth diagonal @@ -2747,7 +2694,15 @@ public class Nd4j { * @return new matrix */ public static INDArray diag(INDArray x) { - return diag(x, 0); + INDArray ret; + if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) { + ret = Nd4j.create(x.dataType(), x.length(), x.length()); + Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret})); + } else { + ret = Nd4j.createUninitialized(x.dataType(), Math.min(x.size(0), x.size(1))); + Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret)); + } + return ret; } /** @@ -2813,23 +2768,9 @@ public class Nd4j { } public static INDArray appendBias(@NonNull INDArray... vectors) { - INDArray ret = INSTANCE.appendBias(vectors); - return ret; + return INSTANCE.appendBias(vectors); } - /** - * Perform an operation along a diagonal - * - * @param x the ndarray to perform the operation on - * @param func the operation to perform - */ - public static void doAlongDiagonal(INDArray x, Function func) { - if (x.isMatrix()) - for (int i = 0; i < x.rows(); i++) - x.put(i, i, func.apply(x.getDouble(i, i))); - } - - ////////////////////// RANDOM /////////////////////////////// /** @@ -2905,7 +2846,7 @@ public class Nd4j { * @return the random ndarray with the specified shape */ public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) { - INDArray ret = Nd4j.createUninitialized(dataType, shape, order); //INSTANCE.rand(order, shape); + INDArray ret = Nd4j.createUninitialized(dataType, shape, order); return rand(ret); } @@ -2920,7 +2861,7 @@ public class Nd4j { * @return the random ndarray with the specified shape */ public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) { - INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()); //INSTANCE.rand(order, shape); + INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()); return rand(ret); } @@ -2935,7 +2876,7 @@ public class Nd4j { if (rows < 1 || columns < 1) throw new ND4JIllegalStateException("Number of rows and columns should be positive for new INDArray"); - INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order());//INSTANCE.rand(rows, columns, Nd4j.getRandom()); + INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order()); return rand(ret); } @@ -2974,7 +2915,6 @@ public class Nd4j { */ @Deprecated public static INDArray rand(int[] shape, long seed) { - INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed); return rand(seed, ArrayUtil.toLongArray(shape)); } @@ -3038,8 +2978,6 @@ public class Nd4j { * @return the random ndarray with the specified shape */ public static INDArray rand(@NonNull Distribution dist, @NonNull long... shape) { - //INDArray ret = INSTANCE.rand(shape, dist); - //logCreationIfNecessary(ret); return dist.sample(shape); } @@ -3052,7 +2990,7 @@ public class Nd4j { * @return the random ndarray with the specified shape */ public static INDArray rand(int rows, int columns, @NonNull org.nd4j.linalg.api.rng.Random rng) { - INDArray ret = createUninitialized(new int[] {rows, columns}, order());//INSTANCE.rand(rows, columns, rng); + INDArray ret = createUninitialized(new int[] {rows, columns}, order()); return rand(ret, rng); } @@ -3069,7 +3007,7 @@ public class Nd4j { */ @Deprecated public static INDArray rand(long[] shape, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) { - INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng); + INDArray ret = createUninitialized(shape, order()); return rand(ret, min, max, rng); } @@ -3083,7 +3021,7 @@ public class Nd4j { * @return a random matrix of the specified shape and range */ public static INDArray rand(double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) { - INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, min, max, rng); + INDArray ret = createUninitialized(shape, order()); return rand(ret, min, max, rng); } @@ -3098,7 +3036,7 @@ public class Nd4j { * @return a drandom matrix of the specified shape and range */ public static INDArray rand(int rows, int columns, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) { - INDArray ret = createUninitialized(rows, columns);//INSTANCE.rand(rows, columns, min, max, rng); + INDArray ret = createUninitialized(rows, columns); return rand(ret, min, max, rng); } @@ -3127,7 +3065,7 @@ public class Nd4j { * Create a ndarray of the given shape and data type with values from N(0,1) * * @param shape the shape of the ndarray - * @return + * @return new array with random values */ public static INDArray randn(@NonNull DataType dataType, @NonNull int... shape) { return randn(dataType, ArrayUtil.toLongArray(shape)); @@ -3207,7 +3145,7 @@ public class Nd4j { * Random normal N(0, 1) using the specified seed * * @param shape the shape of the array - * @return + * @return new array with random values */ public static INDArray randn(long seed, @NonNull long... shape) { INDArray ret = Nd4j.createUninitialized(shape, order()); @@ -3219,7 +3157,7 @@ public class Nd4j { * * @param rows the number of rows in the matrix * @param columns the number of columns in the matrix - * @return + * @return new array with random values */ public static INDArray randn(long rows, long columns) { INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order()); @@ -3243,7 +3181,7 @@ public class Nd4j { * * @param rows the number of rows in the matrix * @param columns the number of columns in the matrix - * @return + * @return new array with random values */ public static INDArray randn(long rows, long columns, long seed) { INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order()); @@ -3256,7 +3194,7 @@ public class Nd4j { * @param rows the number of rows in the matrix * @param columns the number of columns in the matrix * @param r the random generator to use - * @return + * @return new array with random values */ public static INDArray randn(long rows, long columns, @NonNull org.nd4j.linalg.api.rng.Random r) { INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order()); @@ -3284,7 +3222,7 @@ public class Nd4j { * * @param shape the shape of the array * @param r the random generator to use - * @return + * @return new array with random values */ public static INDArray randn(@NonNull org.nd4j.linalg.api.rng.Random r, @NonNull long... shape) { final INDArray ret = Nd4j.createUninitialized(shape, order()); @@ -3455,9 +3393,9 @@ public class Nd4j { * * PLEASE NOTE: memory of underlying array will be NOT initialized, and won't be set to 0.0 * - * @param rows - * @param columns - * @return + * @param rows rows + * @param columns columns + * @return uninitialized 2D array of rows x columns */ public static INDArray createUninitialized(long rows, long columns) { return createUninitialized(new long[] {rows, columns}); @@ -3578,7 +3516,7 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(double[][][] data) { - return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length}); + return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length); } /** @@ -3587,7 +3525,7 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(float[][][] data) { - return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length}); + return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length); } /** @@ -3605,7 +3543,7 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(double[][][][] data) { - return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length, data[0][0][0].length}); + return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length, data[0][0][0].length); } /** @@ -3614,7 +3552,7 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(float[][][][] data) { - return create(ArrayUtil.flatten(data), new int[] {data.length, data[0].length, data[0][0].length, data[0][0][0].length}); + return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length, data[0][0][0].length); } /** @@ -3655,8 +3593,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray create(float[] data, char order) { - INDArray ret = INSTANCE.create(data, order); - return ret; + return INSTANCE.create(data, order); } /** @@ -3667,8 +3604,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray create(double[] data, char order) { - INDArray ret = INSTANCE.create(data, order); - return ret; + return INSTANCE.create(data, order); } /** @@ -3679,8 +3615,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray create(int columns, char order) { - INDArray ret = INSTANCE.create(new long[] {columns}, Nd4j.getStrides(new long[] {columns}, order), 0, order); - return ret; + return INSTANCE.create(new long[] {columns}, Nd4j.getStrides(new long[] {columns}, order), 0, order); } /** @@ -3765,8 +3700,7 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(int[] data, long[] shape, long[]strides, char order, DataType type) { - val ret = INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace()); - return ret; + return INSTANCE.create(data, shape, strides, order, type, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** @@ -5387,25 +5321,7 @@ public class Nd4j { * @return the tiled ndarray */ public static INDArray tile(INDArray tile, @NonNull int... repeat) { - int d = repeat.length; - long[] shape = ArrayUtil.copy(tile.shape()); - long n = Math.max(tile.length(), 1); - if (d < tile.rank()) { - repeat = Ints.concat(ArrayUtil.nTimes(tile.rank() - d, 1), repeat); - } - for (int i = 0; i < shape.length; i++) { - if (repeat[i] != 1) { - tile = tile.reshape(-1, n).repeat(0, repeat[i]); - } - - long in = shape[i]; - long nOut = in * repeat[i]; - shape[i] = nOut; - n /= Math.max(in, 1); - - } - - return tile.reshape(shape); + return Nd4j.exec(new Tile(new INDArray[]{tile}, new INDArray[]{}, repeat))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java index a38a4a198..26d850366 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java @@ -57,13 +57,12 @@ public class DeallocatorService { log.debug("Starting deallocator thread {}", e + 1); queues[e] = new ReferenceQueue<>(); + int deviceId = e % numDevices; // attaching queue to its own thread - deallocatorThreads[e] = new DeallocatorServiceThread(queues[e], e); + deallocatorThreads[e] = new DeallocatorServiceThread(queues[e], e, deviceId); deallocatorThreads[e].setName("DeallocatorServiceThread_" + e); deallocatorThreads[e].setDaemon(true); - int deviceId = e % numDevices; - Nd4j.getAffinityManager().attachThreadToDevice(deallocatorThreads[e], deviceId); deviceMap.get(deviceId).add(queues[e]); deallocatorThreads[e].start(); @@ -87,16 +86,19 @@ public class DeallocatorService { private final ReferenceQueue queue; private final int threadIdx; public static final String DeallocatorThreadNamePrefix = "DeallocatorServiceThread thread "; + private final int deviceId; - private DeallocatorServiceThread(@NonNull ReferenceQueue queue, int threadIdx) { + private DeallocatorServiceThread(@NonNull ReferenceQueue queue, int threadIdx, int deviceId) { this.queue = queue; this.threadIdx = threadIdx; this.setName(DeallocatorThreadNamePrefix + threadIdx); + this.deviceId = deviceId; setContextClassLoader(null); } @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); boolean canRun = true; long cnt = 0; while (canRun) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index b3885962a..174be9a7d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1175,4 +1175,14 @@ public interface NativeOps { String runFullBenchmarkSuit(boolean printOut); long getCachedMemory(int deviceId); + + OpaqueLaunchContext defaultLaunchContext(); + + Pointer lcScalarPointer(OpaqueLaunchContext lc); + Pointer lcReductionPointer(OpaqueLaunchContext lc); + Pointer lcAllocationPointer(OpaqueLaunchContext lc); + Pointer lcExecutionStream(OpaqueLaunchContext lc); + Pointer lcCopyStream(OpaqueLaunchContext lc); + Pointer lcBlasHandle(OpaqueLaunchContext lc); + Pointer lcSolverHandle(OpaqueLaunchContext lc); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueLaunchContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueLaunchContext.java new file mode 100644 index 000000000..d5f3df5e8 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueLaunchContext.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.nativeblas; + +import org.bytedeco.javacpp.Pointer; + +/** + * + * @author saudet + */ +public class OpaqueLaunchContext extends Pointer { + public OpaqueLaunchContext(Pointer p) { super(p); } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/Allocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/Allocator.java index 6ebdeda6f..df45f85e5 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/Allocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/Allocator.java @@ -17,8 +17,6 @@ package org.nd4j.jita.allocator; import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.context.ExternalContext; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AllocationShape; @@ -50,7 +48,7 @@ public interface Allocator { * * @return */ - ExternalContext getDeviceContext(); + CudaContext getDeviceContext(); /** * This methods specifies Mover implementation to be used internally @@ -170,8 +168,6 @@ public interface Allocator { FlowController getFlowController(); - ContextPool getContextPool(); - DataBuffer getConstantBuffer(int[] array); DataBuffer getConstantBuffer(float[] array); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPack.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPack.java deleted file mode 100644 index aee69924e..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPack.java +++ /dev/null @@ -1,62 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.allocator.context; - -import lombok.Getter; -import lombok.NonNull; -import lombok.Setter; -import org.apache.commons.lang3.RandomUtils; -import org.nd4j.linalg.jcublas.context.CudaContext; - -import java.util.HashMap; -import java.util.Map; - -/** - * @author raver119@gmail.com - */ -public class ContextPack { - @Getter - @Setter - private Integer deviceId; - @Getter - private int availableLanes; - private Map lanes = new HashMap<>(); - - public ContextPack(int totalLanes) { - availableLanes = totalLanes; - } - - public ContextPack(CudaContext context) { - this.availableLanes = 1; - lanes.put(0, context); - } - - public void addLane(@NonNull Integer laneId, @NonNull CudaContext context) { - lanes.put(laneId, context); - context.setLaneId(laneId); - } - - public CudaContext getContextForLane(Integer laneId) { - return lanes.get(laneId); - } - - public int nextRandomLane() { - if (availableLanes == 1) - return 0; - return RandomUtils.nextInt(0, availableLanes); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/BasicContextPool.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/BasicContextPool.java deleted file mode 100644 index 5e2eb20e6..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/BasicContextPool.java +++ /dev/null @@ -1,318 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.allocator.context.impl; - -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.commons.lang3.RandomUtils; -import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.context.ContextPack; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.pointers.cuda.CUcontext; -import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; -import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; -import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t; -import org.nd4j.linalg.jcublas.context.CudaContext; -import org.nd4j.nativeblas.NativeOps; -import org.nd4j.nativeblas.NativeOpsHolder; - - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.Semaphore; - -import org.bytedeco.cuda.cublas.*; -import org.bytedeco.cuda.cusolver.*; -import static org.bytedeco.cuda.global.cublas.*; -import static org.bytedeco.cuda.global.cusolver.*; - -/** - * This is context pool implementation, addressing shared cublas allocations together with shared stream pools - * - * Each context given contains: - * 1. Stream for custom kernel invocations. - * 2. cuBLAS handle tied with separate stream. - * - * @author raver119@gmail.com - */ -@Slf4j -public class BasicContextPool implements ContextPool { - // TODO: number of max threads should be device-dependant - protected static final int MAX_STREAMS_PER_DEVICE = Integer.MAX_VALUE - 1; - - protected volatile Map cuPool = new ConcurrentHashMap<>(); - - protected volatile Map cublasPool = new ConcurrentHashMap<>(); - protected volatile Map solverPool = new ConcurrentHashMap<>(); - - protected volatile Map contextsPool = new ConcurrentHashMap<>(); - - protected volatile Map> contextsForDevices = new ConcurrentHashMap<>(); - - protected Semaphore lock = new Semaphore(1); - - protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - - public BasicContextPool() { - - } - - public boolean containsContextForThread(long threadId) { - return contextsPool.containsKey(threadId); - } - - public CudaContext getContextForDevice(Integer deviceId) { - return acquireContextForDevice(deviceId); - } - - @Override - public CudaContext acquireContextForDevice(Integer deviceId) { - /* - We should check, if we have context for this specific thread/device - If we don't have context for this thread - we should stick to one of existent contexts available at pool - */ - Long threadId = Thread.currentThread().getId(); - if (!contextsPool.containsKey(threadId)) { - // we don't have attached context for this thread. we should pick up existing context for target device (if any). - - try { - // this is lockable thing, but since it locks once per thread initialization, performance impact won't be big - lock.acquire(); - - if (!contextsForDevices.containsKey(deviceId)) { - contextsForDevices.put(deviceId, new ConcurrentHashMap()); - } - - // if we hadn't hit MAX_STREAMS_PER_DEVICE limit - we add new stream. Otherwise we use random one. - if (contextsForDevices.get(deviceId).size() < MAX_STREAMS_PER_DEVICE) { - log.debug("Creating new context..."); - CudaContext context = createNewStream(deviceId); - - getDeviceBuffers(context, deviceId); - - if (contextsForDevices.get(deviceId).size() == 0) { - // if we have no contexts created - it's just awesome time to attach cuBLAS handle here - log.debug("Creating new cuBLAS handle for device [{}]...", deviceId); - - //cudaStream_t cublasStream = createNewStream(deviceId).getOldStream(); - - cublasHandle_t handle = createNewCublasHandle(context.getOldStream()); - context.setHandle(handle); - //context.setCublasStream(cublasStream); - - cublasPool.put(deviceId, handle); - - log.debug("Creating new cuSolver handle for device [{}]...", deviceId); - - cudaStream_t solverStream = createNewStream(deviceId).getOldStream(); - - cusolverDnHandle_t solverhandle = createNewSolverHandle(solverStream); - context.setSolverHandle(solverhandle); - context.setSolverStream(solverStream); - - solverPool.put(deviceId, solverhandle); - - } else { - // just pick handle out there - log.debug("Reusing blas here..."); - cublasHandle_t handle = cublasPool.get(deviceId); - context.setHandle(handle); - - log.debug("Reusing solver here..."); - cusolverDnHandle_t solverHandle = solverPool.get(deviceId); - context.setSolverHandle(solverHandle); - - // TODO: actually we don't need this anymore - // cudaStream_t cublasStream = new cudaStream_t(); - // JCublas2.cublasGetStream(handle, cublasStream); - // context.setCublasStream(cublasStream); - } - - // we need this sync to finish memset - context.syncOldStream(); - - contextsPool.put(threadId, context); - contextsForDevices.get(deviceId).put(contextsForDevices.get(deviceId).size(), context); - - return context; - } else { - Integer rand = RandomUtils.nextInt(0, MAX_STREAMS_PER_DEVICE); - log.debug("Reusing context: " + rand); - - nativeOps.setDevice(deviceId); - - CudaContext context = contextsForDevices.get(deviceId).get(rand); - - contextsPool.put(threadId, context); - return context; - } - - } catch (Exception e) { - throw new RuntimeException(e); - } finally { - lock.release(); - } - } - - return contextsPool.get(threadId); - } - - @Override - public void releaseContext(CudaContext context) { - // no-op - } - - protected CudaContext createNewStream(Integer deviceId) { - log.trace("Creating new stream for thread: [{}], device: [{}]...", Thread.currentThread().getId(), deviceId); - nativeOps.setDevice(deviceId); - - CudaContext context = new CudaContext(); - context.initOldStream(); - - return context; - } - - protected cublasHandle_t createNewCublasHandle() { - cublasContext pointer = new cublasContext(); - int result = cublasCreate_v2(pointer); - if (result != 0) { - throw new IllegalStateException("Can't create new cuBLAS handle! cuBLAS errorCode: [" + result + "]"); - } - - cublasHandle_t handle = new cublasHandle_t(pointer); - - return handle; - } - - - protected cublasHandle_t createNewCublasHandle(cudaStream_t stream) { - return createNewCublasHandle(); - } - - protected cusolverDnHandle_t createNewSolverHandle() { - cusolverDnContext pointer = new cusolverDnContext(); - int result = cusolverDnCreate(pointer); - if (result != 0) { - throw new IllegalStateException("Can't create new cuBLAS handle! cusolverDn errorCode: [" + result - + "] from cusolverDnCreate()"); - } - - cusolverDnHandle_t handle = new cusolverDnHandle_t(pointer); - - return handle; - } - - protected cusolverDnHandle_t createNewSolverHandle(cudaStream_t stream) { - return createNewSolverHandle(); - } - - protected CUcontext createNewContext(Integer deviceId) { - /* - log.debug("Creating new CUcontext..."); - CUdevice device = new CUdevice(); - CUcontext context = new CUcontext(); - - //JCuda.cudaSetDevice(deviceId); - - - int result = cuDeviceGet(device, deviceId); - if (result != CUresult.CUDA_SUCCESS) { - throw new RuntimeException("Failed to setDevice on driver"); - } - - result = cuCtxCreate(context, 0, device); - if (result != CUresult.CUDA_SUCCESS) { - throw new RuntimeException("Failed to create context on driver"); - } - - return context; - */ - return null; - } - - /** - * This methods reset everything in pool, forcing recreation of all streams - * - * PLEASE NOTE: This is debugging-related method, and should NOT be used in real tasks - */ - public synchronized void resetPool(int deviceId) { - /* - for (CUcontext cuContext: cuPool.values()) { - log.debug("Destroying context: " + cuContext); - JCudaDriver.cuCtxDestroy(cuContext); - } - - cuPool.clear(); - contextsForDevices.clear(); - contextsPool.clear(); - cublasPool.clear(); - - solverPool.clear(); - - acquireContextForDevice(deviceId); - */ - } - - public CUcontext getCuContextForDevice(Integer deviceId) { - return cuPool.get(deviceId); - } - - /** - * This method is used to allocate - * @param context - * @param deviceId - */ - protected void getDeviceBuffers(CudaContext context, int deviceId) { - NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); //((CudaExecutioner) Nd4j.getExecutioner()).getNativeOps(); - - // we hardcode sizeOf to sizeOf(double) - int sizeOf = 8; - - val reductionPointer = nativeOps.mallocDevice(16384 * sizeOf, deviceId, 0); - if (reductionPointer == null) - throw new IllegalStateException("Can't allocate [DEVICE] reduction buffer memory!"); - - nativeOps.memsetAsync(reductionPointer, 0, 16384 * sizeOf, 0, context.getOldStream()); - - context.syncOldStream(); - - val allocationPointer = nativeOps.mallocDevice(16384 * sizeOf, deviceId, 0); - if (allocationPointer == null) - throw new IllegalStateException("Can't allocate [DEVICE] allocation buffer memory!"); - - val scalarPointer = nativeOps.mallocHost(sizeOf, 0); - if (scalarPointer == null) - throw new IllegalStateException("Can't allocate [HOST] scalar buffer memory!"); - - context.setBufferScalar(scalarPointer); - context.setBufferAllocation(allocationPointer); - context.setBufferReduction(reductionPointer); - - val specialPointer = nativeOps.mallocDevice(16384 * sizeOf, deviceId, 0); - if (specialPointer == null) - throw new IllegalStateException("Can't allocate [DEVICE] special buffer memory!"); - - nativeOps.memsetAsync(specialPointer, 0, 16384 * sizeOf, 0, context.getOldStream()); - - context.setBufferSpecial(specialPointer); - } - - public ContextPack acquireContextPackForDevice(Integer deviceId) { - return new ContextPack(acquireContextForDevice(deviceId)); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/LimitedContextPool.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/LimitedContextPool.java deleted file mode 100644 index 20f8adc21..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/LimitedContextPool.java +++ /dev/null @@ -1,265 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.allocator.context.impl; - -import lombok.NonNull; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import lombok.var; -import org.apache.commons.lang3.RandomUtils; -import org.nd4j.jita.allocator.context.ContextPack; -import org.nd4j.jita.allocator.garbage.DeallocatableThread; -import org.nd4j.jita.allocator.garbage.GarbageResourceReference; -import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; -import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t; -import org.nd4j.jita.conf.CudaEnvironment; -import org.nd4j.linalg.api.memory.Deallocatable; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.jcublas.context.CudaContext; -import org.nd4j.nativeblas.NativeOps; -import org.nd4j.nativeblas.NativeOpsHolder; - -import java.lang.ref.ReferenceQueue; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.LockSupport; - -/** - * @author raver119@gmail.com - */ -@Slf4j -public class LimitedContextPool extends BasicContextPool { - - // pool of free contexts - protected Map> pool = new HashMap<>(); - - // pool of used pools - protected Map acquired = new ConcurrentHashMap<>(); - //protected AtomicInteger currentPoolSize = new AtomicInteger(0); - protected List devicePoolSizes = new ArrayList<>(); - protected Map> queueMap = new HashMap<>(); - - protected ThreadLocal threadHooks = new ThreadLocal<>(); - - public LimitedContextPool() { - - int perDevicePool = CudaEnvironment.getInstance().getConfiguration().getPoolSize(); - -/* - for (int i = 0; i < 4; i++) { - val queue = new ReferenceQueue(); - val collector = new ResourceGarbageCollectorThread(i, queue); - collector.start(); - - collectors.put(i, collector); - queueMap.put(i, queue); - } -*/ - fillPoolWithResources(perDevicePool, false); - } - - protected void addResourcesToPool(int numResources) { - int device = AtomicAllocator.getInstance().getDeviceId(); - - val handle = createNewCublasHandle(); - for (int cnt = 0; cnt < numResources; cnt++) { - val context = createNewStream(device); - context.initOldStream(); - getDeviceBuffers(context, device); - context.setHandle(handle); - - context.syncOldStream(); - - pool.get(device).add(context); - } - } - - protected synchronized void fillPoolWithResources(int numResources, boolean restoreDevice) { - List devices = CudaEnvironment.getInstance().getConfiguration().getAvailableDevices(); - - int cDevice = 0; - if (restoreDevice) { - cDevice = AtomicAllocator.getInstance().getDeviceId(); - } - - NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - - for (Integer device : devices) { - nativeOps.setDevice(device); - pool.put(device, new LinkedBlockingQueue()); - devicePoolSizes.add(new AtomicInteger(numResources)); - - val handle = createNewCublasHandle(); - val solverHandle = createNewSolverHandle(); - for (int cnt = 0; cnt < numResources; cnt++) { - val context = createNewStream(device); - context.initOldStream(); - getDeviceBuffers(context, device); - context.setHandle(handle); - context.setSolverHandle(solverHandle); - - context.syncOldStream(); - - pool.get(device).add(context); - } - - - } - - if (restoreDevice) { - nativeOps.setDevice(cDevice); - } - } - - public void removeAcquired() { - val threadIdx = Thread.currentThread().getId(); - acquired.remove(threadIdx); - } - - @Override - public CudaContext acquireContextForDevice(Integer deviceId) { - val threadIdx = Thread.currentThread().getId(); - var context = acquired.get(threadIdx); - if (context != null && deviceId == context.getDeviceId()) { - return context; - } - - //log.info("Setting device to {}", deviceId); - nativeOps.setDevice(deviceId); - context = pool.get(deviceId).poll(); - if (context != null) { - //val reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue()); - //context.attachReference(reference); - context.setDeviceId(deviceId); - context.setThreadId(threadIdx); - val hook = new DeallocatableThread(Thread.currentThread(), context); - threadHooks.set(hook); - Nd4j.getDeallocatorService().pickObject(hook); - - - acquired.put(threadIdx, context); - return context; - } else { - - do { - try { - Nd4j.getMemoryManager().invokeGc(); - - context = pool.get(deviceId).poll(1, TimeUnit.SECONDS); - if (context != null) { - //val reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue()); - //context.attachReference(reference); - context.setDeviceId(deviceId); - context.setThreadId(threadIdx); - val hook = new DeallocatableThread(Thread.currentThread(), context); - threadHooks.set(hook); - Nd4j.getDeallocatorService().pickObject(hook); - - acquired.put(threadIdx, context); - } else { - val currentPoolSize = devicePoolSizes.get(deviceId); - synchronized (currentPoolSize) { - if (currentPoolSize.get() < CudaEnvironment.getInstance().getConfiguration().getPoolSize() * 3) { - addResourcesToPool(16); - - // there's possible race condition, but we don't really care - currentPoolSize.addAndGet(16); - log.warn("Initial pool size: {}; Current pool size: {}", CudaEnvironment.getInstance().getConfiguration().getPoolSize(), currentPoolSize.get()); - } else { - log.warn("Can't allocate new context, sleeping..."); - - Nd4j.getMemoryManager().invokeGc(); - try { - Thread.sleep(500); - } catch (Exception e) { - // - } - } - } - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } while (context == null); - - return context; - } - } - - @Override - @Deprecated - public ContextPack acquireContextPackForDevice(Integer deviceId) { - return new ContextPack(acquireContextForDevice(deviceId)); - } - - @Override - public CudaContext getContextForDevice(Integer deviceId) { - return acquireContextForDevice(deviceId); - } - - @Override - public void releaseContext(CudaContext context) { - val threadIdx = context.getThreadId(); - val deviceId = context.getDeviceId(); - - context.setThreadId(-1); - - acquired.remove(threadIdx); - pool.get(deviceId).add(context); - } - - /* - private class ResourceGarbageCollectorThread extends Thread implements Runnable { - private final ReferenceQueue queue; - - public ResourceGarbageCollectorThread(int threadId, @NonNull ReferenceQueue queue) { - this.queue = queue; - this.setDaemon(true); - this.setName("ResourceGC thread " + threadId); - } - - @Override - public void run() { - while (true) { - GarbageResourceReference reference = (GarbageResourceReference) queue.poll(); - if (reference != null) { - CudaContext context = reference.getContext(); - val threadId = reference.getThreadId(); - val deviceId = reference.getDeviceId(); - - // there's a chance context was already released - if (context.getThreadId() != threadId) - continue; - - pool.get(deviceId).add(context); - acquired.remove(threadId); - } else { - LockSupport.parkNanos(500000L); - } - } - } - } - */ -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/PackedContextPool.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/PackedContextPool.java deleted file mode 100644 index 55a72f5c6..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/PackedContextPool.java +++ /dev/null @@ -1,109 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.allocator.context.impl; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.jita.allocator.context.ContextPack; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; -import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; -import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t; -import org.nd4j.jita.conf.CudaEnvironment; -import org.nd4j.linalg.jcublas.context.CudaContext; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -/** - * @author raver119@gmail.com - */ -@Deprecated -@Slf4j -public class PackedContextPool extends BasicContextPool implements ContextPool { - - protected static final int LANES_PER_THREAD = - CudaEnvironment.getInstance().getConfiguration().getCommandLanesNumber(); - - private volatile Map contextsPool = new ConcurrentHashMap<>(); - - @Override - public CudaContext acquireContextForDevice(Integer deviceId) { - return acquireContextPackForDevice(deviceId).getContextForLane(0); - } - - @Override - public ContextPack acquireContextPackForDevice(Integer deviceId) { - Long threadId = Thread.currentThread().getId(); - if (!contextsPool.containsKey(threadId)) { - try { - lock.acquire(); - - ContextPack pack = new ContextPack(LANES_PER_THREAD); - for (int c = 0; c < LANES_PER_THREAD; c++) { - CudaContext context = createNewStream(deviceId); - - getDeviceBuffers(context, deviceId); - - if (cublasPool.get(deviceId) == null) { - // if we have no contexts created - it's just awesome time to attach cuBLAS handle here - log.debug("Creating new cuBLAS handle for device [{}]", deviceId); - - //cudaStream_t cublasStream = createNewStream(deviceId).getOldStream(); - - cublasHandle_t handle = createNewCublasHandle(context.getOldStream()); - context.setHandle(handle); - //context.setCublasStream(cublasStream); - - cublasPool.put(deviceId, handle); - - log.debug("Creating new cuSolver handle for device [{}]...", deviceId); - - cudaStream_t solverStream = createNewStream(deviceId).getOldStream(); - - cusolverDnHandle_t solverhandle = createNewSolverHandle(solverStream); - context.setSolverHandle(solverhandle); - context.setSolverStream(solverStream); - - solverPool.put(deviceId, solverhandle); - - } else { - // just pick handle out there - log.debug("Reusing cuBLAS handle for device [{}]", deviceId); - cublasHandle_t handle = cublasPool.get(deviceId); - context.setHandle(handle); - - log.debug("Reusing solver here..."); - cusolverDnHandle_t solverHandle = solverPool.get(deviceId); - context.setSolverHandle(solverHandle); - } - - pack.addLane(c, context); - } - - contextsPool.put(threadId, pack); - - - } catch (Exception e) { - throw new RuntimeException(e); - } finally { - lock.release(); - } - } - - return contextsPool.get(threadId); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java index bbefcb0fc..ad4cad0b0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java @@ -22,8 +22,6 @@ import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.Allocator; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.context.ExternalContext; import org.nd4j.jita.allocator.enums.Aggressiveness; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.garbage.GarbageBufferReference; @@ -226,7 +224,7 @@ public class AtomicAllocator implements Allocator { * @return */ @Override - public ExternalContext getDeviceContext() { + public CudaContext getDeviceContext() { // FIXME: proper lock avoidance required here return memoryHandler.getDeviceContext(); } @@ -290,7 +288,7 @@ public class AtomicAllocator implements Allocator { } public Pointer getPointer(DataBuffer buffer) { - return memoryHandler.getDevicePointer(buffer, (CudaContext) getDeviceContext().getContext()); + return memoryHandler.getDevicePointer(buffer, getDeviceContext()); } /** @@ -1072,11 +1070,6 @@ public class AtomicAllocator implements Allocator { return memoryHandler.getFlowController(); } - @Override - public ContextPool getContextPool() { - return memoryHandler.getContextPool(); - } - @Override public DataBuffer getConstantBuffer(int[] array) { return Nd4j.getConstantHandler().getConstantBuffer(array, DataType.INT); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java index 0655e7cb1..77c709487 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java @@ -68,101 +68,9 @@ public class CudaAffinityManager extends BasicAffinityManager { */ @Override public Integer getDeviceForCurrentThread() { - return getDeviceForThread(Thread.currentThread().getId()); + return NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice(); } - /** - * This method returns deviceId for given thread. - * - * If no device was assigned to this thread before this call, it'll be assinged here. - * @param thread - * @return - */ - @Override - public Integer getDeviceForThread(Thread thread) { - return getDeviceForThread(thread.getId()); - } - - /** - * This method returns deviceId for given thread, identified by threadId - * - * If no device was assigned to this thread before this call, it'll be assinged here. - * - * @param threadId - * @return - */ - @Override - public Integer getDeviceForThread(long threadId) { - if (getNumberOfDevices() == 1) - return 0; - - Integer aff = affinityMap.get(threadId); - - if (aff == null) { - Integer deviceId = getNextDevice(threadId); - affinityMap.put(threadId, deviceId); - affiliated.set(new AtomicBoolean(false)); - - if (threadId == Thread.currentThread().getId()) { - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId); - //logger.error("setDevice({}) called for thread {}", deviceId, Thread.currentThread().getName()); - affiliated.get().set(true); - } - - return deviceId; - } else { - - if (threadId == Thread.currentThread().getId()) { - if (affiliated.get() == null) - affiliated.set(new AtomicBoolean(false)); - - if (!affiliated.get().get()) { - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(aff); - //logger.error("SCARY setDevice({}) called for thread {}", aff, threadId); - affiliated.get().set(true); - return aff; - } - } - - return aff; - } -/* - - - return affinityMap.get(threadId); -*/ - //return 0; - } - - /** - * This method pairs specified thread & device - * - * @param thread - * @param deviceId - */ - @Override - public void attachThreadToDevice(Thread thread, Integer deviceId) { - attachThreadToDevice(thread.getId(), deviceId); - } - - /** - * This method pairs specified thread & device - * - * @param threadId - * @param deviceId - */ - @Override - public void attachThreadToDevice(long threadId, Integer deviceId) { - val t = Thread.currentThread(); - String name = "N/A"; - if (t.getId() == threadId) - name = t.getName(); - - List devices = new ArrayList<>(CudaEnvironment.getInstance().getConfiguration().getAvailableDevices()); - logger.trace("Manually mapping thread [{} - {}] to device [{}], out of [{}] devices...", threadId, - name, deviceId, devices.size()); - affinityMap.put(threadId, deviceId); - } /** * This method returns device id available. Round-robin balancing used here. @@ -275,14 +183,13 @@ public class CudaAffinityManager extends BasicAffinityManager { val empty = array.isEmpty(); // we use this call to get device memory updated - AtomicAllocator.getInstance().getPointer(array, (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()); + AtomicAllocator.getInstance().getPointer(array, AtomicAllocator.getInstance().getDeviceContext()); int currentDeviceId = getDeviceForCurrentThread(); if (currentDeviceId != deviceId.intValue()) { Nd4j.getMemoryManager().releaseCurrentContext(); - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId); - attachThreadToDevice(Thread.currentThread().getId(), deviceId); + unsafeSetDevice(deviceId); } @@ -292,8 +199,7 @@ public class CudaAffinityManager extends BasicAffinityManager { if (currentDeviceId != deviceId.intValue()) { Nd4j.getMemoryManager().releaseCurrentContext(); - attachThreadToDevice(Thread.currentThread().getId(), currentDeviceId); - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(currentDeviceId); + unsafeSetDevice(currentDeviceId); } @@ -312,11 +218,11 @@ public class CudaAffinityManager extends BasicAffinityManager { if (buffer == null) return null; - int currentDeviceId = AtomicAllocator.getInstance().getDeviceId(); + int currentDeviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + if (currentDeviceId != deviceId) { Nd4j.getMemoryManager().releaseCurrentContext(); - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId); - Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread().getId(), deviceId); + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); } DataBuffer dstBuffer = Nd4j.createBuffer(buffer.dataType(), buffer.length(), false); @@ -324,8 +230,7 @@ public class CudaAffinityManager extends BasicAffinityManager { if (currentDeviceId != deviceId) { Nd4j.getMemoryManager().releaseCurrentContext(); - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(currentDeviceId); - Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread().getId(), currentDeviceId); + Nd4j.getAffinityManager().unsafeSetDevice(currentDeviceId); } return dstBuffer; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java index 54e5df7e6..5548d854a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java @@ -143,7 +143,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, deviceId, requiredMemoryBytes); long currentOffset = constantOffsets.get(deviceId).get(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) { if (point.getAllocationStatus() == AllocationStatus.HOST && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java index 9a8feeb0b..07cad5269 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java @@ -72,7 +72,7 @@ public class SynchronousFlowController implements FlowController { public void synchronizeToHost(AllocationPoint point) { if (!point.isActualOnHostSide()) { - val context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); if (!point.isConstant()) waitTillFinished(point); @@ -102,7 +102,7 @@ public class SynchronousFlowController implements FlowController { if (!point.isActualOnDeviceSide()) { if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - val context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); long perfD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -135,7 +135,7 @@ public class SynchronousFlowController implements FlowController { @Override public CudaContext prepareActionAllWrite(INDArray... operands) { - val context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); val cId = allocator.getDeviceId(); for (INDArray operand : operands) { @@ -168,7 +168,7 @@ public class SynchronousFlowController implements FlowController { @Override public CudaContext prepareAction(INDArray result, INDArray... operands) { - val context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); val cId = allocator.getDeviceId(); @@ -290,7 +290,7 @@ public class SynchronousFlowController implements FlowController { @Override public CudaContext prepareAction(AllocationPoint result, AllocationPoint... operands) { - val context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); if (result != null) { result.acquireLock(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java index 4d11a6564..36d8e05fb 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java @@ -19,8 +19,6 @@ package org.nd4j.jita.handler; import com.google.common.collect.Table; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.Allocator; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.context.ExternalContext; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AllocationShape; @@ -295,7 +293,7 @@ public interface MemoryHandler { * This method returns ExternalContext wrapper (if applicable) * @return */ - ExternalContext getDeviceContext(); + CudaContext getDeviceContext(); void registerAction(CudaContext context, INDArray result, INDArray... operands); @@ -306,8 +304,4 @@ public interface MemoryHandler { boolean promoteObject(DataBuffer buffer); void relocateObject(DataBuffer buffer); - - ContextPool getContextPool(); - - } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index 02374315b..e35628728 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -25,10 +25,6 @@ import org.apache.commons.lang3.RandomUtils; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.concurrency.DeviceAllocationsTracker; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.context.ExternalContext; -import org.nd4j.jita.allocator.context.impl.LimitedContextPool; -import org.nd4j.jita.allocator.context.impl.PackedContextPool; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.enums.CudaConstants; import org.nd4j.jita.allocator.impl.AllocationPoint; @@ -37,6 +33,9 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.impl.MemoryTracker; import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.pointers.PointersPair; +import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; +import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; +import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t; import org.nd4j.jita.allocator.utils.AllocationUtils; import org.nd4j.jita.conf.Configuration; import org.nd4j.jita.conf.CudaEnvironment; @@ -99,8 +98,6 @@ public class CudaZeroHandler implements MemoryHandler { private final AtomicBoolean wasInitialised = new AtomicBoolean(false); - private final ContextPool contextPool; - @Getter private final MemoryProvider memoryProvider; @@ -142,7 +139,6 @@ public class CudaZeroHandler implements MemoryHandler { switch (configuration.getExecutionModel()) { case SEQUENTIAL: { this.flowController = new GridFlowController(); - this.contextPool = new LimitedContextPool(); } break; default: @@ -222,7 +218,7 @@ public class CudaZeroHandler implements MemoryHandler { boolean initialize) { long reqMemory = AllocationUtils.getRequiredMemory(shape); - CudaContext context = getCudaContext(); + val context = getCudaContext(); switch (targetMode) { case HOST: { if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) { @@ -1158,8 +1154,8 @@ public class CudaZeroHandler implements MemoryHandler { * @return */ @Override - public ExternalContext getDeviceContext() { - return new ExternalContext(getCudaContext()); + public CudaContext getDeviceContext() { + return getCudaContext(); } /** @@ -1167,30 +1163,20 @@ public class CudaZeroHandler implements MemoryHandler { * @return */ public CudaContext getCudaContext() { - // FIXME: remove this before release - Integer deviceId = getDeviceId(); - return contextPool.acquireContextForDevice(deviceId); - } + val lc = nativeOps.defaultLaunchContext(); + // TODO: maybe make ThreadLocal cache for context? - /** - * This method does initialization for thread. - * - * - * @param threadId - */ - protected void initCudaContextForThread(Long threadId) { - - // we set device to be used prior to stream creation - - nativeOps.setDevice(getDeviceId()); - - CudaContext context = new CudaContext(); - context.initHandle(); - context.initOldStream(); - context.initStream(); - context.associateHandle(); - //contextPool.put(threadId, context); + return CudaContext.builder() + .bufferScalar(nativeOps.lcScalarPointer(lc)) + .bufferReduction(nativeOps.lcReductionPointer(lc)) + .bufferAllocation(nativeOps.lcAllocationPointer(lc)) + .bufferSpecial(nativeOps.lcScalarPointer(lc)) + .oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc))) + .specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc))) + .cublasHandle(new cublasHandle_t(nativeOps.lcBlasHandle(lc))) + .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) + .build(); } /** @@ -1227,11 +1213,4 @@ public class CudaZeroHandler implements MemoryHandler { public FlowController getFlowController() { return flowController; } - - @Override - public ContextPool getContextPool() { - return contextPool; - } - - } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java index 3263bf291..da36da6db 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java @@ -19,7 +19,6 @@ package org.nd4j.jita.memory; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.context.impl.LimitedContextPool; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; @@ -79,7 +78,7 @@ public class CudaMemoryManager extends BasicMemoryManager { throw new RuntimeException("Failed to allocate " + bytes + " bytes from DEVICE [" + Nd4j.getAffinityManager().getDeviceForCurrentThread() + "] memory"); if (initialize) { - val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); int i = NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(ptr, 0, bytes, 0, context.getSpecialStream()); if (i == 0) @@ -168,7 +167,7 @@ public class CudaMemoryManager extends BasicMemoryManager { */ @Override public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) { - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (dstBuffer instanceof CompressedDataBuffer && !(srcBuffer instanceof CompressedDataBuffer)) { @@ -258,7 +257,7 @@ public class CudaMemoryManager extends BasicMemoryManager { AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(array); if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + CudaContext context = AtomicAllocator.getInstance().getDeviceContext(); NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(AtomicAllocator.getInstance().getPointer(array, context),0, array.data().length() * Nd4j.sizeOfDataType(array.data().dataType()),0, context.getOldStream()); // we also memset host pointer @@ -289,20 +288,6 @@ public class CudaMemoryManager extends BasicMemoryManager { @Override public void releaseCurrentContext() { - // gettting context for this thread - val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); - - if (context == null) - return; - - // we dont want any remnaints below this line - context.syncOldStream(); - context.syncSpecialStream(); - - val pool = AtomicAllocator.getInstance().getContextPool(); - - // push it back to pool - pool.releaseContext(context); - ((LimitedContextPool) pool).removeAcquired(); + throw new UnsupportedOperationException("Not implemented yet"); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java index 5e1d2eeaf..c901cdd67 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java @@ -177,7 +177,7 @@ public class CudaWorkspace extends Nd4jWorkspace { log.info("Workspace [{}] device_{}: alloc array of {} bytes, capacity of {} elements; prevOffset: {}; newOffset: {}; size: {}; address: {}", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory, numElements, prevOffset, deviceOffset.get(), currentSize.get(), ptr.address()); if (initialize) { - val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); int ret = NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(ptr, 0, requiredMemory, 0, context.getSpecialStream()); if (ret == 0) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java index 8b37e8ead..cd0356d18 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java @@ -570,7 +570,7 @@ public class JCublasNDArray extends BaseNDArray { //Nd4j.getExecutioner().commit(); AtomicAllocator allocator = AtomicAllocator.getInstance(); - CudaContext context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = (CudaContext) allocator.getDeviceContext(); AllocationPoint srcPoint = allocator.getAllocationPoint(this); AllocationPoint dstPoint = allocator.getAllocationPoint(ret); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java index b0db6ec50..44c361d87 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataTypeEx; import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.memory.enums.MemoryKind; +import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayType; @@ -410,6 +411,10 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + return Nd4j.exec(new Concat(dimension, toConcat))[0]; + + // legacy implementation +/* boolean allScalars = true; var outputShape = ArrayUtil.copy(toConcat[0].shape()); @@ -531,6 +536,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { return ret; //return super.concat(dimension, toConcat); + */ } @@ -546,7 +552,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { PointerPointer dataPointers = new PointerPointer(toConcat.length); AtomicAllocator allocator = AtomicAllocator.getInstance(); - CudaContext context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); int sumAlongDim = 0; @@ -783,10 +789,10 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { Nd4j.getExecutioner().commit(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); - PointerPointer dataPointers = new PointerPointer(arrays.length); - PointerPointer extras = new PointerPointer(null, // not used + val dataPointers = new PointerPointer(arrays.length); + val extras = new PointerPointer(null, // not used context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), new CudaPointer(1) ); for (int i = 0; i < arrays.length; i++) { @@ -899,10 +905,10 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { */ long len = target == null ? arrays[0].lengthLong() : target.lengthLong(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); - PointerPointer dataPointers = new PointerPointer(arrays.length); - PointerPointer extras = new PointerPointer(null, // not used + val dataPointers = new PointerPointer(arrays.length); + val extras = new PointerPointer(null, // not used context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), new CudaPointer(1) ); for (int i = 0; i < arrays.length; i++) { @@ -1249,7 +1255,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { @Override public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, Pointer target, long length) { - val stream = ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).getOldStream(); + val stream = AtomicAllocator.getInstance().getDeviceContext().getOldStream(); val p = new PointerPointer<>(new Pointer[]{null, stream}); @@ -1262,7 +1268,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { Pointer dstPtr = null; long size = 0; long ssize = 0; - val stream = ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).getOldStream(); + val stream = AtomicAllocator.getInstance().getDeviceContext().getOldStream(); if (buffer instanceof CompressedDataBuffer) { // compressing size = ((CompressedDataBuffer) buffer).getCompressionDescriptor().getCompressedLength(); @@ -1291,7 +1297,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { @Override public void convertDataEx(DataTypeEx typeSrc, DataBuffer source, DataTypeEx typeDst, DataBuffer target) { - val stream = ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).getOldStream(); + val stream = AtomicAllocator.getInstance().getDeviceContext().getOldStream(); Pointer srcPtr = null; Pointer dstPtr = null; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java index 592e9f3f6..74a8fc99c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java @@ -75,7 +75,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -142,7 +142,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -214,7 +214,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -330,7 +330,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = (CudaContext) allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -439,7 +439,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = (CudaContext) allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -523,7 +523,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -656,7 +656,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = (CudaContext) allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -766,7 +766,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -853,7 +853,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = (CudaContext) allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -928,7 +928,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java index 7009bfbaa..d4efb87b4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java @@ -100,7 +100,7 @@ public class JcublasLevel1 extends BaseLevel1 { val xCPointer = new CublasPointer(X, ctx); val yCPointer = new CublasPointer(Y, ctx); - val handle = ctx.getHandle(); + val handle = ctx.getCublasHandle(); val cctx = new cublasContext(handle); synchronized (handle) { @@ -144,7 +144,7 @@ public class JcublasLevel1 extends BaseLevel1 { val xCPointer = new CublasPointer(X, ctx); val yCPointer = new CublasPointer(Y, ctx); - val handle = ctx.getHandle(); + val handle = ctx.getCublasHandle(); synchronized (handle) { val cctx = new cublasContext(handle); cublasSetStream_v2(cctx, new CUstream_st(ctx.getCublasStream())); @@ -177,7 +177,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer cAPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -235,7 +235,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer cAPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -276,7 +276,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -306,7 +306,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -337,7 +337,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); CublasPointer yCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -361,7 +361,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); CublasPointer yCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -395,7 +395,7 @@ public class JcublasLevel1 extends BaseLevel1 { // CublasPointer xAPointer = new CublasPointer(X, ctx); // CublasPointer xBPointer = new CublasPointer(Y, ctx); - // cublasHandle_t handle = ctx.getHandle(); + // cublasHandle_t handle = ctx.getCublasHandle(); ((CudaExecutioner) Nd4j.getExecutioner()).exec(new Axpy(X, Y, Y, alpha)); @@ -424,7 +424,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); CublasPointer yCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -446,7 +446,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); CublasPointer yCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -540,7 +540,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -563,7 +563,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java index 652d7b328..05f33ac3e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java @@ -62,7 +62,7 @@ public class JcublasLevel2 extends BaseLevel2 { CublasPointer cBPointer = new CublasPointer(X, ctx); CublasPointer cCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -134,7 +134,7 @@ public class JcublasLevel2 extends BaseLevel2 { CublasPointer cBPointer = new CublasPointer(X, ctx); CublasPointer cCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java index 99a9718c8..7f8f9bb51 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java @@ -72,7 +72,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer cBPointer = new CublasPointer(B, ctx); CublasPointer cCPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -118,7 +118,7 @@ public class JcublasLevel3 extends BaseLevel3 { val cBPointer = new CublasPointer(B, ctx); val cCPointer = new CublasPointer(C, ctx); - val handle = ctx.getHandle(); + val handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -144,7 +144,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer bPointer = new CublasPointer(B, ctx); CublasPointer cPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -169,7 +169,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer aPointer = new CublasPointer(A, ctx); CublasPointer cPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -206,7 +206,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer aPointer = new CublasPointer(A, ctx); CublasPointer bPointer = new CublasPointer(B, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -236,7 +236,7 @@ public class JcublasLevel3 extends BaseLevel3 { val cBPointer = new CublasPointer(B, ctx); val cCPointer = new CublasPointer(C, ctx); - val handle = ctx.getHandle(); + val handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -261,7 +261,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer bPointer = new CublasPointer(B, ctx); CublasPointer cPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -286,7 +286,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer aPointer = new CublasPointer(A, ctx); CublasPointer cPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -311,7 +311,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer bPointer = new CublasPointer(B, ctx); CublasPointer cPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -336,7 +336,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer aPointer = new CublasPointer(A, ctx); CublasPointer bPointer = new CublasPointer(B, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -362,7 +362,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer aPointer = new CublasPointer(A, ctx); CublasPointer bPointer = new CublasPointer(B, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 8d29e2b7b..1bec19dd0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -121,7 +121,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda Nd4j.getDeallocatorService().pickObject(this); // now we're - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); val perfD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -1522,7 +1522,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda lazyAllocateHostPointer(); } - val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream()); MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java index bee68b3bf..19e8f8df6 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.jcublas.compression; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.apache.commons.math3.util.FastMath; import org.bytedeco.javacpp.*; import org.nd4j.compression.impl.AbstractCompressor; @@ -118,7 +119,7 @@ public class CudaThreshold extends AbstractCompressor { DataBuffer result = Nd4j.createBuffer(type, originalLength, false); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); PointerPointer extras = new PointerPointer(32).put(1, context.getOldStream()); @@ -139,7 +140,7 @@ public class CudaThreshold extends AbstractCompressor { int numThreads = 1024; int numBlocks = (int) (buffer.length() / numThreads + (buffer.length() % numThreads == 0 ? 0 : 1)); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); DataBuffer blocksBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(numBlocks+1, true) : Nd4j.getDataBufferFactory().createInt(numBlocks+1, true, Nd4j.getMemoryManager().getCurrentWorkspace()); PointerPointer extras = new PointerPointer(32).put(1, context.getOldStream()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java index a1f1b39be..826bb0797 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java @@ -16,8 +16,7 @@ package org.nd4j.linalg.jcublas.context; -import lombok.Data; -import lombok.val; +import lombok.*; import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -44,49 +43,32 @@ import java.util.concurrent.atomic.AtomicBoolean; * */ @Data +@AllArgsConstructor +@NoArgsConstructor +@Builder public class CudaContext { - //private CUcontext context; - //private CUstream stream; - //private CUevent cUevent; + + // execution stream private cudaStream_t oldStream; - private cudaStream_t solverStream; - + // memcpy stream private cudaStream_t specialStream; - //private cudaEvent_t oldEvent; - private cublasHandle_t handle; + // exactly what it says + private cublasHandle_t cublasHandle; private cusolverDnHandle_t solverHandle; - private CublasPointer resultPointer; - private AtomicBoolean oldStreamReturned = new AtomicBoolean(false); - private AtomicBoolean handleReturned = new AtomicBoolean(false); - private AtomicBoolean streamReturned = new AtomicBoolean(false); - private boolean streamFromPool = true; - private boolean handleFromPool = true; - private boolean oldStreamFromPool = true; - private boolean free = true; - private boolean oldEventDestroyed = true; - private boolean eventDestroyed = true; + // temporary buffers, exactly 1 per thread private Pointer bufferReduction; private Pointer bufferAllocation; private Pointer bufferScalar; + + // legacy. to be removed. private Pointer bufferSpecial; - private GarbageResourceReference reference; private int deviceId = -1; - private long threadId; - - private int laneId = 0; - - private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - - - public CudaContext(boolean free) { - this(); - this.free = free; - } + private transient final static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); @Override public String toString() { @@ -94,34 +76,16 @@ public class CudaContext { "bufferReduction=" + bufferReduction + ", bufferScalar=" + bufferScalar + ", deviceId=" + deviceId + - ", threadId=" + threadId + - ", laneId=" + laneId + '}'; } - public void attachReference(GarbageResourceReference ref) { - reference = ref; - } - - - public CudaContext() { - // ContextHolder.getInstance().setContext(); - } - - /** - * Synchronizes on the new - * stream - */ - public void syncStream() { - //JCudaDriver.cuStreamSynchronize(stream); - } - /** * Synchronizes * on the old stream */ public void syncOldStream() { - syncOldStream(false); + if (nativeOps.streamSynchronize(oldStream) == 0) + throw new ND4JIllegalStateException("CUDA stream synchronization failed"); } public void syncSpecialStream() { @@ -129,125 +93,21 @@ public class CudaContext { throw new ND4JIllegalStateException("CUDA special stream synchronization failed"); } - public void syncOldStream(boolean syncCuBlas) { - // ContextHolder.getInstance().setContext(); - if (nativeOps.streamSynchronize(oldStream) == 0) - throw new ND4JIllegalStateException("CUDA stream synchronization failed"); - } - public Pointer getCublasStream() { + // FIXME: can we cache this please val lptr = new PointerPointer(this.getOldStream()); return lptr.get(0); } - - public void syncSolverStream() { - if (solverStream != null) { - if (nativeOps.streamSynchronize(solverStream) == 0) - throw new ND4JIllegalStateException("CUDA stream synchronization failed"); - } else - throw new IllegalStateException("cuBLAS stream isnt set"); + public cublasHandle_t getCublasHandle() { + // FIXME: can we cache this please + val lptr = new PointerPointer(cublasHandle); + return new cublasHandle_t(lptr.get(0)); } - /** - * Associates - * the handle on this context - * to the given stream - */ - public synchronized void associateHandle() { - //JCublas2.cublasSetStream(handle,oldStream); + public cusolverDnHandle_t getSolverHandle() { + // FIXME: can we cache this please + val lptr = new PointerPointer(solverHandle); + return new cusolverDnHandle_t(lptr.get(0)); } - - - - /** - * Initializes the stream - */ - public void initStream() { - // ContextHolder.getInstance().setContext(); - /* - if(stream == null) { - stream = new CUstream(); - JCudaDriver.cuStreamCreate(stream, CUstream_flags.CU_STREAM_DEFAULT); - streamFromPool = false; - eventDestroyed = false; - } - */ - } - - /** - * Initializes the old stream - */ - public void initOldStream() { - // ContextHolder.getInstance().setContext(); - if (oldStream == null) { - oldStreamFromPool = false; - oldStream = new cudaStream_t(nativeOps.createStream()); - //JCuda.cudaStreamCreate(oldStream); - - specialStream = new cudaStream_t(nativeOps.createStream()); - //JCuda.cudaStreamCreate(specialStream); - } - - } - - - - /** - * Initializes a handle and - * associates with the given stream. - * initOldStream() should be called first - * - */ - public void initHandle() { - /* - - We don't create handles here anymore - - if(handle == null) { - handle = new cublasHandle(); - JCublas2.cublasCreate(handle); - handleFromPool = false; - } - */ - } - - /** - * Destroys the context - * and associated resources - */ - @Deprecated - public void destroy(CublasPointer resultPointer, boolean freeIfNotEqual) {} - - - /** - * Destroys the context - * and associated resources - */ - @Deprecated - public void destroy() { - - } - - - /** - * Finishes a blas operation - * and destroys this context - */ - public void finishBlasOperation() { - //destroy(); - } - - /** - * Sets up a context with an old stream - * and a blas handle - * @return the cuda context - * as setup for cublas usage - */ - public static CudaContext getBlasContext() { - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); - //context.syncOldStream(false); - return context; - } - } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 3d1f66ef8..26f228430 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -60,6 +60,7 @@ import org.nd4j.linalg.cache.TADManager; import org.nd4j.linalg.compression.ThresholdCompression; import org.nd4j.linalg.exception.ND4JIllegalArgumentException; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.buffer.AddressRetriever; import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; @@ -1513,11 +1514,11 @@ public class CudaExecutioner extends DefaultOpExecutioner { val surfaceBuffer = (BaseCudaDataBuffer) getBuffer(batch); surfaceBuffer.lazyAllocateHostPointer(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); - IntPointer pointer = (IntPointer) new CudaPointer(AtomicAllocator.getInstance().getHostPointer(surfaceBuffer)) + val pointer = (IntPointer) new CudaPointer(AtomicAllocator.getInstance().getHostPointer(surfaceBuffer)) .asIntPointer(); - AllocationPoint surfacePoint = AtomicAllocator.getInstance().getAllocationPoint(surfaceBuffer); + val surfacePoint = AtomicAllocator.getInstance().getAllocationPoint(surfaceBuffer); int maxTypes = 5; @@ -1659,7 +1660,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { this.exec(single); } - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); context.syncOldStream(); } @@ -1671,9 +1672,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { int numIntArrays = op.getIntArrayArguments().size(); int numRealArguments = op.getRealArguments().size(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); - PointerPointer extraArgs = new PointerPointer(32); + val extraArgs = new PointerPointer(32); extraArgs.put(0, null); extraArgs.put(1, context.getOldStream()); extraArgs.put(2, new CudaPointer(1)); @@ -1890,8 +1891,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void commit() { - ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).syncOldStream(); - ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).syncSpecialStream(); + AtomicAllocator.getInstance().getDeviceContext().syncOldStream(); + AtomicAllocator.getInstance().getDeviceContext().syncSpecialStream(); } @Override @@ -1901,14 +1902,14 @@ public class CudaExecutioner extends DefaultOpExecutioner { int numThreads = 1024; int numBlocks = (int) (buffer.length() / numThreads + (buffer.length() % numThreads == 0 ? 0 : 1)); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); DataBuffer blocksBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(numBlocks+1, true) : Nd4j.getDataBufferFactory().createInt(numBlocks+1, true, Nd4j.getMemoryManager().getCurrentWorkspace()); if (extraz.get() == null) extraz.set(new PointerPointer(32)); - PointerPointer extras = extraz.get().put(1, context.getOldStream()); + val extras = extraz.get().put(1, context.getOldStream()); @@ -2024,7 +2025,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { DataBuffer result = target.data(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (extraz.get() == null) extraz.set(new PointerPointer(32)); @@ -2254,10 +2255,11 @@ public class CudaExecutioner extends DefaultOpExecutioner { } } - val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val ctx = AtomicAllocator.getInstance().getDeviceContext(); val name = op.opName(); try (val context = (CudaOpContext) buildContext()) { + context.markInplace(op.isInplaceCall()); // transferring rng state @@ -2279,6 +2281,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { Nd4j.getRandom().setStates(states.getFirst(), states.getSecond()); return result; + } catch (ND4JOpProfilerException e) { + throw e; } catch (Exception e) { throw new RuntimeException("Op [" + name + "] execution failed", e); } @@ -2545,11 +2549,15 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray[] exec(CustomOp op, OpContext context) { - val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + long st = profilingConfigurableHookIn(op); + + val ctx = AtomicAllocator.getInstance().getDeviceContext(); ((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation()); nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer()); + profilingConfigurableHookOut(op, st); + if (context.getOutputArrays().isEmpty()) return new INDArray[0]; else @@ -2559,7 +2567,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArrayStatistics inspectArray(@NonNull INDArray array) { val debugInfo = new Nd4jCuda.DebugInfo(); - val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val ctx = AtomicAllocator.getInstance().getDeviceContext(); AtomicAllocator.getInstance().synchronizeHostData(array); if (extraz.get() == null) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java index b2c86bf3a..ca8e4eb07 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java @@ -164,9 +164,9 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio } protected boolean compareDevicePointers(INDArray array, Op op) { - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); - Pointer pointer = AtomicAllocator.getInstance().getPointer(array, context); + val pointer = AtomicAllocator.getInstance().getPointer(array, context); long opZ = AtomicAllocator.getInstance().getPointer(op.z(), context).address(); long opX = AtomicAllocator.getInstance().getPointer(op.x(), context).address(); @@ -193,7 +193,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio protected boolean compareHostPointers(INDArray array, Op op) { - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); Pointer pointer = AtomicAllocator.getInstance().getPointer(array, context); @@ -506,9 +506,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio AtomicAllocator allocator = AtomicAllocator.getInstance(); - // CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); - // FIXME: do not leave it as is - CudaContext context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); pointers.setX(allocator.getPointer(op.x(), context)); pointers.setXShapeInfo(allocator.getPointer(op.x().shapeInfoDataBuffer(), context)); @@ -930,7 +928,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio public void flushQueueBlocking() { flushQueue(); - val context =((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()); + val context = AtomicAllocator.getInstance().getDeviceContext(); context.syncSpecialStream(); context.syncOldStream(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index 8db04257b..749f5cc96 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -84,7 +84,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext { // FIXME: remove Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE); - val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val ctx = AtomicAllocator.getInstance().getDeviceContext(); nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); super.setInputArray(index, array); @@ -94,7 +94,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext { public void setOutputArray(int index, @NonNull INDArray array) { Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE); - val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val ctx = AtomicAllocator.getInstance().getDeviceContext(); nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); super.setOutputArray(index, array); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index bc33b9e55..b15e4455e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3104,6 +3104,15 @@ public native void deleteRandomGenerator(OpaqueRandomGenerator ptr); public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut); +public native OpaqueLaunchContext defaultLaunchContext(); +public native @Cast("Nd4jPointer") Pointer lcScalarPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcReductionPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcAllocationPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcExecutionStream(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcCopyStream(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); + // #endif //NATIVEOPERATIONS_NATIVEOPS_H @@ -4036,6 +4045,11 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p public native void printBuffer(); public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/); + /** + * print element by element consequently in a way they (elements) are stored in physical memory + */ + public native void printLinearBuffer(); + /** * prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status */ @@ -7047,9 +7061,9 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder); @@ -7894,10 +7908,6 @@ public static final int PREALLOC_SIZE = 33554432; * Returns the prod of the data * up to the given length */ - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongPointer data, int length); - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongBuffer data, int length); - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") long[] data, int length); - @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length); @@ -8745,10 +8755,6 @@ public static final int PREALLOC_SIZE = 33554432; * @param originalTadNum the tad number for the reduced version of the problem */ -/** - * Returns the prod of the data - * up to the given length - */ /** * Returns the prod of the data @@ -9931,6 +9937,8 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include +// #include @Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer { static { Loader.load(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java index 6b9979ec2..51b9ce7e4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java @@ -121,6 +121,7 @@ public class Nd4jCudaPresets implements InfoMapper { .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) + .put(new Info("OpaqueLaunchContext").pointerTypes("OpaqueLaunchContext")) .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", "@Cast(\"char*\") BytePointer")) .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java index 7c21fc86f..c19adf4ad 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java @@ -22,7 +22,6 @@ import org.apache.commons.lang3.RandomUtils; import org.bytedeco.javacpp.Pointer; import org.junit.Ignore; import org.junit.Test; -import org.nd4j.jita.allocator.context.impl.LimitedContextPool; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.impl.MemoryTracker; @@ -539,15 +538,6 @@ public class AllocatorTest { assertEquals(currEventsNumber+5, controller.getEventsProvider().getEventsNumber()); } - @Test - public void testReleaseContext() { - LimitedContextPool pool = (LimitedContextPool) AtomicAllocator.getInstance().getContextPool(); - System.out.println(pool.acquireContextForDevice(0)); - INDArray x = Nd4j.rand(1,10); - pool.releaseContext(pool.getContextForDevice(0)); - System.out.println(pool.getContextForDevice(0)); - } - @Test public void testDataBuffers() { INDArray x = Nd4j.create(DataType.FLOAT, 10, 5); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java index cecfe07d0..9584d5692 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java @@ -38,14 +38,16 @@ public class DeviceLocalNDArrayTests { val dl = new DeviceLocalNDArray(arr); for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { + val f = e; val t = new Thread(new Runnable() { @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(f); dl.get().addi(1.0); Nd4j.getExecutioner().commit(); } }); - Nd4j.getAffinityManager().attachThreadToDevice(t, e); + t.start(); t.join(); } @@ -60,9 +62,11 @@ public class DeviceLocalNDArrayTests { val dl = new DeviceLocalNDArray(arr); for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { + val f = e; val t = new Thread(new Runnable() { @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(f); for (int i = 0; i < 10; i++) { val tmp = Nd4j.create(DataType.DOUBLE, shape); tmp.addi(1.0); @@ -70,7 +74,7 @@ public class DeviceLocalNDArrayTests { } } }); - Nd4j.getAffinityManager().attachThreadToDevice(t, e); + t.start(); t.join(); @@ -79,14 +83,16 @@ public class DeviceLocalNDArrayTests { System.gc(); for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { + val f = e; val t = new Thread(new Runnable() { @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(f); dl.get().addi(1.0); Nd4j.getExecutioner().commit(); } }); - Nd4j.getAffinityManager().attachThreadToDevice(t, e); + t.start(); t.join(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index 28c0b12b3..2b47103c3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.*; import org.nd4j.linalg.api.ops.custom.Flatten; +import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayType; @@ -572,6 +573,10 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { if (toConcat.length == 1) return toConcat[0]; + return Nd4j.exec(new Concat(dimension, toConcat))[0]; + + // legacy implementation +/* // if reusable var wasn't created for this thread, or is smaller then needed - set it to new value if (extrazA.get() == null || extrazB.get() == null || extrazSize.get() == null || extrazSize.get() < toConcat.length) { extrazA.set(new PointerPointer(toConcat.length)); @@ -627,6 +632,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { null, null); return ret; + */ } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 5e2a4e296..52fe5c652 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3104,6 +3104,15 @@ public native void deleteRandomGenerator(OpaqueRandomGenerator ptr); public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut); +public native OpaqueLaunchContext defaultLaunchContext(); +public native @Cast("Nd4jPointer") Pointer lcScalarPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcReductionPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcAllocationPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcExecutionStream(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcCopyStream(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); + // #endif //NATIVEOPERATIONS_NATIVEOPS_H @@ -4036,6 +4045,11 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p public native void printBuffer(); public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/); + /** + * print element by element consequently in a way they (elements) are stored in physical memory + */ + public native void printLinearBuffer(); + /** * prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status */ @@ -7047,9 +7061,9 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); + @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder); @@ -7894,10 +7908,6 @@ public static final int PREALLOC_SIZE = 33554432; * Returns the prod of the data * up to the given length */ - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongPointer data, int length); - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongBuffer data, int length); - @Namespace("shape") public static native int prod(@Cast("Nd4jLong*") long[] data, int length); - @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length); @@ -8745,10 +8755,6 @@ public static final int PREALLOC_SIZE = 33554432; * @param originalTadNum the tad number for the reduced version of the problem */ -/** - * Returns the prod of the data - * up to the given length - */ /** * Returns the prod of the data @@ -9726,10 +9732,13 @@ public static final int PREALLOC_SIZE = 33554432; // #ifndef __CLION_IDE__ // #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) // #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) +// #define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) // #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) // #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} +// #define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} // #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} + // #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} // #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} // #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, (TYPES_Z), (TYPES_Y), TYPES_X)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); } } @@ -9739,8 +9748,10 @@ public static final int PREALLOC_SIZE = 33554432; // #else // #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) // #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) +// #define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) // #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) // #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) +// #define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) // #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) // #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) // #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) @@ -9776,6 +9787,12 @@ public static final int PREALLOC_SIZE = 33554432; // #define _SELECTOR_SINGLE_THRICE(A, B, C, D) case C: {AB; break;}; // #define SELECTOR_SINGLE_THRICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) +// #define _SELECTOR_SINGLE_TWICE(A, B, C, D) case C: {AB; break;}; +// #define SELECTOR_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) + +// #define _TEMPLATE_SINGLE_TWICE(A, B, C, D) AB; +// #define TEMPLATE_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) + // #define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) case C: {A D, UNPAREN2(B); break;}; // #define SELECTOR_PARTIAL_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) @@ -9804,6 +9821,7 @@ public static final int PREALLOC_SIZE = 33554432; // #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) +public static final int ALL_INDICES =INT64; public static final int ALL_INTS =UINT64; public static final int ALL_FLOATS =BFLOAT16; @@ -22697,6 +22715,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include // #include +// #include +// #include @Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer { static { Loader.load(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java index 5ad008055..dd47eb25d 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java @@ -164,6 +164,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled { .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) + .put(new Info("OpaqueLaunchContext").pointerTypes("OpaqueLaunchContext")) .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", "@Cast(\"char*\") BytePointer")) .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index d37ddb889..638cd8ac3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -6521,6 +6521,15 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, array2); } + @Test + public void testRndBloat16() { + INDArray x = Nd4j.rand(DataType.BFLOAT16 , 'c', new long[]{5}); + assertTrue(x.sumNumber().floatValue() > 0); + + x = Nd4j.randn(DataType.BFLOAT16 , 10); + assertTrue(x.sumNumber().floatValue() > 0); + } + @Test public void testLegacyDeserialization_2() throws Exception { val f = new ClassPathResource("legacy/NDArray_longshape_float.bin").getFile(); @@ -7150,7 +7159,11 @@ public class Nd4jTestsC extends BaseNd4jTest { Nd4j.getRandom().setSeed(12345); INDArray a = Nd4j.rand(2,5); INDArray b = Nd4j.rand(5,3); - INDArray exp = a.mmul(b).transpose(); + INDArray exp = a.mmul(b); + Nd4j.getExecutioner().commit(); + + exp = exp.transpose(); + INDArray act = a.mmul(b, MMulTranspose.builder().transposeResult(true).build()); assertEquals(exp, act); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java index b068a1a65..1bf709a5f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java @@ -101,6 +101,22 @@ public class RngTests extends BaseNd4jTest { } + @Test + void testRandomBinomial() { + //silly tests. Just increasing the usage for randomBinomial to stop compiler warnings. + INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3); + assertTrue(x.sum().getDouble(0) > 0.0); //silly test. Just increasing th usage for randomBinomial + + x = Nd4j.randomBinomial(10, 0.5, x); + assertTrue(x.sum().getDouble(0) > 0.0); + + x = Nd4j.randomExponential(0.5, 3,3); + assertTrue(x.sum().getDouble(0) > 0.0); + + x = Nd4j.randomExponential(0.5, x); + assertTrue(x.sum().getDouble(0) > 0.0); + } + @Override public char ordering() { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index dfcf5dc79..f46d5e694 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -1073,11 +1073,13 @@ public class OpExecutionerTestsC extends BaseNd4jTest { List arrays = new ArrayList<>(); val num = 10; for (int i = 0; i < num; i++) { - arrays.add(Nd4j.create(20, 20).assign(i)); + arrays.add(Nd4j.create(5, 20).assign(i)); } INDArray pile = Nd4j.pile(arrays); + log.info("Pile: {}", pile); + INDArray[] tears = Nd4j.tear(pile, 1, 2); for (int i = 0; i < num; i++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java index 8a67bd2c2..d0c61de9b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java @@ -444,6 +444,8 @@ public class OperationProfilerTests extends BaseNd4jTest { Nd4j.exec(op); //Should trigger NaN panic fail(); } catch (Exception e){ + //throw new RuntimeException(e); + log.info("Message: {}", e.getMessage()); assertTrue(e.getMessage(), e.getMessage().contains("NaN")); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java index 0fecaa6fe..596bf16a7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; +import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.INDArrayIndex; @@ -212,7 +213,7 @@ public class ConcatTestsC extends BaseNd4jTest { assertEquals(exp, concat2); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = ND4JIllegalStateException.class) public void testConcatVector() { System.out.println(Nd4j.concat(0, Nd4j.ones(1,1000000), Nd4j.create(1, 1))); } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml index 7d69be7b6..734b1b738 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml @@ -50,6 +50,8 @@ httpmime ${httpmime.version} + + com.mashape.unirest unirest-java diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml index c895a9cc6..1122f90d7 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml @@ -54,11 +54,13 @@ httpmime ${httpmime.version} + com.mashape.unirest unirest-java ${unirest.version} + org.nd4j nd4j-jackson diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java index dd698b96c..f69b9c24b 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java @@ -21,13 +21,15 @@ import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import com.beust.jcommander.Parameters; import com.google.common.primitives.Ints; -import com.mashape.unirest.http.HttpResponse; + +import org.nd4j.shade.jackson.databind.ObjectMapper; import com.mashape.unirest.http.Unirest; import io.aeron.Aeron; import io.aeron.driver.MediaDriver; import io.aeron.driver.ThreadingMode; import lombok.Data; import lombok.NoArgsConstructor; +import lombok.val; import org.agrona.CloseHelper; import org.agrona.concurrent.BusySpinIdleStrategy; import org.json.JSONObject; @@ -49,7 +51,6 @@ import org.nd4j.parameterserver.updater.SoftSyncParameterUpdater; import org.nd4j.parameterserver.updater.SynchronousParameterUpdater; import org.nd4j.parameterserver.updater.storage.InMemoryUpdateStorage; import org.nd4j.parameterserver.util.CheckSocket; -import org.nd4j.shade.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -342,7 +343,7 @@ public class ParameterServerSubscriber implements AutoCloseable { JSONObject jsonObject = new JSONObject(objectMapper.writeValueAsString(subscriberState)); String url = String.format("http://%s:%d/updatestatus/%d", statusServerHost, statusServerPort, streamId); - HttpResponse entity = Unirest.post(url).header("Content-Type", "application/json") + val entity = Unirest.post(url).header("Content-Type", "application/json") .body(jsonObject).asString(); } catch (Exception e) { failCount.incrementAndGet(); diff --git a/nd4j/nd4j-remote/README.md b/nd4j/nd4j-remote/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/nd4j/nd4j-serde/nd4j-grpc/pom.xml b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml similarity index 97% rename from nd4j/nd4j-serde/nd4j-grpc/pom.xml rename to nd4j/nd4j-remote/nd4j-grpc-client/pom.xml index 9105e333b..1f668eb66 100644 --- a/nd4j/nd4j-serde/nd4j-grpc/pom.xml +++ b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml @@ -19,13 +19,13 @@ - nd4j-serde + nd4j-remote org.nd4j 1.0.0-SNAPSHOT 4.0.0 - nd4j-grpc + nd4j-grpc-client nd4j-grpc diff --git a/nd4j/nd4j-serde/nd4j-grpc/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-grpc/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java rename to nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java diff --git a/nd4j/nd4j-serde/nd4j-grpc/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-grpc/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java rename to nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java diff --git a/nd4j/nd4j-serde/nd4j-grpc/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-grpc/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java rename to nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/pom.xml b/nd4j/nd4j-remote/nd4j-json-client/pom.xml new file mode 100644 index 000000000..d1ceeeda9 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/pom.xml @@ -0,0 +1,70 @@ + + + + 4.0.0 + jar + + + org.nd4j + nd4j-remote + 1.0.0-SNAPSHOT + + + nd4j-json-client + + nd4j-json-client + + + UTF-8 + 1.7 + 1.7 + + + + + junit + junit + test + + + + com.mashape.unirest + unirest-java + ${unirest.version} + + + + org.slf4j + slf4j-api + + + + org.nd4j + jackson + ${project.version} + + + + + + testresources + + + + org.nd4j + nd4j-native + ${project.version} + test + + + org.nd4j + nd4j-native + ${project.version} + ${javacpp.platform} + test + + + + + diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/JsonRemoteInference.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/JsonRemoteInference.java new file mode 100644 index 000000000..bbd6ac49e --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/JsonRemoteInference.java @@ -0,0 +1,153 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients; + +import com.mashape.unirest.http.HttpResponse; +import com.mashape.unirest.http.Unirest; +import com.mashape.unirest.http.exceptions.UnirestException; +import lombok.Builder; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.json.JSONObject; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * This class provides remote inference functionality via JSON-powered REST APIs. + * + * Basically we assume that there's remote JSON server available (on bare metal or in k8s/swarm/whatever cluster), and with proper serializers/deserializers provided we can issue REST requests and get back responses. + * So, in this way application logic can be separated from DL logic. + * + * You just need to provide serializer/deserializer and address of the REST server, i.e. "http://model:8080/v1/serving" + * + * @param type of the input class, i.e. String + * @param type of the output class, i.e. Sentiment + * + * @author raver119@gmail.com + */ +@Slf4j +public class JsonRemoteInference { + private String endpointAddress; + private JsonSerializer serializer; + private JsonDeserializer deserializer; + + @Builder + public JsonRemoteInference(@NonNull String endpointAddress, @NonNull JsonSerializer inputSerializer, @NonNull JsonDeserializer outputDeserializer) { + this.endpointAddress = endpointAddress; + this.serializer = inputSerializer; + this.deserializer = outputDeserializer; + } + + private O processResponse(HttpResponse response) throws IOException { + if (response.getStatus() != 200) + throw new IOException("Inference request returned bad error code: " + response.getStatus()); + + O result = deserializer.deserialize(response.getBody()); + if (result == null) { + throw new IOException("Deserialization failed!"); + } + return result; + } + + /** + * This method does remote inference in a blocking way + * + * @param input + * @return + * @throws IOException + */ + public O predict(I input) throws IOException { + try { + val stringResult = Unirest.post(endpointAddress) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .body(new JSONObject(serializer.serialize(input))).asString(); + + return processResponse(stringResult); + } catch (UnirestException e) { + throw new IOException(e); + } + } + + /** + * This method does remote inference in asynchronous way, returning Future instead + * @param input + * @return + */ + public Future predictAsync(I input) { + val stringResult = Unirest.post(endpointAddress) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .body(new JSONObject(serializer.serialize(input))).asStringAsync(); + return new InferenceFuture(stringResult); + } + + /** + * This class holds a Future of the object returned by remote inference server + */ + private class InferenceFuture implements Future { + private Future> unirestFuture; + + private InferenceFuture(@NonNull Future> future) { + this.unirestFuture = future; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return unirestFuture.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return unirestFuture.isCancelled(); + } + + @Override + public boolean isDone() { + return unirestFuture.isDone(); + } + + @Override + public O get() throws InterruptedException, ExecutionException { + val stringResult = unirestFuture.get(); + + try { + return processResponse(stringResult); + } catch (IOException e) { + throw new ExecutionException(e); + } + } + + @Override + public O get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + val stringResult = unirestFuture.get(timeout, unit); + + try { + return processResponse(stringResult); + } catch (IOException e) { + throw new ExecutionException(e); + } + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPool.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonDeserializer.java similarity index 59% rename from nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPool.java rename to nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonDeserializer.java index 5045c8870..77f8939b2 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPool.java +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonDeserializer.java @@ -14,28 +14,20 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.jita.allocator.context; - -import org.nd4j.linalg.jcublas.context.CudaContext; +package org.nd4j.remote.clients.serde; /** - * This interface describes pool of CudaContext objects, used to execute kernels + * This interface describes basic JSON deserializer interface used for JsonRemoteInference + * @param type of the deserializable class + * * @author raver119@gmail.com */ -public interface ContextPool { +public interface JsonDeserializer { + /** - * This method returns CudaContext for given device - * @param deviceId + * This method serializes given object into JSON-string + * @param json string containing JSON representation of the object * @return */ - CudaContext acquireContextForDevice(Integer deviceId); - - @Deprecated - ContextPack acquireContextPackForDevice(Integer deviceId); - - /** - * This method returns CudaContext to the pool for reuse - * @param context - */ - void releaseContext(CudaContext context); + T deserialize(String json); } diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonSerializer.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonSerializer.java new file mode 100644 index 000000000..4258b98cc --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonSerializer.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde; + +/** + * This interface describes basic JSON serializer interface used for JsonRemoteInference + * @param type of the serializable class + * + * @author raver119@gmail.com + */ +public interface JsonSerializer { + + /** + * This method serializes given object into JSON-string + * + * @param o object to be serialized + * @return + */ + String serialize(T o); +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/AbstractSerDe.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/AbstractSerDe.java new file mode 100644 index 000000000..fa02496bf --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/AbstractSerDe.java @@ -0,0 +1,46 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.IOException; + +public abstract class AbstractSerDe implements JsonDeserializer, JsonSerializer { + protected ObjectMapper objectMapper = new ObjectMapper(); + + + protected String serializeClass(@NonNull T obj) { + try { + return objectMapper.writeValueAsString(obj); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + protected T deserializeClass(@NonNull String json, @NonNull Class cls) { + try { + return objectMapper.readValue(json, cls); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/BooleanSerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/BooleanSerde.java new file mode 100644 index 000000000..db92a4346 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/BooleanSerde.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; + +/** + * This class provides JSON ser/de for Java Boolean. Single value only. + */ +public class BooleanSerde extends AbstractSerDe { + @Override + public Boolean deserialize(@NonNull String json) { + return deserializeClass(json, Boolean.class); + } + + @Override + public String serialize(@NonNull Boolean o) { + return serializeClass(o); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java new file mode 100644 index 000000000..fbe44a101 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.*; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.IOException; + +/** + * This class provides JSON ser/de for Java double[] + */ +public class DoubleArraySerde extends AbstractSerDe { + + @Override + public String serialize(@NonNull double[] data) { + return serializeClass(data); + } + + @Override + public double[] deserialize(@NonNull String json) { + return deserializeClass(json, double[].class); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleSerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleSerde.java new file mode 100644 index 000000000..d12a1dc66 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleSerde.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; + +/** + * This class provides JSON ser/de for Java Double. Single value only. + */ +public class DoubleSerde extends AbstractSerDe { + @Override + public Double deserialize(@NonNull String json) { + return deserializeClass(json, Double.class); + } + + @Override + public String serialize(@NonNull Double o) { + return serializeClass(o); + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/DeallocatableThread.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatArraySerde.java similarity index 50% rename from nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/DeallocatableThread.java rename to nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatArraySerde.java index b850f7836..24c78ecfd 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/DeallocatableThread.java +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatArraySerde.java @@ -13,38 +13,30 @@ * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +package org.nd4j.remote.clients.serde.impl; -package org.nd4j.jita.allocator.garbage; -import org.nd4j.linalg.api.memory.Deallocatable; -import org.nd4j.linalg.api.memory.Deallocator; -import org.nd4j.linalg.jcublas.context.CudaContext; +import lombok.*; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.IOException; + /** - * This class enables Thread tracking via DeallocatorService - * @author raver119@gmail.com + * This class provides JSON ser/de for Java float[] */ -public class DeallocatableThread implements Deallocatable { - private long threadId; - private CudaContext context; +public class FloatArraySerde extends AbstractSerDe { - public DeallocatableThread(Thread thread, CudaContext context) { - this.threadId = thread.getId(); - this.context = context; + @Override + public String serialize(@NonNull float[] data) { + return serializeClass(data); } @Override - public String getUniqueId() { - return "thread_" + threadId; - } - - @Override - public Deallocator deallocator() { - return new ContextDeallocator(context); - } - - @Override - public int targetDevice() { - return context.getDeviceId(); + public float[] deserialize(@NonNull String json) { + return deserializeClass(json, float[].class); } } diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatSerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatSerde.java new file mode 100644 index 000000000..14b822ba7 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatSerde.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; + +/** + * This class provides JSON ser/de for Java Float. Single value only. + */ +public class FloatSerde extends AbstractSerDe { + @Override + public Float deserialize(@NonNull String json) { + return deserializeClass(json, Float.class); + } + + @Override + public String serialize(@NonNull Float o) { + return serializeClass(o); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/IntegerSerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/IntegerSerde.java new file mode 100644 index 000000000..0d8b40272 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/IntegerSerde.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; + +/** + * This class provides JSON ser/de for Java Integer. Single value only. + */ +public class IntegerSerde extends AbstractSerDe { + @Override + public Integer deserialize(@NonNull String json) { + return deserializeClass(json, Integer.class); + } + + @Override + public String serialize(@NonNull Integer o) { + return serializeClass(o); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/StringSerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/StringSerde.java new file mode 100644 index 000000000..20e1d8fd4 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/StringSerde.java @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.clients.serde.impl; + +import lombok.NonNull; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.IOException; + +/** + * This class provides fake JSON serializer/deserializer functionality for String. + * It doesn't put any JSON-specific bits into actual string + */ +public class StringSerde extends AbstractSerDe { + + @Override + public String serialize(@NonNull String data) { + return serializeClass(data); + } + + @Override + public String deserialize(@NonNull String json) { + return deserializeClass(json, String.class); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/README.md b/nd4j/nd4j-remote/nd4j-json-server/README.md new file mode 100644 index 000000000..963890a7b --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/README.md @@ -0,0 +1,35 @@ +## SameDiff model serving + +This modules provides JSON-based serving of SameDiff models + +## Example + +First of all we'll create server instance. Most probably you'll do it in application that will be running in container +```java +val server = SameDiffJsonModelServer.builder() + .adapter(new StringToSentimentAdapter()) + .model(mySameDiffModel) + .port(8080) + .serializer(new SentimentSerializer()) + .deserializer(new StringDeserializer()) + .build(); + +server.start(); +server.join(); +``` + +Now, presumably in some other container, we'll set up remote inference client: +```java +val client = JsonRemoteInference.builder() + .endpointAddress("http://youraddress:8080/v1/serving") + .serializer(new StringSerializer()) + .deserializer(new SentimentDeserializer()) + .build(); + +Sentiment result = client.predict(myText); +``` + On top of that, there's async call available, for cases when you need to chain multiple requests to one or multiple remote model servers. + +```java +Future result = client.predictAsync(myText); +``` \ No newline at end of file diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/nd4j/nd4j-remote/nd4j-json-server/pom.xml new file mode 100644 index 000000000..d5e506839 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/pom.xml @@ -0,0 +1,179 @@ + + + + 4.0.0 + jar + + + org.nd4j + nd4j-remote + 1.0.0-SNAPSHOT + + + nd4j-json-server + nd4j-json-server + + + UTF-8 + 1.7 + 1.7 + 2.26 + + + + + junit + junit + test + + + + org.nd4j + nd4j-json-client + ${project.version} + + + + org.slf4j + slf4j-api + + + + org.nd4j + nd4j-api + ${project.version} + + + + org.glassfish.jersey.core + jersey-client + ${jersey.version} + + + + org.glassfish.jersey.core + jersey-server + ${jersey.version} + + + + org.eclipse.jetty + jetty-server + 9.4.19.v20190610 + + + + org.eclipse.jetty + jetty-servlet + 9.4.19.v20190610 + + + + org.glassfish.jersey.inject + jersey-hk2 + ${jersey.version} + + + + org.glassfish.jersey.media + jersey-media-json-processing + ${jersey.version} + + + + org.glassfish.jersey.containers + jersey-container-servlet-core + ${jersey.version} + + + + ch.qos.logback + logback-core + ${logback.version} + test + + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + javax.xml.bind + jaxb-api + 2.3.0 + + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + + + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + + javax.activation + activation + 1.1 + + + + + + + nd4j-tests-cpu + + true + + + + org.nd4j + nd4j-native + ${project.version} + test + + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + + + + + testresources + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + ${maven.compiler.source} + ${maven.compiler.target} + + + + + diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/SameDiffJsonModelServer.java b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/SameDiffJsonModelServer.java new file mode 100644 index 000000000..93f640b75 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/SameDiffJsonModelServer.java @@ -0,0 +1,288 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote; + +import lombok.*; +import lombok.extern.slf4j.Slf4j; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.nd4j.adapters.InputAdapter; +import org.nd4j.adapters.OutputAdapter; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.adapters.InferenceAdapter; +import org.nd4j.remote.serving.ModelServingServlet; +import org.nd4j.remote.serving.SameDiffServlet; + +import java.util.List; + +/** + * This class provides JSON-powered model serving functionality for SameDiff graphs. + * Server url will be http://0.0.0.0:{port}>/v1/serving + * Server only accepts POST requests + * + * @param type of the input class, i.e. String + * @param type of the output class, i.e. Sentiment + * + * @author raver119@gmail.com + */ +@Slf4j +public class SameDiffJsonModelServer { + + + protected SameDiff sdModel; + protected final JsonSerializer serializer; + protected final JsonDeserializer deserializer; + protected final InferenceAdapter inferenceAdapter; + protected final int port; + + // this servlet will be used to serve models + protected ModelServingServlet servingServlet; + + // HTTP server instance + protected Server server; + + // for SameDiff only + protected String[] orderedInputNodes; + protected String[] orderedOutputNodes; + + protected SameDiffJsonModelServer(@NonNull InferenceAdapter inferenceAdapter, @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer, int port) { + Preconditions.checkArgument(port > 0 && port < 65535, "TCP port must be in range of 0..65535"); + + this.inferenceAdapter = inferenceAdapter; + this.serializer = serializer; + this.deserializer = deserializer; + this.port = port; + } + + //@Builder + public SameDiffJsonModelServer(SameDiff sdModel, @NonNull InferenceAdapter inferenceAdapter, @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer, int port, String[] orderedInputNodes, @NonNull String[] orderedOutputNodes) { + this(inferenceAdapter, serializer, deserializer, port); + this.sdModel = sdModel; + this.orderedInputNodes = orderedInputNodes; + this.orderedOutputNodes = orderedOutputNodes; + + // TODO: both lists of nodes should be validated, to make sure nodes specified here exist in actual model + if (orderedInputNodes != null) { + // input nodes list might be null. strange but ok + } + + Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "SameDiff serving requires at least 1 output node"); + } + + protected void start(int port, @NonNull ModelServingServlet servlet) throws Exception { + val context = new ServletContextHandler(ServletContextHandler.SESSIONS); + context.setContextPath("/"); + + server = new Server(port); + server.setHandler(context); + + val jerseyServlet = context.addServlet(org.glassfish.jersey.servlet.ServletContainer.class, "/*"); + jerseyServlet.setInitOrder(0); + jerseyServlet.setServlet(servlet); + + server.start(); + } + + public void start() throws Exception { + Preconditions.checkArgument(sdModel != null, "SameDiff model wasn't defined"); + + servingServlet = SameDiffServlet.builder() + .sdModel(sdModel) + .serializer(serializer) + .deserializer(deserializer) + .inferenceAdapter(inferenceAdapter) + .orderedInputNodes(orderedInputNodes) + .orderedOutputNodes(orderedOutputNodes) + .build(); + + start(port, servingServlet); + } + + public void join() throws InterruptedException { + Preconditions.checkArgument(server != null, "Model server wasn't started yet"); + + server.join(); + } + + public void stop() throws Exception { + //Preconditions.checkArgument(server != null, "Model server wasn't started yet"); + + server.stop(); + } + + + public static class Builder { + private SameDiff sameDiff; + private String[] orderedInputNodes; + private String[] orderedOutputNodes; + private InferenceAdapter inferenceAdapter; + private JsonSerializer serializer; + private JsonDeserializer deserializer; + private int port; + + private InputAdapter inputAdapter; + private OutputAdapter outputAdapter; + + public Builder() {} + + public Builder sdModel(@NonNull SameDiff sameDiff) { + this.sameDiff = sameDiff; + return this; + } + + /** + * This method defines InferenceAdapter implementation, which will be used to convert object of Input type to the set of INDArray(s), and for conversion of resulting INDArray(s) into object of Output type + * @param inferenceAdapter + * @return + */ + public Builder inferenceAdapter(InferenceAdapter inferenceAdapter) { + this.inferenceAdapter = inferenceAdapter; + return this; + } + + /** + * This method allows you to specify InputAdapter to be used for inference + * + * PLEASE NOTE: This method is optional, and will require OutputAdapter defined + * @param inputAdapter + * @return + */ + public Builder inputAdapter(@NonNull InputAdapter inputAdapter) { + this.inputAdapter = inputAdapter; + return this; + } + + /** + * This method allows you to specify OutputAdapter to be used for inference + * + * PLEASE NOTE: This method is optional, and will require InputAdapter defined + * @param outputAdapter + * @return + */ + public Builder outputAdapter(@NonNull OutputAdapter outputAdapter) { + this.outputAdapter = outputAdapter; + return this; + } + + /** + * This method defines JsonSerializer instance to be used to convert object of output type into JSON format, so it could be sent over the wire + * + * @param serializer + * @return + */ + public Builder outputSerializer(@NonNull JsonSerializer serializer) { + this.serializer = serializer; + return this; + } + + /** + * This method defines JsonDeserializer instance to be used to convert JSON passed through HTTP into actual object of input type, that will be fed into SameDiff model + * + * @param deserializer + * @return + */ + public Builder inputDeserializer(@NonNull JsonDeserializer deserializer) { + this.deserializer = deserializer; + return this; + } + + /** + * This method defines the order of placeholders to be filled with INDArrays provided by Deserializer + * + * @param args + * @return + */ + public Builder orderedInputNodes(String... args) { + orderedInputNodes = args; + return this; + } + + /** + * This method defines the order of placeholders to be filled with INDArrays provided by Deserializer + * + * @param args + * @return + */ + public Builder orderedInputNodes(@NonNull List args) { + orderedInputNodes = args.toArray(new String[args.size()]); + return this; + } + + /** + * This method defines list of graph nodes to be extracted after feed-forward pass and used as OutputAdapter input + * @param args + * @return + */ + public Builder orderedOutputNodes(String... args) { + Preconditions.checkArgument(args != null && args.length > 0, "OutputNodes should contain at least 1 element"); + orderedOutputNodes = args; + return this; + } + + /** + * This method defines list of graph nodes to be extracted after feed-forward pass and used as OutputAdapter input + * @param args + * @return + */ + public Builder orderedOutputNodes(@NonNull List args) { + Preconditions.checkArgument(args.size() > 0, "OutputNodes should contain at least 1 element"); + orderedOutputNodes = args.toArray(new String[args.size()]); + return this; + } + + /** + * This method allows to configure HTTP port used for serving + * + * PLEASE NOTE: port must be free and be in range regular TCP/IP ports range + * @param port + * @return + */ + public Builder port(int port) { + this.port = port; + return this; + } + + /** + * This method builds SameDiffJsonModelServer instance + * @return + */ + public SameDiffJsonModelServer build() { + if (inferenceAdapter == null) { + if (inputAdapter != null && outputAdapter != null) { + inferenceAdapter = new InferenceAdapter() { + @Override + public MultiDataSet apply(I input) { + return inputAdapter.apply(input); + } + + @Override + public O apply(INDArray... outputs) { + return outputAdapter.apply(outputs); + } + }; + } else + throw new IllegalArgumentException("Either InferenceAdapter or InputAdapter + OutputAdapter should be configured"); + } + return new SameDiffJsonModelServer(sameDiff, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ExternalContext.java b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ModelServingServlet.java similarity index 69% rename from nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ExternalContext.java rename to nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ModelServingServlet.java index 08ec374b2..dc240b4f6 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ExternalContext.java +++ b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ModelServingServlet.java @@ -14,21 +14,17 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.jita.allocator.context; +package org.nd4j.remote.serving; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; +import javax.servlet.Servlet; /** - * This is simple class-independant storage for device contexts. + * This interface describes Servlet interface extension, suited for ND4J/DL4J model serving + * @param + * @param * - * TODO: Something better then typecast required here * @author raver119@gmail.com */ -@Data -@NoArgsConstructor -@AllArgsConstructor -public class ExternalContext { - private Object context; +public interface ModelServingServlet extends Servlet { + // } diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java new file mode 100644 index 000000000..0be7b8757 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java @@ -0,0 +1,206 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.serving; + +import lombok.*; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.adapters.InferenceAdapter; + +import javax.servlet.*; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.HttpMethod; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.LinkedHashMap; + +import static javax.ws.rs.core.MediaType.APPLICATION_JSON; + +/** + * This servlet provides SameDiff model serving capabilities + * + * @param + * @param + * + * @author raver119@gmail.com + */ +@NoArgsConstructor +@AllArgsConstructor +@Slf4j +@Builder +public class SameDiffServlet implements ModelServingServlet { + + protected SameDiff sdModel; + protected JsonSerializer serializer; + protected JsonDeserializer deserializer; + protected InferenceAdapter inferenceAdapter; + + protected String[] orderedInputNodes; + protected String[] orderedOutputNodes; + + protected final static String SERVING_ENDPOINT = "/v1/serving"; + protected final static String LISTING_ENDPOINT = "/v1"; + protected final static int PAYLOAD_SIZE_LIMIT = 10 * 1024; // TODO: should be customizable + + protected SameDiffServlet(@NonNull InferenceAdapter inferenceAdapter, @NonNull JsonSerializer serializer, @NonNull JsonDeserializer deserializer){ + this.serializer = serializer; + this.deserializer = deserializer; + this.inferenceAdapter = inferenceAdapter; + } + + @Override + public void init(ServletConfig servletConfig) throws ServletException { + // + } + + @Override + public ServletConfig getServletConfig() { + return null; + } + + @Override + public void service(ServletRequest servletRequest, ServletResponse servletResponse) throws ServletException, IOException { + // we'll parse request here, and do model serving + val httpRequest = (HttpServletRequest) servletRequest; + val httpResponse = (HttpServletResponse) servletResponse; + + if (httpRequest.getMethod().equals(HttpMethod.GET)) { + doGet(httpRequest, httpResponse); + } + else if (httpRequest.getMethod().equals(HttpMethod.POST)) { + doPost(httpRequest, httpResponse); + } + + } + + protected void sendError(String uri, HttpServletResponse response) throws IOException { + val msg = "Requested endpoint [" + uri + "] not found"; + response.setStatus(404, msg); + response.sendError(404, msg); + } + + protected void sendBadContentType(String actualContentType, HttpServletResponse response) throws IOException { + val msg = "Content type [" + actualContentType + "] not supported"; + response.setStatus(415, msg); + response.sendError(415, msg); + } + + protected boolean validateRequest(HttpServletRequest request, HttpServletResponse response) + throws IOException{ + val contentType = request.getContentType(); + if (!StringUtils.equals(contentType, APPLICATION_JSON)) { + sendBadContentType(contentType, response); + int contentLength = request.getContentLength(); + if (contentLength > PAYLOAD_SIZE_LIMIT) { + response.sendError(500, "Payload size limit violated!"); + } + return false; + } + return true; + } + + protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException { + val processor = new ServingProcessor(); + String processorReturned = ""; + String path = request.getPathInfo(); + if (path.equals(LISTING_ENDPOINT)) { + val contentType = request.getContentType(); + if (!StringUtils.equals(contentType, APPLICATION_JSON)) { + sendBadContentType(contentType, response); + } + processorReturned = processor.listEndpoints(); + } + else { + sendError(request.getRequestURI(), response); + } + try { + val out = response.getWriter(); + out.write(processorReturned); + } catch (IOException e) { + log.error(e.getMessage()); + } + } + + protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException { + val processor = new ServingProcessor(); + String processorReturned = ""; + String path = request.getPathInfo(); + if (path.equals(SERVING_ENDPOINT)) { + val contentType = request.getContentType(); + /*Preconditions.checkArgument(StringUtils.equals(contentType, APPLICATION_JSON), + "Content type is " + contentType);*/ + if (validateRequest(request,response)) { + val stream = request.getInputStream(); + val bufferedReader = new BufferedReader(new InputStreamReader(stream)); + char[] charBuffer = new char[128]; + int bytesRead = -1; + val buffer = new StringBuilder(); + while ((bytesRead = bufferedReader.read(charBuffer)) > 0) { + buffer.append(charBuffer, 0, bytesRead); + } + val requestString = buffer.toString(); + + val mds = inferenceAdapter.apply(deserializer.deserialize(requestString)); + val map = new LinkedHashMap(); + + // optionally define placeholders with names provided in server constructor + if (orderedInputNodes != null && orderedInputNodes.length > 0) { + int cnt = 0; + for (val n : orderedInputNodes) + map.put(n, mds.getFeatures(cnt++)); + } + + val output = sdModel.exec(map, orderedOutputNodes); + val arrays = new INDArray[output.size()]; + + // now we need to get ordered output arrays, as specified in server constructor + int cnt = 0; + for (val n : orderedOutputNodes) + arrays[cnt++] = output.get(n); + + // process result + val result = inferenceAdapter.apply(arrays); + processorReturned = serializer.serialize(result); + } + } else { + // we return error otherwise + sendError(request.getRequestURI(), response); + } + try { + val out = response.getWriter(); + out.write(processorReturned); + } catch (IOException e) { + log.error(e.getMessage()); + } + } + + @Override + public String getServletInfo() { + return null; + } + + @Override + public void destroy() { + // + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ServingProcessor.java b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ServingProcessor.java new file mode 100644 index 000000000..6f77215bd --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ServingProcessor.java @@ -0,0 +1,14 @@ +package org.nd4j.remote.serving; + +public class ServingProcessor { + + public String listEndpoints() { + String retVal = "/v1/ \n/v1/serving/"; + return retVal; + } + + public String processModel(String body) { + String response = null; //"Not implemented"; + return response; + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java new file mode 100644 index 000000000..290d1c4a1 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java @@ -0,0 +1,271 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.junit.Test; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.remote.clients.JsonRemoteInference; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.helpers.House; +import org.nd4j.remote.helpers.HouseToPredictedPriceAdapter; +import org.nd4j.remote.helpers.PredictedPrice; +import org.nd4j.remote.clients.serde.impl.FloatArraySerde; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +import static org.junit.Assert.*; + +@Slf4j +public class SameDiffJsonModelServerTest { + + @Test + public void basicServingTest_1() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .sdModel(sd) + .port(18080) + .build(); + + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:18080/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house); + + val timeStart = System.currentTimeMillis(); + price = client.predict(house); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + + assertNotNull(price); + assertEquals((float) district + 1.0f, price.getPrice(), 1e-5); + + server.stop(); + } + + @Test + public void testDeserialization_1() { + String request = "{\"bedrooms\":3,\"area\":100,\"district\":2,\"bathrooms\":2}"; + val deserializer = new House.HouseDeserializer(); + val result = deserializer.deserialize(request); + assertEquals(2, result.getDistrict()); + assertEquals(100, result.getArea()); + assertEquals(2, result.getBathrooms()); + assertEquals(3, result.getBedrooms()); + + } + + @Test + public void testDeserialization_2() { + String request = "{\"price\":1}"; + val deserializer = new PredictedPrice.PredictedPriceDeserializer(); + val result = deserializer.deserialize(request); + assertEquals(1.0, result.getPrice(), 1e-4); + } + + @Test + public void testDeserialization_3() { + float[] data = {0.0f, 0.1f, 0.2f}; + val serialized = new FloatArraySerde().serialize(data); + val deserialized = new FloatArraySerde().deserialize(serialized); + assertArrayEquals(data, deserialized, 1e-5f); + } + + @Test(expected = NullPointerException.class) + public void negativeServingTest_1() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(null) + .sdModel(sd) + .port(18080) + .build(); + } + + @Test(expected = NullPointerException.class) + public void negativeServingTest_2() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .sdModel(sd) + .port(18080) + .build(); + + } + + @Test(expected = IOException.class) + public void negativeServingTest_3() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .sdModel(sd) + .port(18080) + .build(); + + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new JsonDeserializer() { + @Override + public PredictedPrice deserialize(String json) { + return null; + } + }) + .endpointAddress("http://localhost:18080/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + // warmup + PredictedPrice price = client.predict(house); + + server.stop(); + } + + @Test + public void asyncServingTest() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .sdModel(sd) + .port(18080) + .build(); + + server.start(); + + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new PredictedPrice.PredictedPriceDeserializer()) + .endpointAddress("http://localhost:18080/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + val timeStart = System.currentTimeMillis(); + Future price = client.predictAsync(house); + assertNotNull(price); + assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + + + server.stop(); + } + + @Test + public void negativeAsyncTest() throws Exception { + val sd = SameDiff.create(); + val sdVariable = sd.placeHolder("input", DataType.INT, 4); + val result = sdVariable.add(1.0); + val total = result.mean("total", Integer.MAX_VALUE); + + val server = new SameDiffJsonModelServer.Builder() + .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) + .inputDeserializer(new House.HouseDeserializer()) + .inferenceAdapter(new HouseToPredictedPriceAdapter()) + .orderedInputNodes(new String[]{"input"}) + .orderedOutputNodes(new String[]{"total"}) + .sdModel(sd) + .port(18080) + .build(); + + server.start(); + + // Fake deserializer to test failure + val client = JsonRemoteInference.builder() + .inputSerializer(new House.HouseSerializer()) + .outputDeserializer(new JsonDeserializer() { + @Override + public PredictedPrice deserialize(String json) { + return null; + } + }) + .endpointAddress("http://localhost:18080/v1/serving") + .build(); + + int district = 2; + House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build(); + + val timeStart = System.currentTimeMillis(); + try { + Future price = client.predictAsync(house); + assertNotNull(price); + assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5); + val timeStop = System.currentTimeMillis(); + + log.info("Time spent: {} ms", timeStop - timeStart); + } catch (ExecutionException e) { + assertTrue(e.getMessage().contains("Deserialization failed")); + } + + server.stop(); + } + +} \ No newline at end of file diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java new file mode 100644 index 000000000..a26809efc --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java @@ -0,0 +1,116 @@ +package org.nd4j.remote; + +import lombok.val; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.HttpClientBuilder; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; +import org.nd4j.adapters.InferenceAdapter; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +public class SameDiffServletTest { + + private SameDiffJsonModelServer server; + + @Before + public void setUp() throws Exception { + server = new SameDiffJsonModelServer.Builder() + .sdModel(SameDiff.create()) + .port(8080) + .inferenceAdapter(new InferenceAdapter() { + @Override + public MultiDataSet apply(String input) { + return null; + } + + @Override + public String apply(INDArray... nnOutput) { + return null; + } + }) + .outputSerializer(new JsonSerializer() { + @Override + public String serialize(String o) { + return ""; + } + }) + .inputDeserializer(new JsonDeserializer() { + @Override + public String deserialize(String json) { + return ""; + } + }) + .orderedOutputNodes(new String[]{"output"}) + .build(); + + server.start(); + //server.join(); + } + + @After + public void tearDown() throws Exception { + server.stop(); + } + + @Test + public void getEndpoints() throws IOException { + val request = new HttpGet( "http://localhost:8080/v1" ); + request.setHeader("Content-type", "application/json"); + + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + @Test + public void testContentTypeGet() throws IOException { + val request = new HttpGet( "http://localhost:8080/v1" ); + request.setHeader("Content-type", "text/plain"); + + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(415, response.getStatusLine().getStatusCode()); + } + + @Test + public void testContentTypePost() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving"); + request.setHeader("Content-type", "text/plain"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(415, response.getStatusLine().getStatusCode()); + } + + @Test + public void postForServing() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving"); + request.setHeader("Content-type", "application/json"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(500, response.getStatusLine().getStatusCode()); + } + + @Test + public void testNotFoundPost() throws Exception { + val request = new HttpPost("http://localhost:8080/v1/serving/some"); + request.setHeader("Content-type", "application/json"); + val response = HttpClientBuilder.create().build().execute( request ); + assertEquals(404, response.getStatusLine().getStatusCode()); + } + + @Test + public void testNotFoundGet() throws Exception { + val requestGet = new HttpGet( "http://localhost:8080/v1/not_found" ); + requestGet.setHeader("Content-type", "application/json"); + + val responseGet = HttpClientBuilder.create().build().execute( requestGet ); + assertEquals(404, responseGet.getStatusLine().getStatusCode()); + } + +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/House.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/House.java new file mode 100644 index 000000000..e5089e585 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/House.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.helpers; + +import com.google.gson.Gson; +import lombok.*; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +@Data +@Builder +@AllArgsConstructor +@NoArgsConstructor +public class House { + private int district; + private int bedrooms; + private int bathrooms; + private int area; + + + public static class HouseSerializer implements JsonSerializer { + @Override + public String serialize(@NonNull House o) { + return new Gson().toJson(o); + } + } + + public static class HouseDeserializer implements JsonDeserializer { + @Override + public House deserialize(@NonNull String json) { + return new Gson().fromJson(json, House.class); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/ContextDeallocator.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/HouseToPredictedPriceAdapter.java similarity index 56% rename from nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/ContextDeallocator.java rename to nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/HouseToPredictedPriceAdapter.java index 263316d02..fe07623d3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/ContextDeallocator.java +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/HouseToPredictedPriceAdapter.java @@ -14,29 +14,27 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.jita.allocator.garbage; +package org.nd4j.remote.helpers; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.linalg.api.memory.Deallocator; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.jcublas.context.CudaContext; +import org.nd4j.adapters.InferenceAdapter; -/** - * This class provides Deallocator implementation for tracking/releasing CudaContexts once thread holding it dies - * @author raver119@gmail.com - */ @Slf4j -public class ContextDeallocator implements Deallocator { - private CudaContext context; +public class HouseToPredictedPriceAdapter implements InferenceAdapter { - public ContextDeallocator(@NonNull CudaContext context) { - this.context = context; + @Override + public MultiDataSet apply(@NonNull House input) { + // we just create vector array with shape[4] and assign it's value to the district value + return new MultiDataSet(Nd4j.create(DataType.FLOAT, 4).assign(input.getDistrict()), null); } @Override - public void deallocate() { - AtomicAllocator.getInstance().getContextPool().releaseContext(context); + public PredictedPrice apply(INDArray... nnOutput) { + return new PredictedPrice(nnOutput[0].getFloat(0)); } } diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/PredictedPrice.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/PredictedPrice.java new file mode 100644 index 000000000..41d2bd253 --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/PredictedPrice.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.remote.helpers; + +import com.google.gson.Gson; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.remote.clients.serde.JsonDeserializer; +import org.nd4j.remote.clients.serde.JsonSerializer; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class PredictedPrice { + private float price; + + public static class PredictedPriceSerializer implements JsonSerializer { + @Override + public String serialize(@NonNull PredictedPrice o) { + return new Gson().toJson(o); + } + } + + public static class PredictedPriceDeserializer implements JsonDeserializer { + @Override + public PredictedPrice deserialize(@NonNull String json) { + return new Gson().fromJson(json, PredictedPrice.class); + } + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java new file mode 100644 index 000000000..3aca94b8a --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java @@ -0,0 +1,90 @@ +package org.nd4j.remote.serde; + +import lombok.val; +import org.junit.Test; +import org.nd4j.remote.clients.serde.impl.*; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class BasicSerdeTests { + private final static DoubleArraySerde doubleArraySerde = new DoubleArraySerde(); + private final static FloatArraySerde floatArraySerde = new FloatArraySerde(); + private final static StringSerde stringSerde = new StringSerde(); + private final static IntegerSerde integerSerde = new IntegerSerde(); + private final static FloatSerde floatSerde = new FloatSerde(); + private final static DoubleSerde doubleSerde = new DoubleSerde(); + private final static BooleanSerde booleanSerde = new BooleanSerde(); + + @Test + public void testStringSerde_1() { + val jvmString = "String with { strange } elements"; + + val serialized = stringSerde.serialize(jvmString); + val deserialized = stringSerde.deserialize(serialized); + + assertEquals(jvmString, deserialized); + } + + @Test + public void testFloatArraySerDe_1() { + val jvmArray = new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + + val serialized = floatArraySerde.serialize(jvmArray); + val deserialized = floatArraySerde.deserialize(serialized); + + assertArrayEquals(jvmArray, deserialized, 1e-5f); + } + + @Test + public void testDoubleArraySerDe_1() { + val jvmArray = new double[] {1.0, 2.0, 3.0, 4.0, 5.0}; + + val serialized = doubleArraySerde.serialize(jvmArray); + val deserialized = doubleArraySerde.deserialize(serialized); + + assertArrayEquals(jvmArray, deserialized, 1e-5); + } + + @Test + public void testFloatSerde_1() { + val f = 119.f; + + val serialized = floatSerde.serialize(f); + val deserialized = floatSerde.deserialize(serialized); + + assertEquals(f, deserialized, 1e-5f); + } + + @Test + public void testDoubleSerde_1() { + val d = 119.; + + val serialized = doubleSerde.serialize(d); + val deserialized = doubleSerde.deserialize(serialized); + + assertEquals(d, deserialized, 1e-5); + } + + @Test + public void testIntegerSerde_1() { + val f = 119; + + val serialized = integerSerde.serialize(f); + val deserialized = integerSerde.deserialize(serialized); + + + assertEquals(f, deserialized.intValue()); + } + + @Test + public void testBooleanSerde_1() { + val f = true; + + val serialized = booleanSerde.serialize(f); + val deserialized = booleanSerde.deserialize(serialized); + + + assertEquals(f, deserialized); + } +} diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml b/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml new file mode 100644 index 000000000..59b35644e --- /dev/null +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml @@ -0,0 +1,48 @@ + + + + + + + + logs/application.log + + %date - [%level] - from %logger in %thread + %n%message%n%xException%n + + + + + + %logger{15} - %message%n%xException{5} + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/nd4j/nd4j-remote/pom.xml b/nd4j/nd4j-remote/pom.xml new file mode 100644 index 000000000..c6a0cdb86 --- /dev/null +++ b/nd4j/nd4j-remote/pom.xml @@ -0,0 +1,35 @@ + + + + 4.0.0 + pom + + + nd4j-json-client + nd4j-grpc-client + nd4j-json-server + + + + org.nd4j + nd4j + 1.0.0-SNAPSHOT + + + nd4j-remote + 1.0.0-SNAPSHOT + nd4j-remote + + + UTF-8 + 1.7 + 1.7 + + + + + testresources + + + diff --git a/nd4j/nd4j-serde/pom.xml b/nd4j/nd4j-serde/pom.xml index 3234912cb..d4fc4ff05 100644 --- a/nd4j/nd4j-serde/pom.xml +++ b/nd4j/nd4j-serde/pom.xml @@ -33,7 +33,6 @@ nd4j-camel-routes nd4j-gson nd4j-arrow - nd4j-grpc diff --git a/nd4j/pom.xml b/nd4j/pom.xml index 255859171..6c294d7e7 100644 --- a/nd4j/pom.xml +++ b/nd4j/pom.xml @@ -62,6 +62,7 @@ nd4j-parameter-server-parent nd4j-uberjar nd4j-tensorflow + nd4j-remote