diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml
index 064dd3ecd..04a1fc0f0 100644
--- a/arbiter/arbiter-core/pom.xml
+++ b/arbiter/arbiter-core/pom.xml
@@ -91,7 +91,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/arbiter/arbiter-deeplearning4j/pom.xml b/arbiter/arbiter-deeplearning4j/pom.xml
index 85afe7a6b..b163e2ae4 100644
--- a/arbiter/arbiter-deeplearning4j/pom.xml
+++ b/arbiter/arbiter-deeplearning4j/pom.xml
@@ -70,7 +70,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java
index 9d9db6261..c64a06040 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java
@@ -305,7 +305,7 @@ public class TestGraphLocalExecution {
@Test
public void testLocalExecutionEarlyStopping() throws Exception {
EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
- .epochTerminationConditions(new MaxEpochsTerminationCondition(6))
+ .epochTerminationConditions(new MaxEpochsTerminationCondition(4))
.scoreCalculator(new ScoreProvider())
.modelSaver(new InMemoryModelSaver()).build();
Map commands = new HashMap<>();
@@ -348,7 +348,7 @@ public class TestGraphLocalExecution {
.dataProvider(dataProvider)
.scoreFunction(ScoreFunctions.testSetF1())
.modelSaver(new FileModelSaver(modelSavePath))
- .terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS),
+ .terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.build();
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/util/TestDataFactoryProviderMnist.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/util/TestDataFactoryProviderMnist.java
index 1e652cdbe..4416dd8cf 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/util/TestDataFactoryProviderMnist.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/util/TestDataFactoryProviderMnist.java
@@ -32,7 +32,7 @@ public class TestDataFactoryProviderMnist implements DataSetIteratorFactory {
private int terminationIter;
public TestDataFactoryProviderMnist(){
- this(16, 10);
+ this(16, 4);
}
@Override
diff --git a/arbiter/arbiter-server/pom.xml b/arbiter/arbiter-server/pom.xml
index 5d14fa6a0..bdea61138 100644
--- a/arbiter/arbiter-server/pom.xml
+++ b/arbiter/arbiter-server/pom.xml
@@ -56,7 +56,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/arbiter/arbiter-ui/pom.xml b/arbiter/arbiter-ui/pom.xml
index 03a335ea8..2067a3fc7 100644
--- a/arbiter/arbiter-ui/pom.xml
+++ b/arbiter/arbiter-ui/pom.xml
@@ -37,7 +37,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/arbiter/pom.xml b/arbiter/pom.xml
index 5f660c646..364c6d904 100644
--- a/arbiter/pom.xml
+++ b/arbiter/pom.xml
@@ -151,7 +151,7 @@
${skipTestResourceEnforcement}
- test-nd4j-native,test-nd4j-cuda-10.1
+ test-nd4j-native,test-nd4j-cuda-10.2
false
@@ -333,11 +333,11 @@
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
org.nd4j
- nd4j-cuda-10.1
+ nd4j-cuda-10.2
${nd4j.version}
test
diff --git a/change-cuda-versions.sh b/change-cuda-versions.sh
index 8acda5b1a..21f17bb72 100755
--- a/change-cuda-versions.sh
+++ b/change-cuda-versions.sh
@@ -20,7 +20,7 @@
set -e
-VALID_VERSIONS=( 9.2 10.0 10.1 )
+VALID_VERSIONS=( 9.2 10.0 10.1 10.2 )
usage() {
echo "Usage: $(basename $0) [-h|--help]
@@ -47,6 +47,10 @@ check_cuda_version() {
check_cuda_version "$VERSION"
case $VERSION in
+ 10.2)
+ VERSION2="7.6"
+ VERSION3="1.5.2"
+ ;;
10.1)
VERSION2="7.6"
VERSION3="1.5.2"
diff --git a/datavec/datavec-api/pom.xml b/datavec/datavec-api/pom.xml
index b3401b431..10ed3517a 100644
--- a/datavec/datavec-api/pom.xml
+++ b/datavec/datavec-api/pom.xml
@@ -117,7 +117,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-arrow/pom.xml b/datavec/datavec-arrow/pom.xml
index 6134bbf27..04420a5e9 100644
--- a/datavec/datavec-arrow/pom.xml
+++ b/datavec/datavec-arrow/pom.xml
@@ -56,7 +56,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-camel/pom.xml b/datavec/datavec-camel/pom.xml
index fdd8cdd86..3390242bc 100644
--- a/datavec/datavec-camel/pom.xml
+++ b/datavec/datavec-camel/pom.xml
@@ -110,7 +110,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-data/datavec-data-audio/pom.xml b/datavec/datavec-data/datavec-data-audio/pom.xml
index a7ee9a51d..1f99eab7c 100644
--- a/datavec/datavec-data/datavec-data-audio/pom.xml
+++ b/datavec/datavec-data/datavec-data-audio/pom.xml
@@ -72,7 +72,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-data/datavec-data-codec/pom.xml b/datavec/datavec-data/datavec-data-codec/pom.xml
index 2a65cb41a..a8e7ce493 100644
--- a/datavec/datavec-data/datavec-data-codec/pom.xml
+++ b/datavec/datavec-data/datavec-data-codec/pom.xml
@@ -59,7 +59,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-data/datavec-data-image/pom.xml b/datavec/datavec-data/datavec-data-image/pom.xml
index b97c9d5c6..fc85482a4 100644
--- a/datavec/datavec-data/datavec-data-image/pom.xml
+++ b/datavec/datavec-data/datavec-data-image/pom.xml
@@ -126,7 +126,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-data/datavec-data-nlp/pom.xml b/datavec/datavec-data/datavec-data-nlp/pom.xml
index 12df0fb08..0933a283e 100644
--- a/datavec/datavec-data/datavec-data-nlp/pom.xml
+++ b/datavec/datavec-data/datavec-data-nlp/pom.xml
@@ -67,7 +67,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-data/pom.xml b/datavec/datavec-data/pom.xml
index d6361cfb8..ef1558aab 100644
--- a/datavec/datavec-data/pom.xml
+++ b/datavec/datavec-data/pom.xml
@@ -58,7 +58,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-excel/pom.xml b/datavec/datavec-excel/pom.xml
index 589bf8b14..00fc890d8 100644
--- a/datavec/datavec-excel/pom.xml
+++ b/datavec/datavec-excel/pom.xml
@@ -58,7 +58,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-geo/pom.xml b/datavec/datavec-geo/pom.xml
index 50e843555..007d1cccd 100644
--- a/datavec/datavec-geo/pom.xml
+++ b/datavec/datavec-geo/pom.xml
@@ -49,7 +49,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-hadoop/pom.xml b/datavec/datavec-hadoop/pom.xml
index 5ec6d4c3f..7b74ead38 100644
--- a/datavec/datavec-hadoop/pom.xml
+++ b/datavec/datavec-hadoop/pom.xml
@@ -67,7 +67,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-jdbc/pom.xml b/datavec/datavec-jdbc/pom.xml
index 54a531047..bfafd25d0 100644
--- a/datavec/datavec-jdbc/pom.xml
+++ b/datavec/datavec-jdbc/pom.xml
@@ -65,7 +65,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-local/pom.xml b/datavec/datavec-local/pom.xml
index d2b15ffed..5c2c6f4ac 100644
--- a/datavec/datavec-local/pom.xml
+++ b/datavec/datavec-local/pom.xml
@@ -88,7 +88,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java
index 2f508f09e..19733f297 100644
--- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java
+++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java
@@ -256,11 +256,9 @@ public class ExecutionTest {
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.transform(
- new PythonTransform(
- "first = np.sin(first)\nsecond = np.cos(second)",
- schema
- )
- )
+ PythonTransform.builder().code(
+ "first = np.sin(first)\nsecond = np.cos(second)")
+ .outputSchema(schema).build())
.build();
List> functions = new ArrayList<>();
diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonTransformProcess.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java
similarity index 64%
rename from datavec/datavec-python/src/test/java/org/datavec/python/TestPythonTransformProcess.java
rename to datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java
index 77ba53e26..37df8ae52 100644
--- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonTransformProcess.java
+++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java
@@ -14,35 +14,40 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
-package org.datavec.python;
+package org.datavec.local.transforms.transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.condition.Condition;
import org.datavec.api.transform.filter.ConditionFilter;
import org.datavec.api.transform.filter.Filter;
-import org.datavec.api.writable.*;
import org.datavec.api.transform.schema.Schema;
-import org.junit.Ignore;
+import org.datavec.local.transforms.LocalTransformExecutor;
+
+import org.datavec.api.writable.*;
+import org.datavec.python.PythonCondition;
+import org.datavec.python.PythonTransform;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
-
+import javax.annotation.concurrent.NotThreadSafe;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
+import static junit.framework.TestCase.assertTrue;
+import static org.datavec.api.transform.schema.Schema.Builder;
+import static org.junit.Assert.*;
+
+@NotThreadSafe
public class TestPythonTransformProcess {
- @Test(timeout = 60000L)
+
+ @Test()
public void testStringConcat() throws Exception{
- Schema.Builder schemaBuilder = new Schema.Builder();
+ Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnString("col1")
.addColumnString("col2");
@@ -54,10 +59,12 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- new PythonTransform(pythonCode, finalSchema)
+ PythonTransform.builder().code(pythonCode)
+ .outputSchema(finalSchema)
+ .build()
).build();
- List inputs = Arrays.asList((Writable) new Text("Hello "), new Text("World!"));
+ List inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
List outputs = tp.execute(inputs);
assertEquals((outputs.get(0)).toString(), "Hello ");
@@ -68,7 +75,7 @@ public class TestPythonTransformProcess {
@Test(timeout = 60000L)
public void testMixedTypes() throws Exception{
- Schema.Builder schemaBuilder = new Schema.Builder();
+ Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
@@ -83,11 +90,12 @@ public class TestPythonTransformProcess {
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- new PythonTransform(pythonCode, finalSchema)
- ).build();
+ PythonTransform.builder().code(pythonCode)
+ .outputSchema(finalSchema)
+ .inputSchema(initialSchema)
+ .build() ).build();
- List inputs = Arrays.asList((Writable)
- new IntWritable(10),
+ List inputs = Arrays.asList((Writable)new IntWritable(10),
new FloatWritable(3.5f),
new Text("5"),
new DoubleWritable(2.0)
@@ -105,7 +113,7 @@ public class TestPythonTransformProcess {
INDArray expectedOutput = arr1.add(arr2);
- Schema.Builder schemaBuilder = new Schema.Builder();
+ Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
@@ -116,12 +124,14 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- new PythonTransform(pythonCode, finalSchema)
- ).build();
+ PythonTransform.builder().code(pythonCode)
+ .outputSchema(finalSchema)
+ .build() ).build();
List inputs = Arrays.asList(
- (Writable) new NDArrayWritable(arr1),
- new NDArrayWritable(arr2)
+ (Writable)
+ new NDArrayWritable(arr1),
+ new NDArrayWritable(arr2)
);
List outputs = tp.execute(inputs);
@@ -139,7 +149,7 @@ public class TestPythonTransformProcess {
INDArray expectedOutput = arr1.add(arr2);
- Schema.Builder schemaBuilder = new Schema.Builder();
+ Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
@@ -150,11 +160,13 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- new PythonTransform(pythonCode, finalSchema)
- ).build();
+ PythonTransform.builder().code(pythonCode)
+ .outputSchema(finalSchema)
+ .build() ).build();
List inputs = Arrays.asList(
- (Writable) new NDArrayWritable(arr1),
+ (Writable)
+ new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
@@ -172,7 +184,7 @@ public class TestPythonTransformProcess {
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
- Schema.Builder schemaBuilder = new Schema.Builder();
+ Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
@@ -183,11 +195,14 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- new PythonTransform(pythonCode, finalSchema)
+ PythonTransform.builder().code(pythonCode)
+ .outputSchema(finalSchema)
+ .build()
).build();
List inputs = Arrays.asList(
- (Writable) new NDArrayWritable(arr1),
+ (Writable)
+ new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
@@ -199,8 +214,8 @@ public class TestPythonTransformProcess {
}
@Test(timeout = 60000L)
- public void testPythonFilter(){
- Schema schema = new Schema.Builder().addColumnInteger("column").build();
+ public void testPythonFilter() {
+ Schema schema = new Builder().addColumnInteger("column").build();
Condition condition = new PythonCondition(
"f = lambda: column < 0"
@@ -210,17 +225,17 @@ public class TestPythonTransformProcess {
Filter filter = new ConditionFilter(condition);
- assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(10))));
- assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(1))));
- assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(0))));
- assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-1))));
- assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-10))));
+ assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
+ assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
+ assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
+ assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
+ assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
}
@Test(timeout = 60000L)
public void testPythonFilterAndTransform() throws Exception{
- Schema.Builder schemaBuilder = new Schema.Builder();
+ Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
@@ -241,33 +256,85 @@ public class TestPythonTransformProcess {
String pythonCode = "col6 = str(col1 + col2)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- new PythonTransform(
- pythonCode,
- finalSchema
- )
+ PythonTransform.builder().code(pythonCode)
+ .outputSchema(finalSchema)
+ .build()
).filter(
filter
).build();
List> inputs = new ArrayList<>();
inputs.add(
- Arrays.asList((Writable) new IntWritable(5),
+ Arrays.asList(
+ (Writable)
+ new IntWritable(5),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
- Arrays.asList((Writable) new IntWritable(-3),
+ Arrays.asList(
+ (Writable)
+ new IntWritable(-3),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
- Arrays.asList((Writable) new IntWritable(5),
+ Arrays.asList(
+ (Writable)
+ new IntWritable(5),
new FloatWritable(11.2f),
new Text("abcd"),
new DoubleWritable(2.1))
);
+ LocalTransformExecutor.execute(inputs,tp);
}
-}
+
+
+ @Test
+ public void testPythonTransformNoOutputSpecified() throws Exception {
+ PythonTransform pythonTransform = PythonTransform.builder()
+ .code("a += 2; b = 'hello world'")
+ .returnAllInputs(true)
+ .build();
+ List> inputs = new ArrayList<>();
+ inputs.add(Arrays.asList((Writable)new IntWritable(1)));
+ Schema inputSchema = new Builder()
+ .addColumnInteger("a")
+ .build();
+
+ TransformProcess tp = new TransformProcess.Builder(inputSchema)
+ .transform(pythonTransform)
+ .build();
+ List> execute = LocalTransformExecutor.execute(inputs, tp);
+ assertEquals(3,execute.get(0).get(0).toInt());
+ assertEquals("hello world",execute.get(0).get(1).toString());
+
+ }
+
+ @Test
+ public void testNumpyTransform() throws Exception {
+ PythonTransform pythonTransform = PythonTransform.builder()
+ .code("a += 2; b = 'hello world'")
+ .returnAllInputs(true)
+ .build();
+
+ List> inputs = new ArrayList<>();
+ inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
+ Schema inputSchema = new Builder()
+ .addColumnNDArray("a",new long[]{1,1})
+ .build();
+
+ TransformProcess tp = new TransformProcess.Builder(inputSchema)
+ .transform(pythonTransform)
+ .build();
+ List> execute = LocalTransformExecutor.execute(inputs, tp);
+ assertFalse(execute.isEmpty());
+ assertNotNull(execute.get(0));
+ assertNotNull(execute.get(0).get(0));
+ assertEquals("hello world",execute.get(0).get(0).toString());
+ }
+
+}
\ No newline at end of file
diff --git a/datavec/datavec-perf/pom.xml b/datavec/datavec-perf/pom.xml
index 95f3135e5..a51b9aba1 100644
--- a/datavec/datavec-perf/pom.xml
+++ b/datavec/datavec-perf/pom.xml
@@ -59,7 +59,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml
index e60bc9219..55cf6c5da 100644
--- a/datavec/datavec-python/pom.xml
+++ b/datavec/datavec-python/pom.xml
@@ -28,15 +28,21 @@
- com.googlecode.json-simple
- json-simple
- 1.1
+ org.json
+ json
+ 20190722
org.bytedeco
cpython-platform
${cpython-platform.version}
+
+ org.bytedeco
+ numpy-platform
+ ${numpy.javacpp.version}
+
+
com.google.code.findbugs
jsr305
@@ -65,7 +71,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java
index a6ccc3036..ab49cf5ea 100644
--- a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java
+++ b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java
@@ -16,10 +16,13 @@
package org.datavec.python;
+import lombok.Builder;
import lombok.Getter;
+import lombok.NoArgsConstructor;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
@@ -33,19 +36,27 @@ import org.nd4j.linalg.api.buffer.DataType;
* @author Fariz Rahman
*/
@Getter
+@NoArgsConstructor
public class NumpyArray {
- private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
+ private static NativeOps nativeOps;
private long address;
private long[] shape;
private long[] strides;
- private DataType dtype = DataType.FLOAT;
+ private DataType dtype;
private INDArray nd4jArray;
+ static {
+ //initialize
+ Nd4j.scalar(1.0);
+ nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
+ }
- public NumpyArray(long address, long[] shape, long strides[], boolean copy){
+ @Builder
+ public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) {
this.address = address;
this.shape = shape;
this.strides = strides;
+ this.dtype = dtype;
setND4JArray();
if (copy){
nd4jArray = nd4jArray.dup();
@@ -57,8 +68,9 @@ public class NumpyArray {
public NumpyArray copy(){
return new NumpyArray(nd4jArray.dup());
}
+
public NumpyArray(long address, long[] shape, long strides[]){
- this(address, shape, strides, false);
+ this(address, shape, strides, false,DataType.FLOAT);
}
public NumpyArray(long address, long[] shape, long strides[], DataType dtype){
@@ -77,9 +89,9 @@ public class NumpyArray {
}
}
- private void setND4JArray(){
+ private void setND4JArray() {
long size = 1;
- for(long d: shape){
+ for(long d: shape) {
size *= d;
}
Pointer ptr = nativeOps.pointerForAddress(address);
@@ -88,10 +100,11 @@ public class NumpyArray {
DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype);
int elemSize = buff.getElementSize();
long[] nd4jStrides = new long[strides.length];
- for (int i=0; i= 1,"Python code must not be empty!");
code = pythonCode;
}
- private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{
- PythonVariables pyVars = new PythonVariables();
- int numCols = schema.numColumns();
- for (int i=0; i writables){
- PythonVariables ret = new PythonVariables();
- for (String name: pyInputs.getVariables()){
- int colIdx = inputSchema.getIndexOfColumn(name);
- Writable w = writables.get(colIdx);
- PythonVariables.Type pyType = pyInputs.getType(name);
- switch (pyType){
- case INT:
- if (w instanceof LongWritable){
- ret.addInt(name, ((LongWritable)w).get());
- }
- else{
- ret.addInt(name, ((IntWritable)w).get());
- }
- break;
- case FLOAT:
- ret.addFloat(name, ((DoubleWritable)w).get());
- break;
- case STR:
- ret.addStr(name, ((Text)w).toString());
- break;
- case NDARRAY:
- ret.addNDArray(name,((NDArrayWritable)w).get());
- break;
- }
-
- }
- return ret;
- }
@Override
- public void setInputSchema(Schema inputSchema){
+ public void setInputSchema(Schema inputSchema) {
this.inputSchema = inputSchema;
try{
pyInputs = schemaToPythonVariables(inputSchema);
PythonVariables pyOuts = new PythonVariables();
pyOuts.addInt("out");
- pythonTransform = new PythonTransform(
- code + "\n\nout=f()\nout=0 if out is None else int(out)", // TODO: remove int conversion after boolean support is covered
- pyInputs,
- pyOuts
- );
+ pythonTransform = PythonTransform.builder()
+ .code(code + "\n\nout=f()\nout=0 if out is None else int(out)")
+ .inputs(pyInputs)
+ .outputs(pyOuts)
+ .build();
+
}
catch (Exception e){
throw new RuntimeException(e);
@@ -127,41 +76,47 @@ public class PythonCondition implements Condition {
return inputSchema;
}
- public String[] outputColumnNames(){
+ @Override
+ public String[] outputColumnNames() {
String[] columnNames = new String[inputSchema.numColumns()];
inputSchema.getColumnNames().toArray(columnNames);
return columnNames;
}
+ @Override
public String outputColumnName(){
return outputColumnNames()[0];
}
+ @Override
public String[] columnNames(){
return outputColumnNames();
}
+ @Override
public String columnName(){
return outputColumnName();
}
+ @Override
public Schema transform(Schema inputSchema){
return inputSchema;
}
- public boolean condition(List list){
+ @Override
+ public boolean condition(List list) {
PythonVariables inputs = getPyInputsFromWritables(list);
try{
PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs());
boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0;
return ret;
}
- catch (Exception e){
+ catch (Exception e) {
throw new RuntimeException(e);
}
-
}
+ @Override
public boolean condition(Object input){
return condition(input);
}
@@ -177,5 +132,37 @@ public class PythonCondition implements Condition {
throw new UnsupportedOperationException("not supported");
}
+ private PythonVariables getPyInputsFromWritables(List writables) {
+ PythonVariables ret = new PythonVariables();
-}
+ for (int i = 0; i < inputSchema.numColumns(); i++){
+ String name = inputSchema.getName(i);
+ Writable w = writables.get(i);
+ PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i));
+ switch (pyType){
+ case INT:
+ if (w instanceof LongWritable) {
+ ret.addInt(name, ((LongWritable)w).get());
+ }
+ else {
+ ret.addInt(name, ((IntWritable)w).get());
+ }
+
+ break;
+ case FLOAT:
+ ret.addFloat(name, ((DoubleWritable)w).get());
+ break;
+ case STR:
+ ret.addStr(name, w.toString());
+ break;
+ case NDARRAY:
+ ret.addNDArray(name,((NDArrayWritable)w).get());
+ break;
+ }
+ }
+
+ return ret;
+ }
+
+
+}
\ No newline at end of file
diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java
index c46d0d710..c6272e7ad 100644
--- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java
+++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java
@@ -17,132 +17,504 @@
package org.datavec.python;
-import java.io.File;
-import java.io.FileInputStream;
-import java.util.HashMap;
+import java.io.*;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Map;
-import java.util.regex.Pattern;
+
import lombok.extern.slf4j.Slf4j;
-import org.json.simple.JSONArray;
-import org.json.simple.JSONObject;
-import org.json.simple.parser.JSONParser;
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.io.IOUtils;
+import org.json.JSONObject;
+import org.json.JSONArray;
import org.bytedeco.javacpp.*;
import org.bytedeco.cpython.*;
import static org.bytedeco.cpython.global.python.*;
+import org.bytedeco.numpy.global.numpy;
+
+import static org.datavec.python.PythonUtils.*;
+
+import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.io.ClassPathResource;
+
/**
- * Python executioner
+ * Allows execution of python scripts managed by
+ * an internal interpreter.
+ * An end user may specify a python script to run
+ * via any of the execution methods available in this class.
+ *
+ * At static initialization time (when the class is first initialized)
+ * a number of components are setup:
+ * 1. The python path. A user may over ride this with the system property {@link #DEFAULT_PYTHON_PATH_PROPERTY}
+ *
+ * 2. Since this executioner uses javacpp to manage and run python interpreters underneath the covers,
+ * a user may also over ride the system property {@link #JAVACPP_PYTHON_APPEND_TYPE} with one of the {@link JavaCppPathType}
+ * values. This will allow the user to determine whether the javacpp default python path is used at all, and if so
+ * whether it is appended, prepended, or not used. This behavior is useful when you need to use an external
+ * python distribution such as anaconda.
+ *
+ * 3. A main interpreter: This is the default interpreter to be used with the main thread.
+ * We may initialize one or more relative to the thread invoking the python code.
+ *
+ * 4. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of
+ * native libraries needed by numpy are allowed to load in the proper order. If we don't do this,
+ * it causes a variety of issues with running numpy.
+ *
+ * 5. Various python scripts pre defined on the classpath included right with the java code.
+ * These are auxillary python scripts used for loading classes, pre defining certain kinds of behavior
+ * in order for us to manipulate values within the python memory, as well as pulling them out of memory
+ * for integration within the internal python executioner. You can see this behavior in {@link #_readOutputs(PythonVariables)}
+ * as an example.
+ *
+ * For more information on how this works, please take a look at the {@link #init()}
+ * method.
+ *
+ * Generally, a user defining a python script for use by the python executioner
+ * will have a set of defined target input values and output values.
+ * These values should not be present when actually running the script, but just referenced.
+ * In order to test your python script for execution outside the engine,
+ * we recommend commenting out a few default values as dummy input values.
+ * This will allow an end user to test their script before trying to use the server.
+ *
+ * In order to get output values out of a python script, all a user has to do
+ * is define the output variables they want being used in the final output in the actual pipeline.
+ * For example, if a user wants to return a dictionary, they just have to create a dictionary with that name
+ * and based on the configured {@link PythonVariables} passed as outputs
+ * to one of the execution methods, we can pull the values out automatically.
+ *
+ * For input definitions, it is similar. You just define the values you want used in
+ * {@link PythonVariables} and we will automatically generate code for defining those values
+ * as desired for running. This allows the user to customize values dynamically
+ * at runtime but reference them by name in a python script.
+ *
*
* @author Fariz Rahman
+ * @author Adam Gibson
+ */
+
+
+/**
+ * Allows execution of python scripts managed by
+ * an internal interpreter.
+ * An end user may specify a python script to run
+ * via any of the execution methods available in this class.
+ *
+ * At static initialization time (when the class is first initialized)
+ * a number of components are setup:
+ * 1. The python path. A user may over ride this with the system property {@link #DEFAULT_PYTHON_PATH_PROPERTY}
+ *
+ * 2. Since this executioner uses javacpp to manage and run python interpreters underneath the covers,
+ * a user may also over ride the system property {@link #JAVACPP_PYTHON_APPEND_TYPE} with one of the {@link JavaCppPathType}
+ * values. This will allow the user to determine whether the javacpp default python path is used at all, and if so
+ * whether it is appended, prepended, or not used. This behavior is useful when you need to use an external
+ * python distribution such as anaconda.
+ *
+ * 3. A main interpreter: This is the default interpreter to be used with the main thread.
+ * We may initialize one or more relative to the thread invoking the python code.
+ *
+ * 4. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of
+ * native libraries needed by numpy are allowed to load in the proper order. If we don't do this,
+ * it causes a variety of issues with running numpy.
+ *
+ * 5. Various python scripts pre defined on the classpath included right with the java code.
+ * These are auxillary python scripts used for loading classes, pre defining certain kinds of behavior
+ * in order for us to manipulate values within the python memory, as well as pulling them out of memory
+ * for integration within the internal python executioner. You can see this behavior in {@link #_readOutputs(PythonVariables)}
+ * as an example.
+ *
+ * For more information on how this works, please take a look at the {@link #init()}
+ * method.
+ *
+ * Generally, a user defining a python script for use by the python executioner
+ * will have a set of defined target input values and output values.
+ * These values should not be present when actually running the script, but just referenced.
+ * In order to test your python script for execution outside the engine,
+ * we recommend commenting out a few default values as dummy input values.
+ * This will allow an end user to test their script before trying to use the server.
+ *
+ * In order to get output values out of a python script, all a user has to do
+ * is define the output variables they want being used in the final output in the actual pipeline.
+ * For example, if a user wants to return a dictionary, they just have to create a dictionary with that name
+ * and based on the configured {@link PythonVariables} passed as outputs
+ * to one of the execution methods, we can pull the values out automatically.
+ *
+ * For input definitions, it is similar. You just define the values you want used in
+ * {@link PythonVariables} and we will automatically generate code for defining those values
+ * as desired for running. This allows the user to customize values dynamically
+ * at runtime but reference them by name in a python script.
+ *
+ *
+ * @author Fariz Rahman
+ * @author Adam Gibson
*/
@Slf4j
public class PythonExecutioner {
- private static PyObject module;
- private static PyObject globals;
- private static JSONParser parser = new JSONParser();
- private static Map gilStates = new HashMap<>();
+
+ private final static String fileVarName = "_f" + Nd4j.getRandom().nextInt();
+ private static boolean init;
+ public final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.datavec.python.path";
+ public final static String JAVACPP_PYTHON_APPEND_TYPE = "org.datavec.python.javacpp.path.append";
+ public final static String DEFAULT_APPEND_TYPE = "before";
+ private static Map interpreters = new java.util.concurrent.ConcurrentHashMap<>();
+ private static PyThreadState currentThreadState;
+ private static PyThreadState mainThreadState;
+ public final static String ALL_VARIABLES_KEY = "allVariables";
+ public final static String MAIN_INTERPRETER_NAME = "main";
+ private static String clearVarsCode;
+
+ private static String currentInterpreter = MAIN_INTERPRETER_NAME;
+
+ /**
+ * One of a few desired values
+ * for how we should handle
+ * using javacpp's python path.
+ * BEFORE: Prepend the python path alongside a defined one
+ * AFTER: Append the javacpp python path alongside the defined one
+ * NONE: Don't use javacpp's python path at all
+ */
+ public enum JavaCppPathType {
+ BEFORE,AFTER,NONE
+ }
+
+ /**
+ * Set the python path.
+ * Generally you can just use the PYTHONPATH environment variable,
+ * but if you need to set it from code, this can work as well.
+ */
+ public static synchronized void setPythonPath() {
+ if(!init) {
+ try {
+ String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY);
+ if(path == null) {
+ log.info("Setting python default path");
+ File[] packages = numpy.cachePackages();
+ Py_SetPath(packages);
+ }
+ else {
+ log.info("Setting python path " + path);
+ StringBuffer sb = new StringBuffer();
+ File[] packages = numpy.cachePackages();
+
+ JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE,DEFAULT_APPEND_TYPE).toUpperCase());
+ switch(pathAppendValue) {
+ case BEFORE:
+ for(File cacheDir : packages) {
+ sb.append(cacheDir);
+ sb.append(java.io.File.pathSeparator);
+ }
+
+ sb.append(path);
+
+ log.info("Prepending javacpp python path " + sb.toString());
+ break;
+ case AFTER:
+ sb.append(path);
+
+ for(File cacheDir : packages) {
+ sb.append(cacheDir);
+ sb.append(java.io.File.pathSeparator);
+ }
+
+ log.info("Appending javacpp python path " + sb.toString());
+ break;
+ case NONE:
+ log.info("Not appending javacpp path");
+ sb.append(path);
+ break;
+ }
+
+ //prepend the javacpp packages
+ log.info("Final python path " + sb.toString());
+
+ Py_SetPath(sb.toString());
+ }
+ } catch (IOException e) {
+ log.error("Failed to set python path.", e);
+ }
+ }
+ else {
+ throw new IllegalStateException("Unable to reset python path. Already initialized.");
+ }
+ }
+
+ /**
+ * Initialize the name space and the python execution
+ * Calling this method more than once will be a no op
+ */
+ public static synchronized void init() {
+ if(init) {
+ return;
+ }
+
+ try(InputStream is = new org.nd4j.linalg.io.ClassPathResource("pythonexec/clear_vars.py").getInputStream()) {
+ clearVarsCode = IOUtils.toString(new java.io.InputStreamReader(is));
+ } catch (java.io.IOException e) {
+ throw new IllegalStateException("Unable to read pythonexec/clear_vars.py");
+ }
+
+ log.info("CPython: PyEval_InitThreads()");
+ PyEval_InitThreads();
+ log.info("CPython: Py_InitializeEx()");
+ Py_InitializeEx(0);
+ log.info("CPython: PyGILState_Release()");
+ init = true;
+ interpreters.put(MAIN_INTERPRETER_NAME, PyThreadState_Get());
+ numpy._import_array();
+ applyPatches();
+ }
+
+
+ /**
+ * Run {@link #resetInterpreter(String)}
+ * on all interpreters.
+ */
+ public static void resetAllInterpreters() {
+ for(String interpreter : interpreters.keySet()) {
+ resetInterpreter(interpreter);
+ }
+ }
+
+ /**
+ * Reset the main interpreter.
+ * For more information see {@link #resetInterpreter(String)}
+ */
+ public static void resetMainInterpreter() {
+ resetInterpreter(MAIN_INTERPRETER_NAME);
+ }
+
+ /**
+ * Reset the interpreter with the given name.
+ * Runs pythonexec/clear_vars.py
+ * For more information see:
+ * https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script
+ * @param interpreterName the interpreter name to
+ * reset
+ */
+ public static synchronized void resetInterpreter(String interpreterName) {
+ Preconditions.checkState(hasInterpreter(interpreterName));
+ log.info("Resetting interpreter " + interpreterName);
+ String oldInterpreter = currentInterpreter;
+ setInterpreter(interpreterName);
+ exec("pass");
+ //exec(interpreterName); // ??
+ setInterpreter(oldInterpreter);
+ }
+
+ /**
+ * Clear the non main intrepreters.
+ */
+ public static void clearNonMainInterpreters() {
+ for(String key : interpreters.keySet()) {
+ if(!key.equals(MAIN_INTERPRETER_NAME)) {
+ deleteInterpreter(key);
+ }
+ }
+ }
+
+ public static PythonVariables defaultPythonVariableOutput() {
+ PythonVariables ret = new PythonVariables();
+ ret.add(ALL_VARIABLES_KEY, PythonVariables.Type.DICT);
+ return ret;
+ }
+
+ /**
+ * Return the python path being used.
+ * @return a string specifying the python path in use
+ */
+ public static String getPythonPath() {
+ return new BytePointer(Py_GetPath()).getString();
+ }
+
static {
+ setPythonPath();
init();
}
- public static void init(){
- log.info("CPython: Py_InitializeEx()");
- Py_InitializeEx(1);
- log.info("CPython: PyEval_InitThreads()");
- PyEval_InitThreads();
- log.info("CPython: PyImport_AddModule()");
- module = PyImport_AddModule("__main__");
- log.info("CPython: PyModule_GetDict()");
- globals = PyModule_GetDict(module);
- log.info("CPython: PyThreadState_Get()");
+
+ /* ---------sub-interpreter and gil management-----------*/
+ public static void setInterpreter(String interpreterName) {
+ if (!hasInterpreter(interpreterName)){
+ PyThreadState main = PyThreadState_Get();
+ PyThreadState ts = Py_NewInterpreter();
+
+ interpreters.put(interpreterName, ts);
+ PyThreadState_Swap(main);
+ }
+
+ currentInterpreter = interpreterName;
+ }
+
+ /**
+ * Returns the current interpreter.
+ * @return
+ */
+ public static String getInterpreter() {
+ return currentInterpreter;
+ }
+
+
+ public static boolean hasInterpreter(String interpreterName){
+ return interpreters.containsKey(interpreterName);
+ }
+
+ public static void deleteInterpreter(String interpreterName) {
+ if (interpreterName.equals("main")){
+ throw new IllegalArgumentException("Can not delete main interpreter");
+ }
+
+ Py_EndInterpreter(interpreters.remove(interpreterName));
+ }
+
+ private static synchronized void acquireGIL() {
+ log.info("acquireGIL()");
+ log.info("CPython: PyEval_SaveThread()");
+ mainThreadState = PyEval_SaveThread();
+ log.info("CPython: PyThreadState_New()");
+ currentThreadState = PyThreadState_New(interpreters.get(currentInterpreter).interp());
+ log.info("CPython: PyEval_RestoreThread()");
+ PyEval_RestoreThread(currentThreadState);
+ log.info("CPython: PyThreadState_Swap()");
+ PyThreadState_Swap(currentThreadState);
+
+ }
+
+ private static synchronized void releaseGIL() {
+ log.info("CPython: PyEval_SaveThread()");
PyEval_SaveThread();
+ log.info("CPython: PyEval_RestoreThread()");
+ PyEval_RestoreThread(mainThreadState);
}
- public static void free(){
- Py_Finalize();
+ /* -------------------*/
+ /**
+ * Print the python version to standard out.
+ */
+ public static void printPythonVersion() {
+ exec("import sys; print(sys.version) sys.stdout.flush();");
}
- private static String inputCode(PythonVariables pyInputs)throws Exception{
- String inputCode = "loc={};";
+
+
+ private static String inputCode(PythonVariables pyInputs)throws Exception {
+ String inputCode = "";
if (pyInputs == null){
return inputCode;
}
+
Map strInputs = pyInputs.getStrVariables();
Map intInputs = pyInputs.getIntVariables();
Map floatInputs = pyInputs.getFloatVariables();
- Map ndInputs = pyInputs.getNDArrayVariables();
+ Map ndInputs = pyInputs.getNdVars();
Map listInputs = pyInputs.getListVariables();
Map fileInputs = pyInputs.getFileVariables();
+ Map> dictInputs = pyInputs.getDictVariables();
- String[] VarNames;
+ String[] varNames;
- VarNames = strInputs.keySet().toArray(new String[strInputs.size()]);
- for(Object varName: VarNames){
+ varNames = strInputs.keySet().toArray(new String[strInputs.size()]);
+ for(String varName: varNames) {
+ Preconditions.checkNotNull(varName,"Var name is null!");
+ Preconditions.checkNotNull(varName.isEmpty(),"Var name can not be empty!");
String varValue = strInputs.get(varName);
- inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n";
- inputCode += "loc['" + varName + "']=" + varName + "\n";
+ //inputCode += varName + "= {}\n";
+ if(varValue != null)
+ inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n";
+ else {
+ inputCode += varName + " = ''\n";
+ }
}
- VarNames = intInputs.keySet().toArray(new String[intInputs.size()]);
- for(String varName: VarNames){
+ varNames = intInputs.keySet().toArray(new String[intInputs.size()]);
+ for(String varName: varNames) {
Long varValue = intInputs.get(varName);
- inputCode += varName + " = " + varValue.toString() + "\n";
- inputCode += "loc['" + varName + "']=" + varName + "\n";
+ if(varValue != null)
+ inputCode += varName + " = " + varValue.toString() + "\n";
+ else {
+ inputCode += " = 0\n";
+ }
}
- VarNames = floatInputs.keySet().toArray(new String[floatInputs.size()]);
- for(String varName: VarNames){
+ varNames = dictInputs.keySet().toArray(new String[dictInputs.size()]);
+ for(String varName: varNames) {
+ Map,?> varValue = dictInputs.get(varName);
+ if(varValue != null) {
+ throw new IllegalArgumentException("Unable to generate input code for dictionaries.");
+ }
+ else {
+ inputCode += " = {}\n";
+ }
+ }
+
+ varNames = floatInputs.keySet().toArray(new String[floatInputs.size()]);
+ for(String varName: varNames){
Double varValue = floatInputs.get(varName);
- inputCode += varName + " = " + varValue.toString() + "\n";
- inputCode += "loc['" + varName + "']=" + varName + "\n";
+ if(varValue != null)
+ inputCode += varName + " = " + varValue.toString() + "\n";
+ else {
+ inputCode += varName + " = 0.0\n";
+ }
}
- VarNames = listInputs.keySet().toArray(new String[listInputs.size()]);
- for (String varName: VarNames){
+ varNames = listInputs.keySet().toArray(new String[listInputs.size()]);
+ for (String varName: varNames) {
Object[] varValue = listInputs.get(varName);
- String listStr = jArrayToPyString(varValue);
- inputCode += varName + " = " + listStr + "\n";
- inputCode += "loc['" + varName + "']=" + varName + "\n";
+ if(varValue != null) {
+ String listStr = jArrayToPyString(varValue);
+ inputCode += varName + " = " + listStr + "\n";
+ }
+ else {
+ inputCode += varName + " = []\n";
+ }
+
}
- VarNames = fileInputs.keySet().toArray(new String[fileInputs.size()]);
- for(Object varName: VarNames){
+ varNames = fileInputs.keySet().toArray(new String[fileInputs.size()]);
+ for(String varName: varNames) {
String varValue = fileInputs.get(varName);
- inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n";
- inputCode += "loc['" + varName + "']=" + varName + "\n";
+ if(varValue != null)
+ inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n";
+ else {
+ inputCode += varName + " = ''\n";
+ }
}
- if (ndInputs.size()> 0){
- inputCode += "import ctypes; import numpy as np;";
- VarNames = ndInputs.keySet().toArray(new String[ndInputs.size()]);
+ if (!ndInputs.isEmpty()) {
+ inputCode += "import ctypes\n\nimport sys\nimport numpy as np\n";
+ varNames = ndInputs.keySet().toArray(new String[ndInputs.size()]);
- String converter = "__arr_converter = lambda addr, shape, type: np.ctypeslib.as_array(ctypes.cast(addr, ctypes.POINTER(type)), shape);";
+ String converter = "__arr_converter = lambda addr, shape, type: np.ctypeslib.as_array(ctypes.cast(addr, ctypes.POINTER(type)), shape)\n";
inputCode += converter;
- for(String varName: VarNames){
+ for(String varName: varNames) {
NumpyArray npArr = ndInputs.get(varName);
+ if(npArr == null)
+ continue;
+
npArr = npArr.copy();
String shapeStr = "(";
for (long d: npArr.getShape()){
- shapeStr += String.valueOf(d) + ",";
+ shapeStr += d + ",";
}
shapeStr += ")";
String code;
String ctype;
- if (npArr.getDtype() == DataType.FLOAT){
+ if (npArr.getDtype() == DataType.FLOAT) {
ctype = "ctypes.c_float";
}
- else if (npArr.getDtype() == DataType.DOUBLE){
+ else if (npArr.getDtype() == DataType.DOUBLE) {
ctype = "ctypes.c_double";
}
- else if (npArr.getDtype() == DataType.SHORT){
+ else if (npArr.getDtype() == DataType.SHORT) {
ctype = "ctypes.c_int16";
}
- else if (npArr.getDtype() == DataType.INT){
+ else if (npArr.getDtype() == DataType.INT) {
ctype = "ctypes.c_int32";
}
else if (npArr.getDtype() == DataType.LONG){
@@ -152,10 +524,9 @@ public class PythonExecutioner {
throw new Exception("Unsupported data type: " + npArr.getDtype().toString() + ".");
}
- code = "__arr_converter(" + String.valueOf(npArr.getAddress()) + "," + shapeStr + "," + ctype + ")";
- code = varName + "=" + code + "\n";
+ code = "__arr_converter(" + npArr.getAddress() + "," + shapeStr + "," + ctype + ")";
+ code = varName + "=" + code + "\n";
inputCode += code;
- inputCode += "loc['" + varName + "']=" + varName + "\n";
}
}
@@ -163,49 +534,62 @@ public class PythonExecutioner {
}
- private static void _readOutputs(PythonVariables pyOutputs){
- String json = read(getTempFile());
+ private static synchronized void _readOutputs(PythonVariables pyOutputs) throws IOException {
File f = new File(getTempFile());
+ Preconditions.checkState(f.exists(),"File " + f.getAbsolutePath() + " failed to get written for reading outputs!");
+ String json = FileUtils.readFileToString(f, Charset.defaultCharset());
+ log.info("Executioner output: ");
+ log.info(json);
f.delete();
- JSONParser p = new JSONParser();
- try{
- JSONObject jobj = (JSONObject) p.parse(json);
- for (String varName: pyOutputs.getVariables()){
+
+ if(json.isEmpty()) {
+ log.warn("No json found fore reading outputs. Returning.");
+ return;
+ }
+
+ try {
+ JSONObject jobj = new JSONObject(json);
+ for (String varName: pyOutputs.getVariables()) {
PythonVariables.Type type = pyOutputs.getType(varName);
- if (type == PythonVariables.Type.NDARRAY){
+ if (type == PythonVariables.Type.NDARRAY) {
JSONObject varValue = (JSONObject)jobj.get(varName);
- long address = (Long)varValue.get("address");
- JSONArray shapeJson = (JSONArray)varValue.get("shape");
- JSONArray stridesJson = (JSONArray)varValue.get("strides");
+ long address = (Long) varValue.getLong("address");
+ JSONArray shapeJson = (JSONArray) varValue.get("shape");
+ JSONArray stridesJson = (JSONArray) varValue.get("strides");
long[] shape = jsonArrayToLongArray(shapeJson);
long[] strides = jsonArrayToLongArray(stridesJson);
String dtypeName = (String)varValue.get("dtype");
DataType dtype;
- if (dtypeName.equals("float64")){
+ if (dtypeName.equals("float64")) {
dtype = DataType.DOUBLE;
}
- else if (dtypeName.equals("float32")){
+ else if (dtypeName.equals("float32")) {
dtype = DataType.FLOAT;
}
- else if (dtypeName.equals("int16")){
+ else if (dtypeName.equals("int16")) {
dtype = DataType.SHORT;
}
- else if (dtypeName.equals("int32")){
+ else if (dtypeName.equals("int32")) {
dtype = DataType.INT;
}
- else if (dtypeName.equals("int64")){
+ else if (dtypeName.equals("int64")) {
dtype = DataType.LONG;
}
else{
throw new Exception("Unsupported array type " + dtypeName + ".");
}
+
pyOutputs.setValue(varName, new NumpyArray(address, shape, strides, dtype, true));
-
}
- else if (type == PythonVariables.Type.LIST){
- JSONArray varValue = (JSONArray)jobj.get(varName);
- pyOutputs.setValue(varName, varValue.toArray());
+ else if (type == PythonVariables.Type.LIST) {
+ JSONArray varValue = (JSONArray) jobj.get(varName);
+ pyOutputs.setValue(varName, varValue);
+ }
+ else if (type == PythonVariables.Type.DICT) {
+ Map map = toMap((JSONObject) jobj.get(varName));
+ pyOutputs.setValue(varName, map);
+
}
else{
pyOutputs.setValue(varName, jobj.get(varName));
@@ -217,266 +601,422 @@ public class PythonExecutioner {
}
}
- private static void acquireGIL(){
- log.info("---_enterSubInterpreter()---");
- if (PyGILState_Check() != 1){
- gilStates.put(Thread.currentThread().getId(), PyGILState_Ensure());
- log.info("GIL ensured");
+
+
+
+ private static synchronized void _exec(String code) {
+ log.info(code);
+ log.info("CPython: PyRun_SimpleStringFlag()");
+
+ int result = PyRun_SimpleStringFlags(code, null);
+ if (result != 0) {
+ log.info("CPython: PyErr_Print");
+ PyErr_Print();
+ throw new RuntimeException("exec failed");
}
}
- private static void releaseGIL(){
- if (PyGILState_Check() == 1){
- log.info("Releasing gil...");
- PyGILState_Release(gilStates.get(Thread.currentThread().getId()));
- log.info("Gil released.");
- }
-
+ private static synchronized void _exec_wrapped(String code) {
+ _exec(getWrappedCode(code));
}
/**
* Executes python code. Also manages python thread state.
- * @param code
+ * @param code the code to run
*/
- public static void exec(String code){
- code = getFunctionalCode("__f_" + Thread.currentThread().getId(), code);
+ public static void exec(String code) {
+ code = getWrappedCode(code);
+ if(code.contains("import numpy") && !getInterpreter().equals("main")) {// FIXME
+ throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons.");
+ }
+
acquireGIL();
- log.info("CPython: PyRun_SimpleStringFlag()");
- log.info(code);
- int result = PyRun_SimpleStringFlags(code, null);
- if (result != 0){
- PyErr_Print();
- throw new RuntimeException("exec failed");
+ _exec(code);
+ log.info("Exec done");
+ releaseGIL();
+ }
+
+ private static boolean _hasGlobalVariable(String varName){
+ PyObject mainModule = PyImport_AddModule("__main__");
+ PyObject var = PyObject_GetAttrString(mainModule, varName);
+ boolean hasVar = var != null;
+ Py_DecRef(var);
+ return hasVar;
+ }
+
+ /**
+ * Executes python code and looks for methods setup() and run()
+ * If both setup() and run() are found, both are executed for the first
+ * time and for subsequent calls only run() is executed.
+ */
+ public static void execWithSetupAndRun(String code) {
+ code = getWrappedCode(code);
+ if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME
+ throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons.");
+ }
+
+ acquireGIL();
+ _exec(code);
+ if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){
+ log.debug("setup() and run() methods found.");
+ if (!_hasGlobalVariable("__setup_done__")){
+ log.debug("Calling setup()...");
+ _exec("setup()");
+ _exec("__setup_done__ = True");
+ }
+ log.debug("Calling run()...");
+ _exec("run()");
}
log.info("Exec done");
releaseGIL();
}
- public static void exec(String code, PythonVariables pyOutputs){
- exec(code + '\n' + outputCode(pyOutputs));
- _readOutputs(pyOutputs);
+ /**
+ * Executes python code and looks for methods setup() and run()
+ * If both setup() and run() are found, both are executed for the first
+ * time and for subsequent calls only run() is executed.
+ */
+ public static void execWithSetupAndRun(String code, PythonVariables pyOutputs) {
+ code = getWrappedCode(code);
+ if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME
+ throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons.");
+ }
+
+ acquireGIL();
+ _exec(code);
+ if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){
+ log.debug("setup() and run() methods found.");
+ if (!_hasGlobalVariable("__setup_done__")){
+ log.debug("Calling setup()...");
+ _exec("setup()");
+ _exec("__setup_done__ = True");
+ }
+ log.debug("Calling run()...");
+ _exec("__out = run();for (k,v) in __out.items(): globals()[k]=v");
+ }
+ log.info("Exec done");
+ try {
+
+ _readOutputs(pyOutputs);
+
+ } catch (IOException e) {
+ log.error("Failed to read outputs", e);
+ }
+
+ releaseGIL();
}
- public static void exec(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception{
+ /**
+ * Run the given code with the given python outputs
+ * @param code the code to run
+ * @param pyOutputs the outputs to run
+ */
+ public static void exec(String code, PythonVariables pyOutputs) {
+
+ exec(code + '\n' + outputCode(pyOutputs));
+ try {
+
+ _readOutputs(pyOutputs);
+
+ } catch (IOException e) {
+ log.error("Failed to read outputs", e);
+ }
+
+ releaseGIL();
+ }
+
+
+ /**
+ * Execute the given python code with the given
+ * {@link PythonVariables} as inputs and outputs
+ * @param code the code to run
+ * @param pyInputs the inputs to the code
+ * @param pyOutputs the outputs to the code
+ * @throws Exception
+ */
+ public static void exec(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception {
String inputCode = inputCode(pyInputs);
exec(inputCode + code, pyOutputs);
}
-
- public static PythonVariables exec(PythonTransform transform) throws Exception{
- if (transform.getInputs() != null && transform.getInputs().getVariables().length > 0){
- throw new Exception("Required inputs not provided.");
+ /**
+ * Execute the given python code
+ * with the {@link PythonVariables}
+ * inputs and outputs for storing the values
+ * specified by the user and needed by the user
+ * as output
+ * @param code the python code to execute
+ * @param pyInputs the python variables input in to the python script
+ * @param pyOutputs the python variables output returned by the python script
+ * @throws Exception
+ */
+ public static void execWithSetupAndRun(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception {
+ String inputCode = inputCode(pyInputs);
+ code = inputCode +code;
+ code = getWrappedCode(code);
+ if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME
+ throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons.");
}
- exec(transform.getCode(), null, transform.getOutputs());
- return transform.getOutputs();
+ acquireGIL();
+ _exec(code);
+ if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){
+ log.debug("setup() and run() methods found.");
+ if (!_hasGlobalVariable("__setup_done__")){
+ releaseGIL(); // required
+ acquireGIL();
+ log.debug("Calling setup()...");
+ _exec("setup()");
+ _exec("__setup_done__ = True");
+ }else{
+ log.debug("setup() already called once.");
+ }
+ log.debug("Calling run()...");
+ releaseGIL(); // required
+ acquireGIL();
+ _exec("import inspect\n"+
+ "__out = run(**{k:globals()[k]for k in inspect.getfullargspec(run).args})\n"+
+ "globals().update(__out)");
+ }
+ releaseGIL(); // required
+ acquireGIL();
+ _exec(outputCode(pyOutputs));
+ log.info("Exec done");
+ try {
+
+ _readOutputs(pyOutputs);
+
+ } catch (IOException e) {
+ log.error("Failed to read outputs", e);
+ }
+
+ releaseGIL();
}
- public static PythonVariables exec(PythonTransform transform, PythonVariables inputs)throws Exception{
+
+
+ private static String interpreterNameFromTransform(PythonTransform transform){
+ return transform.getName().replace("-", "_");
+ }
+
+
+ /**
+ * Run a {@link PythonTransform} with the given inputs
+ * @param transform the transform to run
+ * @param inputs the inputs to the transform
+ * @return the output variables
+ * @throws Exception
+ */
+ public static PythonVariables exec(PythonTransform transform, PythonVariables inputs)throws Exception {
+ String name = interpreterNameFromTransform(transform);
+ setInterpreter(name);
+ Preconditions.checkNotNull(transform.getOutputs(),"Transform outputs were null!");
exec(transform.getCode(), inputs, transform.getOutputs());
return transform.getOutputs();
}
-
-
- public static String evalSTRING(String varName){
- log.info("CPython: PyImport_AddModule()");
- module = PyImport_AddModule("__main__");
- log.info("CPython: PyModule_GetDict()");
- globals = PyModule_GetDict(module);
- PyObject xObj = PyDict_GetItemString(globals, varName);
- PyObject bytes = PyUnicode_AsEncodedString(xObj, "UTF-8", "strict");
- BytePointer bp = PyBytes_AsString(bytes);
- String ret = bp.getString();
- Py_DecRef(xObj);
- Py_DecRef(bytes);
- return ret;
+ public static PythonVariables execWithSetupAndRun(PythonTransform transform, PythonVariables inputs)throws Exception {
+ String name = interpreterNameFromTransform(transform);
+ setInterpreter(name);
+ Preconditions.checkNotNull(transform.getOutputs(),"Transform outputs were null!");
+ execWithSetupAndRun(transform.getCode(), inputs, transform.getOutputs());
+ return transform.getOutputs();
}
- public static long evalINTEGER(String varName){
- log.info("CPython: PyImport_AddModule()");
- module = PyImport_AddModule("__main__");
- log.info("CPython: PyModule_GetDict()");
- globals = PyModule_GetDict(module);
- PyObject xObj = PyDict_GetItemString(globals, varName);
- long ret = PyLong_AsLongLong(xObj);
- return ret;
+
+ /**
+ * Run the code and return the outputs
+ * @param code the code to run
+ * @return all python variables
+ */
+ public static PythonVariables execAndReturnAllVariables(String code) {
+ exec(code + '\n' + outputCodeForAllVariables());
+ PythonVariables allVars = new PythonVariables();
+ allVars.addDict(ALL_VARIABLES_KEY);
+ try {
+ _readOutputs(allVars);
+ }catch (IOException e) {
+ log.error("Failed to read outputs", e);
+ }
+
+ return expandInnerDict(allVars, ALL_VARIABLES_KEY);
+ }
+ public static PythonVariables execWithSetupRunAndReturnAllVariables(String code) {
+ execWithSetupAndRun(code + '\n' + outputCodeForAllVariables());
+ PythonVariables allVars = new PythonVariables();
+ allVars.addDict(ALL_VARIABLES_KEY);
+ try {
+ _readOutputs(allVars);
+ }catch (IOException e) {
+ log.error("Failed to read outputs", e);
+ }
+
+ return expandInnerDict(allVars, ALL_VARIABLES_KEY);
}
- public static double evalFLOAT(String varName){
- log.info("CPython: PyImport_AddModule()");
- module = PyImport_AddModule("__main__");
- log.info("CPython: PyModule_GetDict()");
- globals = PyModule_GetDict(module);
- PyObject xObj = PyDict_GetItemString(globals, varName);
- double ret = PyFloat_AsDouble(xObj);
- return ret;
+ /**
+ *
+ * @param code code string to run
+ * @param pyInputs python input variables
+ * @return all python variables
+ * @throws Exception throws when there's an issue while execution of python code
+ */
+ public static PythonVariables execAndReturnAllVariables(String code, PythonVariables pyInputs) throws Exception {
+ String inputCode = inputCode(pyInputs);
+ return execAndReturnAllVariables(inputCode + code);
+ }
+ public static PythonVariables execWithSetupRunAndReturnAllVariables(String code, PythonVariables pyInputs) throws Exception {
+ String inputCode = inputCode(pyInputs);
+ return execWithSetupRunAndReturnAllVariables(inputCode + code);
}
- public static Object[] evalLIST(String varName) throws Exception{
- log.info("CPython: PyImport_AddModule()");
- module = PyImport_AddModule("__main__");
- log.info("CPython: PyModule_GetDict()");
- globals = PyModule_GetDict(module);
- PyObject xObj = PyDict_GetItemString(globals, varName);
- PyObject strObj = PyObject_Str(xObj);
- PyObject bytes = PyUnicode_AsEncodedString(strObj, "UTF-8", "strict");
- BytePointer bp = PyBytes_AsString(bytes);
- String listStr = bp.getString();
- Py_DecRef(xObj);
- Py_DecRef(bytes);
- JSONArray jsonArray = (JSONArray)parser.parse(listStr.replace("\'", "\""));
- return jsonArray.toArray();
+
+ /**
+ * Evaluate a string based on the
+ * current variable name.
+ * This variable named needs to be present
+ * or defined earlier in python code
+ * in order to pull out the values.
+ *
+ * @param varName the variable name to evaluate
+ * @return the evaluated value
+ */
+ public static String evalString(String varName) {
+ PythonVariables vars = new PythonVariables();
+ vars.addStr(varName);
+ exec("print('')", vars);
+ return vars.getStrValue(varName);
}
- public static NumpyArray evalNDARRAY(String varName) throws Exception{
- log.info("CPython: PyImport_AddModule()");
- module = PyImport_AddModule("__main__");
- log.info("CPython: PyModule_GetDict()");
- globals = PyModule_GetDict(module);
- PyObject xObj = PyDict_GetItemString(globals, varName);
- PyObject arrayInterface = PyObject_GetAttrString(xObj, "__array_interface__");
- PyObject data = PyDict_GetItemString(arrayInterface, "data");
- PyObject zero = PyLong_FromLong(0);
- PyObject addressObj = PyObject_GetItem(data, zero);
- long address = PyLong_AsLongLong(addressObj);
- PyObject shapeObj = PyObject_GetAttrString(xObj, "shape");
- int ndim = (int)PyObject_Size(shapeObj);
- PyObject iObj;
- long shape[] = new long[ndim];
- for (int i=0; i 0)
+ outputCode = outputCode.substring(0, outputCode.length() - 1);
+ outputCode += "})";
+ outputCode += "\nwith open('" + getTempFile() + "', 'w') as " + fileVarName + ":" + fileVarName + ".write(" + outputVarName() + ")";
+
+
return outputCode;
}
- private static String read(String path){
- try{
- File file = new File(path);
- FileInputStream fis = new FileInputStream(file);
- byte[] data = new byte[(int) file.length()];
- fis.read(data);
- fis.close();
- String str = new String(data, "UTF-8");
- return str;
- }
- catch (Exception e){
- return "";
- }
- }
- private static String jArrayToPyString(Object[] array){
+ private static String jArrayToPyString(Object[] array) {
String str = "[";
- for (int i=0; i < array.length; i++){
+ for (int i = 0; i < array.length; i++){
Object obj = array[i];
if (obj instanceof Object[]){
str += jArrayToPyString((Object[])obj);
@@ -496,32 +1036,109 @@ public class PythonExecutioner {
return str;
}
- private static String escapeStr(String str){
+ private static String escapeStr(String str) {
+ if(str == null)
+ return null;
str = str.replace("\\", "\\\\");
str = str.replace("\"\"\"", "\\\"\\\"\\\"");
return str;
}
- private static String getFunctionalCode(String functionName, String code){
- String out = String.format("def %s():\n", functionName);
- for(String line: code.split(Pattern.quote("\n"))){
- out += " " + line + "\n";
+ private static String getWrappedCode(String code) {
+ try(InputStream is = new ClassPathResource("pythonexec/pythonexec.py").getInputStream()) {
+ String base = IOUtils.toString(is, Charset.defaultCharset());
+ StringBuffer indentedCode = new StringBuffer();
+ for(String split : code.split("\n")) {
+ indentedCode.append(" " + split + "\n");
+
+ }
+
+ String out = base.replace(" pass",indentedCode);
+ return out;
+ } catch (IOException e) {
+ throw new IllegalStateException("Unable to read python code!",e);
}
- return out + "\n\n" + functionName + "()\n";
+
}
- private static String getTempFile(){
- String ret = "temp_" + Thread.currentThread().getId() + ".json";
+
+ private static String getTempFile() {
+ String ret = "temp_" + Thread.currentThread().getId() + "_" + currentInterpreter + ".json";
log.info(ret);
return ret;
}
- private static long[] jsonArrayToLongArray(JSONArray jsonArray){
- long[] longs = new long[jsonArray.size()];
- for (int i=0; i _getPatches() {
+ exec("import numpy as np");
+ exec( "__overrides_path = np.core.overrides.__file__");
+ exec("__random_path = np.random.__file__");
+
+ List patches = new ArrayList<>();
+
+ patches.add(new String[]{
+ "pythonexec/patch0.py",
+ evalString("__overrides_path")
+ });
+ patches.add(new String[]{
+ "pythonexec/patch1.py",
+ evalString("__random_path")
+ });
+
+ return patches;
+ }
+
+ private static void _applyPatch(String src, String dest){
+ try(InputStream is = new ClassPathResource(src).getInputStream()) {
+ String patch = IOUtils.toString(is, Charset.defaultCharset());
+ FileUtils.write(new File(dest), patch, "utf-8");
+ }
+ catch(IOException e){
+ throw new RuntimeException("Error reading resource.");
+ }
+ }
+
+ private static boolean _checkPatchApplied(String dest) {
+ try {
+ return FileUtils.readFileToString(new File(dest), "utf-8").startsWith("#patch");
+ } catch (IOException e) {
+ throw new RuntimeException("Error patching numpy");
+
+ }
+ }
+
+ private static void applyPatches() {
+ for (String[] patch : _getPatches()){
+ if (_checkPatchApplied(patch[1])){
+ log.info("Patch already applied for " + patch[1]);
+ }
+ else{
+ _applyPatch(patch[0], patch[1]);
+ log.info("Applied patch for " + patch[1]);
+ }
+ }
+ for (String[] patch: _getPatches()){
+ if (!_checkPatchApplied(patch[1])){
+ throw new RuntimeException("Error patching numpy");
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java
index e3b3fb2bf..8f2460035 100644
--- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java
+++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java
@@ -16,16 +16,29 @@
package org.datavec.python;
+import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
+import org.apache.commons.io.IOUtils;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
+import org.nd4j.base.Preconditions;
+import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
+import org.nd4j.linalg.io.ClassPathResource;
+import org.nd4j.shade.jackson.core.JsonProcessingException;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
import java.util.UUID;
+import static org.datavec.python.PythonUtils.schemaToPythonVariables;
+
/**
* Row-wise Transform that applies arbitrary python code on each row
*
@@ -34,31 +47,87 @@ import java.util.UUID;
@NoArgsConstructor
@Data
-public class PythonTransform implements Transform{
+public class PythonTransform implements Transform {
+
private String code;
- private PythonVariables pyInputs;
- private PythonVariables pyOutputs;
- private String name;
+ private PythonVariables inputs;
+ private PythonVariables outputs;
+ private String name = UUID.randomUUID().toString();
private Schema inputSchema;
private Schema outputSchema;
+ private String outputDict;
+ private boolean returnAllVariables;
+ private boolean setupAndRun = false;
- public PythonTransform(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception{
+ @Builder
+ public PythonTransform(String code,
+ PythonVariables inputs,
+ PythonVariables outputs,
+ String name,
+ Schema inputSchema,
+ Schema outputSchema,
+ String outputDict,
+ boolean returnAllInputs,
+ boolean setupAndRun) {
+ Preconditions.checkNotNull(code,"No code found to run!");
this.code = code;
- this.pyInputs = pyInputs;
- this.pyOutputs = pyOutputs;
- this.name = UUID.randomUUID().toString();
+ this.returnAllVariables = returnAllInputs;
+ this.setupAndRun = setupAndRun;
+ if(inputs != null)
+ this.inputs = inputs;
+ if(outputs != null)
+ this.outputs = outputs;
+
+ if(name != null)
+ this.name = name;
+ if (outputDict != null) {
+ this.outputDict = outputDict;
+ this.outputs = new PythonVariables();
+ this.outputs.addDict(outputDict);
+
+ String helpers;
+ try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) {
+ helpers = IOUtils.toString(is, Charset.defaultCharset());
+
+ }catch (IOException e){
+ throw new RuntimeException("Error reading python code");
+ }
+ this.code += "\n\n" + helpers;
+ this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")";
+ }
+
+ try {
+ if(inputSchema != null) {
+ this.inputSchema = inputSchema;
+ if(inputs == null || inputs.isEmpty()) {
+ this.inputs = schemaToPythonVariables(inputSchema);
+ }
+ }
+
+ if(outputSchema != null) {
+ this.outputSchema = outputSchema;
+ if(outputs == null || outputs.isEmpty()) {
+ this.outputs = schemaToPythonVariables(outputSchema);
+ }
+ }
+ }catch(Exception e) {
+ throw new IllegalStateException(e);
+ }
+
}
+
@Override
- public void setInputSchema(Schema inputSchema){
+ public void setInputSchema(Schema inputSchema) {
+ Preconditions.checkNotNull(inputSchema,"No input schema found!");
this.inputSchema = inputSchema;
try{
- pyInputs = schemaToPythonVariables(inputSchema);
+ inputs = schemaToPythonVariables(inputSchema);
}catch (Exception e){
throw new RuntimeException(e);
}
- if (outputSchema == null){
+ if (outputSchema == null && outputDict == null){
outputSchema = inputSchema;
}
@@ -88,12 +157,42 @@ public class PythonTransform implements Transform{
throw new UnsupportedOperationException("Not yet implemented");
}
+
+
+
@Override
- public List map(List writables){
+ public List map(List writables) {
PythonVariables pyInputs = getPyInputsFromWritables(writables);
+ Preconditions.checkNotNull(pyInputs,"Inputs must not be null!");
+
+
try{
- PythonExecutioner.exec(code, pyInputs, pyOutputs);
- return getWritablesFromPyOutputs(pyOutputs);
+ if (returnAllVariables) {
+ if (setupAndRun){
+ return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs));
+ }
+ return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs));
+ }
+
+ if (outputDict != null) {
+ if (setupAndRun) {
+ PythonExecutioner.execWithSetupAndRun(this, pyInputs);
+ }else{
+ PythonExecutioner.exec(this, pyInputs);
+ }
+ PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict);
+ return getWritablesFromPyOutputs(out);
+ }
+ else {
+ if (setupAndRun) {
+ PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs);
+ }else{
+ PythonExecutioner.exec(code, pyInputs, outputs);
+ }
+
+ return getWritablesFromPyOutputs(outputs);
+ }
+
}
catch (Exception e){
throw new RuntimeException(e);
@@ -102,7 +201,7 @@ public class PythonTransform implements Transform{
@Override
public String[] outputColumnNames(){
- return pyOutputs.getVariables();
+ return outputs.getVariables();
}
@Override
@@ -111,7 +210,7 @@ public class PythonTransform implements Transform{
}
@Override
public String[] columnNames(){
- return pyOutputs.getVariables();
+ return outputs.getVariables();
}
@Override
@@ -124,14 +223,13 @@ public class PythonTransform implements Transform{
}
- private PythonVariables getPyInputsFromWritables(List writables){
-
+ private PythonVariables getPyInputsFromWritables(List writables) {
PythonVariables ret = new PythonVariables();
- for (String name: pyInputs.getVariables()){
+ for (String name: inputs.getVariables()) {
int colIdx = inputSchema.getIndexOfColumn(name);
Writable w = writables.get(colIdx);
- PythonVariables.Type pyType = pyInputs.getType(name);
+ PythonVariables.Type pyType = inputs.getType(name);
switch (pyType){
case INT:
if (w instanceof LongWritable){
@@ -143,7 +241,7 @@ public class PythonTransform implements Transform{
break;
case FLOAT:
- if (w instanceof DoubleWritable){
+ if (w instanceof DoubleWritable) {
ret.addFloat(name, ((DoubleWritable)w).get());
}
else{
@@ -151,96 +249,99 @@ public class PythonTransform implements Transform{
}
break;
case STR:
- ret.addStr(name, ((Text)w).toString());
+ ret.addStr(name, w.toString());
break;
case NDARRAY:
ret.addNDArray(name,((NDArrayWritable)w).get());
break;
+ default:
+ throw new RuntimeException("Unsupported input type:" + pyType);
}
}
return ret;
}
- private List getWritablesFromPyOutputs(PythonVariables pyOuts){
+ private List getWritablesFromPyOutputs(PythonVariables pyOuts) {
List out = new ArrayList<>();
- for (int i=0; i dictValue = pyOuts.getDictValue(name);
+ Map noNullValues = new java.util.HashMap<>();
+ for(Map.Entry entry : dictValue.entrySet()) {
+ if(entry.getValue() != org.json.JSONObject.NULL) {
+ noNullValues.put(entry.getKey(), entry.getValue());
+ }
+ }
+
+ try {
+ out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(noNullValues)));
+ } catch (JsonProcessingException e) {
+ throw new IllegalStateException("Unable to serialize dictionary " + name + " to json!");
+ }
+ break;
+ case LIST:
+ Object[] listValue = pyOuts.getListValue(name);
+ try {
+ out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue)));
+ } catch (JsonProcessingException e) {
+ throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
+ }
+ break;
+ default:
+ throw new IllegalStateException("Unable to support type " + pyType.name());
}
}
return out;
}
- public PythonTransform(String code) throws Exception{
- this.code = code;
- this.name = UUID.randomUUID().toString();
- }
- private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{
- PythonVariables pyVars = new PythonVariables();
- int numCols = schema.numColumns();
- for (int i=0; i 0,"Input must have variables. Found none.");
+ for(Map.Entry entry : input.getVars().entrySet()) {
+ switch(entry.getValue()) {
+ case INT:
+ schemaBuilder.addColumnInteger(entry.getKey());
+ break;
+ case STR:
+ schemaBuilder.addColumnString(entry.getKey());
+ break;
+ case FLOAT:
+ schemaBuilder.addColumnFloat(entry.getKey());
+ break;
+ case NDARRAY:
+ schemaBuilder.addColumnNDArray(entry.getKey(),null);
+ break;
+ case BOOL:
+ schemaBuilder.addColumn(new BooleanMetaData(entry.getKey()));
+ }
+ }
+
+ return schemaBuilder.build();
+ }
+
+ /**
+ * Create a {@link Schema} from an input
+ * {@link PythonVariables}
+ * Types are mapped to types of the same name
+ * @param input the input schema
+ * @return the output python variables.
+ */
+ public static PythonVariables fromSchema(Schema input) {
+ PythonVariables ret = new PythonVariables();
+ for(int i = 0; i < input.numColumns(); i++) {
+ String currColumnName = input.getName(i);
+ ColumnType columnType = input.getType(i);
+ switch(columnType) {
+ case NDArray:
+ ret.add(currColumnName, PythonVariables.Type.NDARRAY);
+ break;
+ case Boolean:
+ ret.add(currColumnName, PythonVariables.Type.BOOL);
+ break;
+ case Categorical:
+ case String:
+ ret.add(currColumnName, PythonVariables.Type.STR);
+ break;
+ case Double:
+ case Float:
+ ret.add(currColumnName, PythonVariables.Type.FLOAT);
+ break;
+ case Integer:
+ case Long:
+ ret.add(currColumnName, PythonVariables.Type.INT);
+ break;
+ case Bytes:
+ break;
+ case Time:
+ throw new UnsupportedOperationException("Unable to process dates with python yet.");
+ }
+ }
+
+ return ret;
+ }
+ /**
+ * Convert a {@link Schema}
+ * to {@link PythonVariables}
+ * @param schema the input schema
+ * @return the output {@link PythonVariables} where each
+ * name in the map is associated with a column name in the schema.
+ * A proper type is also chosen based on the schema
+ * @throws Exception
+ */
+ public static PythonVariables schemaToPythonVariables(Schema schema) throws Exception {
+ PythonVariables pyVars = new PythonVariables();
+ int numCols = schema.numColumns();
+ for (int i = 0; i < numCols; i++) {
+ String colName = schema.getName(i);
+ ColumnType colType = schema.getType(i);
+ switch (colType){
+ case Long:
+ case Integer:
+ pyVars.addInt(colName);
+ break;
+ case Double:
+ case Float:
+ pyVars.addFloat(colName);
+ break;
+ case String:
+ pyVars.addStr(colName);
+ break;
+ case NDArray:
+ pyVars.addNDArray(colName);
+ break;
+ default:
+ throw new Exception("Unsupported python input type: " + colType.toString());
+ }
+ }
+
+ return pyVars;
+ }
+
+
+ public static NumpyArray mapToNumpyArray(Map map){
+ String dtypeName = (String)map.get("dtype");
+ DataType dtype;
+ if (dtypeName.equals("float64")){
+ dtype = DataType.DOUBLE;
+ }
+ else if (dtypeName.equals("float32")){
+ dtype = DataType.FLOAT;
+ }
+ else if (dtypeName.equals("int16")){
+ dtype = DataType.SHORT;
+ }
+ else if (dtypeName.equals("int32")){
+ dtype = DataType.INT;
+ }
+ else if (dtypeName.equals("int64")){
+ dtype = DataType.LONG;
+ }
+ else{
+ throw new RuntimeException("Unsupported array type " + dtypeName + ".");
+ }
+ List shapeList = (List)map.get("shape");
+ long[] shape = new long[shapeList.size()];
+ for (int i = 0; i < shape.length; i++) {
+ shape[i] = (Long)shapeList.get(i);
+ }
+
+ List strideList = (List)map.get("shape");
+ long[] stride = new long[strideList.size()];
+ for (int i = 0; i < stride.length; i++) {
+ stride[i] = (Long)strideList.get(i);
+ }
+ long address = (Long)map.get("address");
+ NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
+ return numpyArray;
+ }
+
+ public static PythonVariables expandInnerDict(PythonVariables pyvars, String key){
+ Map dict = pyvars.getDictValue(key);
+ String[] keys = (String[])dict.keySet().toArray(new String[dict.keySet().size()]);
+ PythonVariables pyvars2 = new PythonVariables();
+ for (String subkey: keys){
+ Object value = dict.get(subkey);
+ if (value instanceof Map){
+ Map map = (Map)value;
+ if (map.containsKey("_is_numpy_array")){
+ pyvars2.addNDArray(subkey, mapToNumpyArray(map));
+
+ }
+ else{
+ pyvars2.addDict(subkey, (Map)value);
+ }
+
+ }
+ else if (value instanceof List){
+ pyvars2.addList(subkey, ((List) value).toArray());
+ }
+ else if (value instanceof String){
+ System.out.println((String)value);
+ pyvars2.addStr(subkey, (String) value);
+ }
+ else if (value instanceof Integer || value instanceof Long) {
+ Number number = (Number) value;
+ pyvars2.addInt(subkey, number.intValue());
+ }
+ else if (value instanceof Float || value instanceof Double) {
+ Number number = (Number) value;
+ pyvars2.addFloat(subkey, number.doubleValue());
+ }
+ else if (value instanceof NumpyArray){
+ pyvars2.addNDArray(subkey, (NumpyArray)value);
+ }
+ else if (value == null){
+ pyvars2.addStr(subkey, "None"); // FixMe
+ }
+ else{
+ throw new RuntimeException("Unsupported type!" + value);
+ }
+ }
+ return pyvars2;
+ }
+
+ public static long[] jsonArrayToLongArray(JSONArray jsonArray){
+ long[] longs = new long[jsonArray.length()];
+ for (int i=0; i toMap(JSONObject jsonobj) {
+ Map map = new HashMap<>();
+ String[] keys = (String[])jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]);
+ for (String key: keys){
+ Object value = jsonobj.get(key);
+ if (value instanceof JSONArray) {
+ value = toList((JSONArray) value);
+ } else if (value instanceof JSONObject) {
+ JSONObject jsonobj2 = (JSONObject)value;
+ if (jsonobj2.has("_is_numpy_array")){
+ value = jsonToNumpyArray(jsonobj2);
+ }
+ else{
+ value = toMap(jsonobj2);
+ }
+
+ }
+
+ map.put(key, value);
+ } return map;
+ }
+
+
+ public static List
- com.typesafe.play
- play-java_2.11
- ${playframework.version}
-
-
- com.google.code.findbugs
- jsr305
-
-
- org.apache.tomcat
- tomcat-servlet-api
-
-
- net.jodah
- typetools
-
-
+ io.vertx
+ vertx-core
+ ${vertx.version}
+
- net.jodah
- typetools
- ${jodah.typetools.version}
+ io.vertx
+ vertx-web
+ ${vertx.version}
+
com.mashape.unirest
unirest-java
@@ -108,25 +92,16 @@
${project.version}
test
-
- com.typesafe.play
- play-json_2.11
- ${playframework.version}
-
-
- com.typesafe.play
- play-server_2.11
- ${playframework.version}
-
com.beust
jcommander
${jcommander.version}
+
- com.typesafe.play
- play-netty-server_2.11
- ${playframework.version}
+ ch.qos.logback
+ logback-classic
+ test
@@ -144,11 +119,11 @@
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
org.nd4j
- nd4j-cuda-10.1
+ nd4j-cuda-10.2
${project.version}
test
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java
index a79b57b19..6610e75f9 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -19,6 +20,11 @@ package org.deeplearning4j.nearestneighbor.server;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
+import io.netty.handler.codec.http.HttpResponseStatus;
+import io.vertx.core.AbstractVerticle;
+import io.vertx.core.Vertx;
+import io.vertx.ext.web.Router;
+import io.vertx.ext.web.handler.BodyHandler;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.clustering.sptree.DataPoint;
@@ -26,6 +32,7 @@ import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nearestneighbor.model.*;
+import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
@@ -33,19 +40,10 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.serde.binary.BinarySerde;
-import play.BuiltInComponents;
-import play.Mode;
-import play.libs.Json;
-import play.routing.Router;
-import play.routing.RoutingDsl;
-import play.server.Server;
import java.io.File;
import java.util.*;
-import static play.mvc.Controller.request;
-import static play.mvc.Results.*;
-
/**
* A rest server for using an
* {@link VPTree} based on loading an ndarray containing
@@ -57,22 +55,33 @@ import static play.mvc.Results.*;
* @author Adam Gibson
*/
@Slf4j
-public class NearestNeighborsServer {
- @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
- private String ndarrayPath = null;
- @Parameter(names = {"--labelsPath"}, arity = 1, required = false)
- private String labelsPath = null;
- @Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
- private int port = 9000;
- @Parameter(names = {"--similarityFunction"}, arity = 1)
- private String similarityFunction = "euclidean";
- @Parameter(names = {"--invert"}, arity = 1)
- private boolean invert = false;
+public class NearestNeighborsServer extends AbstractVerticle {
- private Server server;
+ private static class RunArgs {
+ @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
+ private String ndarrayPath = null;
+ @Parameter(names = {"--labelsPath"}, arity = 1, required = false)
+ private String labelsPath = null;
+ @Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
+ private int port = 9000;
+ @Parameter(names = {"--similarityFunction"}, arity = 1)
+ private String similarityFunction = "euclidean";
+ @Parameter(names = {"--invert"}, arity = 1)
+ private boolean invert = false;
+ }
- public void runMain(String... args) throws Exception {
- JCommander jcmdr = new JCommander(this);
+ private static RunArgs instanceArgs;
+ private static NearestNeighborsServer instance;
+
+ public NearestNeighborsServer(){ }
+
+ public static NearestNeighborsServer getInstance(){
+ return instance;
+ }
+
+ public static void runMain(String... args) {
+ RunArgs r = new RunArgs();
+ JCommander jcmdr = new JCommander(r);
try {
jcmdr.parse(args);
@@ -84,7 +93,7 @@ public class NearestNeighborsServer {
//User provides invalid input -> print the usage info
jcmdr.usage();
- if (ndarrayPath == null)
+ if (r.ndarrayPath == null)
log.error("Json path parameter is missing (null)");
try {
Thread.sleep(500);
@@ -93,16 +102,20 @@ public class NearestNeighborsServer {
System.exit(1);
}
+ instanceArgs = r;
try {
- runHelper();
+ Vertx vertx = Vertx.vertx();
+ vertx.deployVerticle(NearestNeighborsServer.class.getName());
} catch (Throwable t){
log.error("Error in NearestNeighboursServer run method",t);
}
}
- protected void runHelper() throws Exception {
+ @Override
+ public void start() throws Exception {
+ instance = this;
- String[] pathArr = ndarrayPath.split(",");
+ String[] pathArr = instanceArgs.ndarrayPath.split(",");
//INDArray[] pointsArr = new INDArray[pathArr.length];
// first of all we reading shapes of saved eariler files
int rows = 0;
@@ -111,7 +124,7 @@ public class NearestNeighborsServer {
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
- Shape.size(shape, 1));
+ Shape.size(shape, 1));
if (Shape.rank(shape) != 2)
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
@@ -122,12 +135,12 @@ public class NearestNeighborsServer {
cols = Shape.size(shape, 1);
else if (cols != Shape.size(shape, 1))
throw new DL4JInvalidInputException(
- "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
+ "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
}
final List labels = new ArrayList<>();
- if (labelsPath != null) {
- String[] labelsPathArr = labelsPath.split(",");
+ if (instanceArgs.labelsPath != null) {
+ String[] labelsPathArr = instanceArgs.labelsPath.split(",");
for (int i = 0; i < labelsPathArr.length; i++) {
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
}
@@ -149,7 +162,7 @@ public class NearestNeighborsServer {
System.gc();
}
- VPTree tree = new VPTree(points, similarityFunction, invert);
+ VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
//Set play secret key, if required
//http://www.playframework.com/documentation/latest/ApplicationSecret
@@ -163,40 +176,57 @@ public class NearestNeighborsServer {
System.setProperty("play.crypto.secret", base64);
}
+ Router r = Router.router(vertx);
+ r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
+ createRoutes(r, labels, tree, points);
- server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b));
+ vertx.createHttpServer()
+ .requestHandler(r)
+ .listen(instanceArgs.port);
}
- protected Router createRouter(VPTree tree, List labels, INDArray points, BuiltInComponents builtInComponents){
- RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
- //return the host information for a given id
- routingDsl.POST("/knn").routingTo(request -> {
+ private void createRoutes(Router r, List labels, VPTree tree, INDArray points){
+
+ r.post("/knn").handler(rc -> {
try {
- NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class);
+ String json = rc.getBodyAsJson().encode();
+ NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
+
NearestNeighbor nearestNeighbor =
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
- if (record == null)
- return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
+ if (record == null) {
+ rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
+ .putHeader("content-type", "application/json")
+ .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
+ return;
+ }
- NearestNeighborsResults results =
- NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
-
-
- return ok(Json.toJson(results));
+ NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
+ rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
+ .putHeader("content-type", "application/json")
+ .end(JsonMappers.getMapper().writeValueAsString(results));
+ return;
} catch (Throwable e) {
log.error("Error in POST /knn",e);
e.printStackTrace();
- return internalServerError(e.getMessage());
+ rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
+ .end("Error parsing request - " + e.getMessage());
+ return;
}
});
- routingDsl.POST("/knnnew").routingTo(request -> {
+ r.post("/knnnew").handler(rc -> {
try {
- Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class);
- if (record == null)
- return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
+ String json = rc.getBodyAsJson().encode();
+ Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
+ if (record == null) {
+ rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
+ .putHeader("content-type", "application/json")
+ .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
+ return;
+ }
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
List results;
@@ -214,9 +244,10 @@ public class NearestNeighborsServer {
}
if (results.size() != distances.size()) {
- return internalServerError(
- String.format("results.size == %d != %d == distances.size",
- results.size(), distances.size()));
+ rc.response()
+ .setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
+ .end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
+ return;
}
List nnResult = new ArrayList<>();
@@ -228,30 +259,29 @@ public class NearestNeighborsServer {
}
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
- return ok(Json.toJson(results2));
-
+ String j = JsonMappers.getMapper().writeValueAsString(results2);
+ rc.response()
+ .putHeader("content-type", "application/json")
+ .end(j);
} catch (Throwable e) {
log.error("Error in POST /knnnew",e);
e.printStackTrace();
- return internalServerError(e.getMessage());
+ rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
+ .end("Error parsing request - " + e.getMessage());
+ return;
}
});
-
- return routingDsl.build();
}
/**
* Stop the server
*/
- public void stop() {
- if (server != null) {
- log.info("Attempting to stop server");
- server.stop();
- }
+ public void stop() throws Exception {
+ super.stop();
}
public static void main(String[] args) throws Exception {
- new NearestNeighborsServer().runMain(args);
+ runMain(args);
}
}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java
index 9f8fd7241..b42c407e5 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -50,7 +51,6 @@ public class NearestNeighborTest extends BaseDL4JTest {
public TemporaryFolder testDir = new TemporaryFolder();
@Test
- //@Ignore("AB 2019/05/21 - Failing - Issue #7657")
public void testNearestNeighbor() {
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
INDArray arr = Nd4j.create(data);
@@ -119,14 +119,15 @@ public class NearestNeighborTest extends BaseDL4JTest {
File writeToTmp = testDir.newFile();
writeToTmp.deleteOnExit();
BinarySerde.writeArrayToDisk(rand, writeToTmp);
- NearestNeighborsServer server = new NearestNeighborsServer();
- server.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
+ NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
String.valueOf(localPort));
+ Thread.sleep(3000);
+
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
assertEquals(5, result.getResults().size());
- server.stop();
+ NearestNeighborsServer.getInstance().stop();
}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml
new file mode 100644
index 000000000..7953c2712
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml
@@ -0,0 +1,42 @@
+
+
+
+
+
+ logs/application.log
+
+ %date - [%level] - from %logger in %thread
+ %n%message%n%xException%n
+
+
+
+
+
+ %logger{15} - %message%n%xException{5}
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml
index d6b64b025..e3ca20366 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml
@@ -54,7 +54,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml
index 609d48a39..bfd004c41 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml
@@ -53,7 +53,7 @@
test-nd4j-native
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml
index f95f9268d..87bb7e68e 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml
@@ -83,11 +83,11 @@
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
org.nd4j
- nd4j-cuda-10.1
+ nd4j-cuda-10.2
${project.version}
test
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml
index d820dd6b7..23d5d225d 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml
@@ -1,5 +1,6 @@
- test-nd4j-cuda-10.1
+ test-nd4j-cuda-10.2
false
@@ -514,7 +514,7 @@
org.nd4j
- nd4j-cuda-10.1
+ nd4j-cuda-10.2
${nd4j.version}
test
diff --git a/docs/deeplearning4j/templates/config-cudnn.md b/docs/deeplearning4j/templates/config-cudnn.md
index 5044b3ca0..24f69da87 100644
--- a/docs/deeplearning4j/templates/config-cudnn.md
+++ b/docs/deeplearning4j/templates/config-cudnn.md
@@ -10,17 +10,8 @@ weight: 3
Deeplearning4j supports CUDA but can be further accelerated with cuDNN. Most 2D CNN layers (such as ConvolutionLayer, SubsamplingLayer, etc), and also LSTM and BatchNormalization layers support CuDNN.
-The only thing we need to do to have DL4J load cuDNN is to add a dependency on `deeplearning4j-cuda-9.2`, `deeplearning4j-cuda-10.0`, or `deeplearning4j-cuda-10.1`, for example:
+The only thing we need to do to have DL4J load cuDNN is to add a dependency on `deeplearning4j-cuda-10.0`, `deeplearning4j-cuda-10.1`, or `deeplearning4j-cuda-10.2` for example:
-```xml
-
- org.deeplearning4j
- deeplearning4j-cuda-9.2
- {{page.version}}
-
-```
-
-or
```xml
org.deeplearning4j
@@ -38,6 +29,16 @@ or
```
+or
+```xml
+
+ org.deeplearning4j
+ deeplearning4j-cuda-10.2
+ {{page.version}}
+
+```
+
+
The actual library for cuDNN is not bundled, so be sure to download and install the appropriate package for your platform from NVIDIA:
* [NVIDIA cuDNN](https://developer.nvidia.com/cudnn)
@@ -48,39 +49,20 @@ Note there are multiple combinations of cuDNN and CUDA supported. At this time t
CUDA Version |
cuDNN Version |
- 9.2 | 7.2 |
10.0 | 7.4 |
10.1 | 7.6 |
+ 10.2 | 7.6 |
- To install, simply extract the library to a directory found in the system path used by native libraries. The easiest way is to place it alongside other libraries from CUDA in the default directory (`/usr/local/cuda/lib64/` on Linux, `/usr/local/cuda/lib/` on Mac OS X, and `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2\bin\`, `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0\bin\`, or `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\bin\` on Windows).
+ To install, simply extract the library to a directory found in the system path used by native libraries. The easiest way is to place it alongside other libraries from CUDA in the default directory (`/usr/local/cuda/lib64/` on Linux, `/usr/local/cuda/lib/` on Mac OS X, and `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0\bin\`, `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\bin\`, or `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2\bin\` on Windows).
-Alternatively, in the case of CUDA 10.1, cuDNN comes bundled with the "redist" package of the [JavaCPP Presets for CUDA](https://github.com/bytedeco/javacpp-presets/tree/master/cuda). [After agreeing to the license](https://github.com/bytedeco/javacpp-presets/tree/master/cuda#license-agreements), we can add the following dependencies instead of installing CUDA and cuDNN:
+Alternatively, in the case of CUDA 10.2, cuDNN comes bundled with the "redist" package of the [JavaCPP Presets for CUDA](https://github.com/bytedeco/javacpp-presets/tree/master/cuda). [After agreeing to the license](https://github.com/bytedeco/javacpp-presets/tree/master/cuda#license-agreements), we can add the following dependencies instead of installing CUDA and cuDNN:
org.bytedeco
- cuda
- 10.1-7.6-1.5.2
- linux-x86_64-redist
-
-
- org.bytedeco
- cuda
- 10.1-7.6-1.5.2
- linux-ppc64le-redist
-
-
- org.bytedeco
- cuda
- 10.1-7.6-1.5.2
- macosx-x86_64-redist
-
-
- org.bytedeco
- cuda
- 10.1-7.6-1.5.2
- windows-x86_64-redist
+ cuda-platform-redist
+ 10.2-7.6-1.5.2
Also note that, by default, Deeplearning4j will use the fastest algorithms available according to cuDNN, but memory usage may be excessive, causing strange launch errors. When this happens, try to reduce memory usage by using the [`NO_WORKSPACE` mode settable via the network configuration](/api/{{page.version}}/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.Builder.html#cudnnAlgoMode-org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode-), instead of the default of `ConvolutionLayer.AlgoMode.PREFER_FASTEST`, for example:
diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt
index 50c6b9b8a..d8b0439b4 100755
--- a/libnd4j/CMakeLists.txt
+++ b/libnd4j/CMakeLists.txt
@@ -25,8 +25,8 @@ elseif (APPLE)
elseif(WIN32)
set(X86_BUILD true)
if (CUDA_BLAS)
- set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true /wd4804")
- set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305")
+ set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true")
+ set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc")
else()
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -g -O2 -fPIC -std=c++11 -fmax-errors=2")
diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt
index c804ce5ec..c86bdc13a 100755
--- a/libnd4j/blas/CMakeLists.txt
+++ b/libnd4j/blas/CMakeLists.txt
@@ -111,7 +111,7 @@ elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel")
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
# using Visual Studio C++
- set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc ${ARCH_TUNE}")
+ set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}")
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
# using GCC
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}")
@@ -158,7 +158,7 @@ if(CUDA_BLAS)
include_directories(${CUDA_INCLUDE_DIRS})
message("CUDA found!")
set( CUDA_ARCHITECTURE_MINIMUM "3.0" CACHE STRING "Minimum required CUDA compute capability" )
- SET(CUDA_VERBOSE_BUILD ON)
+ SET(CUDA_VERBOSE_BUILD OFF)
SET(CUDA_SEPARABLE_COMPILATION OFF)
#set(CUDA_COMPUTE_CAPABILITY "61")
set(CUDA_COMPUTE_CAPABILITY "35")
@@ -175,9 +175,9 @@ if(CUDA_BLAS)
if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 10
if ("${COMPUTE}" STREQUAL "all")
if (APPLE)
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60)
else()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
endif()
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
@@ -185,9 +185,9 @@ if(CUDA_BLAS)
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
if ("${COMPUTE}" STREQUAL "all")
if (APPLE)
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60)
else()
- list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
+ list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
endif()
else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
@@ -264,24 +264,13 @@ if(CUDA_BLAS)
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/impl/*.cpp ../include/loops/*.h)
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
- if (NOT BUILD_TESTS)
- CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
+
+ CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
- else()
- set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
-
- CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
- ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
- ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
- cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
- Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
- ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
- endif()
-
if(WIN32)
message("CUDA on Windows: enabling /EHsc")
@@ -289,11 +278,16 @@ if(CUDA_BLAS)
SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc /bigobj /std:c++14")
endif()
-
target_link_libraries(${LIBND4J_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
install(TARGETS ${LIBND4J_NAME} DESTINATION .)
+
+ add_custom_command(
+ TARGET ${LIBND4J_NAME} POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy
+ $
+ ${PROJECT_BINARY_DIR}/../../tests_cpu/)
endif(CUDA_FOUND)
elseif(CPU_BLAS)
diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp
index 00a984d45..df358b64f 100644
--- a/libnd4j/blas/NDArray.hpp
+++ b/libnd4j/blas/NDArray.hpp
@@ -31,9 +31,9 @@
namespace nd4j {
template <>
-utf8string NDArray::e(const Nd4jLong i) const;
+ND4J_EXPORT utf8string NDArray::e(const Nd4jLong i) const;
template <>
-std::string NDArray::e(const Nd4jLong i) const;
+ND4J_EXPORT std::string NDArray::e(const Nd4jLong i) const;
//////////////////////////////////////////////////////////////////////////
template
@@ -48,7 +48,7 @@ NDArray* NDArray::asT() const{
return result;
}
-BUILD_SINGLE_TEMPLATE(template NDArray* NDArray::asT, () const, LIBND4J_TYPES);
+BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray* NDArray::asT, () const, LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////
// copy constructor
@@ -435,7 +435,7 @@ std::vector NDArray::getBufferAsVector() {
vector[e] = this->e(e);
return vector;
}
-BUILD_SINGLE_TEMPLATE(template std::vector, NDArray::getBufferAsVector(), LIBND4J_TYPES);
+BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector(), LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////
std::vector NDArray::getShapeAsFlatVector() {
@@ -813,7 +813,7 @@ void NDArray::templatedSet(void *buffer, const Nd4jLong *indices, const void *va
auto xOffset = shape::getOffset(getShapeInfo(), indices);
t[xOffset] = static_cast(y);
}
-BUILD_DOUBLE_TEMPLATE(template void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES);
+BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
template
@@ -823,7 +823,7 @@ void NDArray::templatedSet(void *buffer, const Nd4jLong offset, const void *valu
t[offset] = static_cast(y);
}
-BUILD_DOUBLE_TEMPLATE(template void NDArray::templatedSet, (void *buffer, const Nd4jLong offset, const void *value), LIBND4J_TYPES, LIBND4J_TYPES);
+BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong offset, const void *value), LIBND4J_TYPES, LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
void NDArray::setContext(nd4j::LaunchContext *context) {
@@ -1301,7 +1301,7 @@ template
void* NDArray::templatedPointerShift(const Nd4jLong offset) const {
return reinterpret_cast(getBuffer()) + offset;
}
-BUILD_SINGLE_TEMPLATE(template void* NDArray::templatedPointerShift, (const Nd4jLong offset) const, LIBND4J_TYPES);
+BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void* NDArray::templatedPointerShift, (const Nd4jLong offset) const, LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected
@@ -1608,7 +1608,7 @@ bool NDArray::isUnitary() {
//////////////////////////////////////////////////////////////////////////
template <>
-std::string* NDArray::bufferAsT() const {
+std::string* ND4J_EXPORT NDArray::bufferAsT() const {
throw std::runtime_error("This method is NOT supposed to be used");
}
@@ -1620,7 +1620,7 @@ T* NDArray::bufferAsT() const {
return reinterpret_cast(getBuffer());
}
-BUILD_SINGLE_UNCHAINED_TEMPLATE(template, * NDArray::bufferAsT() const, LIBND4J_TYPES);
+BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , * NDArray::bufferAsT() const, LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////
NDArray* NDArray::subarray(IndicesList& idx) const {
@@ -1797,16 +1797,16 @@ NDArray NDArray::operator+(const T& scalar) const {
return result;
}
-template NDArray NDArray::operator+(const double& scalar) const;
-template NDArray NDArray::operator+(const float& scalar) const;
-template NDArray NDArray::operator+(const float16& scalar) const;
-template NDArray NDArray::operator+(const bfloat16& scalar) const;
-template NDArray NDArray::operator+(const Nd4jLong& scalar) const;
-template NDArray NDArray::operator+(const int& scalar) const;
-template NDArray NDArray::operator+(const int16_t& scalar) const;
-template NDArray NDArray::operator+(const int8_t& scalar) const;
-template NDArray NDArray::operator+(const uint8_t& scalar) const;
-template NDArray NDArray::operator+(const bool& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator+(const double& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator+(const float& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator+(const float16& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator+(const bfloat16& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator+(const Nd4jLong& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator+(const int& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator+(const int16_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator+(const int8_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator+(const uint8_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator+(const bool& scalar) const;
////////////////////////////////////////////////////////////////////////
// subtraction operator array - scalar
@@ -1824,16 +1824,16 @@ NDArray NDArray::operator-(const T& scalar) const {
return result;
}
-template NDArray NDArray::operator-(const double& scalar) const;
-template NDArray NDArray::operator-(const float& scalar) const;
-template NDArray NDArray::operator-(const float16& scalar) const;
-template NDArray NDArray::operator-(const bfloat16& scalar) const;
-template NDArray NDArray::operator-(const Nd4jLong& scalar) const;
-template NDArray NDArray::operator-(const int& scalar) const;
-template NDArray NDArray::operator-(const int16_t& scalar) const;
-template NDArray NDArray::operator-(const int8_t& scalar) const;
-template NDArray NDArray::operator-(const uint8_t& scalar) const;
-template NDArray NDArray::operator-(const bool& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator-(const double& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator-(const float& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator-(const float16& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator-(const bfloat16& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator-(const Nd4jLong& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator-(const int& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator-(const int16_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator-(const int8_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator-(const uint8_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator-(const bool& scalar) const;
////////////////////////////////////////////////////////////////////////
// multiplication operator array*scalar
@@ -1851,16 +1851,16 @@ NDArray NDArray::operator*(const T& scalar) const {
return result;
}
-template NDArray NDArray::operator*(const double& scalar) const;
-template NDArray NDArray::operator*(const float& scalar) const;
-template NDArray NDArray::operator*(const float16& scalar) const;
-template NDArray NDArray::operator*(const bfloat16& scalar) const;
-template NDArray NDArray::operator*(const Nd4jLong& scalar) const;
-template NDArray NDArray::operator*(const int& scalar) const;
-template NDArray NDArray::operator*(const int16_t& scalar) const;
-template NDArray NDArray::operator*(const int8_t& scalar) const;
-template NDArray NDArray::operator*(const uint8_t& scalar) const;
-template NDArray NDArray::operator*(const bool& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator*(const double& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator*(const float& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator*(const float16& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator*(const bfloat16& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator*(const Nd4jLong& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator*(const int& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator*(const int16_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator*(const int8_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator*(const uint8_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator*(const bool& scalar) const;
////////////////////////////////////////////////////////////////////////
// division operator array / scalar
@@ -1881,16 +1881,16 @@ NDArray NDArray::operator/(const T& scalar) const {
return result;
}
-template NDArray NDArray::operator/(const double& scalar) const;
-template NDArray NDArray::operator/(const float& scalar) const;
-template NDArray NDArray::operator/(const float16& scalar) const;
-template NDArray NDArray::operator/(const bfloat16& scalar) const;
-template NDArray NDArray::operator/(const Nd4jLong& scalar) const;
-template NDArray NDArray::operator/(const int& scalar) const;
-template NDArray NDArray::operator/(const int16_t& scalar) const;
-template NDArray NDArray::operator/(const int8_t& scalar) const;
-template NDArray NDArray::operator/(const uint8_t& scalar) const;
-template NDArray NDArray::operator/(const bool& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator/(const double& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator/(const float& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator/(const float16& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator/(const bfloat16& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator/(const Nd4jLong& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator/(const int& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator/(const int16_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator/(const int8_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator/(const uint8_t& scalar) const;
+template ND4J_EXPORT NDArray NDArray::operator/(const bool& scalar) const;
////////////////////////////////////////////////////////////////////////
// addition operator scalar + array
@@ -2260,13 +2260,13 @@ void NDArray::operator+=(const T value) {
NDArray::registerSpecialUse({this}, {});
}
-template void NDArray::operator+=(const double value);
-template void NDArray::operator+=(const float value);
-template void NDArray::operator+=(const float16 value);
-template void NDArray::operator+=(const bfloat16 value);
-template void NDArray::operator+=(const Nd4jLong value);
-template void NDArray::operator+=(const int value);
-template void NDArray::operator+=(const bool value);
+template ND4J_EXPORT void NDArray::operator+=(const double value);
+template ND4J_EXPORT void NDArray::operator+=(const float value);
+template ND4J_EXPORT void NDArray::operator+=(const float16 value);
+template ND4J_EXPORT void NDArray::operator+=(const bfloat16 value);
+template ND4J_EXPORT void NDArray::operator+=(const Nd4jLong value);
+template ND4J_EXPORT void NDArray::operator+=(const int value);
+template ND4J_EXPORT void NDArray::operator+=(const bool value);
////////////////////////////////////////////////////////////////////////
template
@@ -2282,13 +2282,13 @@ void NDArray::operator-=(const T value) {
NDArray::registerSpecialUse({this}, {});
}
-template void NDArray::operator-=(const double value);
-template void NDArray::operator-=(const float value);
-template void NDArray::operator-=(const float16 value);
-template void NDArray::operator-=(const bfloat16 value);
-template void NDArray::operator-=(const Nd4jLong value);
-template void NDArray::operator-=(const int value);
-template void NDArray::operator-=(const bool value);
+template ND4J_EXPORT void NDArray::operator-=(const double value);
+template ND4J_EXPORT void NDArray::operator-=(const float value);
+template ND4J_EXPORT void NDArray::operator-=(const float16 value);
+template ND4J_EXPORT void NDArray::operator-=(const bfloat16 value);
+template ND4J_EXPORT void NDArray::operator-=(const Nd4jLong value);
+template ND4J_EXPORT void NDArray::operator-=(const int value);
+template ND4J_EXPORT void NDArray::operator-=(const bool value);
////////////////////////////////////////////////////////////////////////
template
@@ -2302,16 +2302,16 @@ void NDArray::operator*=(const T scalar) {
NDArray::registerSpecialUse({this}, {});
}
-template void NDArray::operator*=(const double scalar);
-template void NDArray::operator*=(const float scalar);
-template void NDArray::operator*=(const float16 scalar);
-template void NDArray::operator*=(const bfloat16 scalar);
-template void NDArray::operator*=(const Nd4jLong scalar);
-template void NDArray::operator*=(const int scalar);
-template void NDArray::operator*=(const int16_t scalar);
-template void NDArray::operator*=(const int8_t scalar);
-template void NDArray::operator*=(const uint8_t scalar);
-template void NDArray::operator*=(const bool scalar);
+template ND4J_EXPORT void NDArray::operator*=(const double scalar);
+template ND4J_EXPORT void NDArray::operator*=(const float scalar);
+template ND4J_EXPORT void NDArray::operator*=(const float16 scalar);
+template ND4J_EXPORT void NDArray::operator*=(const bfloat16 scalar);
+template ND4J_EXPORT void NDArray::operator*=(const Nd4jLong scalar);
+template ND4J_EXPORT void NDArray::operator*=(const int scalar);
+template ND4J_EXPORT void NDArray::operator*=(const int16_t scalar);
+template ND4J_EXPORT void NDArray::operator*=(const int8_t scalar);
+template ND4J_EXPORT void NDArray::operator*=(const uint8_t scalar);
+template ND4J_EXPORT void NDArray::operator*=(const bool scalar);
////////////////////////////////////////////////////////////////////////
template
@@ -2324,16 +2324,16 @@ void NDArray::operator/=(const T scalar) {
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
NDArray::registerSpecialUse({this}, {});
}
-template void NDArray::operator/=(const double scalar);
-template void NDArray::operator/=(const float scalar);
-template void NDArray::operator/=(const float16 scalar);
-template void NDArray::operator/=(const bfloat16 scalar);
-template void NDArray::operator/=(const Nd4jLong scalar);
-template void NDArray::operator/=(const int scalar);
-template void NDArray::operator/=(const int16_t scalar);
-template void NDArray::operator/=(const int8_t scalar);
-template void NDArray::operator/=(const uint8_t scalar);
-template void NDArray::operator/=(const bool scalar);
+template ND4J_EXPORT void NDArray::operator/=(const double scalar);
+template ND4J_EXPORT void NDArray::operator/=(const float scalar);
+template ND4J_EXPORT void NDArray::operator/=(const float16 scalar);
+template ND4J_EXPORT void NDArray::operator/=(const bfloat16 scalar);
+template ND4J_EXPORT void NDArray::operator/=(const Nd4jLong scalar);
+template ND4J_EXPORT void NDArray::operator/=(const int scalar);
+template ND4J_EXPORT void NDArray::operator/=(const int16_t scalar);
+template ND4J_EXPORT void NDArray::operator/=(const int8_t scalar);
+template ND4J_EXPORT void NDArray::operator/=(const uint8_t scalar);
+template ND4J_EXPORT void NDArray::operator/=(const bool scalar);
////////////////////////////////////////////////////////////////////////
// subtraction operator array - array
@@ -2929,7 +2929,7 @@ std::vector NDArray::asVectorT() {
return result;
}
-BUILD_SINGLE_TEMPLATE(template std::vector, NDArray::asVectorT(), LIBND4J_TYPES);
+BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
// set new order and shape in case of suitable array length
@@ -3046,7 +3046,7 @@ template
void NDArray::templatedSet(void *buffer, const Nd4jLong xOfsset, nd4j::DataType dtype, const void *value) {
BUILD_SINGLE_PARTIAL_SELECTOR(dtype, templatedSet< , T>(buffer, xOfsset, value), LIBND4J_TYPES);
}
-BUILD_SINGLE_TEMPLATE(template void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, nd4j::DataType dtype, const void *value), LIBND4J_TYPES);
+BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, nd4j::DataType dtype, const void *value), LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////
void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray* other, NDArray *target, ExtraArguments *extraParams) const{
@@ -3109,7 +3109,7 @@ void NDArray::templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const
const auto y = reinterpret_cast(yBuffer);
x[xOffset] = static_cast(y[yOffset]);
}
-BUILD_DOUBLE_TEMPLATE(template void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES);
+BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////
void NDArray::varianceAlongDimension(nd4j::variance::Ops op, NDArray *target, const bool biasCorrected, const std::vector& dimensions) const {
@@ -3356,7 +3356,7 @@ T NDArray::e(const Nd4jLong i) const {
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), rp), LIBND4J_TYPES);
}
-BUILD_SINGLE_UNCHAINED_TEMPLATE(template , NDArray::e(const Nd4jLong) const, LIBND4J_TYPES);
+BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong) const, LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
// Returns value from 2D matrix by coordinates/indexes
@@ -3376,7 +3376,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j) const {
return static_cast(119);
}
-BUILD_SINGLE_UNCHAINED_TEMPLATE(template , NDArray::e(const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES);
+BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
// returns value from 3D tensor by coordinates
@@ -3396,7 +3396,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
return static_cast(119);
}
-BUILD_SINGLE_UNCHAINED_TEMPLATE(template , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES);
+BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
// returns value from 3D tensor by coordinates
@@ -3416,7 +3416,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLon
return static_cast(119);
}
-BUILD_SINGLE_UNCHAINED_TEMPLATE(template , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES);
+BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
NDArray NDArray::e(const Nd4jLong i) const {
@@ -3591,17 +3591,17 @@ void NDArray::applyScalar(nd4j::scalar::Ops op, const T scalar, NDArray *target,
applyScalarArr(op, &scalarArr, target, extraParams);
}
-template <> void NDArray::applyScalar(nd4j::scalar::Ops op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");}
-template void NDArray::applyScalar(nd4j::scalar::Ops op, const double scalar, NDArray *target, ExtraArguments *extraParams);
-template void NDArray::applyScalar(nd4j::scalar::Ops op, const float scalar, NDArray *target, ExtraArguments *extraParams);
-template void NDArray::applyScalar(nd4j::scalar::Ops op, const float16 scalar, NDArray *target, ExtraArguments *extraParams);
-template void NDArray::applyScalar(nd4j::scalar::Ops op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams);
-template void NDArray::applyScalar(nd4j::scalar::Ops op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams);
-template void NDArray::applyScalar(nd4j::scalar::Ops op, const int scalar, NDArray *target, ExtraArguments *extraParams);
-template void NDArray::applyScalar(nd4j::scalar::Ops op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams);
-template void NDArray::applyScalar(nd4j::scalar::Ops op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams);
-template void NDArray::applyScalar(nd4j::scalar::Ops op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams);
-template void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDArray *target, ExtraArguments *extraParams);
+template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");}
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const double scalar, NDArray *target, ExtraArguments *extraParams);
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float scalar, NDArray *target, ExtraArguments *extraParams);
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float16 scalar, NDArray *target, ExtraArguments *extraParams);
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams);
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams);
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int scalar, NDArray *target, ExtraArguments *extraParams);
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams);
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams);
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams);
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDArray *target, ExtraArguments *extraParams);
//////////////////////////////////////////////////////////////////////////
void NDArray::applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const {
@@ -3627,17 +3627,17 @@ void NDArray::applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray *tar
applyScalarArr(op, &scalarArr, target, extraParams);
}
-template <> void NDArray::applyScalar(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");}
-template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const;
-template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const;
-template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const;
-template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const;
-template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const;
-template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const;
-template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const;
-template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
-template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
-template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const;
+template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");}
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const;
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const;
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const;
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const;
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const;
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const;
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const;
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
+template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const;
//////////////////////////////////////////////////////////////////////////
@@ -3665,17 +3665,17 @@ template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool sc
applyScalarArr(op, &scalarArr, target, extraParams);
}
- template <> void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");}
- template void NDArray::applyScalar(nd4j::scalar::IntOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const;
- template void NDArray::applyScalar(nd4j::scalar::IntOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const;
- template void NDArray::applyScalar(nd4j::scalar::IntOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const;
- template void NDArray::applyScalar(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const;
- template void NDArray::applyScalar(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const;
- template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const;
- template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const;
- template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
- template void NDArray::applyScalar(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
- template void NDArray::applyScalar(nd4j::scalar::IntOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const;
+ template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");}
+ template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const;
+ template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const;
+ template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const;
+ template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const;
+ template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const;
+ template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const;
+ template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const;
+ template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
+ template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
+ template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const;
////////////////////////////////////////////////////////////////////////
@@ -3966,19 +3966,19 @@ void NDArray::p(const Nd4jLong i, const T value) {
NDArray::registerPrimaryUse({this}, {});
}
-template void NDArray::p(const Nd4jLong i, const double value);
-template void NDArray::p(const Nd4jLong i, const float value);
-template void NDArray::p(const Nd4jLong i, const float16 value);
-template void NDArray::p(const Nd4jLong i, const bfloat16 value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong value);
-template void NDArray::p(const Nd4jLong i, const int value);
-template void NDArray::p(const Nd4jLong i, const int8_t value);
-template void NDArray::p(const Nd4jLong i, const uint8_t value);
-template void NDArray::p(const Nd4jLong i, const uint16_t value);
-template void NDArray::p(const Nd4jLong i, const uint32_t value);
-template void NDArray::p(const Nd4jLong i, const uint64_t value);
-template void NDArray::p(const Nd4jLong i, const int16_t value);
-template void NDArray::p(const Nd4jLong i, const bool value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const double value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float16 value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bfloat16 value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int8_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint8_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint16_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint32_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint64_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int16_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bool value);
//////////////////////////////////////////////////////////////////////////
// This method sets value in 2D matrix to position i, j
@@ -3996,19 +3996,19 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) {
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
NDArray::registerPrimaryUse({this}, {});
}
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float16 value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bfloat16 value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int8_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint8_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint16_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint32_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint64_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int16_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bool value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float16 value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bfloat16 value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int8_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint8_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint16_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint32_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint64_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int16_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bool value);
//////////////////////////////////////////////////////////////////////////
// This method sets value in 3D matrix to position i,j,k
@@ -4026,19 +4026,19 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T va
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
NDArray::registerPrimaryUse({this}, {});
}
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float16 value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bfloat16 value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int8_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint8_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint16_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint32_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint64_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int16_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bool value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float16 value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bfloat16 value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int8_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint8_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint16_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint32_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint64_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int16_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bool value);
//////////////////////////////////////////////////////////////////////////
template
@@ -4055,19 +4055,19 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4j
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
NDArray::registerPrimaryUse({this}, {});
}
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float16 value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bfloat16 value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const Nd4jLong value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int8_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint8_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint16_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint32_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint64_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int16_t value);
-template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bool value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float16 value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bfloat16 value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const Nd4jLong value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int8_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint8_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint16_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint32_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint64_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int16_t value);
+template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bool value);
////////////////////////////////////////////////////////////////////////
void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
@@ -4256,7 +4256,7 @@ void NDArray::templatedAssign(void *xBuffer, Nd4jLong xOffset, const void *yBuff
if (xBuffer != nullptr && yBuffer != nullptr)
*(reinterpret_cast(xBuffer) + xOffset) = *(reinterpret_cast(yBuffer) + yOffset);
}
-BUILD_SINGLE_TEMPLATE(template void NDArray::templatedAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES);
+BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
diff --git a/libnd4j/blas/cpu/NDArrayFactory.cpp b/libnd4j/blas/cpu/NDArrayFactory.cpp
index b091f13b7..54cc6bba8 100644
--- a/libnd4j/blas/cpu/NDArrayFactory.cpp
+++ b/libnd4j/blas/cpu/NDArrayFactory.cpp
@@ -29,7 +29,7 @@ namespace nd4j {
////////////////////////////////////////////////////////////////////////
template <>
- NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context) {
+ ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context) {
if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
@@ -71,8 +71,19 @@ namespace nd4j {
NDArray result(buffer, descriptor, context);
return result;
-
}
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context);
NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) {
std::string s(str);
@@ -118,7 +129,7 @@ template
NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, nd4j::LaunchContext * context) {
return create_(order, shape, DataTypeUtils::fromT(), context);
}
-BUILD_SINGLE_TEMPLATE(template NDArray* NDArrayFactory::create_, (const char order, const std::vector &shape, nd4j::LaunchContext * context), LIBND4J_TYPES);
+BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray* NDArrayFactory::create_, (const char order, const std::vector &shape, nd4j::LaunchContext * context), LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////
template
@@ -128,20 +139,20 @@ void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) {
}
template <>
-void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) {
+void ND4J_EXPORT NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) {
auto p = reinterpret_cast(ptr);
for (Nd4jLong e = 0; e < vector.size(); e++)
p[e] = vector[e];
}
-template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
-template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
-template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
-template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
-template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
-template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
-template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
-template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
+template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
+template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
+template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
+template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
+template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
+template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
+template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
+template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector);
#ifndef __JAVACPP_HACK__
@@ -150,16 +161,16 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector& shape, const T value, const char order, nd4j::LaunchContext * context) {
return valueOf(std::vector(shape), value, order);
}
- template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const double value, const char order, nd4j::LaunchContext * context);
- template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float value, const char order, nd4j::LaunchContext * context);
- template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float16 value, const char order, nd4j::LaunchContext * context);
- template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bfloat16 value, const char order, nd4j::LaunchContext * context);
- template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const Nd4jLong value, const char order, nd4j::LaunchContext * context);
- template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int value, const char order, nd4j::LaunchContext * context);
- template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const uint8_t value, const char order, nd4j::LaunchContext * context);
- template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int8_t value, const char order, nd4j::LaunchContext * context);
- template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int16_t value, const char order, nd4j::LaunchContext * context);
- template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bool value, const char order, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const double value, const char order, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float value, const char order, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float16 value, const char order, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bfloat16 value, const char order, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const Nd4jLong value, const char order, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int value, const char order, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const uint8_t value, const char order, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int8_t value, const char order, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int16_t value, const char order, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bool value, const char order, nd4j::LaunchContext * context);
////////////////////////////////////////////////////////////////////////
template
@@ -167,18 +178,18 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector vec(data);
return create(order, shape, vec, context);
}
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
- template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
+ template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context);
#endif
@@ -197,19 +208,19 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector
NDArray NDArrayFactory::create(nd4j::DataType type, const T scalar, nd4j::LaunchContext * context) {
@@ -223,20 +234,20 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector
NDArray NDArrayFactory::create(const T scalar, nd4j::LaunchContext * context) {
@@ -252,19 +263,19 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &
return new NDArray(NDArrayFactory::create(order, shape, data, context));
}
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
-template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
+template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context);
+template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector