Merge remote-tracking branch 'fork/master'
commit
37e053ad90
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.spark.transform.client;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import com.mashape.unirest.http.exceptions.UnirestException;
|
||||
|
|
|
@ -51,11 +51,13 @@
|
|||
<artifactId>datavec-spark-inference-model</artifactId>
|
||||
<version>${datavec.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-spark_2.11</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data-image</artifactId>
|
||||
|
@ -67,61 +69,73 @@
|
|||
<artifactId>akka-cluster_2.11</artifactId>
|
||||
<version>${akka.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>joda-time</groupId>
|
||||
<artifactId>joda-time</artifactId>
|
||||
<version>${jodatime.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>${commons-lang3.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.hibernate</groupId>
|
||||
<artifactId>hibernate-validator</artifactId>
|
||||
<version>${hibernate.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-reflect</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.yaml</groupId>
|
||||
<artifactId>snakeyaml</artifactId>
|
||||
<version>${snakeyaml.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-core</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-annotations</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-jdk8</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-jsr310</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-java_2.11</artifactId>
|
||||
|
@ -137,39 +151,44 @@
|
|||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>net.jodah</groupId>
|
||||
<artifactId>typetools</artifactId>
|
||||
<version>${jodah.typetools.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-json_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-server_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-netty-server_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
<version>${unirest.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.beust</groupId>
|
||||
<artifactId>jcommander</artifactId>
|
||||
|
|
|
@ -52,6 +52,7 @@ public class CSVSparkTransformServerNoJsonTest {
|
|||
public static void before() throws Exception {
|
||||
server = new CSVSparkTransformServer();
|
||||
FileUtils.write(fileSave, transformProcess.toJson());
|
||||
|
||||
// Only one time
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
|
@ -73,6 +74,7 @@ public class CSVSparkTransformServerNoJsonTest {
|
|||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.runMain(new String[] {"-dp", "9050"});
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.JsonNode;
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
|
@ -49,6 +50,7 @@ public class CSVSparkTransformServerTest {
|
|||
server = new CSVSparkTransformServer();
|
||||
FileUtils.write(fileSave, transformProcess.toJson());
|
||||
// Only one time
|
||||
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
||||
|
@ -69,6 +71,7 @@ public class CSVSparkTransformServerTest {
|
|||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"});
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.JsonNode;
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.JsonNode;
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
|
|
|
@ -19,10 +19,10 @@ package org.deeplearning4j.nearestneighbor.client;
|
|||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import com.mashape.unirest.request.HttpRequest;
|
||||
import com.mashape.unirest.request.HttpRequestWithBody;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nearestneighbor.model.*;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
|
@ -51,6 +51,7 @@ public class NearestNeighborsClient {
|
|||
|
||||
static {
|
||||
// Only one time
|
||||
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
||||
|
@ -89,7 +90,7 @@ public class NearestNeighborsClient {
|
|||
NearestNeighborRequest request = new NearestNeighborRequest();
|
||||
request.setInputIndex(index);
|
||||
request.setK(k);
|
||||
HttpRequestWithBody req = Unirest.post(url + "/knn");
|
||||
val req = Unirest.post(url + "/knn");
|
||||
req.header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(request);
|
||||
addAuthHeader(req);
|
||||
|
@ -112,7 +113,7 @@ public class NearestNeighborsClient {
|
|||
Base64NDArrayBody base64NDArrayBody =
|
||||
Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build();
|
||||
|
||||
HttpRequestWithBody req = Unirest.post(url + "/knnnew");
|
||||
val req = Unirest.post(url + "/knnnew");
|
||||
req.header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(base64NDArrayBody);
|
||||
addAuthHeader(req);
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.deeplearning4j.models.word2vec;
|
|||
import com.google.gson.JsonArray;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.google.gson.JsonParser;
|
||||
import jdk.nashorn.internal.objects.annotations.Property;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.compress.compressors.gzip.GzipUtils;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.deeplearning4j.nn.adapters;
|
||||
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.OutputAdapter;
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.deeplearning4j.nn.adapters;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.OutputAdapter;
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.nn.api;
|
||||
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
|
|
|
@ -24,6 +24,7 @@ import lombok.val;
|
|||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||
import org.deeplearning4j.exception.DL4JException;
|
||||
|
|
|
@ -25,6 +25,7 @@ import lombok.val;
|
|||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.linalg.dataset.AsyncDataSetIterator;;
|
||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||
import org.deeplearning4j.eval.RegressionEvaluation;
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<parent>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-remote</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>deeplearning4j-json-server</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<name>deeplearning4j-json-server</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<version>${lombok.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-json-client</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-json-server</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-parallel-wrapper</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>${slf4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-core</artifactId>
|
||||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<activation>
|
||||
<activeByDefault>true</activeByDefault>
|
||||
</activation>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
<activation>
|
||||
<activeByDefault>false</activeByDefault>
|
||||
</activation>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -0,0 +1,205 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.parallelism.ParallelInference;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.remote.serving.SameDiffServlet;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author astoyakin
|
||||
*/
|
||||
@Slf4j
|
||||
@NoArgsConstructor
|
||||
public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
||||
|
||||
protected ParallelInference parallelInference;
|
||||
protected Model model;
|
||||
protected boolean parallelEnabled = true;
|
||||
|
||||
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
|
||||
super(inferenceAdapter, serializer, deserializer);
|
||||
this.parallelInference = parallelInference;
|
||||
this.model = null;
|
||||
this.parallelEnabled = true;
|
||||
}
|
||||
|
||||
public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
|
||||
super(inferenceAdapter, serializer, deserializer);
|
||||
this.model = model;
|
||||
this.parallelInference = null;
|
||||
this.parallelEnabled = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
|
||||
String processorReturned = "";
|
||||
String path = request.getPathInfo();
|
||||
if (path.equals(SERVING_ENDPOINT)) {
|
||||
val contentType = request.getContentType();
|
||||
if (validateRequest(request,response)) {
|
||||
val stream = request.getInputStream();
|
||||
val bufferedReader = new BufferedReader(new InputStreamReader(stream));
|
||||
char[] charBuffer = new char[128];
|
||||
int bytesRead = -1;
|
||||
val buffer = new StringBuilder();
|
||||
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
|
||||
buffer.append(charBuffer, 0, bytesRead);
|
||||
}
|
||||
val requestString = buffer.toString();
|
||||
|
||||
val mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
|
||||
|
||||
O result = null;
|
||||
if (parallelEnabled) {
|
||||
// process result
|
||||
result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
||||
}
|
||||
else {
|
||||
synchronized(this) {
|
||||
if (model instanceof ComputationGraph)
|
||||
result = inferenceAdapter.apply(((ComputationGraph)model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
||||
else if (model instanceof MultiLayerNetwork) {
|
||||
Preconditions.checkArgument(mds.getFeatures().length > 1 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 1),
|
||||
"Input data for MultilayerNetwork is invalid!");
|
||||
result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
|
||||
mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
|
||||
}
|
||||
}
|
||||
}
|
||||
processorReturned = serializer.serialize(result);
|
||||
}
|
||||
} else {
|
||||
// we return error otherwise
|
||||
sendError(request.getRequestURI(), response);
|
||||
}
|
||||
try {
|
||||
val out = response.getWriter();
|
||||
out.write(processorReturned);
|
||||
} catch (IOException e) {
|
||||
log.error(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates servlet to serve models
|
||||
*
|
||||
* @param <I> type of Input class
|
||||
* @param <O> type of Output class
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
* @author astoyakin
|
||||
*/
|
||||
public static class Builder<I,O> {
|
||||
|
||||
private ParallelInference pi;
|
||||
private Model model;
|
||||
|
||||
private InferenceAdapter<I, O> inferenceAdapter;
|
||||
private JsonSerializer<O> serializer;
|
||||
private JsonDeserializer<I> deserializer;
|
||||
private int port;
|
||||
private boolean parallelEnabled = true;
|
||||
|
||||
public Builder(@NonNull ParallelInference pi) {
|
||||
this.pi = pi;
|
||||
}
|
||||
|
||||
public Builder(@NonNull Model model) {
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public Builder<I,O> inferenceAdapter(@NonNull InferenceAdapter<I,O> inferenceAdapter) {
|
||||
this.inferenceAdapter = inferenceAdapter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is required to specify serializer
|
||||
*
|
||||
* @param serializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> serializer(@NonNull JsonSerializer<O> serializer) {
|
||||
this.serializer = serializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows to specify deserializer
|
||||
*
|
||||
* @param deserializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> deserializer(@NonNull JsonDeserializer<I> deserializer) {
|
||||
this.deserializer = deserializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows to specify port
|
||||
*
|
||||
* @param port
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> port(int port) {
|
||||
this.port = port;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method activates parallel inference
|
||||
*
|
||||
* @param parallelEnabled
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> parallelEnabled(boolean parallelEnabled) {
|
||||
this.parallelEnabled = parallelEnabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DL4jServlet<I,O> build() {
|
||||
return parallelEnabled ? new DL4jServlet<I, O>(pi, inferenceAdapter, serializer, deserializer) :
|
||||
new DL4jServlet<I, O>(model, inferenceAdapter, serializer, deserializer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,392 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.api.ModelAdapter;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.parallelism.ParallelInference;
|
||||
import org.deeplearning4j.parallelism.inference.InferenceMode;
|
||||
import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.adapters.InputAdapter;
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.remote.SameDiffJsonModelServer;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* This class provides JSON-based model serving ability for Deeplearning4j/SameDiff models
|
||||
*
|
||||
* Server url will be http://0.0.0.0:{port}>/v1/serving
|
||||
* Server only accepts POST requests
|
||||
*
|
||||
* @param <I> type of the input class, i.e. String
|
||||
* @param <O> type of the output class, i.e. Sentiment
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
* @author astoyakin
|
||||
*/
|
||||
public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
||||
|
||||
// all serving goes through ParallelInference
|
||||
protected ParallelInference parallelInference;
|
||||
|
||||
|
||||
protected ModelAdapter<O> modelAdapter;
|
||||
|
||||
// actual models
|
||||
protected ComputationGraph cgModel;
|
||||
protected MultiLayerNetwork mlnModel;
|
||||
|
||||
// service stuff
|
||||
protected InferenceMode inferenceMode;
|
||||
protected int numWorkers;
|
||||
|
||||
protected boolean enabledParallel = true;
|
||||
|
||||
protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, String[] orderedOutputNodes) {
|
||||
super(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||
}
|
||||
|
||||
protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
||||
super(inferenceAdapter, serializer, deserializer, port);
|
||||
|
||||
this.cgModel = cgModel;
|
||||
this.inferenceMode = inferenceMode;
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
||||
super(inferenceAdapter, serializer, deserializer, port);
|
||||
|
||||
this.mlnModel = mlnModel;
|
||||
this.inferenceMode = inferenceMode;
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port) {
|
||||
super(inferenceAdapter, serializer, deserializer, port);
|
||||
|
||||
this.parallelInference = pi;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method stops server
|
||||
*
|
||||
* @throws Exception
|
||||
*/
|
||||
@Override
|
||||
public void stop() throws Exception {
|
||||
if (parallelInference != null)
|
||||
parallelInference.shutdown();
|
||||
super.stop();
|
||||
}
|
||||
|
||||
/**
|
||||
* This method starts server
|
||||
* @throws Exception
|
||||
*/
|
||||
@Override
|
||||
public void start() throws Exception {
|
||||
// if we're just serving sdModel - we'll just call super. no dl4j functionality required in this case
|
||||
if (sdModel != null) {
|
||||
super.start();
|
||||
return;
|
||||
}
|
||||
Preconditions.checkArgument(cgModel != null || mlnModel != null, "Model serving requires either MultilayerNetwork or ComputationGraph defined");
|
||||
|
||||
val model = cgModel != null ? (Model) cgModel : (Model) mlnModel;
|
||||
// PI construction is optional, since we can have it defined
|
||||
if (enabledParallel) {
|
||||
if (parallelInference == null) {
|
||||
Preconditions.checkArgument(numWorkers >= 1, "Number of workers should be >= 1, got " + numWorkers + " instead");
|
||||
|
||||
parallelInference = new ParallelInference.Builder(model)
|
||||
.inferenceMode(inferenceMode)
|
||||
.workers(numWorkers)
|
||||
.loadBalanceMode(LoadBalanceMode.FIFO)
|
||||
.batchLimit(16)
|
||||
.queueLimit(128)
|
||||
.build();
|
||||
}
|
||||
servingServlet = new DL4jServlet.Builder<I, O>(parallelInference)
|
||||
.parallelEnabled(true)
|
||||
.serializer(serializer)
|
||||
.deserializer(deserializer)
|
||||
.inferenceAdapter(inferenceAdapter)
|
||||
.build();
|
||||
}
|
||||
else {
|
||||
servingServlet = new DL4jServlet.Builder<I, O>(model)
|
||||
.parallelEnabled(false)
|
||||
.serializer(serializer)
|
||||
.deserializer(deserializer)
|
||||
.inferenceAdapter(inferenceAdapter)
|
||||
.build();
|
||||
}
|
||||
start(port, servingServlet);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates servlet to serve different types of models
|
||||
*
|
||||
* @param <I> type of Input class
|
||||
* @param <O> type of Output class
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
* @author astoyakin
|
||||
*/
|
||||
public static class Builder<I,O> {
|
||||
|
||||
private SameDiff sdModel;
|
||||
private ComputationGraph cgModel;
|
||||
private MultiLayerNetwork mlnModel;
|
||||
private ParallelInference pi;
|
||||
|
||||
private String[] orderedInputNodes;
|
||||
private String[] orderedOutputNodes;
|
||||
|
||||
private InferenceAdapter<I, O> inferenceAdapter;
|
||||
private JsonSerializer<O> serializer;
|
||||
private JsonDeserializer<I> deserializer;
|
||||
|
||||
private InputAdapter<I> inputAdapter;
|
||||
private OutputAdapter<O> outputAdapter;
|
||||
|
||||
private int port;
|
||||
|
||||
private boolean parallelMode = true;
|
||||
|
||||
// these fields actually require defaults
|
||||
private InferenceMode inferenceMode = InferenceMode.BATCHED;
|
||||
private int numWorkers = Nd4j.getAffinityManager().getNumberOfDevices();
|
||||
|
||||
public Builder(@NonNull SameDiff sdModel) {
|
||||
this.sdModel = sdModel;
|
||||
}
|
||||
|
||||
public Builder(@NonNull MultiLayerNetwork mlnModel) {
|
||||
this.mlnModel = mlnModel;
|
||||
}
|
||||
|
||||
public Builder(@NonNull ComputationGraph cgModel) {
|
||||
this.cgModel = cgModel;
|
||||
}
|
||||
|
||||
public Builder(@NonNull ParallelInference pi) {
|
||||
this.pi = pi;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method defines InferenceAdapter implementation, which will be used to convert object of Input type to the set of INDArray(s), and for conversion of resulting INDArray(s) into object of Output type
|
||||
* @param inferenceAdapter
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inferenceAdapter(@NonNull InferenceAdapter<I,O> inferenceAdapter) {
|
||||
this.inferenceAdapter = inferenceAdapter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify InputAdapter to be used for inference
|
||||
*
|
||||
* PLEASE NOTE: This method is optional, and will require OutputAdapter<O> defined
|
||||
* @param inputAdapter
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inputAdapter(@NonNull InputAdapter<I> inputAdapter) {
|
||||
this.inputAdapter = inputAdapter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify OutputtAdapter to be used for inference
|
||||
*
|
||||
* PLEASE NOTE: This method is optional, and will require InputAdapter<I> defined
|
||||
* @param outputAdapter
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> outputAdapter(@NonNull OutputAdapter<O> outputAdapter) {
|
||||
this.outputAdapter = outputAdapter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify serializer
|
||||
*
|
||||
* @param serializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> outputSerializer(@NonNull JsonSerializer<O> serializer) {
|
||||
this.serializer = serializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify deserializer
|
||||
*
|
||||
* @param deserializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inputDeserializer(@NonNull JsonDeserializer<I> deserializer) {
|
||||
this.deserializer = deserializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details
|
||||
*
|
||||
* @param inferenceMode
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inferenceMode(@NonNull InferenceMode inferenceMode) {
|
||||
this.inferenceMode = inferenceMode;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify number of worker threads for ParallelInference
|
||||
*
|
||||
* @param numWorkers
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> numWorkers(int numWorkers) {
|
||||
this.numWorkers = numWorkers;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models
|
||||
*
|
||||
* PLEASE NOTE: this argument only used for SameDiff models
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedInputNodes(String... args) {
|
||||
orderedInputNodes = args;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models
|
||||
*
|
||||
* PLEASE NOTE: this argument only used for SameDiff models
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedInputNodes(@NonNull List<String> args) {
|
||||
orderedInputNodes = args.toArray(new String[args.size()]);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify output nodes
|
||||
*
|
||||
* PLEASE NOTE: this argument only used for SameDiff models
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedOutputNodes(String... args) {
|
||||
Preconditions.checkArgument(args != null && args.length > 0, "OutputNodes should contain at least 1 element");
|
||||
orderedOutputNodes = args;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify output nodes
|
||||
*
|
||||
* PLEASE NOTE: this argument only used for SameDiff models
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedOutputNodes(@NonNull List<String> args) {
|
||||
Preconditions.checkArgument(args.size() > 0, "OutputNodes should contain at least 1 element");
|
||||
orderedOutputNodes = args.toArray(new String[args.size()]);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify http port
|
||||
*
|
||||
* PLEASE NOTE: port must be free and be in range regular TCP/IP ports range
|
||||
* @param port
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> port(int port) {
|
||||
this.port = port;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method switches on ParallelInference usage
|
||||
* @param - true - to use ParallelInference, false - to use ComputationGraph or
|
||||
* MultiLayerNetwork directly
|
||||
*
|
||||
* PLEASE NOTE: this doesn't apply to SameDiff models
|
||||
*
|
||||
* @throws Exception
|
||||
*/
|
||||
public Builder<I,O> parallelMode(boolean enable) {
|
||||
this.parallelMode = enable;
|
||||
return this;
|
||||
}
|
||||
|
||||
public JsonModelServer<I,O> build() {
|
||||
if (inferenceAdapter == null) {
|
||||
if (inputAdapter != null && outputAdapter != null) {
|
||||
inferenceAdapter = new InferenceAdapter<I, O>() {
|
||||
@Override
|
||||
public MultiDataSet apply(I input) {
|
||||
return inputAdapter.apply(input);
|
||||
}
|
||||
|
||||
@Override
|
||||
public O apply(INDArray... outputs) {
|
||||
return outputAdapter.apply(outputs);
|
||||
}
|
||||
};
|
||||
} else
|
||||
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
|
||||
}
|
||||
|
||||
if (sdModel != null) {
|
||||
Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "For SameDiff model serving OutputNodes should be defined");
|
||||
return new JsonModelServer<I, O>(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||
} else if (cgModel != null)
|
||||
return new JsonModelServer<I,O>(cgModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
|
||||
else if (mlnModel != null)
|
||||
return new JsonModelServer<I,O>(mlnModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
|
||||
else if (pi != null)
|
||||
return new JsonModelServer<I,O>(pi, inferenceAdapter, serializer, deserializer, port);
|
||||
else
|
||||
throw new IllegalStateException("No models were defined for JsonModelServer");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,749 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.parallelism.inference.InferenceMode;
|
||||
import org.deeplearning4j.remote.helpers.House;
|
||||
import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter;
|
||||
import org.deeplearning4j.remote.helpers.PredictedPrice;
|
||||
import org.junit.After;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
import org.nd4j.remote.clients.JsonRemoteInference;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE;
|
||||
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@Slf4j
|
||||
public class JsonModelServerTest {
|
||||
private static final MultiLayerNetwork model;
|
||||
private final int PORT = 18080;
|
||||
|
||||
static {
|
||||
val conf = new NeuralNetConfiguration.Builder()
|
||||
.seed(119)
|
||||
.updater(new Adam(0.119f))
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.list()
|
||||
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(10).build())
|
||||
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.SIGMOID).nIn(10).nOut(1).build())
|
||||
.build();
|
||||
|
||||
model = new MultiLayerNetwork(conf);
|
||||
model.init();
|
||||
}
|
||||
|
||||
@After
|
||||
public void pause() throws Exception {
|
||||
// TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
|
||||
TimeUnit.SECONDS.sleep(2);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testStartStopParallel() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val serverDL = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.numWorkers(1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
val serverSD = new JsonModelServer.Builder<House, PredictedPrice>(sd)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.orderedInputNodes(new String[]{"input"})
|
||||
.orderedOutputNodes(new String[]{"total"})
|
||||
.port(PORT+1)
|
||||
.build();
|
||||
try {
|
||||
serverDL.start();
|
||||
serverSD.start();
|
||||
|
||||
val clientDL = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
PredictedPrice price = clientDL.predict(house);
|
||||
long timeStart = System.currentTimeMillis();
|
||||
price = clientDL.predict(house);
|
||||
long timeStop = System.currentTimeMillis();
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
assertNotNull(price);
|
||||
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
|
||||
|
||||
val clientSD = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving")
|
||||
.build();
|
||||
|
||||
PredictedPrice price2 = clientSD.predict(house);
|
||||
timeStart = System.currentTimeMillis();
|
||||
price = clientSD.predict(house);
|
||||
timeStop = System.currentTimeMillis();
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
assertNotNull(price);
|
||||
assertEquals((float) 3.0, price.getPrice(), 1e-5);
|
||||
|
||||
}
|
||||
finally {
|
||||
serverSD.stop();
|
||||
serverDL.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStartStopSequential() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val serverDL = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.numWorkers(1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
val serverSD = new JsonModelServer.Builder<House, PredictedPrice>(sd)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.orderedInputNodes(new String[]{"input"})
|
||||
.orderedOutputNodes(new String[]{"total"})
|
||||
.port(PORT+1)
|
||||
.build();
|
||||
|
||||
serverDL.start();
|
||||
serverDL.stop();
|
||||
|
||||
serverSD.start();
|
||||
serverSD.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void basicServingTestForSD() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(sd)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.orderedInputNodes(new String[]{"input"})
|
||||
.orderedOutputNodes(new String[]{"total"})
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
// warmup
|
||||
PredictedPrice price = client.predict(house);
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
price = client.predict(house);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
|
||||
assertNotNull(price);
|
||||
assertEquals((float) district + 1.0f, price.getPrice(), 1e-5);
|
||||
}
|
||||
finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void basicServingTestForDLSynchronized() throws Exception {
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.numWorkers(1)
|
||||
.inferenceMode(INPLACE)
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house1 = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
House house2 = House.builder().area(50).bathrooms(1).bedrooms(2).district(district).build();
|
||||
House house3 = House.builder().area(80).bathrooms(1).bedrooms(3).district(district).build();
|
||||
|
||||
// warmup
|
||||
PredictedPrice price = client.predict(house1);
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
PredictedPrice price1 = client.predict(house1);
|
||||
PredictedPrice price2 = client.predict(house2);
|
||||
PredictedPrice price3 = client.predict(house3);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
|
||||
assertNotNull(price);
|
||||
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
|
||||
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void basicServingTestForDL() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.numWorkers(1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.port(PORT)
|
||||
.parallelMode(false)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
// warmup
|
||||
PredictedPrice price = client.predict(house);
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
price = client.predict(house);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
|
||||
assertNotNull(price);
|
||||
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
|
||||
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeserialization_1() {
|
||||
String request = "{\"bedrooms\":3,\"area\":100,\"district\":2,\"bathrooms\":2}";
|
||||
val deserializer = new House.HouseDeserializer();
|
||||
val result = deserializer.deserialize(request);
|
||||
assertEquals(2, result.getDistrict());
|
||||
assertEquals(100, result.getArea());
|
||||
assertEquals(2, result.getBathrooms());
|
||||
assertEquals(3, result.getBedrooms());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeserialization_2() {
|
||||
String request = "{\"price\":1}";
|
||||
val deserializer = new PredictedPrice.PredictedPriceDeserializer();
|
||||
val result = deserializer.deserialize(request);
|
||||
assertEquals(1.0, result.getPrice(), 1e-4);
|
||||
}
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void negativeServingTest_1() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(null)
|
||||
.port(18080)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test //(expected = NullPointerException.class)
|
||||
public void negativeServingTest_2() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = IOException.class)
|
||||
public void negativeServingTest_3() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(1)
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
|
||||
@Override
|
||||
public PredictedPrice deserialize(String json) {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.endpointAddress("http://localhost:18080/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
// warmup
|
||||
PredictedPrice price = client.predict(house);
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void asyncServingTest() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(1)
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
Future<PredictedPrice> price = client.predictAsync(house);
|
||||
assertNotNull(price);
|
||||
assertEquals((float) 0.421444, price.get().getPrice(), 1e-5);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
}
|
||||
finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void negativeAsyncTest() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.inferenceMode(InferenceMode.BATCHED)
|
||||
.numWorkers(1)
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
// Fake deserializer to test failure
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
|
||||
@Override
|
||||
public PredictedPrice deserialize(String json) {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
try {
|
||||
Future<PredictedPrice> price = client.predictAsync(house);
|
||||
assertNotNull(price);
|
||||
assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
} catch (ExecutionException e) {
|
||||
assertTrue(e.getMessage().contains("Deserialization failed"));
|
||||
}
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testSameDiffMnist() throws Exception {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28);
|
||||
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10));
|
||||
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10));
|
||||
SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b));
|
||||
|
||||
val server = new JsonModelServer.Builder<float[], Integer>(sd)
|
||||
.outputSerializer( new IntSerde())
|
||||
.inputDeserializer(new FloatSerde())
|
||||
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
|
||||
@Override
|
||||
public MultiDataSet apply(float[] input) {
|
||||
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer apply(INDArray... nnOutput) {
|
||||
return nnOutput[0].argMax().getInt(0);
|
||||
}
|
||||
})
|
||||
.orderedInputNodes("in")
|
||||
.orderedOutputNodes("softmax")
|
||||
.port(PORT+1)
|
||||
.build();
|
||||
|
||||
val client = JsonRemoteInference.<float[], Integer>builder()
|
||||
.endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving")
|
||||
.outputDeserializer(new IntSerde())
|
||||
.inputSerializer( new FloatSerde())
|
||||
.build();
|
||||
|
||||
try{
|
||||
server.start();
|
||||
for( int i=0; i<10; i++ ){
|
||||
INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28*28);
|
||||
INDArray exp = sd.output(Collections.singletonMap("in", f), "softmax").get("softmax");
|
||||
float[] fArr = f.toFloatVector();
|
||||
int out = client.predict(fArr);
|
||||
assertEquals(exp.argMax().getInt(0), out);
|
||||
}
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMlnMnist() throws Exception {
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.list()
|
||||
.layer(new DenseLayer.Builder().nIn(784).nOut(10).build())
|
||||
.layer(new LossLayer.Builder().activation(Activation.SOFTMAX).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
val server = new JsonModelServer.Builder<float[], Integer>(net)
|
||||
.outputSerializer( new IntSerde())
|
||||
.inputDeserializer(new FloatSerde())
|
||||
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
|
||||
@Override
|
||||
public MultiDataSet apply(float[] input) {
|
||||
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer apply(INDArray... nnOutput) {
|
||||
return nnOutput[0].argMax().getInt(0);
|
||||
}
|
||||
})
|
||||
.orderedInputNodes("in")
|
||||
.orderedOutputNodes("softmax")
|
||||
.port(PORT + 1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(2)
|
||||
.build();
|
||||
|
||||
val client = JsonRemoteInference.<float[], Integer>builder()
|
||||
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
|
||||
.outputDeserializer(new IntSerde())
|
||||
.inputSerializer( new FloatSerde())
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
for (int i = 0; i < 10; i++) {
|
||||
INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28 * 28);
|
||||
INDArray exp = net.output(f);
|
||||
float[] fArr = f.toFloatVector();
|
||||
int out = client.predict(fArr);
|
||||
assertEquals(exp.argMax().getInt(0), out);
|
||||
}
|
||||
} catch (Exception e){
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCompGraph() throws Exception {
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.graphBuilder()
|
||||
.addInputs("input1", "input2")
|
||||
.addLayer("L1", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input1")
|
||||
.addLayer("L2", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input2")
|
||||
.addVertex("merge", new MergeVertex(), "L1", "L2")
|
||||
.addLayer("out", new OutputLayer.Builder().nIn(4+4).nOut(3).build(), "merge")
|
||||
.setOutputs("out")
|
||||
.build();
|
||||
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.init();
|
||||
|
||||
val server = new JsonModelServer.Builder<float[], Integer>(net)
|
||||
.outputSerializer( new IntSerde())
|
||||
.inputDeserializer(new FloatSerde())
|
||||
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
|
||||
@Override
|
||||
public MultiDataSet apply(float[] input) {
|
||||
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer apply(INDArray... nnOutput) {
|
||||
return nnOutput[0].argMax().getInt(0);
|
||||
}
|
||||
})
|
||||
.orderedInputNodes("in")
|
||||
.orderedOutputNodes("softmax")
|
||||
.port(PORT + 1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(2)
|
||||
.parallelMode(false)
|
||||
.build();
|
||||
|
||||
val client = JsonRemoteInference.<float[], Integer>builder()
|
||||
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
|
||||
.outputDeserializer(new IntSerde())
|
||||
.inputSerializer( new FloatSerde())
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
//client.predict(new float[]{0.0f, 1.0f, 2.0f});
|
||||
} catch (Exception e){
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCompGraph_1() throws Exception {
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.updater(new Sgd(0.01))
|
||||
.graphBuilder()
|
||||
.addInputs("input")
|
||||
.addLayer("L1", new DenseLayer.Builder().nIn(8).nOut(4).build(), "input")
|
||||
.addLayer("out1", new OutputLayer.Builder()
|
||||
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||
.nIn(4).nOut(3).build(), "L1")
|
||||
.addLayer("out2", new OutputLayer.Builder()
|
||||
.lossFunction(LossFunctions.LossFunction.MSE)
|
||||
.nIn(4).nOut(2).build(), "L1")
|
||||
.setOutputs("out1","out2")
|
||||
.build();
|
||||
|
||||
final ComputationGraph net = new ComputationGraph(conf);
|
||||
net.init();
|
||||
|
||||
val server = new JsonModelServer.Builder<float[], Integer>(net)
|
||||
.outputSerializer( new IntSerde())
|
||||
.inputDeserializer(new FloatSerde())
|
||||
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
|
||||
@Override
|
||||
public MultiDataSet apply(float[] input) {
|
||||
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer apply(INDArray... nnOutput) {
|
||||
return nnOutput[0].argMax().getInt(0);
|
||||
}
|
||||
})
|
||||
.orderedInputNodes("input")
|
||||
.orderedOutputNodes("out")
|
||||
.port(PORT + 1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(2)
|
||||
.parallelMode(false)
|
||||
.build();
|
||||
|
||||
val client = JsonRemoteInference.<float[], Integer>builder()
|
||||
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
|
||||
.outputDeserializer(new IntSerde())
|
||||
.inputSerializer( new FloatSerde())
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
val result = client.predict(new float[]{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
|
||||
assertNotNull(result);
|
||||
} catch (Exception e){
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
private static class FloatSerde implements JsonSerializer<float[]>, JsonDeserializer<float[]>{
|
||||
private final ObjectMapper om = new ObjectMapper();
|
||||
|
||||
@Override
|
||||
public float[] deserialize(String json) {
|
||||
try {
|
||||
return om.readValue(json, FloatHolder.class).getFloats();
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String serialize(float[] o) {
|
||||
try{
|
||||
return om.writeValueAsString(new FloatHolder(o));
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
//Use float holder so Jackson does ser/de properly (no "{}" otherwise)
|
||||
@AllArgsConstructor @NoArgsConstructor @Data
|
||||
private static class FloatHolder {
|
||||
private float[] floats;
|
||||
}
|
||||
}
|
||||
|
||||
private static class IntSerde implements JsonSerializer<Integer>, JsonDeserializer<Integer> {
|
||||
private final ObjectMapper om = new ObjectMapper();
|
||||
|
||||
@Override
|
||||
public Integer deserialize(String json) {
|
||||
try {
|
||||
return om.readValue(json, Integer.class);
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String serialize(Integer o) {
|
||||
try{
|
||||
return om.writeValueAsString(o);
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote;
|
||||
|
||||
import lombok.val;
|
||||
import org.apache.http.client.methods.HttpGet;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.impl.client.HttpClientBuilder;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class ServletTest {
|
||||
|
||||
private JsonModelServer server;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
server = new JsonModelServer.Builder<String,String>(sd)
|
||||
.port(8080)
|
||||
.inferenceAdapter(new InferenceAdapter<String, String>() {
|
||||
@Override
|
||||
public MultiDataSet apply(String input) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String apply(INDArray... nnOutput) {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.outputSerializer(new JsonSerializer<String>() {
|
||||
@Override
|
||||
public String serialize(String o) {
|
||||
return "";
|
||||
}
|
||||
})
|
||||
.inputDeserializer(new JsonDeserializer<String>() {
|
||||
@Override
|
||||
public String deserialize(String json) {
|
||||
return "";
|
||||
}
|
||||
})
|
||||
.orderedInputNodes("input")
|
||||
.orderedOutputNodes("output")
|
||||
.build();
|
||||
|
||||
server.start();
|
||||
//server.join();
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() throws Exception {
|
||||
server.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getEndpoints() throws IOException {
|
||||
val request = new HttpGet( "http://localhost:8080/v1" );
|
||||
request.setHeader("Content-type", "application/json");
|
||||
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(200, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testContentTypeGet() throws IOException {
|
||||
val request = new HttpGet( "http://localhost:8080/v1" );
|
||||
request.setHeader("Content-type", "text/plain");
|
||||
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(415, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testContentTypePost() throws Exception {
|
||||
val request = new HttpPost("http://localhost:8080/v1/serving");
|
||||
request.setHeader("Content-type", "text/plain");
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(415, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void postForServing() throws Exception {
|
||||
val request = new HttpPost("http://localhost:8080/v1/serving");
|
||||
request.setHeader("Content-type", "application/json");
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(500, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNotFoundPost() throws Exception {
|
||||
val request = new HttpPost("http://localhost:8080/v1/serving/some");
|
||||
request.setHeader("Content-type", "application/json");
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(404, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNotFoundGet() throws Exception {
|
||||
val requestGet = new HttpGet( "http://localhost:8080/v1/not_found" );
|
||||
requestGet.setHeader("Content-type", "application/json");
|
||||
|
||||
val responseGet = HttpClientBuilder.create().build().execute( requestGet );
|
||||
assertEquals(404, responseGet.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote.helpers;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import lombok.*;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class House {
|
||||
private int district;
|
||||
private int bedrooms;
|
||||
private int bathrooms;
|
||||
private int area;
|
||||
|
||||
|
||||
public static class HouseSerializer implements JsonSerializer<House> {
|
||||
@Override
|
||||
public String serialize(@NonNull House o) {
|
||||
return new Gson().toJson(o);
|
||||
}
|
||||
}
|
||||
|
||||
public static class HouseDeserializer implements JsonDeserializer<House> {
|
||||
@Override
|
||||
public House deserialize(@NonNull String json) {
|
||||
return new Gson().fromJson(json, House.class);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote.helpers;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
@Slf4j
|
||||
public class HouseToPredictedPriceAdapter implements InferenceAdapter<House, PredictedPrice> {
|
||||
|
||||
@Override
|
||||
public MultiDataSet apply(@NonNull House input) {
|
||||
// we just create vector array with shape[4] and assign it's value to the district value
|
||||
return new MultiDataSet(Nd4j.create(DataType.FLOAT, 1, 4).assign(input.getDistrict()), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PredictedPrice apply(INDArray... nnOutput) {
|
||||
return new PredictedPrice(nnOutput[0].getFloat(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote.helpers;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class PredictedPrice {
|
||||
private float price;
|
||||
|
||||
public static class PredictedPriceSerializer implements JsonSerializer<PredictedPrice> {
|
||||
@Override
|
||||
public String serialize(@NonNull PredictedPrice o) {
|
||||
return new Gson().toJson(o);
|
||||
}
|
||||
}
|
||||
|
||||
public static class PredictedPriceDeserializer implements JsonDeserializer<PredictedPrice> {
|
||||
@Override
|
||||
public PredictedPrice deserialize(@NonNull String json) {
|
||||
return new Gson().fromJson(json, PredictedPrice.class);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ License for the specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
~
|
||||
~ SPDX-License-Identifier: Apache-2.0
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||
|
||||
<configuration>
|
||||
|
||||
|
||||
|
||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||
<file>logs/application.log</file>
|
||||
<encoder>
|
||||
<pattern>%date - [%level] - from %logger in %thread
|
||||
%n%message%n%xException%n</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder>
|
||||
<pattern> %logger{15} - %message%n%xException{5}
|
||||
</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<logger name="org.eclipse.jetty" level="WARN" />
|
||||
<logger name="org.apache.catalina.core" level="WARN" />
|
||||
<logger name="org.springframework" level="WARN" />
|
||||
<logger name="org.nd4j" level="DEBUG" />
|
||||
<logger name="org.deeplearning4j" level="INFO" />
|
||||
|
||||
|
||||
<root level="ERROR">
|
||||
<appender-ref ref="STDOUT" />
|
||||
<appender-ref ref="FILE" />
|
||||
</root>
|
||||
|
||||
</configuration>
|
|
@ -0,0 +1,30 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<modules>
|
||||
<module>deeplearning4j-json-server</module>
|
||||
</modules>
|
||||
|
||||
<parent>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>deeplearning4j-remote</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<name>deeplearning4j-remote</name>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -21,7 +21,6 @@ import lombok.*;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.api.ModelAdapter;
|
||||
import org.deeplearning4j.nn.api.OutputAdapter;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
|
|
|
@ -22,7 +22,6 @@ import lombok.val;
|
|||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.api.ModelAdapter;
|
||||
import org.deeplearning4j.nn.api.OutputAdapter;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
|
|
|
@ -144,6 +144,7 @@
|
|||
<module>dl4j-perf</module>
|
||||
<module>dl4j-integration-tests</module>
|
||||
<module>deeplearning4j-common</module>
|
||||
<module>deeplearning4j-remote</module>
|
||||
</modules>
|
||||
|
||||
<dependencyManagement>
|
||||
|
|
|
@ -163,9 +163,9 @@ if(CUDA_BLAS)
|
|||
if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 10
|
||||
if ("${COMPUTE}" STREQUAL "all")
|
||||
if (APPLE)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
|
||||
endif()
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
|
@ -173,24 +173,24 @@ if(CUDA_BLAS)
|
|||
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
|
||||
if ("${COMPUTE}" STREQUAL "all")
|
||||
if (APPLE)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||
endif()
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
endif()
|
||||
elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8.0
|
||||
if ("${COMPUTE}" STREQUAL "all")
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60)
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
endif()
|
||||
else()
|
||||
if ("${COMPUTE}" STREQUAL "all")
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 )
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 )
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
@ -205,34 +205,34 @@ if(CUDA_BLAS)
|
|||
message("CUDA 10 Debug build")
|
||||
if ("${COMPUTE}" STREQUAL "all")
|
||||
if (APPLE)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
|
||||
elseif()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
|
||||
endif()
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=compute_${COMPUTE})
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
endif()
|
||||
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
|
||||
if ("${COMPUTE}" STREQUAL "all")
|
||||
if (APPLE)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
|
||||
elseif()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70)
|
||||
endif()
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
endif()
|
||||
elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8
|
||||
if ("${COMPUTE}" STREQUAL "all")
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62)
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
endif()
|
||||
else()
|
||||
if ("${COMPUTE}" STREQUAL "all")
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53)
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53)
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
@ -249,7 +249,7 @@ if(CUDA_BLAS)
|
|||
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
||||
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
|
||||
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/impl/*.cpp ../include/loops/*.h)
|
||||
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
||||
|
||||
if (NOT BUILD_TESTS)
|
||||
|
|
|
@ -344,7 +344,7 @@ bool NDArray::isS() const {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
bool NDArray::isR() const {
|
||||
auto xType = ArrayOptions::dataType(this->_shapeInfo);
|
||||
return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8;
|
||||
return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8 || xType == BFLOAT16;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1769,6 +1769,17 @@ ND4J_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr);
|
|||
ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut);
|
||||
ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut);
|
||||
|
||||
typedef nd4j::LaunchContext OpaqueLaunchContext;
|
||||
|
||||
ND4J_EXPORT OpaqueLaunchContext* defaultLaunchContext();
|
||||
ND4J_EXPORT Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc);
|
||||
ND4J_EXPORT Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc);
|
||||
ND4J_EXPORT Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc);
|
||||
ND4J_EXPORT Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc);
|
||||
ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc);
|
||||
ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc);
|
||||
ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc);
|
||||
|
||||
}
|
||||
|
||||
#endif //NATIVEOPERATIONS_NATIVEOPS_H
|
||||
|
|
|
@ -2985,6 +2985,38 @@ const char* runFullBenchmarkSuit(bool printOut) {
|
|||
return chars;
|
||||
}
|
||||
|
||||
nd4j::LaunchContext* defaultLaunchContext() {
|
||||
return LaunchContext::defaultContext();
|
||||
}
|
||||
|
||||
Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
||||
|
|
|
@ -356,7 +356,7 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
|
|||
auto stream = getContext()->getCudaStream();
|
||||
|
||||
prepareSpecialUse({&target}, {this});
|
||||
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
|
||||
registerSpecialUse({&target}, {this});
|
||||
}
|
||||
|
||||
|
@ -375,7 +375,7 @@ void NDArray::tile(NDArray& target) const {
|
|||
auto stream = getContext()->getCudaStream();
|
||||
|
||||
prepareSpecialUse({&target}, {this});
|
||||
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES);
|
||||
registerSpecialUse({&target}, {this});
|
||||
}
|
||||
|
||||
|
@ -434,7 +434,7 @@ void NDArray::repeat(int dimension, NDArray& target) const {
|
|||
|
||||
NDArray::prepareSpecialUse({&target}, {this});
|
||||
auto stream = getContext()->getCudaStream();
|
||||
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), repeatKernelHH, (getSpecialBuffer(), target.getSpecialBuffer(), numTads, lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), repeatKernelHH, (getSpecialBuffer(), target.getSpecialBuffer(), numTads, lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES);
|
||||
NDArray::registerSpecialUse({&target}, {this});
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,14 @@
|
|||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
static Nd4jLong __device__ __noinline__ __getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) {
|
||||
return shape::getIndexOffset(index, shapeInfo, length);
|
||||
}
|
||||
|
||||
static Nd4jLong __device__ __noinline__ __length(Nd4jLong *shapeInfo) {
|
||||
return shape::length(shapeInfo);
|
||||
}
|
||||
|
||||
template <typename T, typename Lambda> static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
|
||||
template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
|
||||
template <typename T, typename Lambda> static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda);
|
||||
|
@ -86,7 +94,7 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
|
|||
auto xOrder = shape::order(xShapeInfo);
|
||||
auto zOrder = shape::order(zShapeInfo);
|
||||
|
||||
auto zLength = shape::length(zShapeInfo);
|
||||
auto zLength = __length(zShapeInfo);
|
||||
|
||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
|
@ -95,8 +103,8 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
|
|||
z[e * zEws] = lambda(x[e * xEws]);
|
||||
} else {
|
||||
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
||||
auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
|
||||
auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
|
||||
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
|
||||
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
|
||||
|
||||
z[zOffset] = lambda(x[xOffset]);
|
||||
}
|
||||
|
@ -115,7 +123,7 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
|
|||
auto xOrder = shape::order(xShapeInfo);
|
||||
auto zOrder = shape::order(zShapeInfo);
|
||||
|
||||
auto zLength = shape::length(zShapeInfo);
|
||||
auto zLength = __length(zShapeInfo);
|
||||
|
||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
|
@ -124,8 +132,8 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
|
|||
z[e * zEws] = lambda(e, x[e * xEws]);
|
||||
} else {
|
||||
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
||||
auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
|
||||
auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
|
||||
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
|
||||
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
|
||||
|
||||
z[zOffset] = lambda(e, x[xOffset]);
|
||||
}
|
||||
|
@ -147,7 +155,7 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
|
|||
auto yOrder = shape::order(yShapeInfo);
|
||||
auto zOrder = shape::order(zShapeInfo);
|
||||
|
||||
auto zLength = shape::length(zShapeInfo);
|
||||
auto zLength = __length(zShapeInfo);
|
||||
|
||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
|
@ -156,9 +164,9 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
|
|||
z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]);
|
||||
} else {
|
||||
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
||||
auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
|
||||
auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength);
|
||||
auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
|
||||
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
|
||||
auto yOffset = __getIndexOffset(e, yShapeInfo, zLength);
|
||||
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
|
||||
|
||||
z[zOffset] = lambda(e, x[xOffset], y[yOffset]);
|
||||
}
|
||||
|
@ -180,7 +188,7 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
|
|||
auto yOrder = shape::order(yShapeInfo);
|
||||
auto zOrder = shape::order(zShapeInfo);
|
||||
|
||||
auto zLength = shape::length(zShapeInfo);
|
||||
auto zLength = __length(zShapeInfo);
|
||||
|
||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
|
@ -189,9 +197,9 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
|
|||
z[e * zEws] = lambda(x[e * xEws], y[e * yEws]);
|
||||
} else {
|
||||
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
||||
auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
|
||||
auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength);
|
||||
auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
|
||||
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
|
||||
auto yOffset = __getIndexOffset(e, yShapeInfo, zLength);
|
||||
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
|
||||
|
||||
z[zOffset] = lambda(x[xOffset], y[yOffset]);
|
||||
}
|
||||
|
@ -216,7 +224,7 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
|
|||
auto yOrder = shape::order(yShapeInfo);
|
||||
auto zOrder = shape::order(zShapeInfo);
|
||||
|
||||
auto zLength = shape::length(zShapeInfo);
|
||||
auto zLength = __length(zShapeInfo);
|
||||
|
||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
|
@ -225,10 +233,10 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
|
|||
z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]);
|
||||
} else {
|
||||
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
|
||||
auto wOffset = shape::getIndexOffset(e, wShapeInfo, zLength);
|
||||
auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength);
|
||||
auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength);
|
||||
auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength);
|
||||
auto wOffset = __getIndexOffset(e, wShapeInfo, zLength);
|
||||
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength);
|
||||
auto yOffset = __getIndexOffset(e, yShapeInfo, zLength);
|
||||
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength);
|
||||
|
||||
z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]);
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <helpers/threshold.h>
|
||||
#include <ops/specials_cuda.h>
|
||||
#include <helpers/DebugHelper.h>
|
||||
#include <AffinityManager.h>
|
||||
|
||||
#include <exceptions/datatype_exception.h>
|
||||
#include <helpers/CudaLaunchHelper.h>
|
||||
|
@ -1691,11 +1692,7 @@ void setOmpMinThreads(int threads) {
|
|||
}
|
||||
|
||||
int getDevice() {
|
||||
int curDevice = -1;
|
||||
|
||||
cudaGetDevice(&curDevice);
|
||||
|
||||
return curDevice;
|
||||
return nd4j::AffinityManager::currentDeviceId();
|
||||
}
|
||||
|
||||
void setElementThreshold(int num) {
|
||||
|
@ -2391,8 +2388,8 @@ void sortByValue(Nd4jPointer *extraPointers,
|
|||
|
||||
auto xLength = shape::length(xShapeInfo);
|
||||
auto xEWS = shape::elementWiseStride(xShapeInfo);
|
||||
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||
auto xType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||
|
||||
|
||||
// check if xLength is a power of 2, and use bitonic sort, if that's the case
|
||||
|
@ -2406,7 +2403,7 @@ void sortByValue(Nd4jPointer *extraPointers,
|
|||
|
||||
for (int k = 2; k <= xLength; k = 2*k) {
|
||||
for (int j = k >> 1; j > 0; j = j >> 1) {
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -2430,7 +2427,7 @@ void sortByValue(Nd4jPointer *extraPointers,
|
|||
int rev = 0;
|
||||
do{
|
||||
int half = n >> 1;
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
n>>=1;
|
||||
rev = 1;
|
||||
} while(n > 1);
|
||||
|
@ -3342,6 +3339,7 @@ Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) {
|
|||
nd4j::graph::Context* createGraphContext(int nodeId) {
|
||||
return new nd4j::graph::Context(nodeId);
|
||||
}
|
||||
|
||||
nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) {
|
||||
return &ptr->randomGenerator();
|
||||
}
|
||||
|
@ -3460,3 +3458,35 @@ const char* runFullBenchmarkSuit(bool printOut) {
|
|||
Nd4jLong getCachedMemory(int deviceId) {
|
||||
return nd4j::ConstantHelper::getInstance()->getCachedAmount(deviceId);
|
||||
}
|
||||
|
||||
nd4j::LaunchContext* defaultLaunchContext() {
|
||||
return LaunchContext::defaultContext();
|
||||
}
|
||||
|
||||
Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) {
|
||||
return lc->getScalarPointer();
|
||||
}
|
||||
|
||||
Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) {
|
||||
return lc->getReductionPointer();
|
||||
}
|
||||
|
||||
Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) {
|
||||
return lc->getAllocationPointer();
|
||||
}
|
||||
|
||||
Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) {
|
||||
return lc->getCudaStream();
|
||||
}
|
||||
|
||||
Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) {
|
||||
return lc->getCudaSpecialStream();
|
||||
}
|
||||
|
||||
Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) {
|
||||
return lc->getCublasHandle();
|
||||
}
|
||||
|
||||
Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) {
|
||||
return lc->getCusolverHandle();
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_AFFINITYMANAGER_H
|
||||
#define LIBND4J_AFFINITYMANAGER_H
|
||||
|
||||
#include <dll.h>
|
||||
#include <pointercast.h>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
|
||||
namespace nd4j {
|
||||
class ND4J_EXPORT AffinityManager {
|
||||
private:
|
||||
static std::atomic<int> _lastDevice;
|
||||
static int _numberOfDevices;
|
||||
static std::mutex _currentMutex;
|
||||
static std::mutex _numberMutex;
|
||||
|
||||
public:
|
||||
static int currentNativeDeviceId();
|
||||
static int currentDeviceId();
|
||||
static int numberOfDevices();
|
||||
static void setCurrentDevice(int deviceId);
|
||||
static void setCurrentNativeDevice(int deviceId);
|
||||
};
|
||||
}
|
||||
|
||||
#endif //DEV_TESTS_AFFINITYMANAGER_H
|
|
@ -0,0 +1,58 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_CONTEXTBUFFERS_H
|
||||
#define LIBND4J_CONTEXTBUFFERS_H
|
||||
|
||||
#include <dll.h>
|
||||
#include <pointercast.h>
|
||||
|
||||
namespace nd4j {
|
||||
class ND4J_EXPORT ContextBuffers {
|
||||
private:
|
||||
void* _reductionPointer;
|
||||
void* _scalarPointer;
|
||||
void* _allocationPointer;
|
||||
bool _allocated = true;
|
||||
|
||||
int _deviceId = -1;
|
||||
|
||||
void initialize();
|
||||
public:
|
||||
ContextBuffers();
|
||||
ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner = false);
|
||||
~ContextBuffers();
|
||||
|
||||
void* reductionBuffer();
|
||||
void* scalarBuffer();
|
||||
void* allocationBuffer();
|
||||
|
||||
void setReductionBuffer(void* pointer);
|
||||
void setScalarBuffer(void* pointer);
|
||||
void setAllocationBuffer(void* pointer);
|
||||
|
||||
void triggerOwnership(bool isOwner);
|
||||
|
||||
int deviceId();
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
#endif //DEV_TESTS_CONTEXTBUFFERS_H
|
|
@ -35,6 +35,8 @@
|
|||
#include <op_boilerplate.h>
|
||||
#include <memory/Workspace.h>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <execution/ContextBuffers.h>
|
||||
|
||||
|
||||
|
||||
|
@ -44,17 +46,16 @@ class ND4J_EXPORT LaunchContext {
|
|||
|
||||
private:
|
||||
static std::vector<std::shared_ptr<LaunchContext>> _contexts;
|
||||
static std::mutex _mutex;
|
||||
|
||||
#ifdef __CUDABLAS__
|
||||
|
||||
#ifndef __JAVACPP_HACK__
|
||||
|
||||
void* _reductionPointer;
|
||||
void* _scalarPointer;
|
||||
int* _allocationPointer;
|
||||
cudaStream_t *_cudaStream = nullptr;
|
||||
cudaStream_t *_cudaSpecialStream = nullptr;
|
||||
void *_cublasHandle = nullptr;
|
||||
cudaStream_t* _cudaStream = nullptr;
|
||||
cudaStream_t* _cudaSpecialStream = nullptr;
|
||||
void* _cublasHandle = nullptr;
|
||||
void* _cusolverHandle = nullptr;
|
||||
|
||||
#endif // JCPP
|
||||
|
||||
|
@ -62,31 +63,27 @@ class ND4J_EXPORT LaunchContext {
|
|||
#endif // CUDA
|
||||
nd4j::memory::Workspace* _workspace = nullptr;
|
||||
int _deviceID = 0;
|
||||
|
||||
public:
|
||||
#ifdef __CUDABLAS__
|
||||
|
||||
#ifndef __JAVACPP_HACK__
|
||||
LaunchContext(cudaStream_t* cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer = nullptr, void* scalarPointer = nullptr, int* allocationPointer = nullptr);
|
||||
|
||||
FORCEINLINE void* getReductionPointer () const {return _reductionPointer;};
|
||||
void* getReductionPointer () const;
|
||||
void* getScalarPointer() const;
|
||||
int* getAllocationPointer() const;
|
||||
void* getCublasHandle() const;
|
||||
void* getCusolverHandle() const;
|
||||
cudaStream_t* getCudaStream() const;
|
||||
cudaStream_t* getCudaSpecialStream() const;
|
||||
|
||||
FORCEINLINE void* getScalarPointer() const {return _scalarPointer;};
|
||||
|
||||
FORCEINLINE int* getAllocationPointer() const {return _allocationPointer;};
|
||||
|
||||
FORCEINLINE void* getCublasHandle() const {return _cublasHandle;};
|
||||
FORCEINLINE cudaStream_t* getCudaStream() const {return _cudaStream;};
|
||||
FORCEINLINE cudaStream_t* getCudaSpecialStream() const {return _cudaSpecialStream;};
|
||||
|
||||
FORCEINLINE void setReductionPointer (void* reductionPointer) {_reductionPointer = reductionPointer;};
|
||||
|
||||
FORCEINLINE void setScalarPointer(void* scalarPointer) {_scalarPointer = scalarPointer;};
|
||||
|
||||
FORCEINLINE void setAllocationPointer(int* allocationPointer) {_allocationPointer = allocationPointer;};
|
||||
|
||||
FORCEINLINE void setCudaStream(cudaStream_t* cudaStream) {_cudaStream = cudaStream;};
|
||||
FORCEINLINE void setCudaSpecialStream(cudaStream_t* cudaStream) {_cudaSpecialStream = cudaStream;};
|
||||
FORCEINLINE void setCublasHandle(void *handle) {_cublasHandle = handle; };
|
||||
void setReductionPointer (void* reductionPointer);
|
||||
void setScalarPointer(void* scalarPointer);
|
||||
void setAllocationPointer(int* allocationPointer);
|
||||
void setCudaStream(cudaStream_t* cudaStream);
|
||||
void setCudaSpecialStream(cudaStream_t* cudaStream);
|
||||
void setCublasHandle(void *handle);
|
||||
|
||||
|
||||
#endif // JCPP
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <execution/AffinityManager.h>
|
||||
|
||||
namespace nd4j {
|
||||
int AffinityManager::currentDeviceId() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int AffinityManager::currentNativeDeviceId() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int AffinityManager::numberOfDevices() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
void AffinityManager::setCurrentDevice(int deviceId) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
void AffinityManager::setCurrentNativeDevice(int deviceId) {
|
||||
// no-op
|
||||
}
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
#include <execution/ContextBuffers.h>
|
||||
#include <execution/AffinityManager.h>
|
||||
|
||||
namespace nd4j {
|
||||
ContextBuffers::ContextBuffers() {
|
||||
_deviceId = AffinityManager::currentDeviceId();
|
||||
}
|
||||
|
||||
ContextBuffers::~ContextBuffers() {
|
||||
// no-op
|
||||
}
|
||||
|
||||
ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) {
|
||||
_reductionPointer = rPointer;
|
||||
_scalarPointer = sPointer;
|
||||
_allocationPointer = aPointer;
|
||||
_allocated = isOwner;
|
||||
}
|
||||
|
||||
void ContextBuffers::initialize() {
|
||||
// no-op
|
||||
}
|
||||
|
||||
void* ContextBuffers::reductionBuffer() {
|
||||
return _reductionPointer;
|
||||
}
|
||||
|
||||
void* ContextBuffers::scalarBuffer() {
|
||||
return _scalarPointer;
|
||||
}
|
||||
|
||||
void* ContextBuffers::allocationBuffer() {
|
||||
return _allocationPointer;
|
||||
}
|
||||
|
||||
void ContextBuffers::setReductionBuffer(void* pointer) {
|
||||
_reductionPointer = pointer;
|
||||
}
|
||||
|
||||
void ContextBuffers::setScalarBuffer(void* pointer) {
|
||||
_scalarPointer = pointer;
|
||||
}
|
||||
|
||||
void ContextBuffers::setAllocationBuffer(void* pointer) {
|
||||
_allocationPointer = pointer;
|
||||
}
|
||||
|
||||
void ContextBuffers::triggerOwnership(bool isOwner) {
|
||||
_allocated = isOwner;
|
||||
}
|
||||
|
||||
int ContextBuffers::deviceId() {
|
||||
return _deviceId;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver119 on 30.11.17.
|
||||
//
|
||||
|
||||
#include <execution/LaunchContext.h>
|
||||
#include <logger.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <thread>
|
||||
|
||||
thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
|
||||
|
||||
namespace nd4j {
|
||||
|
||||
LaunchContext::~LaunchContext() {
|
||||
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
LaunchContext::LaunchContext() {
|
||||
// default constructor, just to make clang/ranlib happy
|
||||
_workspace = nullptr;
|
||||
_deviceID = 0;
|
||||
}
|
||||
|
||||
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
|
||||
|
||||
}
|
||||
|
||||
LaunchContext* LaunchContext::defaultContext() {
|
||||
// TODO: we need it to be device-aware, but only once we add NUMA support for cpu
|
||||
if (LaunchContext::_contexts.empty()) {
|
||||
LaunchContext::_contexts.emplace_back(std::make_shared<LaunchContext>());
|
||||
}
|
||||
|
||||
// return context for current device
|
||||
return LaunchContext::_contexts[0].get();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <logger.h>
|
||||
#include <execution/AffinityManager.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
|
||||
thread_local int globalThreadToDevice = -1;
|
||||
|
||||
namespace nd4j {
|
||||
std::mutex AffinityManager::_currentMutex;
|
||||
std::mutex AffinityManager::_numberMutex;
|
||||
int AffinityManager::_numberOfDevices = -1;
|
||||
|
||||
int AffinityManager::currentDeviceId() {
|
||||
// if there's no affinity set - set it now
|
||||
if (globalThreadToDevice < 0) {
|
||||
|
||||
// this block must be thread-local
|
||||
_currentMutex.lock();
|
||||
|
||||
globalThreadToDevice = _lastDevice++;
|
||||
|
||||
// we need to check if we've got deviceId >= number of actual devices, and reset to zero otherwise
|
||||
if (globalThreadToDevice >= numberOfDevices()) {
|
||||
globalThreadToDevice = 0;
|
||||
_lastDevice = numberOfDevices() > 1 ? 1 : 0;
|
||||
}
|
||||
|
||||
_currentMutex.unlock();
|
||||
|
||||
setCurrentDevice(globalThreadToDevice);
|
||||
}
|
||||
|
||||
// if we already know affinity - just return it
|
||||
if (globalThreadToDevice >= 0)
|
||||
return globalThreadToDevice;
|
||||
|
||||
int dev = 0;
|
||||
auto res = cudaGetDevice(&dev);
|
||||
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("cudaGetDevice failed", res);
|
||||
|
||||
return dev;
|
||||
}
|
||||
|
||||
int AffinityManager::currentNativeDeviceId() {
|
||||
int dev = 0;
|
||||
auto res = cudaGetDevice(&dev);
|
||||
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("cudaGetDevice failed", res);
|
||||
|
||||
return dev;
|
||||
}
|
||||
|
||||
int AffinityManager::numberOfDevices() {
|
||||
_numberMutex.lock();
|
||||
// we want to cache number of devices
|
||||
if (_numberOfDevices <= 0) {
|
||||
int dev = 0;
|
||||
auto res = cudaGetDeviceCount(&dev);
|
||||
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("cudaGetDeviceCount failed", res);
|
||||
|
||||
_numberOfDevices = dev;
|
||||
}
|
||||
_numberMutex.unlock();
|
||||
|
||||
return _numberOfDevices;
|
||||
}
|
||||
|
||||
void AffinityManager::setCurrentNativeDevice(int deviceId) {
|
||||
auto res = cudaSetDevice(deviceId);
|
||||
}
|
||||
|
||||
void AffinityManager::setCurrentDevice(int deviceId) {
|
||||
auto res = cudaSetDevice(deviceId);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("cudaSetDevice failed", res);
|
||||
|
||||
// update thread-device affinity
|
||||
globalThreadToDevice = deviceId;
|
||||
|
||||
// TODO: update context buffers?
|
||||
}
|
||||
|
||||
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <execution/ContextBuffers.h>
|
||||
#include <logger.h>
|
||||
#include <AffinityManager.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_device_runtime_api.h>
|
||||
|
||||
namespace nd4j {
|
||||
ContextBuffers::ContextBuffers() {
|
||||
nd4j_printf("Creating ContextBuffers for device [%i]\n", AffinityManager::currentDeviceId());
|
||||
_deviceId = AffinityManager::currentDeviceId();
|
||||
}
|
||||
|
||||
ContextBuffers::~ContextBuffers() {
|
||||
if (_allocated) {
|
||||
nd4j_printf("Releasing ContextBuffers\n","");
|
||||
|
||||
if (_allocationPointer != nullptr)
|
||||
cudaFree(_allocationPointer);
|
||||
|
||||
if (_scalarPointer != nullptr)
|
||||
cudaFree(_scalarPointer);
|
||||
|
||||
if (_allocationPointer != nullptr)
|
||||
cudaFree(_reductionPointer);
|
||||
}
|
||||
}
|
||||
|
||||
ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) {
|
||||
_reductionPointer = rPointer;
|
||||
_scalarPointer = sPointer;
|
||||
_allocationPointer = aPointer;
|
||||
_allocated = isOwner;
|
||||
}
|
||||
|
||||
void ContextBuffers::initialize() {
|
||||
nd4j_printf("Initializing buffers on deviceId [%i]\n", AffinityManager::currentNativeDeviceId());
|
||||
|
||||
auto res = cudaMalloc(reinterpret_cast<void**>(&_reductionPointer), 1024 * 1024 * 8);
|
||||
if (res != 0)
|
||||
throw std::runtime_error("_reductionPointer allocation failed");
|
||||
|
||||
res = cudaMalloc(reinterpret_cast<void**>(&_scalarPointer), 16);
|
||||
if (res != 0)
|
||||
throw std::runtime_error("_scalarPointer allocation failed");
|
||||
|
||||
res = cudaMalloc(reinterpret_cast<void**>(&_allocationPointer), 1024 * 1024 * 8);
|
||||
if (res != 0)
|
||||
throw std::runtime_error("_allocationPointer allocation failed");
|
||||
|
||||
_allocated = true;
|
||||
}
|
||||
|
||||
void* ContextBuffers::reductionBuffer() {
|
||||
if (_reductionPointer == nullptr)
|
||||
initialize();
|
||||
|
||||
return _reductionPointer;
|
||||
}
|
||||
|
||||
void* ContextBuffers::scalarBuffer() {
|
||||
if (_scalarPointer == nullptr)
|
||||
initialize();
|
||||
|
||||
return _scalarPointer;
|
||||
}
|
||||
|
||||
void* ContextBuffers::allocationBuffer() {
|
||||
if (_allocationPointer == nullptr)
|
||||
initialize();
|
||||
|
||||
return _allocationPointer;
|
||||
}
|
||||
|
||||
void ContextBuffers::setReductionBuffer(void* pointer) {
|
||||
_reductionPointer = pointer;
|
||||
}
|
||||
|
||||
void ContextBuffers::setScalarBuffer(void* pointer) {
|
||||
_scalarPointer = pointer;
|
||||
}
|
||||
|
||||
void ContextBuffers::setAllocationBuffer(void* pointer) {
|
||||
_allocationPointer = pointer;
|
||||
}
|
||||
|
||||
void ContextBuffers::triggerOwnership(bool isOwner) {
|
||||
_allocated = isOwner;
|
||||
}
|
||||
|
||||
int ContextBuffers::deviceId() {
|
||||
return _deviceId;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,182 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver119 on 30.11.17.
|
||||
//
|
||||
|
||||
#include <execution/LaunchContext.h>
|
||||
#include <logger.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <helpers/cublasHelper.h>
|
||||
#include <thread>
|
||||
#include <execution/AffinityManager.h>
|
||||
|
||||
thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
|
||||
|
||||
namespace nd4j {
|
||||
|
||||
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
|
||||
std::mutex LaunchContext::_mutex;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) {
|
||||
|
||||
_cudaStream = cudaStream;
|
||||
_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream;
|
||||
//_reductionPointer = reductionPointer;
|
||||
//_scalarPointer = scalarPointer;
|
||||
//_allocationPointer = allocationPointer;
|
||||
_workspace = nullptr;
|
||||
_isAllocated = false;
|
||||
}
|
||||
|
||||
LaunchContext::~LaunchContext() {
|
||||
if (_isAllocated) {
|
||||
cudaStreamSynchronize(*_cudaStream);
|
||||
cudaStreamSynchronize(*_cudaSpecialStream);
|
||||
|
||||
cudaStreamDestroy(*_cudaStream);
|
||||
cudaStreamDestroy(*_cudaSpecialStream);
|
||||
|
||||
delete _cudaStream;
|
||||
delete _cudaSpecialStream;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
LaunchContext::LaunchContext() {
|
||||
// default constructor, just to make clang/ranlib happy
|
||||
_workspace = nullptr;
|
||||
_deviceID = 0;
|
||||
|
||||
_isAllocated = true;
|
||||
_cudaStream = new cudaStream_t();
|
||||
_cudaSpecialStream = new cudaStream_t();
|
||||
if (nullptr == _cudaStream || nullptr == _cudaSpecialStream)
|
||||
throw std::runtime_error("Failed to allocate memory for new CUDA stream");
|
||||
|
||||
cudaError_t err = cudaStreamCreate(_cudaStream);
|
||||
if (err != 0)
|
||||
throw cuda_exception::build("Failed to create default CUDA stream with launch context", err);
|
||||
|
||||
err = cudaStreamCreate(_cudaSpecialStream);
|
||||
if (err != 0)
|
||||
throw cuda_exception::build("Failed to create special CUDA stream with launch context", err);
|
||||
|
||||
_cublasHandle = CublasHelper::getInstance()->handle();
|
||||
|
||||
_cusolverHandle = CublasHelper::getInstance()->solver();
|
||||
|
||||
auto res = cudaStreamSynchronize(*_cudaStream);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("Initial sync failed", res);
|
||||
}
|
||||
|
||||
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
|
||||
_isAllocated = false;
|
||||
_cudaStream = reinterpret_cast<cudaStream_t*>(cudaStream);
|
||||
_cudaSpecialStream = reinterpret_cast<cudaStream_t*>(cudaStream);
|
||||
//_reductionPointer = reductionPointer;
|
||||
//_scalarPointer = scalarPointer;
|
||||
//_allocationPointer = reinterpret_cast<int *>(allocationPointer);
|
||||
}
|
||||
|
||||
LaunchContext* LaunchContext::defaultContext() {
|
||||
/**
|
||||
* This method returns LaunchContext, that has multiple entities within:
|
||||
* 1) temporary buffers. they must be per-thread
|
||||
* 2) CUDA stream. it must be either per-thread or per-device
|
||||
* 3) cuBLAS handle. it must be per-device
|
||||
*/
|
||||
auto deviceId = AffinityManager::currentDeviceId();
|
||||
|
||||
// we need this block synchronous, to avoid double initialization etc
|
||||
_mutex.lock();
|
||||
if (LaunchContext::_contexts.empty()) {
|
||||
// create one context per device
|
||||
auto numDevices = AffinityManager::numberOfDevices();
|
||||
|
||||
_contexts.resize(numDevices);
|
||||
for (int e = 0; e < numDevices; e++) {
|
||||
AffinityManager::setCurrentDevice(e);
|
||||
|
||||
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
|
||||
}
|
||||
|
||||
// don't forget to restore device back again
|
||||
AffinityManager::setCurrentDevice(deviceId);
|
||||
}
|
||||
_mutex.unlock();
|
||||
|
||||
// return context for current device
|
||||
return LaunchContext::_contexts[deviceId].get();
|
||||
}
|
||||
|
||||
|
||||
void* LaunchContext::getReductionPointer () const {
|
||||
return contextBuffers.reductionBuffer();
|
||||
};
|
||||
|
||||
void* LaunchContext::getScalarPointer() const {
|
||||
return contextBuffers.scalarBuffer();
|
||||
};
|
||||
|
||||
int* LaunchContext::getAllocationPointer() const {
|
||||
return reinterpret_cast<int*>(contextBuffers.allocationBuffer());
|
||||
};
|
||||
|
||||
void* LaunchContext::getCublasHandle() const {
|
||||
return _cublasHandle;
|
||||
};
|
||||
|
||||
void* LaunchContext::getCusolverHandle() const {
|
||||
return _cusolverHandle;
|
||||
};
|
||||
|
||||
cudaStream_t* LaunchContext::getCudaStream() const {
|
||||
return _cudaStream;
|
||||
};
|
||||
|
||||
cudaStream_t* LaunchContext::getCudaSpecialStream() const {
|
||||
return _cudaSpecialStream;
|
||||
};
|
||||
|
||||
|
||||
void LaunchContext::setReductionPointer (void* reductionPointer) {
|
||||
contextBuffers.setReductionBuffer(reductionPointer);
|
||||
};
|
||||
|
||||
void LaunchContext::setScalarPointer(void* scalarPointer) {
|
||||
contextBuffers.setScalarBuffer(scalarPointer);
|
||||
};
|
||||
|
||||
void LaunchContext::setAllocationPointer(int* allocationPointer) {
|
||||
contextBuffers.setAllocationBuffer(allocationPointer);
|
||||
};
|
||||
|
||||
void LaunchContext::setCudaStream(cudaStream_t* cudaStream) {
|
||||
_cudaStream = cudaStream;
|
||||
};
|
||||
|
||||
void LaunchContext::setCudaSpecialStream(cudaStream_t* cudaStream) {
|
||||
_cudaSpecialStream = cudaStream;
|
||||
};
|
||||
|
||||
void LaunchContext::setCublasHandle(void *handle) {
|
||||
_cublasHandle = handle;
|
||||
};
|
||||
}
|
|
@ -1,130 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver119 on 30.11.17.
|
||||
//
|
||||
|
||||
#include <execution/LaunchContext.h>
|
||||
#include <logger.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <helpers/cublasHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
|
||||
#ifdef __CUDABLAS__
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) {
|
||||
|
||||
_cudaStream = cudaStream;
|
||||
_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream;
|
||||
_reductionPointer = reductionPointer;
|
||||
_scalarPointer = scalarPointer;
|
||||
_allocationPointer = allocationPointer;
|
||||
_workspace = nullptr;
|
||||
_isAllocated = false;
|
||||
}
|
||||
#endif
|
||||
|
||||
LaunchContext::~LaunchContext() {
|
||||
#ifdef __CUDABLAS__
|
||||
if (_isAllocated) {
|
||||
cudaStreamSynchronize(*_cudaStream);
|
||||
cudaStreamSynchronize(*_cudaSpecialStream);
|
||||
|
||||
cudaStreamDestroy(*_cudaStream);
|
||||
cudaStreamDestroy(*_cudaSpecialStream);
|
||||
|
||||
delete _cudaStream;
|
||||
delete _cudaSpecialStream;
|
||||
|
||||
cudaFree(_reductionPointer);
|
||||
cudaFree(_allocationPointer);
|
||||
cudaFree(_scalarPointer);
|
||||
|
||||
cublas::destroyHandle(_cublasHandle);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
LaunchContext::LaunchContext() {
|
||||
// default constructor, just to make clang/ranlib happy
|
||||
_workspace = nullptr;
|
||||
_deviceID = 0;
|
||||
|
||||
#ifdef __CUDABLAS__
|
||||
_isAllocated = true;
|
||||
_cudaStream = new cudaStream_t();
|
||||
_cudaSpecialStream = new cudaStream_t();
|
||||
if (nullptr == _cudaStream || nullptr == _cudaSpecialStream)
|
||||
throw std::runtime_error("Failed to allocate memory for new CUDA stream");
|
||||
|
||||
cudaError_t err = cudaStreamCreate(_cudaStream);
|
||||
if (err != 0)
|
||||
throw cuda_exception::build("Failed to create default CUDA stream with launch context", err);
|
||||
|
||||
err = cudaStreamCreate(_cudaSpecialStream);
|
||||
if (err != 0)
|
||||
throw cuda_exception::build("Failed to create special CUDA stream with launch context", err);
|
||||
|
||||
_cublasHandle = cublas::handle();
|
||||
|
||||
auto res = cudaStreamSynchronize(*_cudaStream);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("Initial sync failed", res);
|
||||
|
||||
res = cudaMalloc(reinterpret_cast<void**>(&_reductionPointer), 1024 * 1024 * 8);
|
||||
if (res != 0)
|
||||
throw std::runtime_error("_reductionPointer allocation failed");
|
||||
|
||||
res = cudaMalloc(reinterpret_cast<void**>(&_scalarPointer), 8);
|
||||
if (res != 0)
|
||||
throw std::runtime_error("_scalarPointer allocation failed");
|
||||
|
||||
res = cudaMalloc(reinterpret_cast<void**>(&_allocationPointer), 1024 * 1024 * 8);
|
||||
if (res != 0)
|
||||
throw std::runtime_error("_allocationPointer allocation failed");
|
||||
#else
|
||||
//
|
||||
#endif
|
||||
}
|
||||
|
||||
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
|
||||
#ifdef __CUDABLAS__
|
||||
_isAllocated = false;
|
||||
_cudaStream = reinterpret_cast<cudaStream_t*>(cudaStream);
|
||||
_cudaSpecialStream = reinterpret_cast<cudaStream_t*>(cudaStream);
|
||||
_reductionPointer = reductionPointer;
|
||||
_scalarPointer = scalarPointer;
|
||||
_allocationPointer = reinterpret_cast<int *>(allocationPointer);
|
||||
#else
|
||||
// no-op
|
||||
#endif
|
||||
}
|
||||
|
||||
LaunchContext* LaunchContext::defaultContext() {
|
||||
// TODO: we need it to be device-aware
|
||||
if (LaunchContext::_contexts.empty()) {
|
||||
LaunchContext::_contexts.emplace_back(std::make_shared<LaunchContext>());
|
||||
}
|
||||
return LaunchContext::_contexts[0].get();
|
||||
}
|
||||
|
||||
}
|
|
@ -21,6 +21,7 @@
|
|||
#ifndef __CUDABLAS__
|
||||
|
||||
#include <ConstantHelper.h>
|
||||
#include <execution/AffinityManager.h>
|
||||
#include <types/types.h>
|
||||
#include <loops/type_conversions.h>
|
||||
#include <type_boilerplate.h>
|
||||
|
@ -59,11 +60,11 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
int ConstantHelper::getCurrentDevice() {
|
||||
return 0L;
|
||||
return AffinityManager::currentDeviceId();
|
||||
}
|
||||
|
||||
int ConstantHelper::getNumberOfDevices() {
|
||||
return 1;
|
||||
return AffinityManager::numberOfDevices();
|
||||
}
|
||||
|
||||
ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "../MmulHelper.h"
|
||||
#include <NDArrayFactory.h>
|
||||
#include <helpers/BlasHelper.h>
|
||||
#include <exceptions/datatype_exception.h>
|
||||
|
||||
|
||||
namespace nd4j {
|
||||
|
@ -148,6 +149,11 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX,
|
|||
//////////////////////////////////////////////////////////////////////////////
|
||||
// MXK x KxN = MxN
|
||||
NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) {
|
||||
if (A->dataType() != B->dataType())
|
||||
throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), B->dataType());
|
||||
|
||||
if (C != nullptr && A->dataType() != C->dataType())
|
||||
throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), C->dataType());
|
||||
|
||||
if(A->rankOf() != 2)
|
||||
throw std::runtime_error("MmulHelper::mmulMxM: rank of A array is not equal 2 !");
|
||||
|
@ -212,7 +218,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
|
|||
BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, reinterpret_cast<double *>(pA->getBuffer()), lda, reinterpret_cast<double *>(pB->getBuffer()), ldb, (double) beta, reinterpret_cast<double *>(pC->getBuffer()), ldc);
|
||||
}
|
||||
else {
|
||||
BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), NUMERIC_TYPES);
|
||||
//BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
|
||||
if(pC != C) {
|
||||
|
@ -230,6 +237,11 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
|
|||
////////////////////////////////////////////////////////////////////////////
|
||||
// MXN x N = M
|
||||
NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* Y, const double alpha, const double beta, const char outOrder) {
|
||||
if (X->dataType() != A->dataType())
|
||||
throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), X->dataType());
|
||||
|
||||
if (Y != nullptr && X->dataType() != Y->dataType())
|
||||
throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), Y->dataType());
|
||||
|
||||
int xLenDim, yLenDim(0);
|
||||
|
||||
|
@ -279,7 +291,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
|
|||
BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->getBuffer(), lda, (float*)X->getBuffer(), incx, (float)beta, (float*)Y->getBuffer(), incy);
|
||||
}
|
||||
else {
|
||||
BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), NUMERIC_TYPES);
|
||||
//BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
|
||||
if(pA != A)
|
||||
|
@ -291,6 +304,11 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
|
|||
////////////////////////////////////////////////////////////////////////////
|
||||
// (X * Y) = Z[0]
|
||||
NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, const double alpha, const double beta) {
|
||||
if (X->dataType() != Y->dataType())
|
||||
throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Y->dataType());
|
||||
|
||||
if (Z != nullptr && X->dataType() != Z->dataType())
|
||||
throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Z->dataType());
|
||||
|
||||
int xLenDim(0), yLenDim(0);
|
||||
|
||||
|
@ -316,13 +334,14 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
|
|||
const auto yType = Y->dataType();
|
||||
const auto zType = Z->dataType();
|
||||
|
||||
BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), NUMERIC_TYPES);
|
||||
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
|
||||
return Z;
|
||||
}
|
||||
|
||||
BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
|
||||
}
|
||||
|
|
|
@ -21,13 +21,41 @@
|
|||
#include "../cublasHelper.h"
|
||||
|
||||
namespace nd4j {
|
||||
namespace cublas {
|
||||
void* handle() {
|
||||
static void* handle_() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void destroyHandle(void* handle) {
|
||||
//
|
||||
static void destroyHandle_(void* handle) {
|
||||
|
||||
}
|
||||
|
||||
CublasHelper::CublasHelper() {
|
||||
|
||||
}
|
||||
|
||||
CublasHelper::~CublasHelper() {
|
||||
|
||||
}
|
||||
|
||||
CublasHelper* CublasHelper::getInstance() {
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new nd4j::CublasHelper();
|
||||
|
||||
return _INSTANCE;
|
||||
}
|
||||
|
||||
void* CublasHelper::handle() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void* CublasHelper::solver() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void* CublasHelper::handle(int deviceId) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
nd4j::CublasHelper* nd4j::CublasHelper::_INSTANCE = 0;
|
||||
}
|
|
@ -21,12 +21,28 @@
|
|||
#ifndef DEV_TESTS_CUBLASHELPER_H
|
||||
#define DEV_TESTS_CUBLASHELPER_H
|
||||
|
||||
namespace nd4j {
|
||||
namespace cublas {
|
||||
void* handle();
|
||||
#include <dll.h>
|
||||
#include <pointercast.h>
|
||||
#include <vector>
|
||||
|
||||
void destroyHandle(void* handle);
|
||||
}
|
||||
namespace nd4j {
|
||||
class CublasHelper {
|
||||
private:
|
||||
static CublasHelper *_INSTANCE;
|
||||
|
||||
std::vector<void*> _cache;
|
||||
std::vector<void*> _solvers;
|
||||
|
||||
CublasHelper();
|
||||
~CublasHelper();
|
||||
public:
|
||||
static CublasHelper* getInstance();
|
||||
|
||||
void* solver();
|
||||
|
||||
void* handle();
|
||||
void* handle(int deviceId);
|
||||
};
|
||||
}
|
||||
|
||||
#endif //DEV_TESTS_CUBLASHELPER_H
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include <logger.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda.h>
|
||||
#include <execution/AffinityManager.h>
|
||||
|
||||
#define CONSTANT_LIMIT 49152
|
||||
|
||||
|
@ -43,23 +44,11 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
int ConstantHelper::getCurrentDevice() {
|
||||
int dev = 0;
|
||||
auto res = cudaGetDevice(&dev);
|
||||
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("cudaGetDevice failed", res);
|
||||
|
||||
return dev;
|
||||
return AffinityManager::currentDeviceId();
|
||||
}
|
||||
|
||||
int ConstantHelper::getNumberOfDevices() {
|
||||
int dev = 0;
|
||||
auto res = cudaGetDeviceCount(&dev);
|
||||
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("cudaGetDeviceCount failed", res);
|
||||
|
||||
return dev;
|
||||
return AffinityManager::numberOfDevices();
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -250,8 +250,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
|||
blocksPerGrid.y = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.y); // rows
|
||||
}
|
||||
|
||||
BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
// BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES)
|
||||
//BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES)
|
||||
}
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
|
||||
|
@ -339,8 +339,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
|
|||
threadsPerBlock.x = 512;
|
||||
blocksPerGrid.x = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.x); // rows
|
||||
}
|
||||
BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
// BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES)
|
||||
//BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES)
|
||||
}
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
|
||||
|
@ -397,8 +397,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
|
|||
|
||||
NDArray::prepareSpecialUse({Z}, {X, Y});
|
||||
|
||||
BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
// BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES)
|
||||
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES)
|
||||
|
||||
auto cudaResult = cudaStreamSynchronize(*stream);
|
||||
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
|
||||
|
@ -408,8 +408,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
|
|||
return Z;
|
||||
}
|
||||
|
||||
BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
|
||||
}
|
|
@ -20,12 +20,15 @@
|
|||
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cusolverDn.h>
|
||||
#include "../cublasHelper.h"
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <helpers/logger.h>
|
||||
#include <execution/AffinityManager.h>
|
||||
|
||||
namespace nd4j {
|
||||
void* cublas::handle() {
|
||||
|
||||
static void* handle_() {
|
||||
auto _handle = new cublasHandle_t();
|
||||
auto status = cublasCreate_v2(_handle); // initialize CUBLAS context
|
||||
if (status != CUBLAS_STATUS_SUCCESS)
|
||||
|
@ -34,7 +37,16 @@ namespace nd4j {
|
|||
return reinterpret_cast<void *>(_handle);
|
||||
}
|
||||
|
||||
void cublas::destroyHandle(void* handle) {
|
||||
static void* solver_() {
|
||||
auto cusolverH = new cusolverDnHandle_t();
|
||||
auto status = cusolverDnCreate(cusolverH);
|
||||
if (status != CUSOLVER_STATUS_SUCCESS)
|
||||
throw cuda_exception::build("cuSolver handle creation failed !", status);
|
||||
|
||||
return cusolverH;
|
||||
}
|
||||
|
||||
static void destroyHandle_(void* handle) {
|
||||
auto ch = reinterpret_cast<cublasHandle_t *>(handle);
|
||||
auto status = cublasDestroy_v2(*ch);
|
||||
if (status != CUBLAS_STATUS_SUCCESS)
|
||||
|
@ -42,4 +54,57 @@ namespace nd4j {
|
|||
|
||||
delete ch;
|
||||
}
|
||||
|
||||
CublasHelper::CublasHelper() {
|
||||
auto numDevices = AffinityManager::numberOfDevices();
|
||||
auto currentDevice = AffinityManager::currentDeviceId();
|
||||
_cache.resize(numDevices);
|
||||
_solvers.resize(numDevices);
|
||||
for (int e = 0; e < numDevices; e++) {
|
||||
AffinityManager::setCurrentDevice(e);
|
||||
|
||||
_cache[e] = handle_();
|
||||
_solvers[e] = solver_();
|
||||
}
|
||||
|
||||
// don't forget to restore back original device
|
||||
AffinityManager::setCurrentDevice(currentDevice);
|
||||
}
|
||||
|
||||
CublasHelper::~CublasHelper() {
|
||||
auto numDevices = AffinityManager::numberOfDevices();
|
||||
|
||||
for (int e = 0; e < numDevices; e++)
|
||||
destroyHandle_(_cache[e]);
|
||||
}
|
||||
|
||||
CublasHelper* CublasHelper::getInstance() {
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new nd4j::CublasHelper();
|
||||
|
||||
return _INSTANCE;
|
||||
}
|
||||
|
||||
void* CublasHelper::handle() {
|
||||
auto deviceId = AffinityManager::currentDeviceId();
|
||||
return handle(deviceId);
|
||||
}
|
||||
|
||||
void* CublasHelper::solver() {
|
||||
auto deviceId = AffinityManager::currentDeviceId();
|
||||
if (deviceId < 0 || deviceId > _solvers.size())
|
||||
throw cuda_exception::build("requested deviceId doesn't look valid", deviceId);
|
||||
|
||||
return _solvers[deviceId];
|
||||
}
|
||||
|
||||
void* CublasHelper::handle(int deviceId) {
|
||||
if (deviceId < 0 || deviceId > _cache.size())
|
||||
throw cuda_exception::build("requested deviceId doesn't look valid", deviceId);
|
||||
|
||||
return _cache[deviceId];
|
||||
}
|
||||
|
||||
|
||||
nd4j::CublasHelper* nd4j::CublasHelper::_INSTANCE = 0;
|
||||
}
|
|
@ -276,23 +276,6 @@ namespace functions {
|
|||
DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS)
|
||||
}
|
||||
|
||||
// FIXME: eventually we might want to get rid of that
|
||||
#ifndef __CLION_IDE__
|
||||
/*
|
||||
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *y, Nd4jLong *yShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *y, Nd4jLong *yShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *y, Nd4jLong *yShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
|
||||
|
||||
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
|
||||
|
||||
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
|
||||
*/
|
||||
#endif
|
||||
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES);
|
||||
}
|
||||
|
|
|
@ -60,9 +60,18 @@ static __global__ void broadcastInverseSimple(
|
|||
functions::broadcast::Broadcast<X,Y,Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
|
||||
}
|
||||
|
||||
|
||||
namespace functions {
|
||||
namespace broadcast {
|
||||
|
||||
static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) {
|
||||
return shape::getIndexOffset(index, shapeInfo, length);
|
||||
}
|
||||
|
||||
static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) {
|
||||
return shape::length(shapeInfo);
|
||||
}
|
||||
|
||||
template<typename X, typename Y, typename Z>
|
||||
template <typename OpClass>
|
||||
__host__ void Broadcast<X,Y,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
@ -120,9 +129,9 @@ namespace functions {
|
|||
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
tadLength = shape::length(tadOnlyShapeInfo);
|
||||
tadLength = _length(tadOnlyShapeInfo);
|
||||
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
|
||||
numTads = shape::length(yShapeInfo) / tadLength;
|
||||
numTads = _length(yShapeInfo) / tadLength;
|
||||
xEWS = shape::elementWiseStride(xShapeInfo);
|
||||
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
|
||||
}
|
||||
|
@ -146,9 +155,9 @@ namespace functions {
|
|||
else {
|
||||
// it is expected that x and z tads and y array all have the same length
|
||||
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
|
||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, tadLength);
|
||||
auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
||||
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
|
||||
auto xOffset = _getIndexOffset(i, xShapeInfo, tadLength);
|
||||
auto yOffset = _getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
||||
auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
|
||||
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]);
|
||||
}
|
||||
}
|
||||
|
@ -186,9 +195,9 @@ namespace functions {
|
|||
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
tadLength = shape::length(tadOnlyShapeInfo);
|
||||
tadLength = _length(tadOnlyShapeInfo);
|
||||
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
|
||||
numTads = shape::length(xShapeInfo) / tadLength;
|
||||
numTads = _length(xShapeInfo) / tadLength;
|
||||
yEWS = shape::elementWiseStride(yShapeInfo);
|
||||
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
|
||||
}
|
||||
|
@ -212,14 +221,15 @@ namespace functions {
|
|||
// it is expected that x and z tads and y array all have the same length
|
||||
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
|
||||
|
||||
auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
||||
auto yOffset = shape::getIndexOffset(i, yShapeInfo, tadLength);
|
||||
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
|
||||
auto xOffset = _getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
||||
auto yOffset = _getIndexOffset(i, yShapeInfo, tadLength);
|
||||
auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
|
||||
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1);
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#include <loops/broadcasting.h>
|
||||
#include <loops/legacy_ops.h>
|
||||
#include <types/types.h>
|
||||
#include <Environment.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <StringUtils.h>
|
||||
#include <specials_cuda.h>
|
||||
|
||||
namespace functions {
|
||||
namespace broadcast {
|
||||
template <typename X, typename Y, typename Z>
|
||||
void Broadcast<X, Y, Z>::execInverse(int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
//
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
void Broadcast<X, Y, Z>::exec(int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* CPU execution
|
||||
* @param x the input
|
||||
* @param xShapeInfo the x shape information
|
||||
* @param y the y data
|
||||
* @param yShapeInfo the y shape information
|
||||
* @param result the result
|
||||
* @param resultShapeInfo the result shape information
|
||||
* @param dimension the dimension to broadcast along long
|
||||
* @param dimensionLength the length of the dimension buffer
|
||||
*/
|
||||
template <typename X, typename Y, typename Z>
|
||||
template<typename OpType>
|
||||
void Broadcast<X, Y, Z>::exec(void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
//
|
||||
}
|
||||
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
template<typename OpType>
|
||||
void Broadcast<X, Y, Z>::execInverse(void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -224,6 +224,77 @@ namespace functions {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename X, typename Y>
|
||||
void BroadcastBool<X,Y>::exec(int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Y>
|
||||
void BroadcastBool<X,Y>::execInverse(int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void BroadcastBool<X,Y>::exec(void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void BroadcastBool<X,Y>::execInverse(void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES);
|
||||
}
|
||||
}
|
|
@ -361,6 +361,32 @@ namespace functions {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
Nd4jLong IndexReduce<T>::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IndexReduce<T>::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template<typename OpType>
|
||||
Nd4jLong IndexReduce<T>:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template<typename OpType>
|
||||
_CUDA_H void IndexReduce<T>::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include "../pairwise_transform.h"
|
||||
|
||||
namespace functions {
|
||||
namespace pairwise_transforms {
|
||||
template <typename X, typename Y, typename Z>
|
||||
void PairWiseTransform<X, Y, Z>::exec(
|
||||
const int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *z,
|
||||
Nd4jLong *zShapeInfo,
|
||||
void *extraParams) {
|
||||
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
void PairWiseTransform<X, Y, Z>::exec(
|
||||
const int opNum,
|
||||
void *x,
|
||||
Nd4jLong xStride,
|
||||
void *y,
|
||||
Nd4jLong yStride,
|
||||
void *z,
|
||||
Nd4jLong resultStride,
|
||||
void *extraParams,
|
||||
Nd4jLong len) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
template<typename OpType>
|
||||
void PairWiseTransform<X, Y, Z>:: exec(
|
||||
void *vx,
|
||||
Nd4jLong* xShapeInfo,
|
||||
void *vy,
|
||||
Nd4jLong* yShapeInfo,
|
||||
void *vresult,
|
||||
Nd4jLong* zShapeInfo,
|
||||
void *vextraParams) {
|
||||
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
template<typename OpType>
|
||||
void PairWiseTransform<X, Y, Z>::exec(void *vx,
|
||||
Nd4jLong xStride,
|
||||
void *vy,
|
||||
Nd4jLong yStride,
|
||||
void *vresult,
|
||||
Nd4jLong resultStride,
|
||||
void *vextraParams,
|
||||
const Nd4jLong len) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -111,6 +111,63 @@ void PairWiseBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_
|
|||
DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS);
|
||||
}
|
||||
|
||||
|
||||
template<typename X, typename Y>
|
||||
void PairWiseBoolTransform<X,Y>::exec(
|
||||
const int opNum,
|
||||
void *dx,
|
||||
Nd4jLong *xShapeBuffer,
|
||||
void *y,
|
||||
Nd4jLong *yShapeBuffer,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeBuffer,
|
||||
void *extraParams) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Y>
|
||||
void PairWiseBoolTransform<X,Y>::exec(
|
||||
const int opNum,
|
||||
void *dx,
|
||||
Nd4jLong xStride,
|
||||
void *y,
|
||||
Nd4jLong yStride,
|
||||
void *result,
|
||||
Nd4jLong resultStride,
|
||||
void *extraParams,
|
||||
Nd4jLong n) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void PairWiseBoolTransform<X,Y>::exec(
|
||||
void *vx,
|
||||
Nd4jLong* xShapeBuffer,
|
||||
void *vy,
|
||||
Nd4jLong* yShapeBuffer,
|
||||
void *vresult,
|
||||
Nd4jLong* resultShapeBuffer,
|
||||
void *vextraParams) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void PairWiseBoolTransform<X,Y>::exec(void *vx,
|
||||
Nd4jLong xStride,
|
||||
void *vy,
|
||||
Nd4jLong yStride,
|
||||
void *vresult,
|
||||
Nd4jLong resultStride,
|
||||
void *vextraParams,
|
||||
const Nd4jLong n) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -442,6 +442,39 @@ namespace functions {
|
|||
DEBUG_KERNEL(stream, opNum);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template<typename OpClass>
|
||||
void RandomFunction<T>::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
||||
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template<typename OpClass>
|
||||
void RandomFunction<T>::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
||||
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template<typename OpClass>
|
||||
void RandomFunction<T>::execTransform(Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
||||
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
||||
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
||||
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
||||
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#include <loops/reduce3.h>
|
||||
#include <loops/legacy_ops.h>
|
||||
#include <types/types.h>
|
||||
#include <specials_cuda.h>
|
||||
|
||||
namespace functions {
|
||||
namespace reduce3 {
|
||||
template <typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void Reduce3<X,Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename X, typename Y>
|
||||
void Reduce3<X,Y>::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParamsVals, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void Reduce3<X,Y>::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void Reduce3<X,Y>::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void Reduce3<X,Y>::execAll(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename X, typename Y>
|
||||
void Reduce3<X,Y>::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename X, typename Y>
|
||||
void Reduce3<X,Y>::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename X, typename Y>
|
||||
void Reduce3<X,Y>::execAll(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include "loops/scalar.h"
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <op_boilerplate.h>
|
||||
#include <helpers/TAD.h>
|
||||
#include <types/types.h>
|
||||
|
||||
namespace functions {
|
||||
namespace scalar {
|
||||
|
||||
}
|
||||
}
|
|
@ -231,6 +231,41 @@ void ScalarBoolTransform<X,Y>::executeCudaAlongDimension(dim3& launchDims, cudaS
|
|||
}
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES);
|
||||
|
||||
|
||||
template<typename X, typename Y>
|
||||
template <typename OpType>
|
||||
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Y>
|
||||
void ScalarBoolTransform<X,Y>::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Y>
|
||||
void ScalarBoolTransform<X,Y>::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Y>
|
||||
void ScalarBoolTransform<X,Y>::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -21,84 +21,6 @@
|
|||
|
||||
#include <ops/specials_cuda.h>
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
__global__ void bitonicArbitraryStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
|
||||
auto x = static_cast<X*>(vx);
|
||||
auto y = static_cast<Y*>(vy);
|
||||
|
||||
int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
int half = window>>1;
|
||||
|
||||
__shared__ Nd4jLong xLength;
|
||||
if (threadIdx.x == 0) {
|
||||
xLength = shape::length(xShapeInfo);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
//for (int i = 0; i < length; i+= window)
|
||||
/*
|
||||
if window == 4;
|
||||
iterations will be: 0; 4; 8; 12; 16; 20
|
||||
if gridDim = 3;
|
||||
on first iteration we'll have: 0; 4; 8;
|
||||
on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20
|
||||
*/
|
||||
int firstPosition;
|
||||
int firstStep;
|
||||
int secondPosition;
|
||||
int secondStep;
|
||||
|
||||
int WARP_SIZE = 32;
|
||||
int numWarps = (gridDim.x * blockDim.x) / 32;
|
||||
int warpId = tid / WARP_SIZE;
|
||||
int warpIdx = tid % WARP_SIZE;
|
||||
|
||||
if (half >= 128) {
|
||||
firstPosition = blockIdx.x * window;
|
||||
firstStep = gridDim.x * window;
|
||||
|
||||
secondPosition = threadIdx.x;
|
||||
secondStep = blockDim.x;
|
||||
} else if (half >= 32) {
|
||||
firstPosition = warpId * window;
|
||||
firstStep = numWarps * window;
|
||||
|
||||
secondPosition = warpIdx;
|
||||
secondStep = WARP_SIZE;
|
||||
} else {
|
||||
firstPosition = tid * window;
|
||||
firstStep = blockDim.x * gridDim.x * window;
|
||||
|
||||
secondPosition = 0;
|
||||
secondStep = 1;
|
||||
}
|
||||
|
||||
|
||||
for (int i = firstPosition; i < length; i += firstStep) {
|
||||
for (int j = secondPosition; j < half; j += secondStep) {
|
||||
int it = (reverse) ? i + j + half : i + window - j - 1;
|
||||
int ij = i+j;
|
||||
if (it < length && ij < length ) {
|
||||
int posIT = shape::getIndexOffset(it, yShapeInfo, xLength);
|
||||
int posIJ = shape::getIndexOffset(ij, yShapeInfo, xLength);
|
||||
|
||||
Y v0 = y[posIJ];
|
||||
Y v1 = y[posIT];
|
||||
|
||||
if(!descending == (v0 > v1)) {
|
||||
y[posIJ] = v1;
|
||||
y[posIT] = v0;
|
||||
|
||||
X xtemp = x[posIJ];
|
||||
x[posIJ] = x[posIT];
|
||||
x[posIT] = xtemp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
__global__ void bitonicArbitraryStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
|
||||
|
@ -264,11 +186,5 @@ __host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *str
|
|||
bitonicArbitraryStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
__host__ void bitonicArbitraryStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
|
||||
bitonicArbitraryStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
|
|
|
@ -21,60 +21,6 @@
|
|||
|
||||
#include <ops/specials_cuda.h>
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
__global__ void bitonicSortStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
|
||||
|
||||
auto x = static_cast<X*>(vx);
|
||||
auto y = static_cast<Y*>(vy);
|
||||
|
||||
unsigned int i, ixj; /* Sorting partners: i and ixj */
|
||||
i = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
|
||||
__shared__ Nd4jLong xLength;
|
||||
if (threadIdx.x == 0)
|
||||
xLength = shape::length(xShapeInfo);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
if (i >= length)
|
||||
return;
|
||||
|
||||
ixj = i^j;
|
||||
|
||||
/* The threads with the lowest ids sort the array. */
|
||||
if ((ixj)>i) {
|
||||
int posI = shape::getIndexOffset(i, yShapeInfo, xLength);
|
||||
int posIXJ = shape::getIndexOffset(ixj, yShapeInfo, xLength);
|
||||
|
||||
if ((i&k)==0) {
|
||||
/* Sort ascending */
|
||||
if (!descending == (y[posI]>y[posIXJ])) {
|
||||
/* exchange(i,ixj); */
|
||||
X temp = x[posI];
|
||||
x[posI] = x[posIXJ];
|
||||
x[posIXJ] = temp;
|
||||
|
||||
Y ytemp = y[posI];
|
||||
y[posI] = y[posIXJ];
|
||||
y[posIXJ] = ytemp;
|
||||
}
|
||||
} else if ((i&k)!=0) {
|
||||
/* Sort descending */
|
||||
if (!descending == (y[posI]<y[posIXJ])) {
|
||||
/* exchange(i,ixj); */
|
||||
X temp = x[posI];
|
||||
x[posI] = x[posIXJ];
|
||||
x[posIXJ] = temp;
|
||||
|
||||
Y ytemp = y[posI];
|
||||
y[posI] = y[posIXJ];
|
||||
y[posIXJ] = ytemp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
|
@ -189,13 +135,6 @@ __host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream,
|
|||
bitonicSortStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
__host__ void bitonicSortStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
|
||||
bitonicSortStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
|
||||
}
|
||||
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
|
|
|
@ -62,9 +62,9 @@ namespace nd4j {
|
|||
}
|
||||
}
|
||||
}
|
||||
BUILD_DOUBLE_TEMPLATE(template __global__ void repeatKernelDouble, (void const* inputBuffer, void* outputBuffer,
|
||||
BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void repeatKernelDouble, (void const* inputBuffer, void* outputBuffer,
|
||||
Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
||||
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets), LIBND4J_TYPES);
|
||||
|
||||
template <typename T>
|
||||
void repeatKernelH(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong outputLength,
|
||||
|
@ -88,10 +88,10 @@ namespace nd4j {
|
|||
dim3 launchDims(256, 512, 8192);
|
||||
repeatKernelDouble<X,Y><<<launchDims.x, launchDims.y, launchDims.z, stream>>>(inputBuffer, outputBuffer, numTads, inputLength, tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets);
|
||||
}
|
||||
BUILD_DOUBLE_TEMPLATE(template void repeatKernelHH, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength,
|
||||
BUILD_SINGLE_TEMPLATE_TWICE(template void repeatKernelHH, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength,
|
||||
Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
||||
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets,
|
||||
cudaStream_t stream), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
cudaStream_t stream), LIBND4J_TYPES);
|
||||
|
||||
|
||||
}
|
|
@ -21,6 +21,17 @@
|
|||
#include <loops/special_kernels.h>
|
||||
|
||||
namespace nd4j {
|
||||
static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) {
|
||||
return shape::getIndexOffset(index, shapeInfo, length);
|
||||
}
|
||||
|
||||
static Nd4jLong __device__ __noinline__ _subArrayOffset(Nd4jLong index, Nd4jLong *shapeInfoA, Nd4jLong *shapeInfoB) {
|
||||
return shape::subArrayOffset(index, shapeInfoA, shapeInfoB);
|
||||
}
|
||||
|
||||
static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) {
|
||||
return shape::length(shapeInfo);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
|
@ -34,31 +45,20 @@ namespace nd4j {
|
|||
//const auto resultLength = shape::length(outputShape);
|
||||
if (shape::order(outputShape) == 'c') { // ews == 1 always here
|
||||
for (int i = tid; i < resultLength; i += totalThreads) {
|
||||
auto yOffset = shape::subArrayOffset(i, outputShape, inputShape);
|
||||
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
|
||||
*(reinterpret_cast<T *>(outputBuffer) + i) = *(reinterpret_cast<T const *>(inputBuffer) + yOffset);
|
||||
}
|
||||
// for(Nd4jLong i=0; i<resultLen; ++i) {
|
||||
// auto yOffset = shape::subArrayOffset(newShapeInfo, _shapeInfo, i);
|
||||
// BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, (newBuff, i, this->_buffer, yOffset), LIBND4J_TYPES);
|
||||
//
|
||||
// }
|
||||
} else {
|
||||
//
|
||||
//auto inputLength = shape::lenght(inputShape);
|
||||
for (int i = tid; i < resultLength; i += totalThreads) {
|
||||
auto xOffset = shape::getIndexOffset(i, outputShape, resultLength);
|
||||
auto yOffset = shape::subArrayOffset(i, outputShape, inputShape);
|
||||
*(reinterpret_cast<T *>(outputBuffer) + xOffset) = *(reinterpret_cast<T const *>(inputBuffer) +
|
||||
yOffset);
|
||||
// BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, (newBuff, xOffset, this->_buffer, yOffset), LIBND4J_TYPES);
|
||||
auto xOffset = _getIndexOffset(i, outputShape, resultLength);
|
||||
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
|
||||
*(reinterpret_cast<T *>(outputBuffer) + xOffset) = *(reinterpret_cast<T const *>(inputBuffer) + yOffset);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel,
|
||||
(void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength),
|
||||
LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel,(void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength), LIBND4J_TYPES);
|
||||
|
||||
template<typename T>
|
||||
void tileKernelH(void const *inputBuffer, Nd4jLong *inputShape, void *outputBuffer, Nd4jLong *outputShape, Nd4jLong resultLength, cudaStream_t *stream) {
|
||||
|
@ -77,29 +77,26 @@ namespace nd4j {
|
|||
|
||||
if (ordering == 'c' && ews == 1) { // ews == 1 always here
|
||||
for (int i = tid; i < resultLength; i += totalThreads) {
|
||||
auto yOffset = shape::subArrayOffset(i, outputShape, inputShape);
|
||||
*(reinterpret_cast<X *>(outputBuffer) + i) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) +
|
||||
yOffset));
|
||||
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
|
||||
*(reinterpret_cast<X *>(outputBuffer) + i) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) + yOffset));
|
||||
}
|
||||
} else if (ordering == 'c' && ews > 1) {
|
||||
for (int i = tid; i < resultLength; i += totalThreads) {
|
||||
auto yOffset = shape::subArrayOffset(i, outputShape, inputShape);
|
||||
*(reinterpret_cast<X *>(outputBuffer) + i * ews) = static_cast<X>(*(
|
||||
reinterpret_cast<Y const *>(inputBuffer) + yOffset));
|
||||
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
|
||||
*(reinterpret_cast<X *>(outputBuffer) + i * ews) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) + yOffset));
|
||||
}
|
||||
} else {
|
||||
|
||||
for (int i = tid; i < resultLength; i += totalThreads) {
|
||||
|
||||
auto xOffset = shape::getIndexOffset(i, outputShape, resultLength);
|
||||
auto yOffset = shape::subArrayOffset(i, outputShape, inputShape);
|
||||
*(reinterpret_cast<X *>(outputBuffer) + xOffset) = static_cast<X>(*(
|
||||
reinterpret_cast<Y const *>(inputBuffer) + yOffset));
|
||||
auto xOffset = _getIndexOffset(i, outputShape, resultLength);
|
||||
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
|
||||
*(reinterpret_cast<X *>(outputBuffer) + xOffset) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) + yOffset));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template __global__ void tileKernelDouble, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void tileKernelDouble, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews), LIBND4J_TYPES);
|
||||
|
||||
template<typename X, typename Y>
|
||||
void tileKernelHH(void const *inputBuffer, Nd4jLong *inputShape, void *outputBuffer, Nd4jLong *outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream) {
|
||||
|
@ -107,5 +104,5 @@ namespace nd4j {
|
|||
tileKernelDouble<X, Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuffer, inputShape, outputBuffer, outputShape, resultLength, ews);
|
||||
}
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template void tileKernelHH, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream),LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE_TWICE(template void tileKernelHH, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream),LIBND4J_TYPES);
|
||||
}
|
|
@ -413,6 +413,74 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa
|
|||
DEBUG_KERNEL(stream, opNum);
|
||||
}
|
||||
|
||||
|
||||
template <typename X, typename Y>
|
||||
Y SummaryStatsReduce<X,Y>::execScalar(int opNum,
|
||||
bool biasCorrected,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *extraParams) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
void SummaryStatsReduce<X,Y>::execScalar(int opNum,
|
||||
bool biasCorrected,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *extraParams,
|
||||
void *vz,
|
||||
Nd4jLong *resultShapeInfoBuffer) {
|
||||
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
void SummaryStatsReduce<X,Y>::exec(int opNum,
|
||||
bool biasCorrected,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *extraParams,
|
||||
void *vz,
|
||||
Nd4jLong *resultShapeInfoBuffer,
|
||||
int *dimension, int dimensionLength) {
|
||||
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
template<typename OpType>
|
||||
Y SummaryStatsReduce<X,Y>::execScalar(bool biasCorrected,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *extraParams) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void SummaryStatsReduce<X,Y>::execScalar(bool biasCorrected,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *extraParams,
|
||||
void *vz,
|
||||
Nd4jLong *resultShapeInfoBuffer) {
|
||||
//
|
||||
}
|
||||
|
||||
|
||||
template <typename X, typename Y>
|
||||
template<typename OpType>
|
||||
void SummaryStatsReduce<X,Y>::exec(bool biasCorrected,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *extraParams,
|
||||
void *vz,
|
||||
Nd4jLong *resultShapeInfoBuffer,
|
||||
int *dimension,
|
||||
int dimensionLength) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
}
|
|
@ -114,6 +114,17 @@ namespace functions {
|
|||
nd4j::DebugHelper::checkErrorCode(stream, "transformAny(...) failed");
|
||||
}
|
||||
|
||||
template<typename X, typename Z>
|
||||
void TransformAny<X,Z>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Z>
|
||||
template <typename OpType>
|
||||
void TransformAny<X,Z>::exec(void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
|
||||
|
||||
}
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -120,6 +120,17 @@ namespace functions {
|
|||
nd4j::DebugHelper::checkErrorCode(stream, "transformBool(...) failed");
|
||||
}
|
||||
|
||||
template<typename X, typename Z>
|
||||
void TransformBool<X,Z>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Z>
|
||||
template <typename OpType>
|
||||
void TransformBool<X,Z>::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
}
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -142,6 +142,17 @@ namespace functions {
|
|||
nd4j::DebugHelper::checkErrorCode(stream, "transformFloat(...) failed");
|
||||
}
|
||||
|
||||
template<typename X, typename Z>
|
||||
void TransformFloat<X,Z>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X, typename Z>
|
||||
template <typename OpType>
|
||||
void TransformFloat<X,Z>::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
|
|
|
@ -118,6 +118,17 @@ namespace functions {
|
|||
nd4j::DebugHelper::checkErrorCode(stream, "transformSame(...) failed");
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
void TransformSame<X>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
template <typename OpType>
|
||||
void TransformSame<X>::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -119,6 +119,17 @@ namespace functions {
|
|||
nd4j::DebugHelper::checkErrorCode(stream, "transformStrict(...) failed");
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
void TransformStrict<X>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
template <typename OpType>
|
||||
void TransformStrict<X>::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -209,15 +209,6 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
|
|||
}
|
||||
};
|
||||
|
||||
_CUDA_H Nd4jLong TypeCast::estimateQuantizedSize(Nd4jLong rawSize) {
|
||||
if (rawSize <= 0)
|
||||
throw std::runtime_error("Input size for quantization can't be <= 0");
|
||||
|
||||
// 2 fp32 values for max/min, and rawSize number of BYTES
|
||||
return 8 + rawSize;
|
||||
}
|
||||
|
||||
|
||||
template void TypeCast::convertFromThreshold<float>(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);
|
||||
template void TypeCast::convertFromThreshold<float16>(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);
|
||||
template void TypeCast::convertFromThreshold<double>(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);
|
|
@ -69,7 +69,14 @@ namespace nd4j {
|
|||
template <typename T>
|
||||
static _CUDA_H void convertFromThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);
|
||||
|
||||
static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize);
|
||||
FORCEINLINE static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize) {
|
||||
if (rawSize <= 0)
|
||||
throw std::runtime_error("Input size for quantization can't be <= 0");
|
||||
|
||||
// 2 fp32 values for max/min, and rawSize number of BYTES
|
||||
return 8 + rawSize;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
static _CUDA_H void convertToQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz);
|
||||
|
|
|
@ -85,7 +85,7 @@ namespace nd4j {
|
|||
COPY_SHAPE(x, shapeE);
|
||||
COPY_SHAPE(y, shapeG);
|
||||
|
||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
||||
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||
|
||||
return shapeList;
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ namespace nd4j {
|
|||
COPY_SHAPE(x, shapeE);
|
||||
COPY_SHAPE(y, shapeG);
|
||||
|
||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
||||
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||
|
||||
return shapeList;
|
||||
}
|
||||
|
|
|
@ -107,7 +107,7 @@ namespace nd4j {
|
|||
COPY_SHAPE(x, shapeE);
|
||||
COPY_SHAPE(y, shapeG);
|
||||
|
||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
||||
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||
|
||||
return shapeList;
|
||||
}
|
||||
|
|
|
@ -114,7 +114,7 @@ namespace nd4j {
|
|||
COPY_SHAPE(x, shapeE);
|
||||
COPY_SHAPE(y, shapeG);
|
||||
|
||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
||||
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||
|
||||
return shapeList;
|
||||
}
|
||||
|
|
|
@ -112,7 +112,8 @@ namespace nd4j {
|
|||
COPY_SHAPE(input, epsShape);
|
||||
COPY_SHAPE(bias, gradShape);
|
||||
|
||||
return SHAPELIST(epsShape, gradShape);
|
||||
return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape));
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ namespace nd4j {
|
|||
DECLARE_TYPES(non_max_suppression) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_INTS});
|
||||
->setAllowedOutputTypes({ALL_INDICES});
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -253,7 +253,7 @@ DECLARE_SHAPE_FN(gruCell_bp) {
|
|||
Nd4jLong *dLdbcShapeInfo = nullptr;
|
||||
COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo);
|
||||
|
||||
return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWShapeInfo, dLdWcShapeInfo, dLdbShapeInfo, dLdbcShapeInfo);
|
||||
return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdhiShapeInfo), CONSTANT(dLdWShapeInfo), CONSTANT(dLdWcShapeInfo), CONSTANT(dLdbShapeInfo), CONSTANT(dLdbcShapeInfo));
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ namespace ops {
|
|||
Nd4jLong *out;
|
||||
COPY_SHAPE(in, out);
|
||||
|
||||
return SHAPELIST(out);
|
||||
return SHAPELIST(CONSTANT(out));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -87,8 +87,7 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr
|
|||
|
||||
|
||||
void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) {
|
||||
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
@ -89,7 +89,7 @@ static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarA
|
|||
|
||||
void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) {
|
||||
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
@ -119,11 +119,9 @@ void col2im_(nd4j::LaunchContext & context, const NDArray& input, NDArray& outp
|
|||
|
||||
|
||||
void col2im(nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), col2im_, (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), col2im_, (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void col2im_, (nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW), LIBND4J_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2445,71 +2445,52 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
|||
|
||||
|
||||
void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
||||
BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
||||
BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::upsampling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::upsampling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::upsampling2dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::upsampling3dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void ConvolutionUtils::pooling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::pooling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template void conv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void conv2dBP_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2d_, (const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2dBP_, (const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void sconv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void upsampling2d_, (const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void upsampling3d_, (const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void upsampling2dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void upsampling3dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void vol2col_, (const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void col2vol_, (const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling2d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling3d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling2dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling3dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
|
||||
}
|
||||
}
|
|
@ -81,10 +81,8 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const
|
|||
}
|
||||
}
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template void dilation2d_, (NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
|
||||
void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ namespace nd4j {
|
|||
double min_val = input.reduceNumber(reduce::SameOps::Min).e<double>(0);
|
||||
double max_val = input.reduceNumber(reduce::SameOps::Max).e<double>(0);
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INDEXING_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue