Merge remote-tracking branch 'fork/master'

master
AlexDBlack 2019-08-15 12:02:42 +10:00
commit 37e053ad90
256 changed files with 6336 additions and 2549 deletions

View File

@ -16,6 +16,7 @@
package org.datavec.spark.transform.client;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
import com.mashape.unirest.http.exceptions.UnirestException;

View File

@ -51,11 +51,13 @@
<artifactId>datavec-spark-inference-model</artifactId>
<version>${datavec.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-spark_2.11</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
@ -67,61 +69,73 @@
<artifactId>akka-cluster_2.11</artifactId>
<version>${akka.version}</version>
</dependency>
<dependency>
<groupId>joda-time</groupId>
<artifactId>joda-time</artifactId>
<version>${jodatime.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>${commons-lang3.version}</version>
</dependency>
<dependency>
<groupId>org.hibernate</groupId>
<artifactId>hibernate-validator</artifactId>
<version>${hibernate.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.yaml</groupId>
<artifactId>snakeyaml</artifactId>
<version>${snakeyaml.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jdk8</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-java_2.11</artifactId>
@ -137,39 +151,44 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>net.jodah</groupId>
<artifactId>typetools</artifactId>
<version>${jodah.typetools.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-json_2.11</artifactId>
<version>${play.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-server_2.11</artifactId>
<version>${play.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play_2.11</artifactId>
<version>${play.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-netty-server_2.11</artifactId>
<version>${play.version}</version>
</dependency>
<dependency>
<groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId>
<version>${unirest.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.beust</groupId>
<artifactId>jcommander</artifactId>

View File

@ -52,6 +52,7 @@ public class CSVSparkTransformServerNoJsonTest {
public static void before() throws Exception {
server = new CSVSparkTransformServer();
FileUtils.write(fileSave, transformProcess.toJson());
// Only one time
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
@ -73,6 +74,7 @@ public class CSVSparkTransformServerNoJsonTest {
}
}
});
server.runMain(new String[] {"-dp", "9050"});
}

View File

@ -16,6 +16,7 @@
package org.datavec.spark.transform;
import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
@ -49,6 +50,7 @@ public class CSVSparkTransformServerTest {
server = new CSVSparkTransformServer();
FileUtils.write(fileSave, transformProcess.toJson());
// Only one time
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper();
@ -69,6 +71,7 @@ public class CSVSparkTransformServerTest {
}
}
});
server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"});
}

View File

@ -16,6 +16,7 @@
package org.datavec.spark.transform;
import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;

View File

@ -16,6 +16,7 @@
package org.datavec.spark.transform;
import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;

View File

@ -19,10 +19,10 @@ package org.deeplearning4j.nearestneighbor.client;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
import com.mashape.unirest.request.HttpRequest;
import com.mashape.unirest.request.HttpRequestWithBody;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.val;
import org.deeplearning4j.nearestneighbor.model.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.serde.base64.Nd4jBase64;
@ -51,6 +51,7 @@ public class NearestNeighborsClient {
static {
// Only one time
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper();
@ -89,7 +90,7 @@ public class NearestNeighborsClient {
NearestNeighborRequest request = new NearestNeighborRequest();
request.setInputIndex(index);
request.setK(k);
HttpRequestWithBody req = Unirest.post(url + "/knn");
val req = Unirest.post(url + "/knn");
req.header("accept", "application/json")
.header("Content-Type", "application/json").body(request);
addAuthHeader(req);
@ -112,7 +113,7 @@ public class NearestNeighborsClient {
Base64NDArrayBody base64NDArrayBody =
Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build();
HttpRequestWithBody req = Unirest.post(url + "/knnnew");
val req = Unirest.post(url + "/knnnew");
req.header("accept", "application/json")
.header("Content-Type", "application/json").body(base64NDArrayBody);
addAuthHeader(req);

View File

@ -19,7 +19,6 @@ package org.deeplearning4j.models.word2vec;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import jdk.nashorn.internal.objects.annotations.Property;
import lombok.Getter;
import lombok.NonNull;
import org.apache.commons.compress.compressors.gzip.GzipUtils;

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.nn.adapters;
import lombok.val;
import org.deeplearning4j.nn.api.OutputAdapter;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

View File

@ -18,7 +18,7 @@ package org.deeplearning4j.nn.adapters;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.api.OutputAdapter;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.api;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.api.ndarray.INDArray;
/**

View File

@ -24,6 +24,7 @@ import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.exception.DL4JException;

View File

@ -25,6 +25,7 @@ import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.RegressionEvaluation;

View File

@ -0,0 +1,110 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<packaging>jar</packaging>
<parent>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-remote</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>deeplearning4j-json-server</artifactId>
<version>1.0.0-SNAPSHOT</version>
<name>deeplearning4j-json-server</name>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>${lombok.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-json-client</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-json-server</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-parallel-wrapper</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-10.1</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
</project>

View File

@ -0,0 +1,205 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelInference;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.remote.serving.SameDiffServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
/**
*
* @author astoyakin
*/
@Slf4j
@NoArgsConstructor
public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
protected ParallelInference parallelInference;
protected Model model;
protected boolean parallelEnabled = true;
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
super(inferenceAdapter, serializer, deserializer);
this.parallelInference = parallelInference;
this.model = null;
this.parallelEnabled = true;
}
public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter,
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
super(inferenceAdapter, serializer, deserializer);
this.model = model;
this.parallelInference = null;
this.parallelEnabled = false;
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
String processorReturned = "";
String path = request.getPathInfo();
if (path.equals(SERVING_ENDPOINT)) {
val contentType = request.getContentType();
if (validateRequest(request,response)) {
val stream = request.getInputStream();
val bufferedReader = new BufferedReader(new InputStreamReader(stream));
char[] charBuffer = new char[128];
int bytesRead = -1;
val buffer = new StringBuilder();
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
buffer.append(charBuffer, 0, bytesRead);
}
val requestString = buffer.toString();
val mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
O result = null;
if (parallelEnabled) {
// process result
result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays()));
}
else {
synchronized(this) {
if (model instanceof ComputationGraph)
result = inferenceAdapter.apply(((ComputationGraph)model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
else if (model instanceof MultiLayerNetwork) {
Preconditions.checkArgument(mds.getFeatures().length > 1 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 1),
"Input data for MultilayerNetwork is invalid!");
result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
}
}
}
processorReturned = serializer.serialize(result);
}
} else {
// we return error otherwise
sendError(request.getRequestURI(), response);
}
try {
val out = response.getWriter();
out.write(processorReturned);
} catch (IOException e) {
log.error(e.getMessage());
}
}
/**
* Creates servlet to serve models
*
* @param <I> type of Input class
* @param <O> type of Output class
*
* @author raver119@gmail.com
* @author astoyakin
*/
public static class Builder<I,O> {
private ParallelInference pi;
private Model model;
private InferenceAdapter<I, O> inferenceAdapter;
private JsonSerializer<O> serializer;
private JsonDeserializer<I> deserializer;
private int port;
private boolean parallelEnabled = true;
public Builder(@NonNull ParallelInference pi) {
this.pi = pi;
}
public Builder(@NonNull Model model) {
this.model = model;
}
public Builder<I,O> inferenceAdapter(@NonNull InferenceAdapter<I,O> inferenceAdapter) {
this.inferenceAdapter = inferenceAdapter;
return this;
}
/**
* This method is required to specify serializer
*
* @param serializer
* @return
*/
public Builder<I,O> serializer(@NonNull JsonSerializer<O> serializer) {
this.serializer = serializer;
return this;
}
/**
* This method allows to specify deserializer
*
* @param deserializer
* @return
*/
public Builder<I,O> deserializer(@NonNull JsonDeserializer<I> deserializer) {
this.deserializer = deserializer;
return this;
}
/**
* This method allows to specify port
*
* @param port
* @return
*/
public Builder<I,O> port(int port) {
this.port = port;
return this;
}
/**
* This method activates parallel inference
*
* @param parallelEnabled
* @return
*/
public Builder<I,O> parallelEnabled(boolean parallelEnabled) {
this.parallelEnabled = parallelEnabled;
return this;
}
public DL4jServlet<I,O> build() {
return parallelEnabled ? new DL4jServlet<I, O>(pi, inferenceAdapter, serializer, deserializer) :
new DL4jServlet<I, O>(model, inferenceAdapter, serializer, deserializer);
}
}
}

View File

@ -0,0 +1,392 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote;
import lombok.NonNull;
import lombok.val;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelInference;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.adapters.InputAdapter;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.remote.SameDiffJsonModelServer;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import java.util.List;
/**
* This class provides JSON-based model serving ability for Deeplearning4j/SameDiff models
*
* Server url will be http://0.0.0.0:{port}>/v1/serving
* Server only accepts POST requests
*
* @param <I> type of the input class, i.e. String
* @param <O> type of the output class, i.e. Sentiment
*
* @author raver119@gmail.com
* @author astoyakin
*/
public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
// all serving goes through ParallelInference
protected ParallelInference parallelInference;
protected ModelAdapter<O> modelAdapter;
// actual models
protected ComputationGraph cgModel;
protected MultiLayerNetwork mlnModel;
// service stuff
protected InferenceMode inferenceMode;
protected int numWorkers;
protected boolean enabledParallel = true;
protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, String[] orderedOutputNodes) {
super(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
}
protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
super(inferenceAdapter, serializer, deserializer, port);
this.cgModel = cgModel;
this.inferenceMode = inferenceMode;
this.numWorkers = numWorkers;
}
protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
super(inferenceAdapter, serializer, deserializer, port);
this.mlnModel = mlnModel;
this.inferenceMode = inferenceMode;
this.numWorkers = numWorkers;
}
protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port) {
super(inferenceAdapter, serializer, deserializer, port);
this.parallelInference = pi;
}
/**
* This method stops server
*
* @throws Exception
*/
@Override
public void stop() throws Exception {
if (parallelInference != null)
parallelInference.shutdown();
super.stop();
}
/**
* This method starts server
* @throws Exception
*/
@Override
public void start() throws Exception {
// if we're just serving sdModel - we'll just call super. no dl4j functionality required in this case
if (sdModel != null) {
super.start();
return;
}
Preconditions.checkArgument(cgModel != null || mlnModel != null, "Model serving requires either MultilayerNetwork or ComputationGraph defined");
val model = cgModel != null ? (Model) cgModel : (Model) mlnModel;
// PI construction is optional, since we can have it defined
if (enabledParallel) {
if (parallelInference == null) {
Preconditions.checkArgument(numWorkers >= 1, "Number of workers should be >= 1, got " + numWorkers + " instead");
parallelInference = new ParallelInference.Builder(model)
.inferenceMode(inferenceMode)
.workers(numWorkers)
.loadBalanceMode(LoadBalanceMode.FIFO)
.batchLimit(16)
.queueLimit(128)
.build();
}
servingServlet = new DL4jServlet.Builder<I, O>(parallelInference)
.parallelEnabled(true)
.serializer(serializer)
.deserializer(deserializer)
.inferenceAdapter(inferenceAdapter)
.build();
}
else {
servingServlet = new DL4jServlet.Builder<I, O>(model)
.parallelEnabled(false)
.serializer(serializer)
.deserializer(deserializer)
.inferenceAdapter(inferenceAdapter)
.build();
}
start(port, servingServlet);
}
/**
* Creates servlet to serve different types of models
*
* @param <I> type of Input class
* @param <O> type of Output class
*
* @author raver119@gmail.com
* @author astoyakin
*/
public static class Builder<I,O> {
private SameDiff sdModel;
private ComputationGraph cgModel;
private MultiLayerNetwork mlnModel;
private ParallelInference pi;
private String[] orderedInputNodes;
private String[] orderedOutputNodes;
private InferenceAdapter<I, O> inferenceAdapter;
private JsonSerializer<O> serializer;
private JsonDeserializer<I> deserializer;
private InputAdapter<I> inputAdapter;
private OutputAdapter<O> outputAdapter;
private int port;
private boolean parallelMode = true;
// these fields actually require defaults
private InferenceMode inferenceMode = InferenceMode.BATCHED;
private int numWorkers = Nd4j.getAffinityManager().getNumberOfDevices();
public Builder(@NonNull SameDiff sdModel) {
this.sdModel = sdModel;
}
public Builder(@NonNull MultiLayerNetwork mlnModel) {
this.mlnModel = mlnModel;
}
public Builder(@NonNull ComputationGraph cgModel) {
this.cgModel = cgModel;
}
public Builder(@NonNull ParallelInference pi) {
this.pi = pi;
}
/**
* This method defines InferenceAdapter implementation, which will be used to convert object of Input type to the set of INDArray(s), and for conversion of resulting INDArray(s) into object of Output type
* @param inferenceAdapter
* @return
*/
public Builder<I,O> inferenceAdapter(@NonNull InferenceAdapter<I,O> inferenceAdapter) {
this.inferenceAdapter = inferenceAdapter;
return this;
}
/**
* This method allows you to specify InputAdapter to be used for inference
*
* PLEASE NOTE: This method is optional, and will require OutputAdapter<O> defined
* @param inputAdapter
* @return
*/
public Builder<I,O> inputAdapter(@NonNull InputAdapter<I> inputAdapter) {
this.inputAdapter = inputAdapter;
return this;
}
/**
* This method allows you to specify OutputtAdapter to be used for inference
*
* PLEASE NOTE: This method is optional, and will require InputAdapter<I> defined
* @param outputAdapter
* @return
*/
public Builder<I,O> outputAdapter(@NonNull OutputAdapter<O> outputAdapter) {
this.outputAdapter = outputAdapter;
return this;
}
/**
* This method allows you to specify serializer
*
* @param serializer
* @return
*/
public Builder<I,O> outputSerializer(@NonNull JsonSerializer<O> serializer) {
this.serializer = serializer;
return this;
}
/**
* This method allows you to specify deserializer
*
* @param deserializer
* @return
*/
public Builder<I,O> inputDeserializer(@NonNull JsonDeserializer<I> deserializer) {
this.deserializer = deserializer;
return this;
}
/**
* This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details
*
* @param inferenceMode
* @return
*/
public Builder<I,O> inferenceMode(@NonNull InferenceMode inferenceMode) {
this.inferenceMode = inferenceMode;
return this;
}
/**
* This method allows you to specify number of worker threads for ParallelInference
*
* @param numWorkers
* @return
*/
public Builder<I,O> numWorkers(int numWorkers) {
this.numWorkers = numWorkers;
return this;
}
/**
* This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models
*
* PLEASE NOTE: this argument only used for SameDiff models
* @param args
* @return
*/
public Builder<I,O> orderedInputNodes(String... args) {
orderedInputNodes = args;
return this;
}
/**
* This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models
*
* PLEASE NOTE: this argument only used for SameDiff models
* @param args
* @return
*/
public Builder<I,O> orderedInputNodes(@NonNull List<String> args) {
orderedInputNodes = args.toArray(new String[args.size()]);
return this;
}
/**
* This method allows you to specify output nodes
*
* PLEASE NOTE: this argument only used for SameDiff models
* @param args
* @return
*/
public Builder<I,O> orderedOutputNodes(String... args) {
Preconditions.checkArgument(args != null && args.length > 0, "OutputNodes should contain at least 1 element");
orderedOutputNodes = args;
return this;
}
/**
* This method allows you to specify output nodes
*
* PLEASE NOTE: this argument only used for SameDiff models
* @param args
* @return
*/
public Builder<I,O> orderedOutputNodes(@NonNull List<String> args) {
Preconditions.checkArgument(args.size() > 0, "OutputNodes should contain at least 1 element");
orderedOutputNodes = args.toArray(new String[args.size()]);
return this;
}
/**
* This method allows you to specify http port
*
* PLEASE NOTE: port must be free and be in range regular TCP/IP ports range
* @param port
* @return
*/
public Builder<I,O> port(int port) {
this.port = port;
return this;
}
/**
* This method switches on ParallelInference usage
* @param - true - to use ParallelInference, false - to use ComputationGraph or
* MultiLayerNetwork directly
*
* PLEASE NOTE: this doesn't apply to SameDiff models
*
* @throws Exception
*/
public Builder<I,O> parallelMode(boolean enable) {
this.parallelMode = enable;
return this;
}
public JsonModelServer<I,O> build() {
if (inferenceAdapter == null) {
if (inputAdapter != null && outputAdapter != null) {
inferenceAdapter = new InferenceAdapter<I, O>() {
@Override
public MultiDataSet apply(I input) {
return inputAdapter.apply(input);
}
@Override
public O apply(INDArray... outputs) {
return outputAdapter.apply(outputs);
}
};
} else
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
}
if (sdModel != null) {
Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "For SameDiff model serving OutputNodes should be defined");
return new JsonModelServer<I, O>(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
} else if (cgModel != null)
return new JsonModelServer<I,O>(cgModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
else if (mlnModel != null)
return new JsonModelServer<I,O>(mlnModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
else if (pi != null)
return new JsonModelServer<I,O>(pi, inferenceAdapter, serializer, deserializer, port);
else
throw new IllegalStateException("No models were defined for JsonModelServer");
}
}
}

View File

@ -0,0 +1,749 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.remote.helpers.House;
import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter;
import org.deeplearning4j.remote.helpers.PredictedPrice;
import org.junit.After;
import org.junit.Test;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.remote.clients.JsonRemoteInference;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE;
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
import static org.junit.Assert.*;
@Slf4j
public class JsonModelServerTest {
private static final MultiLayerNetwork model;
private final int PORT = 18080;
static {
val conf = new NeuralNetConfiguration.Builder()
.seed(119)
.updater(new Adam(0.119f))
.weightInit(WeightInit.XAVIER)
.list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(10).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.SIGMOID).nIn(10).nOut(1).build())
.build();
model = new MultiLayerNetwork(conf);
model.init();
}
@After
public void pause() throws Exception {
// TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
TimeUnit.SECONDS.sleep(2);
}
@Test
public void testStartStopParallel() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val serverDL = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.numWorkers(1)
.inferenceMode(SEQUENTIAL)
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.port(PORT)
.build();
val serverSD = new JsonModelServer.Builder<House, PredictedPrice>(sd)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.orderedInputNodes(new String[]{"input"})
.orderedOutputNodes(new String[]{"total"})
.port(PORT+1)
.build();
try {
serverDL.start();
serverSD.start();
val clientDL = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
PredictedPrice price = clientDL.predict(house);
long timeStart = System.currentTimeMillis();
price = clientDL.predict(house);
long timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
assertNotNull(price);
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
val clientSD = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving")
.build();
PredictedPrice price2 = clientSD.predict(house);
timeStart = System.currentTimeMillis();
price = clientSD.predict(house);
timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
assertNotNull(price);
assertEquals((float) 3.0, price.getPrice(), 1e-5);
}
finally {
serverSD.stop();
serverDL.stop();
}
}
@Test
public void testStartStopSequential() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val serverDL = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.numWorkers(1)
.inferenceMode(SEQUENTIAL)
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.port(PORT)
.build();
val serverSD = new JsonModelServer.Builder<House, PredictedPrice>(sd)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.orderedInputNodes(new String[]{"input"})
.orderedOutputNodes(new String[]{"total"})
.port(PORT+1)
.build();
serverDL.start();
serverDL.stop();
serverSD.start();
serverSD.stop();
}
@Test
public void basicServingTestForSD() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val server = new JsonModelServer.Builder<House, PredictedPrice>(sd)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.orderedInputNodes(new String[]{"input"})
.orderedOutputNodes(new String[]{"total"})
.port(PORT)
.build();
try {
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
// warmup
PredictedPrice price = client.predict(house);
val timeStart = System.currentTimeMillis();
price = client.predict(house);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
assertNotNull(price);
assertEquals((float) district + 1.0f, price.getPrice(), 1e-5);
}
finally {
server.stop();
}
}
@Test
public void basicServingTestForDLSynchronized() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.numWorkers(1)
.inferenceMode(INPLACE)
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.port(PORT)
.build();
try {
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house1 = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
House house2 = House.builder().area(50).bathrooms(1).bedrooms(2).district(district).build();
House house3 = House.builder().area(80).bathrooms(1).bedrooms(3).district(district).build();
// warmup
PredictedPrice price = client.predict(house1);
val timeStart = System.currentTimeMillis();
PredictedPrice price1 = client.predict(house1);
PredictedPrice price2 = client.predict(house2);
PredictedPrice price3 = client.predict(house3);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
assertNotNull(price);
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
} finally {
server.stop();
}
}
@Test
public void basicServingTestForDL() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.numWorkers(1)
.inferenceMode(SEQUENTIAL)
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.port(PORT)
.parallelMode(false)
.build();
try {
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
// warmup
PredictedPrice price = client.predict(house);
val timeStart = System.currentTimeMillis();
price = client.predict(house);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
assertNotNull(price);
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
} finally {
server.stop();
}
}
@Test
public void testDeserialization_1() {
String request = "{\"bedrooms\":3,\"area\":100,\"district\":2,\"bathrooms\":2}";
val deserializer = new House.HouseDeserializer();
val result = deserializer.deserialize(request);
assertEquals(2, result.getDistrict());
assertEquals(100, result.getArea());
assertEquals(2, result.getBathrooms());
assertEquals(3, result.getBedrooms());
}
@Test
public void testDeserialization_2() {
String request = "{\"price\":1}";
val deserializer = new PredictedPrice.PredictedPriceDeserializer();
val result = deserializer.deserialize(request);
assertEquals(1.0, result.getPrice(), 1e-4);
}
@Test(expected = NullPointerException.class)
public void negativeServingTest_1() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(null)
.port(18080)
.build();
}
@Test //(expected = NullPointerException.class)
public void negativeServingTest_2() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.port(PORT)
.build();
}
@Test(expected = IOException.class)
public void negativeServingTest_3() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.inferenceMode(SEQUENTIAL)
.numWorkers(1)
.port(PORT)
.build();
try {
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
@Override
public PredictedPrice deserialize(String json) {
return null;
}
})
.endpointAddress("http://localhost:18080/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
// warmup
PredictedPrice price = client.predict(house);
} finally {
server.stop();
}
}
@Test
public void asyncServingTest() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.inferenceMode(SEQUENTIAL)
.numWorkers(1)
.port(PORT)
.build();
try {
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
val timeStart = System.currentTimeMillis();
Future<PredictedPrice> price = client.predictAsync(house);
assertNotNull(price);
assertEquals((float) 0.421444, price.get().getPrice(), 1e-5);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
}
finally {
server.stop();
}
}
@Test
public void negativeAsyncTest() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.inferenceMode(InferenceMode.BATCHED)
.numWorkers(1)
.port(PORT)
.build();
try {
server.start();
// Fake deserializer to test failure
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
@Override
public PredictedPrice deserialize(String json) {
return null;
}
})
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
val timeStart = System.currentTimeMillis();
try {
Future<PredictedPrice> price = client.predictAsync(house);
assertNotNull(price);
assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
} catch (ExecutionException e) {
assertTrue(e.getMessage().contains("Deserialization failed"));
}
} finally {
server.stop();
}
}
@Test
public void testSameDiffMnist() throws Exception {
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28);
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10));
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10));
SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b));
val server = new JsonModelServer.Builder<float[], Integer>(sd)
.outputSerializer( new IntSerde())
.inputDeserializer(new FloatSerde())
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
@Override
public MultiDataSet apply(float[] input) {
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
}
@Override
public Integer apply(INDArray... nnOutput) {
return nnOutput[0].argMax().getInt(0);
}
})
.orderedInputNodes("in")
.orderedOutputNodes("softmax")
.port(PORT+1)
.build();
val client = JsonRemoteInference.<float[], Integer>builder()
.endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving")
.outputDeserializer(new IntSerde())
.inputSerializer( new FloatSerde())
.build();
try{
server.start();
for( int i=0; i<10; i++ ){
INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28*28);
INDArray exp = sd.output(Collections.singletonMap("in", f), "softmax").get("softmax");
float[] fArr = f.toFloatVector();
int out = client.predict(fArr);
assertEquals(exp.argMax().getInt(0), out);
}
} finally {
server.stop();
}
}
@Test
public void testMlnMnist() throws Exception {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(10).build())
.layer(new LossLayer.Builder().activation(Activation.SOFTMAX).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
val server = new JsonModelServer.Builder<float[], Integer>(net)
.outputSerializer( new IntSerde())
.inputDeserializer(new FloatSerde())
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
@Override
public MultiDataSet apply(float[] input) {
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
}
@Override
public Integer apply(INDArray... nnOutput) {
return nnOutput[0].argMax().getInt(0);
}
})
.orderedInputNodes("in")
.orderedOutputNodes("softmax")
.port(PORT + 1)
.inferenceMode(SEQUENTIAL)
.numWorkers(2)
.build();
val client = JsonRemoteInference.<float[], Integer>builder()
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
.outputDeserializer(new IntSerde())
.inputSerializer( new FloatSerde())
.build();
try {
server.start();
for (int i = 0; i < 10; i++) {
INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28 * 28);
INDArray exp = net.output(f);
float[] fArr = f.toFloatVector();
int out = client.predict(fArr);
assertEquals(exp.argMax().getInt(0), out);
}
} catch (Exception e){
e.printStackTrace();
throw e;
} finally {
server.stop();
}
}
@Test
public void testCompGraph() throws Exception {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.graphBuilder()
.addInputs("input1", "input2")
.addLayer("L1", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input1")
.addLayer("L2", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input2")
.addVertex("merge", new MergeVertex(), "L1", "L2")
.addLayer("out", new OutputLayer.Builder().nIn(4+4).nOut(3).build(), "merge")
.setOutputs("out")
.build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
val server = new JsonModelServer.Builder<float[], Integer>(net)
.outputSerializer( new IntSerde())
.inputDeserializer(new FloatSerde())
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
@Override
public MultiDataSet apply(float[] input) {
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
}
@Override
public Integer apply(INDArray... nnOutput) {
return nnOutput[0].argMax().getInt(0);
}
})
.orderedInputNodes("in")
.orderedOutputNodes("softmax")
.port(PORT + 1)
.inferenceMode(SEQUENTIAL)
.numWorkers(2)
.parallelMode(false)
.build();
val client = JsonRemoteInference.<float[], Integer>builder()
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
.outputDeserializer(new IntSerde())
.inputSerializer( new FloatSerde())
.build();
try {
server.start();
//client.predict(new float[]{0.0f, 1.0f, 2.0f});
} catch (Exception e){
e.printStackTrace();
throw e;
} finally {
server.stop();
}
}
@Test
public void testCompGraph_1() throws Exception {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Sgd(0.01))
.graphBuilder()
.addInputs("input")
.addLayer("L1", new DenseLayer.Builder().nIn(8).nOut(4).build(), "input")
.addLayer("out1", new OutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(4).nOut(3).build(), "L1")
.addLayer("out2", new OutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MSE)
.nIn(4).nOut(2).build(), "L1")
.setOutputs("out1","out2")
.build();
final ComputationGraph net = new ComputationGraph(conf);
net.init();
val server = new JsonModelServer.Builder<float[], Integer>(net)
.outputSerializer( new IntSerde())
.inputDeserializer(new FloatSerde())
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
@Override
public MultiDataSet apply(float[] input) {
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
}
@Override
public Integer apply(INDArray... nnOutput) {
return nnOutput[0].argMax().getInt(0);
}
})
.orderedInputNodes("input")
.orderedOutputNodes("out")
.port(PORT + 1)
.inferenceMode(SEQUENTIAL)
.numWorkers(2)
.parallelMode(false)
.build();
val client = JsonRemoteInference.<float[], Integer>builder()
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
.outputDeserializer(new IntSerde())
.inputSerializer( new FloatSerde())
.build();
try {
server.start();
val result = client.predict(new float[]{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
assertNotNull(result);
} catch (Exception e){
e.printStackTrace();
throw e;
} finally {
server.stop();
}
}
private static class FloatSerde implements JsonSerializer<float[]>, JsonDeserializer<float[]>{
private final ObjectMapper om = new ObjectMapper();
@Override
public float[] deserialize(String json) {
try {
return om.readValue(json, FloatHolder.class).getFloats();
} catch (IOException e){
throw new RuntimeException(e);
}
}
@Override
public String serialize(float[] o) {
try{
return om.writeValueAsString(new FloatHolder(o));
} catch (IOException e){
throw new RuntimeException(e);
}
}
//Use float holder so Jackson does ser/de properly (no "{}" otherwise)
@AllArgsConstructor @NoArgsConstructor @Data
private static class FloatHolder {
private float[] floats;
}
}
private static class IntSerde implements JsonSerializer<Integer>, JsonDeserializer<Integer> {
private final ObjectMapper om = new ObjectMapper();
@Override
public Integer deserialize(String json) {
try {
return om.readValue(json, Integer.class);
} catch (IOException e){
throw new RuntimeException(e);
}
}
@Override
public String serialize(Integer o) {
try{
return om.writeValueAsString(o);
} catch (IOException e){
throw new RuntimeException(e);
}
}
}
}

View File

@ -0,0 +1,133 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote;
import lombok.val;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.HttpClientBuilder;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class ServletTest {
private JsonModelServer server;
@Before
public void setUp() throws Exception {
val sd = SameDiff.create();
server = new JsonModelServer.Builder<String,String>(sd)
.port(8080)
.inferenceAdapter(new InferenceAdapter<String, String>() {
@Override
public MultiDataSet apply(String input) {
return null;
}
@Override
public String apply(INDArray... nnOutput) {
return null;
}
})
.outputSerializer(new JsonSerializer<String>() {
@Override
public String serialize(String o) {
return "";
}
})
.inputDeserializer(new JsonDeserializer<String>() {
@Override
public String deserialize(String json) {
return "";
}
})
.orderedInputNodes("input")
.orderedOutputNodes("output")
.build();
server.start();
//server.join();
}
@After
public void tearDown() throws Exception {
server.stop();
}
@Test
public void getEndpoints() throws IOException {
val request = new HttpGet( "http://localhost:8080/v1" );
request.setHeader("Content-type", "application/json");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(200, response.getStatusLine().getStatusCode());
}
@Test
public void testContentTypeGet() throws IOException {
val request = new HttpGet( "http://localhost:8080/v1" );
request.setHeader("Content-type", "text/plain");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(415, response.getStatusLine().getStatusCode());
}
@Test
public void testContentTypePost() throws Exception {
val request = new HttpPost("http://localhost:8080/v1/serving");
request.setHeader("Content-type", "text/plain");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(415, response.getStatusLine().getStatusCode());
}
@Test
public void postForServing() throws Exception {
val request = new HttpPost("http://localhost:8080/v1/serving");
request.setHeader("Content-type", "application/json");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(500, response.getStatusLine().getStatusCode());
}
@Test
public void testNotFoundPost() throws Exception {
val request = new HttpPost("http://localhost:8080/v1/serving/some");
request.setHeader("Content-type", "application/json");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(404, response.getStatusLine().getStatusCode());
}
@Test
public void testNotFoundGet() throws Exception {
val requestGet = new HttpGet( "http://localhost:8080/v1/not_found" );
requestGet.setHeader("Content-type", "application/json");
val responseGet = HttpClientBuilder.create().build().execute( requestGet );
assertEquals(404, responseGet.getStatusLine().getStatusCode());
}
}

View File

@ -0,0 +1,48 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote.helpers;
import com.google.gson.Gson;
import lombok.*;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class House {
private int district;
private int bedrooms;
private int bathrooms;
private int area;
public static class HouseSerializer implements JsonSerializer<House> {
@Override
public String serialize(@NonNull House o) {
return new Gson().toJson(o);
}
}
public static class HouseDeserializer implements JsonDeserializer<House> {
@Override
public House deserialize(@NonNull String json) {
return new Gson().fromJson(json, House.class);
}
}
}

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote.helpers;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
@Slf4j
public class HouseToPredictedPriceAdapter implements InferenceAdapter<House, PredictedPrice> {
@Override
public MultiDataSet apply(@NonNull House input) {
// we just create vector array with shape[4] and assign it's value to the district value
return new MultiDataSet(Nd4j.create(DataType.FLOAT, 1, 4).assign(input.getDistrict()), null);
}
@Override
public PredictedPrice apply(INDArray... nnOutput) {
return new PredictedPrice(nnOutput[0].getFloat(0));
}
}

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote.helpers;
import com.google.gson.Gson;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class PredictedPrice {
private float price;
public static class PredictedPriceSerializer implements JsonSerializer<PredictedPrice> {
@Override
public String serialize(@NonNull PredictedPrice o) {
return new Gson().toJson(o);
}
}
public static class PredictedPriceDeserializer implements JsonDeserializer<PredictedPrice> {
@Override
public PredictedPrice deserialize(@NonNull String json) {
return new Gson().fromJson(json, PredictedPrice.class);
}
}
}

View File

@ -0,0 +1,48 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file>
<encoder>
<pattern>%date - [%level] - from %logger in %thread
%n%message%n%xException%n</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<logger name="org.eclipse.jetty" level="WARN" />
<logger name="org.apache.catalina.core" level="WARN" />
<logger name="org.springframework" level="WARN" />
<logger name="org.nd4j" level="DEBUG" />
<logger name="org.deeplearning4j" level="INFO" />
<root level="ERROR">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

View File

@ -0,0 +1,30 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<packaging>pom</packaging>
<modules>
<module>deeplearning4j-json-server</module>
</modules>
<parent>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>deeplearning4j-remote</artifactId>
<version>1.0.0-SNAPSHOT</version>
<name>deeplearning4j-remote</name>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.1</id>
</profile>
</profiles>
</project>

View File

@ -21,7 +21,6 @@ import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.api.OutputAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;

View File

@ -22,7 +22,6 @@ import lombok.val;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.api.OutputAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;

View File

@ -144,6 +144,7 @@
<module>dl4j-perf</module>
<module>dl4j-integration-tests</module>
<module>deeplearning4j-common</module>
<module>deeplearning4j-remote</module>
</modules>
<dependencyManagement>

View File

@ -163,9 +163,9 @@ if(CUDA_BLAS)
if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 10
if ("${COMPUTE}" STREQUAL "all")
if (APPLE)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
endif()
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
@ -173,24 +173,24 @@ if(CUDA_BLAS)
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
if ("${COMPUTE}" STREQUAL "all")
if (APPLE)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
endif()
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8.0
if ("${COMPUTE}" STREQUAL "all")
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
else()
if ("${COMPUTE}" STREQUAL "all")
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 )
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 )
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
endif()
@ -205,34 +205,34 @@ if(CUDA_BLAS)
message("CUDA 10 Debug build")
if ("${COMPUTE}" STREQUAL "all")
if (APPLE)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
elseif()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
endif()
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=compute_${COMPUTE})
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
if ("${COMPUTE}" STREQUAL "all")
if (APPLE)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
elseif()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
endif()
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8
if ("${COMPUTE}" STREQUAL "all")
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
else()
if ("${COMPUTE}" STREQUAL "all")
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53)
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif()
endif()
endif()
@ -249,7 +249,7 @@ if(CUDA_BLAS)
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/impl/*.cpp ../include/loops/*.h)
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
if (NOT BUILD_TESTS)

View File

@ -344,7 +344,7 @@ bool NDArray::isS() const {
//////////////////////////////////////////////////////////////////////////
bool NDArray::isR() const {
auto xType = ArrayOptions::dataType(this->_shapeInfo);
return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8;
return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8 || xType == BFLOAT16;
}
//////////////////////////////////////////////////////////////////////////

View File

@ -1769,6 +1769,17 @@ ND4J_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr);
ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut);
ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut);
typedef nd4j::LaunchContext OpaqueLaunchContext;
ND4J_EXPORT OpaqueLaunchContext* defaultLaunchContext();
ND4J_EXPORT Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc);
}
#endif //NATIVEOPERATIONS_NATIVEOPS_H

View File

@ -2985,6 +2985,38 @@ const char* runFullBenchmarkSuit(bool printOut) {
return chars;
}
nd4j::LaunchContext* defaultLaunchContext() {
return LaunchContext::defaultContext();
}
Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) {
return nullptr;
}
Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) {
return nullptr;
}
Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) {
return nullptr;
}
Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) {
return nullptr;
}
Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) {
return nullptr;
}
Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) {
return nullptr;
}
Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) {
return nullptr;
}
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);

View File

@ -356,7 +356,7 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
auto stream = getContext()->getCudaStream();
prepareSpecialUse({&target}, {this});
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
registerSpecialUse({&target}, {this});
}
@ -375,7 +375,7 @@ void NDArray::tile(NDArray& target) const {
auto stream = getContext()->getCudaStream();
prepareSpecialUse({&target}, {this});
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
registerSpecialUse({&target}, {this});
}
@ -434,7 +434,7 @@ void NDArray::repeat(int dimension, NDArray& target) const {
NDArray::prepareSpecialUse({&target}, {this});
auto stream = getContext()->getCudaStream();
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), repeatKernelHH, (getSpecialBuffer(), target.getSpecialBuffer(), numTads, lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), repeatKernelHH, (getSpecialBuffer(), target.getSpecialBuffer(), numTads, lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES);
NDArray::registerSpecialUse({&target}, {this});
}

View File

@ -23,6 +23,14 @@
#include <cuda.h>
#include <cuda_runtime.h>
static Nd4jLong __device__ __noinline__ __getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) {
return shape::getIndexOffset(index, shapeInfo, length);
}
static Nd4jLong __device__ __noinline__ __length(Nd4jLong *shapeInfo) {
return shape::length(shapeInfo);
}
template <typename T, typename Lambda> static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
@ -86,7 +94,7 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
auto xOrder = shape::order(xShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto zLength = shape::length(zShapeInfo);
auto zLength = __length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@ -95,8 +103,8 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
z[e * zEws] = lambda(x[e * xEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
z[zOffset] = lambda(x[xOffset]);
}
@ -115,7 +123,7 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
auto xOrder = shape::order(xShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto zLength = shape::length(zShapeInfo);
auto zLength = __length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@ -124,8 +132,8 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
z[e * zEws] = lambda(e, x[e * xEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
z[zOffset] = lambda(e, x[xOffset]);
}
@ -147,7 +155,7 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto zLength = shape::length(zShapeInfo);
auto zLength = __length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@ -156,9 +164,9 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength);
auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
auto yOffset = __getIndexOffset(e, yShapeInfo, zLength);
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
z[zOffset] = lambda(e, x[xOffset], y[yOffset]);
}
@ -180,7 +188,7 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto zLength = shape::length(zShapeInfo);
auto zLength = __length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@ -189,9 +197,9 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
z[e * zEws] = lambda(x[e * xEws], y[e * yEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength);
auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
auto yOffset = __getIndexOffset(e, yShapeInfo, zLength);
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
z[zOffset] = lambda(x[xOffset], y[yOffset]);
}
@ -216,7 +224,7 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto zLength = shape::length(zShapeInfo);
auto zLength = __length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@ -225,10 +233,10 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto wOffset = shape::getIndexOffset(e, wShapeInfo, zLength);
auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength);
auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
auto wOffset = __getIndexOffset(e, wShapeInfo, zLength);
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
auto yOffset = __getIndexOffset(e, yShapeInfo, zLength);
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]);
}

View File

@ -28,6 +28,7 @@
#include <helpers/threshold.h>
#include <ops/specials_cuda.h>
#include <helpers/DebugHelper.h>
#include <AffinityManager.h>
#include <exceptions/datatype_exception.h>
#include <helpers/CudaLaunchHelper.h>
@ -1691,11 +1692,7 @@ void setOmpMinThreads(int threads) {
}
int getDevice() {
int curDevice = -1;
cudaGetDevice(&curDevice);
return curDevice;
return nd4j::AffinityManager::currentDeviceId();
}
void setElementThreshold(int num) {
@ -2391,8 +2388,8 @@ void sortByValue(Nd4jPointer *extraPointers,
auto xLength = shape::length(xShapeInfo);
auto xEWS = shape::elementWiseStride(xShapeInfo);
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
auto xType = nd4j::ArrayOptions::dataType(yShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(xShapeInfo);
// check if xLength is a power of 2, and use bitonic sort, if that's the case
@ -2406,7 +2403,7 @@ void sortByValue(Nd4jPointer *extraPointers,
for (int k = 2; k <= xLength; k = 2*k) {
for (int j = k >> 1; j > 0; j = j >> 1) {
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
} else {
@ -2430,7 +2427,7 @@ void sortByValue(Nd4jPointer *extraPointers,
int rev = 0;
do{
int half = n >> 1;
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
n>>=1;
rev = 1;
} while(n > 1);
@ -3342,6 +3339,7 @@ Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) {
nd4j::graph::Context* createGraphContext(int nodeId) {
return new nd4j::graph::Context(nodeId);
}
nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) {
return &ptr->randomGenerator();
}
@ -3460,3 +3458,35 @@ const char* runFullBenchmarkSuit(bool printOut) {
Nd4jLong getCachedMemory(int deviceId) {
return nd4j::ConstantHelper::getInstance()->getCachedAmount(deviceId);
}
nd4j::LaunchContext* defaultLaunchContext() {
return LaunchContext::defaultContext();
}
Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) {
return lc->getScalarPointer();
}
Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) {
return lc->getReductionPointer();
}
Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) {
return lc->getAllocationPointer();
}
Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) {
return lc->getCudaStream();
}
Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) {
return lc->getCudaSpecialStream();
}
Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) {
return lc->getCublasHandle();
}
Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) {
return lc->getCusolverHandle();
}

View File

@ -0,0 +1,46 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_AFFINITYMANAGER_H
#define LIBND4J_AFFINITYMANAGER_H
#include <dll.h>
#include <pointercast.h>
#include <atomic>
#include <mutex>
namespace nd4j {
class ND4J_EXPORT AffinityManager {
private:
static std::atomic<int> _lastDevice;
static int _numberOfDevices;
static std::mutex _currentMutex;
static std::mutex _numberMutex;
public:
static int currentNativeDeviceId();
static int currentDeviceId();
static int numberOfDevices();
static void setCurrentDevice(int deviceId);
static void setCurrentNativeDevice(int deviceId);
};
}
#endif //DEV_TESTS_AFFINITYMANAGER_H

View File

@ -0,0 +1,58 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_CONTEXTBUFFERS_H
#define LIBND4J_CONTEXTBUFFERS_H
#include <dll.h>
#include <pointercast.h>
namespace nd4j {
class ND4J_EXPORT ContextBuffers {
private:
void* _reductionPointer;
void* _scalarPointer;
void* _allocationPointer;
bool _allocated = true;
int _deviceId = -1;
void initialize();
public:
ContextBuffers();
ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner = false);
~ContextBuffers();
void* reductionBuffer();
void* scalarBuffer();
void* allocationBuffer();
void setReductionBuffer(void* pointer);
void setScalarBuffer(void* pointer);
void setAllocationBuffer(void* pointer);
void triggerOwnership(bool isOwner);
int deviceId();
};
}
#endif //DEV_TESTS_CONTEXTBUFFERS_H

View File

@ -35,6 +35,8 @@
#include <op_boilerplate.h>
#include <memory/Workspace.h>
#include <vector>
#include <mutex>
#include <execution/ContextBuffers.h>
@ -44,49 +46,44 @@ class ND4J_EXPORT LaunchContext {
private:
static std::vector<std::shared_ptr<LaunchContext>> _contexts;
static std::mutex _mutex;
#ifdef __CUDABLAS__
#ifndef __JAVACPP_HACK__
void* _reductionPointer;
void* _scalarPointer;
int* _allocationPointer;
cudaStream_t *_cudaStream = nullptr;
cudaStream_t *_cudaSpecialStream = nullptr;
void *_cublasHandle = nullptr;
cudaStream_t* _cudaStream = nullptr;
cudaStream_t* _cudaSpecialStream = nullptr;
void* _cublasHandle = nullptr;
void* _cusolverHandle = nullptr;
#endif // JCPP
bool _isAllocated = false;
#endif // CUDA
nd4j::memory::Workspace* _workspace = nullptr;
int _deviceID = 0;
nd4j::memory::Workspace* _workspace = nullptr;
int _deviceID = 0;
public:
#ifdef __CUDABLAS__
#ifndef __JAVACPP_HACK__
LaunchContext(cudaStream_t* cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer = nullptr, void* scalarPointer = nullptr, int* allocationPointer = nullptr);
FORCEINLINE void* getReductionPointer () const {return _reductionPointer;};
void* getReductionPointer () const;
void* getScalarPointer() const;
int* getAllocationPointer() const;
void* getCublasHandle() const;
void* getCusolverHandle() const;
cudaStream_t* getCudaStream() const;
cudaStream_t* getCudaSpecialStream() const;
FORCEINLINE void* getScalarPointer() const {return _scalarPointer;};
FORCEINLINE int* getAllocationPointer() const {return _allocationPointer;};
FORCEINLINE void* getCublasHandle() const {return _cublasHandle;};
FORCEINLINE cudaStream_t* getCudaStream() const {return _cudaStream;};
FORCEINLINE cudaStream_t* getCudaSpecialStream() const {return _cudaSpecialStream;};
FORCEINLINE void setReductionPointer (void* reductionPointer) {_reductionPointer = reductionPointer;};
FORCEINLINE void setScalarPointer(void* scalarPointer) {_scalarPointer = scalarPointer;};
FORCEINLINE void setAllocationPointer(int* allocationPointer) {_allocationPointer = allocationPointer;};
FORCEINLINE void setCudaStream(cudaStream_t* cudaStream) {_cudaStream = cudaStream;};
FORCEINLINE void setCudaSpecialStream(cudaStream_t* cudaStream) {_cudaSpecialStream = cudaStream;};
FORCEINLINE void setCublasHandle(void *handle) {_cublasHandle = handle; };
void setReductionPointer (void* reductionPointer);
void setScalarPointer(void* scalarPointer);
void setAllocationPointer(int* allocationPointer);
void setCudaStream(cudaStream_t* cudaStream);
void setCudaSpecialStream(cudaStream_t* cudaStream);
void setCublasHandle(void *handle);
#endif // JCPP

View File

@ -0,0 +1,43 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <execution/AffinityManager.h>
namespace nd4j {
int AffinityManager::currentDeviceId() {
return 0;
}
int AffinityManager::currentNativeDeviceId() {
return 0;
}
int AffinityManager::numberOfDevices() {
return 1;
}
void AffinityManager::setCurrentDevice(int deviceId) {
// no-op
}
void AffinityManager::setCurrentNativeDevice(int deviceId) {
// no-op
}
}

View File

@ -0,0 +1,74 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <execution/ContextBuffers.h>
#include <execution/AffinityManager.h>
namespace nd4j {
ContextBuffers::ContextBuffers() {
_deviceId = AffinityManager::currentDeviceId();
}
ContextBuffers::~ContextBuffers() {
// no-op
}
ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) {
_reductionPointer = rPointer;
_scalarPointer = sPointer;
_allocationPointer = aPointer;
_allocated = isOwner;
}
void ContextBuffers::initialize() {
// no-op
}
void* ContextBuffers::reductionBuffer() {
return _reductionPointer;
}
void* ContextBuffers::scalarBuffer() {
return _scalarPointer;
}
void* ContextBuffers::allocationBuffer() {
return _allocationPointer;
}
void ContextBuffers::setReductionBuffer(void* pointer) {
_reductionPointer = pointer;
}
void ContextBuffers::setScalarBuffer(void* pointer) {
_scalarPointer = pointer;
}
void ContextBuffers::setAllocationBuffer(void* pointer) {
_allocationPointer = pointer;
}
void ContextBuffers::triggerOwnership(bool isOwner) {
_allocated = isOwner;
}
int ContextBuffers::deviceId() {
return _deviceId;
}
}

View File

@ -0,0 +1,56 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 30.11.17.
//
#include <execution/LaunchContext.h>
#include <logger.h>
#include <exceptions/cuda_exception.h>
#include <thread>
thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
namespace nd4j {
LaunchContext::~LaunchContext() {
}
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
////////////////////////////////////////////////////////////////////////
LaunchContext::LaunchContext() {
// default constructor, just to make clang/ranlib happy
_workspace = nullptr;
_deviceID = 0;
}
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
}
LaunchContext* LaunchContext::defaultContext() {
// TODO: we need it to be device-aware, but only once we add NUMA support for cpu
if (LaunchContext::_contexts.empty()) {
LaunchContext::_contexts.emplace_back(std::make_shared<LaunchContext>());
}
// return context for current device
return LaunchContext::_contexts[0].get();
}
}

View File

@ -0,0 +1,108 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <logger.h>
#include <execution/AffinityManager.h>
#include <exceptions/cuda_exception.h>
thread_local int globalThreadToDevice = -1;
namespace nd4j {
std::mutex AffinityManager::_currentMutex;
std::mutex AffinityManager::_numberMutex;
int AffinityManager::_numberOfDevices = -1;
int AffinityManager::currentDeviceId() {
// if there's no affinity set - set it now
if (globalThreadToDevice < 0) {
// this block must be thread-local
_currentMutex.lock();
globalThreadToDevice = _lastDevice++;
// we need to check if we've got deviceId >= number of actual devices, and reset to zero otherwise
if (globalThreadToDevice >= numberOfDevices()) {
globalThreadToDevice = 0;
_lastDevice = numberOfDevices() > 1 ? 1 : 0;
}
_currentMutex.unlock();
setCurrentDevice(globalThreadToDevice);
}
// if we already know affinity - just return it
if (globalThreadToDevice >= 0)
return globalThreadToDevice;
int dev = 0;
auto res = cudaGetDevice(&dev);
if (res != 0)
throw cuda_exception::build("cudaGetDevice failed", res);
return dev;
}
int AffinityManager::currentNativeDeviceId() {
int dev = 0;
auto res = cudaGetDevice(&dev);
if (res != 0)
throw cuda_exception::build("cudaGetDevice failed", res);
return dev;
}
int AffinityManager::numberOfDevices() {
_numberMutex.lock();
// we want to cache number of devices
if (_numberOfDevices <= 0) {
int dev = 0;
auto res = cudaGetDeviceCount(&dev);
if (res != 0)
throw cuda_exception::build("cudaGetDeviceCount failed", res);
_numberOfDevices = dev;
}
_numberMutex.unlock();
return _numberOfDevices;
}
void AffinityManager::setCurrentNativeDevice(int deviceId) {
auto res = cudaSetDevice(deviceId);
}
void AffinityManager::setCurrentDevice(int deviceId) {
auto res = cudaSetDevice(deviceId);
if (res != 0)
throw cuda_exception::build("cudaSetDevice failed", res);
// update thread-device affinity
globalThreadToDevice = deviceId;
// TODO: update context buffers?
}
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);
}

View File

@ -0,0 +1,116 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <execution/ContextBuffers.h>
#include <logger.h>
#include <AffinityManager.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <cuda_device_runtime_api.h>
namespace nd4j {
ContextBuffers::ContextBuffers() {
nd4j_printf("Creating ContextBuffers for device [%i]\n", AffinityManager::currentDeviceId());
_deviceId = AffinityManager::currentDeviceId();
}
ContextBuffers::~ContextBuffers() {
if (_allocated) {
nd4j_printf("Releasing ContextBuffers\n","");
if (_allocationPointer != nullptr)
cudaFree(_allocationPointer);
if (_scalarPointer != nullptr)
cudaFree(_scalarPointer);
if (_allocationPointer != nullptr)
cudaFree(_reductionPointer);
}
}
ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) {
_reductionPointer = rPointer;
_scalarPointer = sPointer;
_allocationPointer = aPointer;
_allocated = isOwner;
}
void ContextBuffers::initialize() {
nd4j_printf("Initializing buffers on deviceId [%i]\n", AffinityManager::currentNativeDeviceId());
auto res = cudaMalloc(reinterpret_cast<void**>(&_reductionPointer), 1024 * 1024 * 8);
if (res != 0)
throw std::runtime_error("_reductionPointer allocation failed");
res = cudaMalloc(reinterpret_cast<void**>(&_scalarPointer), 16);
if (res != 0)
throw std::runtime_error("_scalarPointer allocation failed");
res = cudaMalloc(reinterpret_cast<void**>(&_allocationPointer), 1024 * 1024 * 8);
if (res != 0)
throw std::runtime_error("_allocationPointer allocation failed");
_allocated = true;
}
void* ContextBuffers::reductionBuffer() {
if (_reductionPointer == nullptr)
initialize();
return _reductionPointer;
}
void* ContextBuffers::scalarBuffer() {
if (_scalarPointer == nullptr)
initialize();
return _scalarPointer;
}
void* ContextBuffers::allocationBuffer() {
if (_allocationPointer == nullptr)
initialize();
return _allocationPointer;
}
void ContextBuffers::setReductionBuffer(void* pointer) {
_reductionPointer = pointer;
}
void ContextBuffers::setScalarBuffer(void* pointer) {
_scalarPointer = pointer;
}
void ContextBuffers::setAllocationBuffer(void* pointer) {
_allocationPointer = pointer;
}
void ContextBuffers::triggerOwnership(bool isOwner) {
_allocated = isOwner;
}
int ContextBuffers::deviceId() {
return _deviceId;
}
}

View File

@ -0,0 +1,182 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 30.11.17.
//
#include <execution/LaunchContext.h>
#include <logger.h>
#include <exceptions/cuda_exception.h>
#include <helpers/cublasHelper.h>
#include <thread>
#include <execution/AffinityManager.h>
thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
namespace nd4j {
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
std::mutex LaunchContext::_mutex;
////////////////////////////////////////////////////////////////////////
LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) {
_cudaStream = cudaStream;
_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream;
//_reductionPointer = reductionPointer;
//_scalarPointer = scalarPointer;
//_allocationPointer = allocationPointer;
_workspace = nullptr;
_isAllocated = false;
}
LaunchContext::~LaunchContext() {
if (_isAllocated) {
cudaStreamSynchronize(*_cudaStream);
cudaStreamSynchronize(*_cudaSpecialStream);
cudaStreamDestroy(*_cudaStream);
cudaStreamDestroy(*_cudaSpecialStream);
delete _cudaStream;
delete _cudaSpecialStream;
}
}
////////////////////////////////////////////////////////////////////////
LaunchContext::LaunchContext() {
// default constructor, just to make clang/ranlib happy
_workspace = nullptr;
_deviceID = 0;
_isAllocated = true;
_cudaStream = new cudaStream_t();
_cudaSpecialStream = new cudaStream_t();
if (nullptr == _cudaStream || nullptr == _cudaSpecialStream)
throw std::runtime_error("Failed to allocate memory for new CUDA stream");
cudaError_t err = cudaStreamCreate(_cudaStream);
if (err != 0)
throw cuda_exception::build("Failed to create default CUDA stream with launch context", err);
err = cudaStreamCreate(_cudaSpecialStream);
if (err != 0)
throw cuda_exception::build("Failed to create special CUDA stream with launch context", err);
_cublasHandle = CublasHelper::getInstance()->handle();
_cusolverHandle = CublasHelper::getInstance()->solver();
auto res = cudaStreamSynchronize(*_cudaStream);
if (res != 0)
throw cuda_exception::build("Initial sync failed", res);
}
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
_isAllocated = false;
_cudaStream = reinterpret_cast<cudaStream_t*>(cudaStream);
_cudaSpecialStream = reinterpret_cast<cudaStream_t*>(cudaStream);
//_reductionPointer = reductionPointer;
//_scalarPointer = scalarPointer;
//_allocationPointer = reinterpret_cast<int *>(allocationPointer);
}
LaunchContext* LaunchContext::defaultContext() {
/**
* This method returns LaunchContext, that has multiple entities within:
* 1) temporary buffers. they must be per-thread
* 2) CUDA stream. it must be either per-thread or per-device
* 3) cuBLAS handle. it must be per-device
*/
auto deviceId = AffinityManager::currentDeviceId();
// we need this block synchronous, to avoid double initialization etc
_mutex.lock();
if (LaunchContext::_contexts.empty()) {
// create one context per device
auto numDevices = AffinityManager::numberOfDevices();
_contexts.resize(numDevices);
for (int e = 0; e < numDevices; e++) {
AffinityManager::setCurrentDevice(e);
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
}
// don't forget to restore device back again
AffinityManager::setCurrentDevice(deviceId);
}
_mutex.unlock();
// return context for current device
return LaunchContext::_contexts[deviceId].get();
}
void* LaunchContext::getReductionPointer () const {
return contextBuffers.reductionBuffer();
};
void* LaunchContext::getScalarPointer() const {
return contextBuffers.scalarBuffer();
};
int* LaunchContext::getAllocationPointer() const {
return reinterpret_cast<int*>(contextBuffers.allocationBuffer());
};
void* LaunchContext::getCublasHandle() const {
return _cublasHandle;
};
void* LaunchContext::getCusolverHandle() const {
return _cusolverHandle;
};
cudaStream_t* LaunchContext::getCudaStream() const {
return _cudaStream;
};
cudaStream_t* LaunchContext::getCudaSpecialStream() const {
return _cudaSpecialStream;
};
void LaunchContext::setReductionPointer (void* reductionPointer) {
contextBuffers.setReductionBuffer(reductionPointer);
};
void LaunchContext::setScalarPointer(void* scalarPointer) {
contextBuffers.setScalarBuffer(scalarPointer);
};
void LaunchContext::setAllocationPointer(int* allocationPointer) {
contextBuffers.setAllocationBuffer(allocationPointer);
};
void LaunchContext::setCudaStream(cudaStream_t* cudaStream) {
_cudaStream = cudaStream;
};
void LaunchContext::setCudaSpecialStream(cudaStream_t* cudaStream) {
_cudaSpecialStream = cudaStream;
};
void LaunchContext::setCublasHandle(void *handle) {
_cublasHandle = handle;
};
}

View File

@ -1,130 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 30.11.17.
//
#include <execution/LaunchContext.h>
#include <logger.h>
#include <exceptions/cuda_exception.h>
#include <helpers/cublasHelper.h>
namespace nd4j {
#ifdef __CUDABLAS__
////////////////////////////////////////////////////////////////////////
LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) {
_cudaStream = cudaStream;
_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream;
_reductionPointer = reductionPointer;
_scalarPointer = scalarPointer;
_allocationPointer = allocationPointer;
_workspace = nullptr;
_isAllocated = false;
}
#endif
LaunchContext::~LaunchContext() {
#ifdef __CUDABLAS__
if (_isAllocated) {
cudaStreamSynchronize(*_cudaStream);
cudaStreamSynchronize(*_cudaSpecialStream);
cudaStreamDestroy(*_cudaStream);
cudaStreamDestroy(*_cudaSpecialStream);
delete _cudaStream;
delete _cudaSpecialStream;
cudaFree(_reductionPointer);
cudaFree(_allocationPointer);
cudaFree(_scalarPointer);
cublas::destroyHandle(_cublasHandle);
}
#endif
}
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
////////////////////////////////////////////////////////////////////////
LaunchContext::LaunchContext() {
// default constructor, just to make clang/ranlib happy
_workspace = nullptr;
_deviceID = 0;
#ifdef __CUDABLAS__
_isAllocated = true;
_cudaStream = new cudaStream_t();
_cudaSpecialStream = new cudaStream_t();
if (nullptr == _cudaStream || nullptr == _cudaSpecialStream)
throw std::runtime_error("Failed to allocate memory for new CUDA stream");
cudaError_t err = cudaStreamCreate(_cudaStream);
if (err != 0)
throw cuda_exception::build("Failed to create default CUDA stream with launch context", err);
err = cudaStreamCreate(_cudaSpecialStream);
if (err != 0)
throw cuda_exception::build("Failed to create special CUDA stream with launch context", err);
_cublasHandle = cublas::handle();
auto res = cudaStreamSynchronize(*_cudaStream);
if (res != 0)
throw cuda_exception::build("Initial sync failed", res);
res = cudaMalloc(reinterpret_cast<void**>(&_reductionPointer), 1024 * 1024 * 8);
if (res != 0)
throw std::runtime_error("_reductionPointer allocation failed");
res = cudaMalloc(reinterpret_cast<void**>(&_scalarPointer), 8);
if (res != 0)
throw std::runtime_error("_scalarPointer allocation failed");
res = cudaMalloc(reinterpret_cast<void**>(&_allocationPointer), 1024 * 1024 * 8);
if (res != 0)
throw std::runtime_error("_allocationPointer allocation failed");
#else
//
#endif
}
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
#ifdef __CUDABLAS__
_isAllocated = false;
_cudaStream = reinterpret_cast<cudaStream_t*>(cudaStream);
_cudaSpecialStream = reinterpret_cast<cudaStream_t*>(cudaStream);
_reductionPointer = reductionPointer;
_scalarPointer = scalarPointer;
_allocationPointer = reinterpret_cast<int *>(allocationPointer);
#else
// no-op
#endif
}
LaunchContext* LaunchContext::defaultContext() {
// TODO: we need it to be device-aware
if (LaunchContext::_contexts.empty()) {
LaunchContext::_contexts.emplace_back(std::make_shared<LaunchContext>());
}
return LaunchContext::_contexts[0].get();
}
}

View File

@ -21,6 +21,7 @@
#ifndef __CUDABLAS__
#include <ConstantHelper.h>
#include <execution/AffinityManager.h>
#include <types/types.h>
#include <loops/type_conversions.h>
#include <type_boilerplate.h>
@ -59,11 +60,11 @@ namespace nd4j {
}
int ConstantHelper::getCurrentDevice() {
return 0L;
return AffinityManager::currentDeviceId();
}
int ConstantHelper::getNumberOfDevices() {
return 1;
return AffinityManager::numberOfDevices();
}
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {

View File

@ -21,6 +21,7 @@
#include "../MmulHelper.h"
#include <NDArrayFactory.h>
#include <helpers/BlasHelper.h>
#include <exceptions/datatype_exception.h>
namespace nd4j {
@ -147,7 +148,12 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX,
//////////////////////////////////////////////////////////////////////////////
// MXK x KxN = MxN
NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) {
NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) {
if (A->dataType() != B->dataType())
throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), B->dataType());
if (C != nullptr && A->dataType() != C->dataType())
throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), C->dataType());
if(A->rankOf() != 2)
throw std::runtime_error("MmulHelper::mmulMxM: rank of A array is not equal 2 !");
@ -212,7 +218,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, reinterpret_cast<double *>(pA->getBuffer()), lda, reinterpret_cast<double *>(pB->getBuffer()), ldb, (double) beta, reinterpret_cast<double *>(pC->getBuffer()), ldc);
}
else {
BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), NUMERIC_TYPES);
//BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
}
if(pC != C) {
@ -230,6 +237,11 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
////////////////////////////////////////////////////////////////////////////
// MXN x N = M
NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* Y, const double alpha, const double beta, const char outOrder) {
if (X->dataType() != A->dataType())
throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), X->dataType());
if (Y != nullptr && X->dataType() != Y->dataType())
throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), Y->dataType());
int xLenDim, yLenDim(0);
@ -279,7 +291,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->getBuffer(), lda, (float*)X->getBuffer(), incx, (float)beta, (float*)Y->getBuffer(), incy);
}
else {
BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), NUMERIC_TYPES);
//BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
}
if(pA != A)
@ -291,6 +304,11 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
////////////////////////////////////////////////////////////////////////////
// (X * Y) = Z[0]
NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, const double alpha, const double beta) {
if (X->dataType() != Y->dataType())
throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Y->dataType());
if (Z != nullptr && X->dataType() != Z->dataType())
throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Z->dataType());
int xLenDim(0), yLenDim(0);
@ -316,13 +334,14 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
const auto yType = Y->dataType();
const auto zType = Z->dataType();
BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), NUMERIC_TYPES);
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
return Z;
}
BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
}

View File

@ -21,13 +21,41 @@
#include "../cublasHelper.h"
namespace nd4j {
namespace cublas {
void* handle() {
return nullptr;
}
void destroyHandle(void* handle) {
//
}
static void* handle_() {
return nullptr;
}
static void destroyHandle_(void* handle) {
}
CublasHelper::CublasHelper() {
}
CublasHelper::~CublasHelper() {
}
CublasHelper* CublasHelper::getInstance() {
if (!_INSTANCE)
_INSTANCE = new nd4j::CublasHelper();
return _INSTANCE;
}
void* CublasHelper::handle() {
return nullptr;
}
void* CublasHelper::solver() {
return nullptr;
}
void* CublasHelper::handle(int deviceId) {
return nullptr;
}
nd4j::CublasHelper* nd4j::CublasHelper::_INSTANCE = 0;
}

View File

@ -21,12 +21,28 @@
#ifndef DEV_TESTS_CUBLASHELPER_H
#define DEV_TESTS_CUBLASHELPER_H
namespace nd4j {
namespace cublas {
void* handle();
#include <dll.h>
#include <pointercast.h>
#include <vector>
void destroyHandle(void* handle);
}
namespace nd4j {
class CublasHelper {
private:
static CublasHelper *_INSTANCE;
std::vector<void*> _cache;
std::vector<void*> _solvers;
CublasHelper();
~CublasHelper();
public:
static CublasHelper* getInstance();
void* solver();
void* handle();
void* handle(int deviceId);
};
}
#endif //DEV_TESTS_CUBLASHELPER_H

View File

@ -26,6 +26,7 @@
#include <logger.h>
#include <cuda_runtime.h>
#include <cuda.h>
#include <execution/AffinityManager.h>
#define CONSTANT_LIMIT 49152
@ -43,23 +44,11 @@ namespace nd4j {
}
int ConstantHelper::getCurrentDevice() {
int dev = 0;
auto res = cudaGetDevice(&dev);
if (res != 0)
throw cuda_exception::build("cudaGetDevice failed", res);
return dev;
return AffinityManager::currentDeviceId();
}
int ConstantHelper::getNumberOfDevices() {
int dev = 0;
auto res = cudaGetDeviceCount(&dev);
if (res != 0)
throw cuda_exception::build("cudaGetDeviceCount failed", res);
return dev;
return AffinityManager::numberOfDevices();
}

View File

@ -250,8 +250,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
blocksPerGrid.y = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.y); // rows
}
BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
// BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES)
//BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES)
}
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
@ -339,8 +339,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
threadsPerBlock.x = 512;
blocksPerGrid.x = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.x); // rows
}
BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
// BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES)
//BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES)
}
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
@ -397,8 +397,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
NDArray::prepareSpecialUse({Z}, {X, Y});
BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
// BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES)
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES)
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
@ -408,8 +408,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
return Z;
}
BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
}

View File

@ -20,12 +20,15 @@
#include <cublas_v2.h>
#include <cusolverDn.h>
#include "../cublasHelper.h"
#include <exceptions/cuda_exception.h>
#include <helpers/logger.h>
#include <execution/AffinityManager.h>
namespace nd4j {
void* cublas::handle() {
static void* handle_() {
auto _handle = new cublasHandle_t();
auto status = cublasCreate_v2(_handle); // initialize CUBLAS context
if (status != CUBLAS_STATUS_SUCCESS)
@ -34,7 +37,16 @@ namespace nd4j {
return reinterpret_cast<void *>(_handle);
}
void cublas::destroyHandle(void* handle) {
static void* solver_() {
auto cusolverH = new cusolverDnHandle_t();
auto status = cusolverDnCreate(cusolverH);
if (status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("cuSolver handle creation failed !", status);
return cusolverH;
}
static void destroyHandle_(void* handle) {
auto ch = reinterpret_cast<cublasHandle_t *>(handle);
auto status = cublasDestroy_v2(*ch);
if (status != CUBLAS_STATUS_SUCCESS)
@ -42,4 +54,57 @@ namespace nd4j {
delete ch;
}
CublasHelper::CublasHelper() {
auto numDevices = AffinityManager::numberOfDevices();
auto currentDevice = AffinityManager::currentDeviceId();
_cache.resize(numDevices);
_solvers.resize(numDevices);
for (int e = 0; e < numDevices; e++) {
AffinityManager::setCurrentDevice(e);
_cache[e] = handle_();
_solvers[e] = solver_();
}
// don't forget to restore back original device
AffinityManager::setCurrentDevice(currentDevice);
}
CublasHelper::~CublasHelper() {
auto numDevices = AffinityManager::numberOfDevices();
for (int e = 0; e < numDevices; e++)
destroyHandle_(_cache[e]);
}
CublasHelper* CublasHelper::getInstance() {
if (!_INSTANCE)
_INSTANCE = new nd4j::CublasHelper();
return _INSTANCE;
}
void* CublasHelper::handle() {
auto deviceId = AffinityManager::currentDeviceId();
return handle(deviceId);
}
void* CublasHelper::solver() {
auto deviceId = AffinityManager::currentDeviceId();
if (deviceId < 0 || deviceId > _solvers.size())
throw cuda_exception::build("requested deviceId doesn't look valid", deviceId);
return _solvers[deviceId];
}
void* CublasHelper::handle(int deviceId) {
if (deviceId < 0 || deviceId > _cache.size())
throw cuda_exception::build("requested deviceId doesn't look valid", deviceId);
return _cache[deviceId];
}
nd4j::CublasHelper* nd4j::CublasHelper::_INSTANCE = 0;
}

View File

@ -276,23 +276,6 @@ namespace functions {
DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS)
}
// FIXME: eventually we might want to get rid of that
#ifndef __CLION_IDE__
/*
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *y, Nd4jLong *yShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *y, Nd4jLong *yShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *y, Nd4jLong *yShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
*/
#endif
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES);
}

View File

@ -60,9 +60,18 @@ static __global__ void broadcastInverseSimple(
functions::broadcast::Broadcast<X,Y,Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
namespace functions {
namespace broadcast {
static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) {
return shape::getIndexOffset(index, shapeInfo, length);
}
static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) {
return shape::length(shapeInfo);
}
template<typename X, typename Y, typename Z>
template <typename OpClass>
__host__ void Broadcast<X,Y,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
@ -120,9 +129,9 @@ namespace functions {
if (threadIdx.x == 0) {
tadLength = shape::length(tadOnlyShapeInfo);
tadLength = _length(tadOnlyShapeInfo);
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
numTads = shape::length(yShapeInfo) / tadLength;
numTads = _length(yShapeInfo) / tadLength;
xEWS = shape::elementWiseStride(xShapeInfo);
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
}
@ -146,9 +155,9 @@ namespace functions {
else {
// it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, tadLength);
auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
auto xOffset = _getIndexOffset(i, xShapeInfo, tadLength);
auto yOffset = _getIndexOffset(i, tadOnlyShapeInfo, tadLength);
auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]);
}
}
@ -186,9 +195,9 @@ namespace functions {
if (threadIdx.x == 0) {
tadLength = shape::length(tadOnlyShapeInfo);
tadLength = _length(tadOnlyShapeInfo);
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
numTads = shape::length(xShapeInfo) / tadLength;
numTads = _length(xShapeInfo) / tadLength;
yEWS = shape::elementWiseStride(yShapeInfo);
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
}
@ -212,14 +221,15 @@ namespace functions {
// it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, tadLength);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
auto xOffset = _getIndexOffset(i, tadOnlyShapeInfo, tadLength);
auto yOffset = _getIndexOffset(i, yShapeInfo, tadLength);
auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]);
}
}
}
}
/*
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1);

View File

@ -0,0 +1,115 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#include <loops/broadcasting.h>
#include <loops/legacy_ops.h>
#include <types/types.h>
#include <Environment.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <string>
#include <stdexcept>
#include <StringUtils.h>
#include <specials_cuda.h>
namespace functions {
namespace broadcast {
template <typename X, typename Y, typename Z>
void Broadcast<X, Y, Z>::execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
//
}
template <typename X, typename Y, typename Z>
void Broadcast<X, Y, Z>::exec(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
/**
* CPU execution
* @param x the input
* @param xShapeInfo the x shape information
* @param y the y data
* @param yShapeInfo the y shape information
* @param result the result
* @param resultShapeInfo the result shape information
* @param dimension the dimension to broadcast along long
* @param dimensionLength the length of the dimension buffer
*/
template <typename X, typename Y, typename Z>
template<typename OpType>
void Broadcast<X, Y, Z>::exec(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
//
}
template <typename X, typename Y, typename Z>
template<typename OpType>
void Broadcast<X, Y, Z>::execInverse(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
}
}

View File

@ -224,6 +224,77 @@ namespace functions {
}
}
template<typename X, typename Y>
void BroadcastBool<X,Y>::exec(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X, typename Y>
void BroadcastBool<X,Y>::execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X, typename Y>
template<typename OpType>
void BroadcastBool<X,Y>::exec(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X, typename Y>
template<typename OpType>
void BroadcastBool<X,Y>::execInverse(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES);
}
}

View File

@ -361,6 +361,32 @@ namespace functions {
}
}
template <typename T>
Nd4jLong IndexReduce<T>::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) {
return 0;
}
template <typename T>
void IndexReduce<T>::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
}
template <typename T>
template<typename OpType>
Nd4jLong IndexReduce<T>:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) {
return 0;
}
template <typename T>
template<typename OpType>
_CUDA_H void IndexReduce<T>::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES);
}
}

View File

@ -0,0 +1,79 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../pairwise_transform.h"
namespace functions {
namespace pairwise_transforms {
template <typename X, typename Y, typename Z>
void PairWiseTransform<X, Y, Z>::exec(
const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams) {
}
template <typename X, typename Y, typename Z>
void PairWiseTransform<X, Y, Z>::exec(
const int opNum,
void *x,
Nd4jLong xStride,
void *y,
Nd4jLong yStride,
void *z,
Nd4jLong resultStride,
void *extraParams,
Nd4jLong len) {
}
template <typename X, typename Y, typename Z>
template<typename OpType>
void PairWiseTransform<X, Y, Z>:: exec(
void *vx,
Nd4jLong* xShapeInfo,
void *vy,
Nd4jLong* yShapeInfo,
void *vresult,
Nd4jLong* zShapeInfo,
void *vextraParams) {
}
template <typename X, typename Y, typename Z>
template<typename OpType>
void PairWiseTransform<X, Y, Z>::exec(void *vx,
Nd4jLong xStride,
void *vy,
Nd4jLong yStride,
void *vresult,
Nd4jLong resultStride,
void *vextraParams,
const Nd4jLong len) {
}
}
}

View File

@ -110,6 +110,63 @@ void PairWiseBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_
DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS);
}
template<typename X, typename Y>
void PairWiseBoolTransform<X,Y>::exec(
const int opNum,
void *dx,
Nd4jLong *xShapeBuffer,
void *y,
Nd4jLong *yShapeBuffer,
void *result,
Nd4jLong *resultShapeBuffer,
void *extraParams) {
}
template<typename X, typename Y>
void PairWiseBoolTransform<X,Y>::exec(
const int opNum,
void *dx,
Nd4jLong xStride,
void *y,
Nd4jLong yStride,
void *result,
Nd4jLong resultStride,
void *extraParams,
Nd4jLong n) {
}
template<typename X, typename Y>
template<typename OpType>
void PairWiseBoolTransform<X,Y>::exec(
void *vx,
Nd4jLong* xShapeBuffer,
void *vy,
Nd4jLong* yShapeBuffer,
void *vresult,
Nd4jLong* resultShapeBuffer,
void *vextraParams) {
}
template<typename X, typename Y>
template<typename OpType>
void PairWiseBoolTransform<X,Y>::exec(void *vx,
Nd4jLong xStride,
void *vy,
Nd4jLong yStride,
void *vresult,
Nd4jLong resultStride,
void *vextraParams,
const Nd4jLong n) {
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES);
}

View File

@ -442,6 +442,39 @@ namespace functions {
DEBUG_KERNEL(stream, opNum);
}
template<typename T>
template<typename OpClass>
void RandomFunction<T>::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
template<typename T>
template<typename OpClass>
void RandomFunction<T>::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
template<typename T>
template<typename OpClass>
void RandomFunction<T>::execTransform(Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
template<typename T>
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
template<typename T>
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
template<typename T>
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES);
}
}

View File

@ -0,0 +1,82 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#include <loops/reduce3.h>
#include <loops/legacy_ops.h>
#include <types/types.h>
#include <specials_cuda.h>
namespace functions {
namespace reduce3 {
template <typename X, typename Y>
template<typename OpType>
void Reduce3<X,Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo) {
}
template <typename X, typename Y>
void Reduce3<X,Y>::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParamsVals, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo) {
}
template <typename X, typename Y>
template<typename OpType>
void Reduce3<X,Y>::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) {
}
template <typename X, typename Y>
template<typename OpType>
void Reduce3<X,Y>::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
template <typename X, typename Y>
template<typename OpType>
void Reduce3<X,Y>::execAll(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
}
template <typename X, typename Y>
void Reduce3<X,Y>::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) {
}
template <typename X, typename Y>
void Reduce3<X,Y>::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
template <typename X, typename Y>
void Reduce3<X,Y>::execAll(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
}
}
}

View File

@ -0,0 +1,32 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "loops/scalar.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <op_boilerplate.h>
#include <helpers/TAD.h>
#include <types/types.h>
namespace functions {
namespace scalar {
}
}

View File

@ -231,6 +231,41 @@ void ScalarBoolTransform<X,Y>::executeCudaAlongDimension(dim3& launchDims, cudaS
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES);
template<typename X, typename Y>
template <typename OpType>
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
}
template<typename X, typename Y>
void ScalarBoolTransform<X,Y>::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
}
template<typename X, typename Y>
void ScalarBoolTransform<X,Y>::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
}
template<typename X, typename Y>
void ScalarBoolTransform<X,Y>::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
}
template<typename X, typename Y>
template<typename OpType>
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
}
template<typename X, typename Y>
template<typename OpType>
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
}
}
}

View File

@ -21,84 +21,6 @@
#include <ops/specials_cuda.h>
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void bitonicArbitraryStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int half = window>>1;
__shared__ Nd4jLong xLength;
if (threadIdx.x == 0) {
xLength = shape::length(xShapeInfo);
}
__syncthreads();
//for (int i = 0; i < length; i+= window)
/*
if window == 4;
iterations will be: 0; 4; 8; 12; 16; 20
if gridDim = 3;
on first iteration we'll have: 0; 4; 8;
on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20
*/
int firstPosition;
int firstStep;
int secondPosition;
int secondStep;
int WARP_SIZE = 32;
int numWarps = (gridDim.x * blockDim.x) / 32;
int warpId = tid / WARP_SIZE;
int warpIdx = tid % WARP_SIZE;
if (half >= 128) {
firstPosition = blockIdx.x * window;
firstStep = gridDim.x * window;
secondPosition = threadIdx.x;
secondStep = blockDim.x;
} else if (half >= 32) {
firstPosition = warpId * window;
firstStep = numWarps * window;
secondPosition = warpIdx;
secondStep = WARP_SIZE;
} else {
firstPosition = tid * window;
firstStep = blockDim.x * gridDim.x * window;
secondPosition = 0;
secondStep = 1;
}
for (int i = firstPosition; i < length; i += firstStep) {
for (int j = secondPosition; j < half; j += secondStep) {
int it = (reverse) ? i + j + half : i + window - j - 1;
int ij = i+j;
if (it < length && ij < length ) {
int posIT = shape::getIndexOffset(it, yShapeInfo, xLength);
int posIJ = shape::getIndexOffset(ij, yShapeInfo, xLength);
Y v0 = y[posIJ];
Y v1 = y[posIT];
if(!descending == (v0 > v1)) {
y[posIJ] = v1;
y[posIT] = v0;
X xtemp = x[posIJ];
x[posIJ] = x[posIT];
x[posIT] = xtemp;
}
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void bitonicArbitraryStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
@ -264,11 +186,5 @@ __host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *str
bitonicArbitraryStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
}
template <typename X, typename Y>
__host__ void bitonicArbitraryStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
bitonicArbitraryStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
}
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);

View File

@ -21,60 +21,6 @@
#include <ops/specials_cuda.h>
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void bitonicSortStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
unsigned int i, ixj; /* Sorting partners: i and ixj */
i = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ Nd4jLong xLength;
if (threadIdx.x == 0)
xLength = shape::length(xShapeInfo);
__syncthreads();
if (i >= length)
return;
ixj = i^j;
/* The threads with the lowest ids sort the array. */
if ((ixj)>i) {
int posI = shape::getIndexOffset(i, yShapeInfo, xLength);
int posIXJ = shape::getIndexOffset(ixj, yShapeInfo, xLength);
if ((i&k)==0) {
/* Sort ascending */
if (!descending == (y[posI]>y[posIXJ])) {
/* exchange(i,ixj); */
X temp = x[posI];
x[posI] = x[posIXJ];
x[posIXJ] = temp;
Y ytemp = y[posI];
y[posI] = y[posIXJ];
y[posIXJ] = ytemp;
}
} else if ((i&k)!=0) {
/* Sort descending */
if (!descending == (y[posI]<y[posIXJ])) {
/* exchange(i,ixj); */
X temp = x[posI];
x[posI] = x[posIXJ];
x[posIXJ] = temp;
Y ytemp = y[posI];
y[posI] = y[posIXJ];
y[posIXJ] = ytemp;
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
@ -189,13 +135,6 @@ __host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream,
bitonicSortStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
}
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__host__ void bitonicSortStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
bitonicSortStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
}
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);

View File

@ -62,9 +62,9 @@ namespace nd4j {
}
}
}
BUILD_DOUBLE_TEMPLATE(template __global__ void repeatKernelDouble, (void const* inputBuffer, void* outputBuffer,
BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void repeatKernelDouble, (void const* inputBuffer, void* outputBuffer,
Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets), LIBND4J_TYPES, LIBND4J_TYPES);
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets), LIBND4J_TYPES);
template <typename T>
void repeatKernelH(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong outputLength,
@ -88,10 +88,10 @@ namespace nd4j {
dim3 launchDims(256, 512, 8192);
repeatKernelDouble<X,Y><<<launchDims.x, launchDims.y, launchDims.z, stream>>>(inputBuffer, outputBuffer, numTads, inputLength, tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets);
}
BUILD_DOUBLE_TEMPLATE(template void repeatKernelHH, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength,
BUILD_SINGLE_TEMPLATE_TWICE(template void repeatKernelHH, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength,
Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets,
cudaStream_t stream), LIBND4J_TYPES, LIBND4J_TYPES);
cudaStream_t stream), LIBND4J_TYPES);
}

View File

@ -21,6 +21,17 @@
#include <loops/special_kernels.h>
namespace nd4j {
static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) {
return shape::getIndexOffset(index, shapeInfo, length);
}
static Nd4jLong __device__ __noinline__ _subArrayOffset(Nd4jLong index, Nd4jLong *shapeInfoA, Nd4jLong *shapeInfoB) {
return shape::subArrayOffset(index, shapeInfoA, shapeInfoB);
}
static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) {
return shape::length(shapeInfo);
}
////////////////////////////////////////////////////////////////////////
template<typename T>
@ -34,31 +45,20 @@ namespace nd4j {
//const auto resultLength = shape::length(outputShape);
if (shape::order(outputShape) == 'c') { // ews == 1 always here
for (int i = tid; i < resultLength; i += totalThreads) {
auto yOffset = shape::subArrayOffset(i, outputShape, inputShape);
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<T *>(outputBuffer) + i) = *(reinterpret_cast<T const *>(inputBuffer) + yOffset);
}
// for(Nd4jLong i=0; i<resultLen; ++i) {
// auto yOffset = shape::subArrayOffset(newShapeInfo, _shapeInfo, i);
// BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, (newBuff, i, this->_buffer, yOffset), LIBND4J_TYPES);
//
// }
} else {
//
//auto inputLength = shape::lenght(inputShape);
for (int i = tid; i < resultLength; i += totalThreads) {
auto xOffset = shape::getIndexOffset(i, outputShape, resultLength);
auto yOffset = shape::subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<T *>(outputBuffer) + xOffset) = *(reinterpret_cast<T const *>(inputBuffer) +
yOffset);
// BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, (newBuff, xOffset, this->_buffer, yOffset), LIBND4J_TYPES);
auto xOffset = _getIndexOffset(i, outputShape, resultLength);
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<T *>(outputBuffer) + xOffset) = *(reinterpret_cast<T const *>(inputBuffer) + yOffset);
}
}
}
BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel,
(void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength),
LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel,(void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength), LIBND4J_TYPES);
template<typename T>
void tileKernelH(void const *inputBuffer, Nd4jLong *inputShape, void *outputBuffer, Nd4jLong *outputShape, Nd4jLong resultLength, cudaStream_t *stream) {
@ -77,29 +77,26 @@ namespace nd4j {
if (ordering == 'c' && ews == 1) { // ews == 1 always here
for (int i = tid; i < resultLength; i += totalThreads) {
auto yOffset = shape::subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<X *>(outputBuffer) + i) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) +
yOffset));
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<X *>(outputBuffer) + i) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) + yOffset));
}
} else if (ordering == 'c' && ews > 1) {
for (int i = tid; i < resultLength; i += totalThreads) {
auto yOffset = shape::subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<X *>(outputBuffer) + i * ews) = static_cast<X>(*(
reinterpret_cast<Y const *>(inputBuffer) + yOffset));
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<X *>(outputBuffer) + i * ews) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) + yOffset));
}
} else {
for (int i = tid; i < resultLength; i += totalThreads) {
auto xOffset = shape::getIndexOffset(i, outputShape, resultLength);
auto yOffset = shape::subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<X *>(outputBuffer) + xOffset) = static_cast<X>(*(
reinterpret_cast<Y const *>(inputBuffer) + yOffset));
auto xOffset = _getIndexOffset(i, outputShape, resultLength);
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<X *>(outputBuffer) + xOffset) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) + yOffset));
}
}
}
BUILD_DOUBLE_TEMPLATE(template __global__ void tileKernelDouble, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void tileKernelDouble, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews), LIBND4J_TYPES);
template<typename X, typename Y>
void tileKernelHH(void const *inputBuffer, Nd4jLong *inputShape, void *outputBuffer, Nd4jLong *outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream) {
@ -107,5 +104,5 @@ namespace nd4j {
tileKernelDouble<X, Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuffer, inputShape, outputBuffer, outputShape, resultLength, ews);
}
BUILD_DOUBLE_TEMPLATE(template void tileKernelHH, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream),LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE_TWICE(template void tileKernelHH, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream),LIBND4J_TYPES);
}

View File

@ -413,6 +413,74 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa
DEBUG_KERNEL(stream, opNum);
}
template <typename X, typename Y>
Y SummaryStatsReduce<X,Y>::execScalar(int opNum,
bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams) {
return 0;
}
template <typename X, typename Y>
void SummaryStatsReduce<X,Y>::execScalar(int opNum,
bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vz,
Nd4jLong *resultShapeInfoBuffer) {
}
template <typename X, typename Y>
void SummaryStatsReduce<X,Y>::exec(int opNum,
bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vz,
Nd4jLong *resultShapeInfoBuffer,
int *dimension, int dimensionLength) {
}
template <typename X, typename Y>
template<typename OpType>
Y SummaryStatsReduce<X,Y>::execScalar(bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams) {
return 0;
}
template <typename X, typename Y>
template<typename OpType>
void SummaryStatsReduce<X,Y>::execScalar(bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vz,
Nd4jLong *resultShapeInfoBuffer) {
//
}
template <typename X, typename Y>
template<typename OpType>
void SummaryStatsReduce<X,Y>::exec(bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vz,
Nd4jLong *resultShapeInfoBuffer,
int *dimension,
int dimensionLength) {
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES);
}
}

View File

@ -114,6 +114,17 @@ namespace functions {
nd4j::DebugHelper::checkErrorCode(stream, "transformAny(...) failed");
}
template<typename X, typename Z>
void TransformAny<X,Z>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
}
template<typename X, typename Z>
template <typename OpType>
void TransformAny<X,Z>::exec(void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES);
}
}

View File

@ -120,6 +120,17 @@ namespace functions {
nd4j::DebugHelper::checkErrorCode(stream, "transformBool(...) failed");
}
template<typename X, typename Z>
void TransformBool<X,Z>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
template<typename X, typename Z>
template <typename OpType>
void TransformBool<X,Z>::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES);
}
}

View File

@ -142,6 +142,17 @@ namespace functions {
nd4j::DebugHelper::checkErrorCode(stream, "transformFloat(...) failed");
}
template<typename X, typename Z>
void TransformFloat<X,Z>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
template<typename X, typename Z>
template <typename OpType>
void TransformFloat<X,Z>::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES);
}

View File

@ -118,6 +118,17 @@ namespace functions {
nd4j::DebugHelper::checkErrorCode(stream, "transformSame(...) failed");
}
template<typename X>
void TransformSame<X>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
template<typename X>
template <typename OpType>
void TransformSame<X>::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES);
}
}

View File

@ -119,6 +119,17 @@ namespace functions {
nd4j::DebugHelper::checkErrorCode(stream, "transformStrict(...) failed");
}
template<typename X>
void TransformStrict<X>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
template<typename X>
template <typename OpType>
void TransformStrict<X>::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES);
}
}

View File

@ -209,15 +209,6 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
}
};
_CUDA_H Nd4jLong TypeCast::estimateQuantizedSize(Nd4jLong rawSize) {
if (rawSize <= 0)
throw std::runtime_error("Input size for quantization can't be <= 0");
// 2 fp32 values for max/min, and rawSize number of BYTES
return 8 + rawSize;
}
template void TypeCast::convertFromThreshold<float>(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);
template void TypeCast::convertFromThreshold<float16>(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);
template void TypeCast::convertFromThreshold<double>(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);

View File

@ -69,7 +69,14 @@ namespace nd4j {
template <typename T>
static _CUDA_H void convertFromThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);
static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize);
FORCEINLINE static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize) {
if (rawSize <= 0)
throw std::runtime_error("Input size for quantization can't be <= 0");
// 2 fp32 values for max/min, and rawSize number of BYTES
return 8 + rawSize;
}
template <typename T>
static _CUDA_H void convertToQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz);

View File

@ -85,7 +85,7 @@ namespace nd4j {
COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG);
auto shapeList = SHAPELIST(shapeE, shapeG);
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
return shapeList;
}

View File

@ -86,7 +86,7 @@ namespace nd4j {
COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG);
auto shapeList = SHAPELIST(shapeE, shapeG);
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
return shapeList;
}

View File

@ -107,7 +107,7 @@ namespace nd4j {
COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG);
auto shapeList = SHAPELIST(shapeE, shapeG);
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
return shapeList;
}

View File

@ -114,7 +114,7 @@ namespace nd4j {
COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG);
auto shapeList = SHAPELIST(shapeE, shapeG);
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
return shapeList;
}

View File

@ -112,7 +112,8 @@ namespace nd4j {
COPY_SHAPE(input, epsShape);
COPY_SHAPE(bias, gradShape);
return SHAPELIST(epsShape, gradShape);
return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape));
}
}
}

View File

@ -75,7 +75,7 @@ namespace nd4j {
DECLARE_TYPES(non_max_suppression) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_INTS});
->setAllowedOutputTypes({ALL_INDICES});
}
}

View File

@ -253,7 +253,7 @@ DECLARE_SHAPE_FN(gruCell_bp) {
Nd4jLong *dLdbcShapeInfo = nullptr;
COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo);
return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWShapeInfo, dLdWcShapeInfo, dLdbShapeInfo, dLdbcShapeInfo);
return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdhiShapeInfo), CONSTANT(dLdWShapeInfo), CONSTANT(dLdWcShapeInfo), CONSTANT(dLdbShapeInfo), CONSTANT(dLdbcShapeInfo));
}

View File

@ -101,7 +101,7 @@ namespace ops {
Nd4jLong *out;
COPY_SHAPE(in, out);
return SHAPELIST(out);
return SHAPELIST(CONSTANT(out));
}
}

View File

@ -87,8 +87,7 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr
void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) {
BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), FLOAT_TYPES);
}
/*

View File

@ -89,7 +89,7 @@ static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarA
void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) {
BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), FLOAT_TYPES);
}
/*

View File

@ -119,11 +119,9 @@ void col2im_(nd4j::LaunchContext & context, const NDArray& input, NDArray& outp
void col2im(nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) {
BUILD_SINGLE_SELECTOR(input.dataType(), col2im_, (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input.dataType(), col2im_, (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void col2im_, (nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW), LIBND4J_TYPES);
}
}
}

View File

@ -2445,71 +2445,52 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
}
void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
}
void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
}
void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
}
void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
}
void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);
}
void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);
}
void ConvolutionUtils::upsampling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) {
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), FLOAT_TYPES);
}
void ConvolutionUtils::upsampling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) {
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), FLOAT_TYPES);
}
void ConvolutionUtils::upsampling2dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) {
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES);
}
void ConvolutionUtils::upsampling3dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) {
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES);
}
void ConvolutionUtils::pooling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) {
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
}
void ConvolutionUtils::pooling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
}
void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
}
void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template void conv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_DOUBLE_TEMPLATE(template void conv2dBP_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2d_, (const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2dBP_, (const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_DOUBLE_TEMPLATE(template void sconv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template void upsampling2d_, (const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void upsampling3d_, (const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void upsampling2dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void upsampling3dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void vol2col_, (const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void col2vol_, (const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void pooling2d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void pooling3d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void pooling2dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void pooling3dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
}
}

View File

@ -81,10 +81,8 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const
}
}
BUILD_DOUBLE_TEMPLATE(template void dilation2d_, (NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES, FLOAT_TYPES);
void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), FLOAT_TYPES);
}

View File

@ -76,7 +76,7 @@ namespace nd4j {
double min_val = input.reduceNumber(reduce::SameOps::Min).e<double>(0);
double max_val = input.reduceNumber(reduce::SameOps::Max).e<double>(0);
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES);
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INDEXING_TYPES);
}
}
}

Some files were not shown because too many files have changed in this diff Show More