commit
3275fe35a3
|
@ -91,7 +91,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -70,7 +70,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -305,7 +305,7 @@ public class TestGraphLocalExecution {
|
||||||
@Test
|
@Test
|
||||||
public void testLocalExecutionEarlyStopping() throws Exception {
|
public void testLocalExecutionEarlyStopping() throws Exception {
|
||||||
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
|
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
|
||||||
.epochTerminationConditions(new MaxEpochsTerminationCondition(6))
|
.epochTerminationConditions(new MaxEpochsTerminationCondition(4))
|
||||||
.scoreCalculator(new ScoreProvider())
|
.scoreCalculator(new ScoreProvider())
|
||||||
.modelSaver(new InMemoryModelSaver()).build();
|
.modelSaver(new InMemoryModelSaver()).build();
|
||||||
Map<String, Object> commands = new HashMap<>();
|
Map<String, Object> commands = new HashMap<>();
|
||||||
|
@ -348,7 +348,7 @@ public class TestGraphLocalExecution {
|
||||||
.dataProvider(dataProvider)
|
.dataProvider(dataProvider)
|
||||||
.scoreFunction(ScoreFunctions.testSetF1())
|
.scoreFunction(ScoreFunctions.testSetF1())
|
||||||
.modelSaver(new FileModelSaver(modelSavePath))
|
.modelSaver(new FileModelSaver(modelSavePath))
|
||||||
.terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS),
|
.terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(10))
|
new MaxCandidatesCondition(10))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ public class TestDataFactoryProviderMnist implements DataSetIteratorFactory {
|
||||||
private int terminationIter;
|
private int terminationIter;
|
||||||
|
|
||||||
public TestDataFactoryProviderMnist(){
|
public TestDataFactoryProviderMnist(){
|
||||||
this(16, 10);
|
this(16, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -56,7 +56,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -37,7 +37,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -151,7 +151,7 @@
|
||||||
<skip>${skipTestResourceEnforcement}</skip>
|
<skip>${skipTestResourceEnforcement}</skip>
|
||||||
<rules>
|
<rules>
|
||||||
<requireActiveProfile>
|
<requireActiveProfile>
|
||||||
<profiles>test-nd4j-native,test-nd4j-cuda-10.1</profiles>
|
<profiles>test-nd4j-native,test-nd4j-cuda-10.2</profiles>
|
||||||
<all>false</all>
|
<all>false</all>
|
||||||
</requireActiveProfile>
|
</requireActiveProfile>
|
||||||
</rules>
|
</rules>
|
||||||
|
@ -333,11 +333,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
<artifactId>nd4j-cuda-10.2</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
VALID_VERSIONS=( 9.2 10.0 10.1 )
|
VALID_VERSIONS=( 9.2 10.0 10.1 10.2 )
|
||||||
|
|
||||||
usage() {
|
usage() {
|
||||||
echo "Usage: $(basename $0) [-h|--help] <cuda version to be used>
|
echo "Usage: $(basename $0) [-h|--help] <cuda version to be used>
|
||||||
|
@ -47,6 +47,10 @@ check_cuda_version() {
|
||||||
check_cuda_version "$VERSION"
|
check_cuda_version "$VERSION"
|
||||||
|
|
||||||
case $VERSION in
|
case $VERSION in
|
||||||
|
10.2)
|
||||||
|
VERSION2="7.6"
|
||||||
|
VERSION3="1.5.2"
|
||||||
|
;;
|
||||||
10.1)
|
10.1)
|
||||||
VERSION2="7.6"
|
VERSION2="7.6"
|
||||||
VERSION3="1.5.2"
|
VERSION3="1.5.2"
|
||||||
|
|
|
@ -117,7 +117,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -56,7 +56,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -110,7 +110,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -72,7 +72,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -59,7 +59,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -126,7 +126,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -67,7 +67,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -58,7 +58,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -49,7 +49,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -67,7 +67,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -65,7 +65,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -88,7 +88,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -256,11 +256,9 @@ public class ExecutionTest {
|
||||||
|
|
||||||
TransformProcess transformProcess = new TransformProcess.Builder(schema)
|
TransformProcess transformProcess = new TransformProcess.Builder(schema)
|
||||||
.transform(
|
.transform(
|
||||||
new PythonTransform(
|
PythonTransform.builder().code(
|
||||||
"first = np.sin(first)\nsecond = np.cos(second)",
|
"first = np.sin(first)\nsecond = np.cos(second)")
|
||||||
schema
|
.outputSchema(schema).build())
|
||||||
)
|
|
||||||
)
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
List<List<Writable>> functions = new ArrayList<>();
|
List<List<Writable>> functions = new ArrayList<>();
|
||||||
|
|
|
@ -14,35 +14,40 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.datavec.python;
|
package org.datavec.local.transforms.transform;
|
||||||
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
import org.datavec.api.transform.TransformProcess;
|
||||||
import org.datavec.api.transform.condition.Condition;
|
import org.datavec.api.transform.condition.Condition;
|
||||||
import org.datavec.api.transform.filter.ConditionFilter;
|
import org.datavec.api.transform.filter.ConditionFilter;
|
||||||
import org.datavec.api.transform.filter.Filter;
|
import org.datavec.api.transform.filter.Filter;
|
||||||
import org.datavec.api.writable.*;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.junit.Ignore;
|
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||||
|
|
||||||
|
import org.datavec.api.writable.*;
|
||||||
|
import org.datavec.python.PythonCondition;
|
||||||
|
import org.datavec.python.PythonTransform;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertFalse;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
|
import static junit.framework.TestCase.assertTrue;
|
||||||
|
import static org.datavec.api.transform.schema.Schema.Builder;
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
@NotThreadSafe
|
||||||
public class TestPythonTransformProcess {
|
public class TestPythonTransformProcess {
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
|
||||||
|
@Test()
|
||||||
public void testStringConcat() throws Exception{
|
public void testStringConcat() throws Exception{
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Builder schemaBuilder = new Builder();
|
||||||
schemaBuilder
|
schemaBuilder
|
||||||
.addColumnString("col1")
|
.addColumnString("col1")
|
||||||
.addColumnString("col2");
|
.addColumnString("col2");
|
||||||
|
@ -54,10 +59,12 @@ public class TestPythonTransformProcess {
|
||||||
String pythonCode = "col3 = col1 + col2";
|
String pythonCode = "col3 = col1 + col2";
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||||
new PythonTransform(pythonCode, finalSchema)
|
PythonTransform.builder().code(pythonCode)
|
||||||
|
.outputSchema(finalSchema)
|
||||||
|
.build()
|
||||||
).build();
|
).build();
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList((Writable) new Text("Hello "), new Text("World!"));
|
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
List<Writable> outputs = tp.execute(inputs);
|
||||||
assertEquals((outputs.get(0)).toString(), "Hello ");
|
assertEquals((outputs.get(0)).toString(), "Hello ");
|
||||||
|
@ -68,7 +75,7 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test(timeout = 60000L)
|
||||||
public void testMixedTypes() throws Exception{
|
public void testMixedTypes() throws Exception{
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Builder schemaBuilder = new Builder();
|
||||||
schemaBuilder
|
schemaBuilder
|
||||||
.addColumnInteger("col1")
|
.addColumnInteger("col1")
|
||||||
.addColumnFloat("col2")
|
.addColumnFloat("col2")
|
||||||
|
@ -83,11 +90,12 @@ public class TestPythonTransformProcess {
|
||||||
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
|
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||||
new PythonTransform(pythonCode, finalSchema)
|
PythonTransform.builder().code(pythonCode)
|
||||||
).build();
|
.outputSchema(finalSchema)
|
||||||
|
.inputSchema(initialSchema)
|
||||||
|
.build() ).build();
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList((Writable)
|
List<Writable> inputs = Arrays.asList((Writable)new IntWritable(10),
|
||||||
new IntWritable(10),
|
|
||||||
new FloatWritable(3.5f),
|
new FloatWritable(3.5f),
|
||||||
new Text("5"),
|
new Text("5"),
|
||||||
new DoubleWritable(2.0)
|
new DoubleWritable(2.0)
|
||||||
|
@ -105,7 +113,7 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
INDArray expectedOutput = arr1.add(arr2);
|
INDArray expectedOutput = arr1.add(arr2);
|
||||||
|
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Builder schemaBuilder = new Builder();
|
||||||
schemaBuilder
|
schemaBuilder
|
||||||
.addColumnNDArray("col1", shape)
|
.addColumnNDArray("col1", shape)
|
||||||
.addColumnNDArray("col2", shape);
|
.addColumnNDArray("col2", shape);
|
||||||
|
@ -116,12 +124,14 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
String pythonCode = "col3 = col1 + col2";
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||||
new PythonTransform(pythonCode, finalSchema)
|
PythonTransform.builder().code(pythonCode)
|
||||||
).build();
|
.outputSchema(finalSchema)
|
||||||
|
.build() ).build();
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
List<Writable> inputs = Arrays.asList(
|
||||||
(Writable) new NDArrayWritable(arr1),
|
(Writable)
|
||||||
new NDArrayWritable(arr2)
|
new NDArrayWritable(arr1),
|
||||||
|
new NDArrayWritable(arr2)
|
||||||
);
|
);
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
List<Writable> outputs = tp.execute(inputs);
|
||||||
|
@ -139,7 +149,7 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
INDArray expectedOutput = arr1.add(arr2);
|
INDArray expectedOutput = arr1.add(arr2);
|
||||||
|
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Builder schemaBuilder = new Builder();
|
||||||
schemaBuilder
|
schemaBuilder
|
||||||
.addColumnNDArray("col1", shape)
|
.addColumnNDArray("col1", shape)
|
||||||
.addColumnNDArray("col2", shape);
|
.addColumnNDArray("col2", shape);
|
||||||
|
@ -150,11 +160,13 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
String pythonCode = "col3 = col1 + col2";
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||||
new PythonTransform(pythonCode, finalSchema)
|
PythonTransform.builder().code(pythonCode)
|
||||||
).build();
|
.outputSchema(finalSchema)
|
||||||
|
.build() ).build();
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
List<Writable> inputs = Arrays.asList(
|
||||||
(Writable) new NDArrayWritable(arr1),
|
(Writable)
|
||||||
|
new NDArrayWritable(arr1),
|
||||||
new NDArrayWritable(arr2)
|
new NDArrayWritable(arr2)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -172,7 +184,7 @@ public class TestPythonTransformProcess {
|
||||||
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
|
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
|
||||||
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
|
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
|
||||||
|
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Builder schemaBuilder = new Builder();
|
||||||
schemaBuilder
|
schemaBuilder
|
||||||
.addColumnNDArray("col1", shape)
|
.addColumnNDArray("col1", shape)
|
||||||
.addColumnNDArray("col2", shape);
|
.addColumnNDArray("col2", shape);
|
||||||
|
@ -183,11 +195,14 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
String pythonCode = "col3 = col1 + col2";
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||||
new PythonTransform(pythonCode, finalSchema)
|
PythonTransform.builder().code(pythonCode)
|
||||||
|
.outputSchema(finalSchema)
|
||||||
|
.build()
|
||||||
).build();
|
).build();
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
List<Writable> inputs = Arrays.asList(
|
||||||
(Writable) new NDArrayWritable(arr1),
|
(Writable)
|
||||||
|
new NDArrayWritable(arr1),
|
||||||
new NDArrayWritable(arr2)
|
new NDArrayWritable(arr2)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -199,8 +214,8 @@ public class TestPythonTransformProcess {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test(timeout = 60000L)
|
||||||
public void testPythonFilter(){
|
public void testPythonFilter() {
|
||||||
Schema schema = new Schema.Builder().addColumnInteger("column").build();
|
Schema schema = new Builder().addColumnInteger("column").build();
|
||||||
|
|
||||||
Condition condition = new PythonCondition(
|
Condition condition = new PythonCondition(
|
||||||
"f = lambda: column < 0"
|
"f = lambda: column < 0"
|
||||||
|
@ -210,17 +225,17 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
Filter filter = new ConditionFilter(condition);
|
Filter filter = new ConditionFilter(condition);
|
||||||
|
|
||||||
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(10))));
|
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
|
||||||
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(1))));
|
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
|
||||||
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(0))));
|
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
|
||||||
assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-1))));
|
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
|
||||||
assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-10))));
|
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test(timeout = 60000L)
|
||||||
public void testPythonFilterAndTransform() throws Exception{
|
public void testPythonFilterAndTransform() throws Exception{
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Builder schemaBuilder = new Builder();
|
||||||
schemaBuilder
|
schemaBuilder
|
||||||
.addColumnInteger("col1")
|
.addColumnInteger("col1")
|
||||||
.addColumnFloat("col2")
|
.addColumnFloat("col2")
|
||||||
|
@ -241,33 +256,85 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
String pythonCode = "col6 = str(col1 + col2)";
|
String pythonCode = "col6 = str(col1 + col2)";
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||||
new PythonTransform(
|
PythonTransform.builder().code(pythonCode)
|
||||||
pythonCode,
|
.outputSchema(finalSchema)
|
||||||
finalSchema
|
.build()
|
||||||
)
|
|
||||||
).filter(
|
).filter(
|
||||||
filter
|
filter
|
||||||
).build();
|
).build();
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
List<List<Writable>> inputs = new ArrayList<>();
|
||||||
inputs.add(
|
inputs.add(
|
||||||
Arrays.asList((Writable) new IntWritable(5),
|
Arrays.asList(
|
||||||
|
(Writable)
|
||||||
|
new IntWritable(5),
|
||||||
new FloatWritable(3.0f),
|
new FloatWritable(3.0f),
|
||||||
new Text("abcd"),
|
new Text("abcd"),
|
||||||
new DoubleWritable(2.1))
|
new DoubleWritable(2.1))
|
||||||
);
|
);
|
||||||
inputs.add(
|
inputs.add(
|
||||||
Arrays.asList((Writable) new IntWritable(-3),
|
Arrays.asList(
|
||||||
|
(Writable)
|
||||||
|
new IntWritable(-3),
|
||||||
new FloatWritable(3.0f),
|
new FloatWritable(3.0f),
|
||||||
new Text("abcd"),
|
new Text("abcd"),
|
||||||
new DoubleWritable(2.1))
|
new DoubleWritable(2.1))
|
||||||
);
|
);
|
||||||
inputs.add(
|
inputs.add(
|
||||||
Arrays.asList((Writable) new IntWritable(5),
|
Arrays.asList(
|
||||||
|
(Writable)
|
||||||
|
new IntWritable(5),
|
||||||
new FloatWritable(11.2f),
|
new FloatWritable(11.2f),
|
||||||
new Text("abcd"),
|
new Text("abcd"),
|
||||||
new DoubleWritable(2.1))
|
new DoubleWritable(2.1))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
LocalTransformExecutor.execute(inputs,tp);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonTransformNoOutputSpecified() throws Exception {
|
||||||
|
PythonTransform pythonTransform = PythonTransform.builder()
|
||||||
|
.code("a += 2; b = 'hello world'")
|
||||||
|
.returnAllInputs(true)
|
||||||
|
.build();
|
||||||
|
List<List<Writable>> inputs = new ArrayList<>();
|
||||||
|
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
|
||||||
|
Schema inputSchema = new Builder()
|
||||||
|
.addColumnInteger("a")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||||
|
.transform(pythonTransform)
|
||||||
|
.build();
|
||||||
|
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||||
|
assertEquals(3,execute.get(0).get(0).toInt());
|
||||||
|
assertEquals("hello world",execute.get(0).get(1).toString());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNumpyTransform() throws Exception {
|
||||||
|
PythonTransform pythonTransform = PythonTransform.builder()
|
||||||
|
.code("a += 2; b = 'hello world'")
|
||||||
|
.returnAllInputs(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<List<Writable>> inputs = new ArrayList<>();
|
||||||
|
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
|
||||||
|
Schema inputSchema = new Builder()
|
||||||
|
.addColumnNDArray("a",new long[]{1,1})
|
||||||
|
.build();
|
||||||
|
|
||||||
|
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||||
|
.transform(pythonTransform)
|
||||||
|
.build();
|
||||||
|
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||||
|
assertFalse(execute.isEmpty());
|
||||||
|
assertNotNull(execute.get(0));
|
||||||
|
assertNotNull(execute.get(0).get(0));
|
||||||
|
assertEquals("hello world",execute.get(0).get(0).toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -59,7 +59,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -28,15 +28,21 @@
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.googlecode.json-simple</groupId>
|
<groupId>org.json</groupId>
|
||||||
<artifactId>json-simple</artifactId>
|
<artifactId>json</artifactId>
|
||||||
<version>1.1</version>
|
<version>20190722</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.bytedeco</groupId>
|
<groupId>org.bytedeco</groupId>
|
||||||
<artifactId>cpython-platform</artifactId>
|
<artifactId>cpython-platform</artifactId>
|
||||||
<version>${cpython-platform.version}</version>
|
<version>${cpython-platform.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.bytedeco</groupId>
|
||||||
|
<artifactId>numpy-platform</artifactId>
|
||||||
|
<version>${numpy.javacpp.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.google.code.findbugs</groupId>
|
<groupId>com.google.code.findbugs</groupId>
|
||||||
<artifactId>jsr305</artifactId>
|
<artifactId>jsr305</artifactId>
|
||||||
|
@ -65,7 +71,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -16,10 +16,13 @@
|
||||||
|
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.nativeblas.NativeOps;
|
import org.nd4j.nativeblas.NativeOps;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
@ -33,19 +36,27 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
* @author Fariz Rahman
|
* @author Fariz Rahman
|
||||||
*/
|
*/
|
||||||
@Getter
|
@Getter
|
||||||
|
@NoArgsConstructor
|
||||||
public class NumpyArray {
|
public class NumpyArray {
|
||||||
|
|
||||||
private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
private static NativeOps nativeOps;
|
||||||
private long address;
|
private long address;
|
||||||
private long[] shape;
|
private long[] shape;
|
||||||
private long[] strides;
|
private long[] strides;
|
||||||
private DataType dtype = DataType.FLOAT;
|
private DataType dtype;
|
||||||
private INDArray nd4jArray;
|
private INDArray nd4jArray;
|
||||||
|
static {
|
||||||
|
//initialize
|
||||||
|
Nd4j.scalar(1.0);
|
||||||
|
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||||
|
}
|
||||||
|
|
||||||
public NumpyArray(long address, long[] shape, long strides[], boolean copy){
|
@Builder
|
||||||
|
public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) {
|
||||||
this.address = address;
|
this.address = address;
|
||||||
this.shape = shape;
|
this.shape = shape;
|
||||||
this.strides = strides;
|
this.strides = strides;
|
||||||
|
this.dtype = dtype;
|
||||||
setND4JArray();
|
setND4JArray();
|
||||||
if (copy){
|
if (copy){
|
||||||
nd4jArray = nd4jArray.dup();
|
nd4jArray = nd4jArray.dup();
|
||||||
|
@ -57,8 +68,9 @@ public class NumpyArray {
|
||||||
public NumpyArray copy(){
|
public NumpyArray copy(){
|
||||||
return new NumpyArray(nd4jArray.dup());
|
return new NumpyArray(nd4jArray.dup());
|
||||||
}
|
}
|
||||||
|
|
||||||
public NumpyArray(long address, long[] shape, long strides[]){
|
public NumpyArray(long address, long[] shape, long strides[]){
|
||||||
this(address, shape, strides, false);
|
this(address, shape, strides, false,DataType.FLOAT);
|
||||||
}
|
}
|
||||||
|
|
||||||
public NumpyArray(long address, long[] shape, long strides[], DataType dtype){
|
public NumpyArray(long address, long[] shape, long strides[], DataType dtype){
|
||||||
|
@ -77,9 +89,9 @@ public class NumpyArray {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void setND4JArray(){
|
private void setND4JArray() {
|
||||||
long size = 1;
|
long size = 1;
|
||||||
for(long d: shape){
|
for(long d: shape) {
|
||||||
size *= d;
|
size *= d;
|
||||||
}
|
}
|
||||||
Pointer ptr = nativeOps.pointerForAddress(address);
|
Pointer ptr = nativeOps.pointerForAddress(address);
|
||||||
|
@ -88,10 +100,11 @@ public class NumpyArray {
|
||||||
DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype);
|
DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype);
|
||||||
int elemSize = buff.getElementSize();
|
int elemSize = buff.getElementSize();
|
||||||
long[] nd4jStrides = new long[strides.length];
|
long[] nd4jStrides = new long[strides.length];
|
||||||
for (int i=0; i<strides.length; i++){
|
for (int i = 0; i < strides.length; i++) {
|
||||||
nd4jStrides[i] = strides[i] / elemSize;
|
nd4jStrides[i] = strides[i] / elemSize;
|
||||||
}
|
}
|
||||||
this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, 'c', dtype);
|
|
||||||
|
this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,4 +122,4 @@ public class NumpyArray {
|
||||||
this.nd4jArray = nd4jArray;
|
this.nd4jArray = nd4jArray;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -23,6 +23,8 @@ import org.datavec.api.writable.*;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Lets a condition be defined as a python method f that takes no arguments
|
* Lets a condition be defined as a python method f that takes no arguments
|
||||||
* and returns a boolean indicating whether or not to filter a row.
|
* and returns a boolean indicating whether or not to filter a row.
|
||||||
|
@ -38,81 +40,28 @@ public class PythonCondition implements Condition {
|
||||||
private String code;
|
private String code;
|
||||||
|
|
||||||
|
|
||||||
public PythonCondition(String pythonCode){
|
public PythonCondition(String pythonCode) {
|
||||||
|
org.nd4j.base.Preconditions.checkNotNull("Python code must not be null!",pythonCode);
|
||||||
|
org.nd4j.base.Preconditions.checkState(pythonCode.length() >= 1,"Python code must not be empty!");
|
||||||
code = pythonCode;
|
code = pythonCode;
|
||||||
}
|
}
|
||||||
|
|
||||||
private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{
|
|
||||||
PythonVariables pyVars = new PythonVariables();
|
|
||||||
int numCols = schema.numColumns();
|
|
||||||
for (int i=0; i<numCols; i++){
|
|
||||||
String colName = schema.getName(i);
|
|
||||||
ColumnType colType = schema.getType(i);
|
|
||||||
switch (colType){
|
|
||||||
case Long:
|
|
||||||
case Integer:
|
|
||||||
pyVars.addInt(colName);
|
|
||||||
break;
|
|
||||||
case Double:
|
|
||||||
case Float:
|
|
||||||
pyVars.addFloat(colName);
|
|
||||||
break;
|
|
||||||
case String:
|
|
||||||
pyVars.addStr(colName);
|
|
||||||
break;
|
|
||||||
case NDArray:
|
|
||||||
pyVars.addNDArray(colName);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new Exception("Unsupported python input type: " + colType.toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return pyVars;
|
|
||||||
}
|
|
||||||
private PythonVariables getPyInputsFromWritables(List<Writable> writables){
|
|
||||||
|
|
||||||
PythonVariables ret = new PythonVariables();
|
|
||||||
|
|
||||||
for (String name: pyInputs.getVariables()){
|
|
||||||
int colIdx = inputSchema.getIndexOfColumn(name);
|
|
||||||
Writable w = writables.get(colIdx);
|
|
||||||
PythonVariables.Type pyType = pyInputs.getType(name);
|
|
||||||
switch (pyType){
|
|
||||||
case INT:
|
|
||||||
if (w instanceof LongWritable){
|
|
||||||
ret.addInt(name, ((LongWritable)w).get());
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
ret.addInt(name, ((IntWritable)w).get());
|
|
||||||
}
|
|
||||||
|
|
||||||
break;
|
|
||||||
case FLOAT:
|
|
||||||
ret.addFloat(name, ((DoubleWritable)w).get());
|
|
||||||
break;
|
|
||||||
case STR:
|
|
||||||
ret.addStr(name, ((Text)w).toString());
|
|
||||||
break;
|
|
||||||
case NDARRAY:
|
|
||||||
ret.addNDArray(name,((NDArrayWritable)w).get());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
@Override
|
@Override
|
||||||
public void setInputSchema(Schema inputSchema){
|
public void setInputSchema(Schema inputSchema) {
|
||||||
this.inputSchema = inputSchema;
|
this.inputSchema = inputSchema;
|
||||||
try{
|
try{
|
||||||
pyInputs = schemaToPythonVariables(inputSchema);
|
pyInputs = schemaToPythonVariables(inputSchema);
|
||||||
PythonVariables pyOuts = new PythonVariables();
|
PythonVariables pyOuts = new PythonVariables();
|
||||||
pyOuts.addInt("out");
|
pyOuts.addInt("out");
|
||||||
pythonTransform = new PythonTransform(
|
pythonTransform = PythonTransform.builder()
|
||||||
code + "\n\nout=f()\nout=0 if out is None else int(out)", // TODO: remove int conversion after boolean support is covered
|
.code(code + "\n\nout=f()\nout=0 if out is None else int(out)")
|
||||||
pyInputs,
|
.inputs(pyInputs)
|
||||||
pyOuts
|
.outputs(pyOuts)
|
||||||
);
|
.build();
|
||||||
|
|
||||||
}
|
}
|
||||||
catch (Exception e){
|
catch (Exception e){
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
@ -127,41 +76,47 @@ public class PythonCondition implements Condition {
|
||||||
return inputSchema;
|
return inputSchema;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String[] outputColumnNames(){
|
@Override
|
||||||
|
public String[] outputColumnNames() {
|
||||||
String[] columnNames = new String[inputSchema.numColumns()];
|
String[] columnNames = new String[inputSchema.numColumns()];
|
||||||
inputSchema.getColumnNames().toArray(columnNames);
|
inputSchema.getColumnNames().toArray(columnNames);
|
||||||
return columnNames;
|
return columnNames;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public String outputColumnName(){
|
public String outputColumnName(){
|
||||||
return outputColumnNames()[0];
|
return outputColumnNames()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public String[] columnNames(){
|
public String[] columnNames(){
|
||||||
return outputColumnNames();
|
return outputColumnNames();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public String columnName(){
|
public String columnName(){
|
||||||
return outputColumnName();
|
return outputColumnName();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public Schema transform(Schema inputSchema){
|
public Schema transform(Schema inputSchema){
|
||||||
return inputSchema;
|
return inputSchema;
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean condition(List<Writable> list){
|
@Override
|
||||||
|
public boolean condition(List<Writable> list) {
|
||||||
PythonVariables inputs = getPyInputsFromWritables(list);
|
PythonVariables inputs = getPyInputsFromWritables(list);
|
||||||
try{
|
try{
|
||||||
PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs());
|
PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs());
|
||||||
boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0;
|
boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
catch (Exception e){
|
catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public boolean condition(Object input){
|
public boolean condition(Object input){
|
||||||
return condition(input);
|
return condition(input);
|
||||||
}
|
}
|
||||||
|
@ -177,5 +132,37 @@ public class PythonCondition implements Condition {
|
||||||
throw new UnsupportedOperationException("not supported");
|
throw new UnsupportedOperationException("not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
|
||||||
|
PythonVariables ret = new PythonVariables();
|
||||||
|
|
||||||
}
|
for (int i = 0; i < inputSchema.numColumns(); i++){
|
||||||
|
String name = inputSchema.getName(i);
|
||||||
|
Writable w = writables.get(i);
|
||||||
|
PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i));
|
||||||
|
switch (pyType){
|
||||||
|
case INT:
|
||||||
|
if (w instanceof LongWritable) {
|
||||||
|
ret.addInt(name, ((LongWritable)w).get());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
ret.addInt(name, ((IntWritable)w).get());
|
||||||
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
case FLOAT:
|
||||||
|
ret.addFloat(name, ((DoubleWritable)w).get());
|
||||||
|
break;
|
||||||
|
case STR:
|
||||||
|
ret.addStr(name, w.toString());
|
||||||
|
break;
|
||||||
|
case NDARRAY:
|
||||||
|
ret.addNDArray(name,((NDArrayWritable)w).get());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -16,16 +16,29 @@
|
||||||
|
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.datavec.api.transform.ColumnType;
|
import org.datavec.api.transform.ColumnType;
|
||||||
import org.datavec.api.transform.Transform;
|
import org.datavec.api.transform.Transform;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
|
||||||
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.charset.Charset;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
|
||||||
|
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Row-wise Transform that applies arbitrary python code on each row
|
* Row-wise Transform that applies arbitrary python code on each row
|
||||||
*
|
*
|
||||||
|
@ -34,31 +47,87 @@ import java.util.UUID;
|
||||||
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Data
|
@Data
|
||||||
public class PythonTransform implements Transform{
|
public class PythonTransform implements Transform {
|
||||||
|
|
||||||
private String code;
|
private String code;
|
||||||
private PythonVariables pyInputs;
|
private PythonVariables inputs;
|
||||||
private PythonVariables pyOutputs;
|
private PythonVariables outputs;
|
||||||
private String name;
|
private String name = UUID.randomUUID().toString();
|
||||||
private Schema inputSchema;
|
private Schema inputSchema;
|
||||||
private Schema outputSchema;
|
private Schema outputSchema;
|
||||||
|
private String outputDict;
|
||||||
|
private boolean returnAllVariables;
|
||||||
|
private boolean setupAndRun = false;
|
||||||
|
|
||||||
|
|
||||||
public PythonTransform(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception{
|
@Builder
|
||||||
|
public PythonTransform(String code,
|
||||||
|
PythonVariables inputs,
|
||||||
|
PythonVariables outputs,
|
||||||
|
String name,
|
||||||
|
Schema inputSchema,
|
||||||
|
Schema outputSchema,
|
||||||
|
String outputDict,
|
||||||
|
boolean returnAllInputs,
|
||||||
|
boolean setupAndRun) {
|
||||||
|
Preconditions.checkNotNull(code,"No code found to run!");
|
||||||
this.code = code;
|
this.code = code;
|
||||||
this.pyInputs = pyInputs;
|
this.returnAllVariables = returnAllInputs;
|
||||||
this.pyOutputs = pyOutputs;
|
this.setupAndRun = setupAndRun;
|
||||||
this.name = UUID.randomUUID().toString();
|
if(inputs != null)
|
||||||
|
this.inputs = inputs;
|
||||||
|
if(outputs != null)
|
||||||
|
this.outputs = outputs;
|
||||||
|
|
||||||
|
if(name != null)
|
||||||
|
this.name = name;
|
||||||
|
if (outputDict != null) {
|
||||||
|
this.outputDict = outputDict;
|
||||||
|
this.outputs = new PythonVariables();
|
||||||
|
this.outputs.addDict(outputDict);
|
||||||
|
|
||||||
|
String helpers;
|
||||||
|
try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) {
|
||||||
|
helpers = IOUtils.toString(is, Charset.defaultCharset());
|
||||||
|
|
||||||
|
}catch (IOException e){
|
||||||
|
throw new RuntimeException("Error reading python code");
|
||||||
|
}
|
||||||
|
this.code += "\n\n" + helpers;
|
||||||
|
this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
if(inputSchema != null) {
|
||||||
|
this.inputSchema = inputSchema;
|
||||||
|
if(inputs == null || inputs.isEmpty()) {
|
||||||
|
this.inputs = schemaToPythonVariables(inputSchema);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(outputSchema != null) {
|
||||||
|
this.outputSchema = outputSchema;
|
||||||
|
if(outputs == null || outputs.isEmpty()) {
|
||||||
|
this.outputs = schemaToPythonVariables(outputSchema);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}catch(Exception e) {
|
||||||
|
throw new IllegalStateException(e);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setInputSchema(Schema inputSchema){
|
public void setInputSchema(Schema inputSchema) {
|
||||||
|
Preconditions.checkNotNull(inputSchema,"No input schema found!");
|
||||||
this.inputSchema = inputSchema;
|
this.inputSchema = inputSchema;
|
||||||
try{
|
try{
|
||||||
pyInputs = schemaToPythonVariables(inputSchema);
|
inputs = schemaToPythonVariables(inputSchema);
|
||||||
}catch (Exception e){
|
}catch (Exception e){
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
if (outputSchema == null){
|
if (outputSchema == null && outputDict == null){
|
||||||
outputSchema = inputSchema;
|
outputSchema = inputSchema;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,12 +157,42 @@ public class PythonTransform implements Transform{
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
throw new UnsupportedOperationException("Not yet implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Writable> map(List<Writable> writables){
|
public List<Writable> map(List<Writable> writables) {
|
||||||
PythonVariables pyInputs = getPyInputsFromWritables(writables);
|
PythonVariables pyInputs = getPyInputsFromWritables(writables);
|
||||||
|
Preconditions.checkNotNull(pyInputs,"Inputs must not be null!");
|
||||||
|
|
||||||
|
|
||||||
try{
|
try{
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
if (returnAllVariables) {
|
||||||
return getWritablesFromPyOutputs(pyOutputs);
|
if (setupAndRun){
|
||||||
|
return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs));
|
||||||
|
}
|
||||||
|
return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (outputDict != null) {
|
||||||
|
if (setupAndRun) {
|
||||||
|
PythonExecutioner.execWithSetupAndRun(this, pyInputs);
|
||||||
|
}else{
|
||||||
|
PythonExecutioner.exec(this, pyInputs);
|
||||||
|
}
|
||||||
|
PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict);
|
||||||
|
return getWritablesFromPyOutputs(out);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (setupAndRun) {
|
||||||
|
PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs);
|
||||||
|
}else{
|
||||||
|
PythonExecutioner.exec(code, pyInputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
return getWritablesFromPyOutputs(outputs);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
catch (Exception e){
|
catch (Exception e){
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
@ -102,7 +201,7 @@ public class PythonTransform implements Transform{
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] outputColumnNames(){
|
public String[] outputColumnNames(){
|
||||||
return pyOutputs.getVariables();
|
return outputs.getVariables();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -111,7 +210,7 @@ public class PythonTransform implements Transform{
|
||||||
}
|
}
|
||||||
@Override
|
@Override
|
||||||
public String[] columnNames(){
|
public String[] columnNames(){
|
||||||
return pyOutputs.getVariables();
|
return outputs.getVariables();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -124,14 +223,13 @@ public class PythonTransform implements Transform{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private PythonVariables getPyInputsFromWritables(List<Writable> writables){
|
private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
|
||||||
|
|
||||||
PythonVariables ret = new PythonVariables();
|
PythonVariables ret = new PythonVariables();
|
||||||
|
|
||||||
for (String name: pyInputs.getVariables()){
|
for (String name: inputs.getVariables()) {
|
||||||
int colIdx = inputSchema.getIndexOfColumn(name);
|
int colIdx = inputSchema.getIndexOfColumn(name);
|
||||||
Writable w = writables.get(colIdx);
|
Writable w = writables.get(colIdx);
|
||||||
PythonVariables.Type pyType = pyInputs.getType(name);
|
PythonVariables.Type pyType = inputs.getType(name);
|
||||||
switch (pyType){
|
switch (pyType){
|
||||||
case INT:
|
case INT:
|
||||||
if (w instanceof LongWritable){
|
if (w instanceof LongWritable){
|
||||||
|
@ -143,7 +241,7 @@ public class PythonTransform implements Transform{
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
if (w instanceof DoubleWritable){
|
if (w instanceof DoubleWritable) {
|
||||||
ret.addFloat(name, ((DoubleWritable)w).get());
|
ret.addFloat(name, ((DoubleWritable)w).get());
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
|
@ -151,96 +249,99 @@ public class PythonTransform implements Transform{
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case STR:
|
case STR:
|
||||||
ret.addStr(name, ((Text)w).toString());
|
ret.addStr(name, w.toString());
|
||||||
break;
|
break;
|
||||||
case NDARRAY:
|
case NDARRAY:
|
||||||
ret.addNDArray(name,((NDArrayWritable)w).get());
|
ret.addNDArray(name,((NDArrayWritable)w).get());
|
||||||
break;
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Unsupported input type:" + pyType);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Writable> getWritablesFromPyOutputs(PythonVariables pyOuts){
|
private List<Writable> getWritablesFromPyOutputs(PythonVariables pyOuts) {
|
||||||
List<Writable> out = new ArrayList<>();
|
List<Writable> out = new ArrayList<>();
|
||||||
for (int i=0; i<outputSchema.numColumns(); i++){
|
String[] varNames;
|
||||||
String name = outputSchema.getName(i);
|
varNames = pyOuts.getVariables();
|
||||||
PythonVariables.Type pyType = pyOutputs.getType(name);
|
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||||
|
for (int i = 0; i < varNames.length; i++) {
|
||||||
|
String name = varNames[i];
|
||||||
|
PythonVariables.Type pyType = pyOuts.getType(name);
|
||||||
switch (pyType){
|
switch (pyType){
|
||||||
case INT:
|
case INT:
|
||||||
out.add((Writable) new LongWritable(pyOuts.getIntValue(name)));
|
schemaBuilder.addColumnLong(name);
|
||||||
break;
|
break;
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
out.add((Writable) new DoubleWritable(pyOuts.getFloatValue(name)));
|
schemaBuilder.addColumnDouble(name);
|
||||||
break;
|
break;
|
||||||
case STR:
|
case STR:
|
||||||
out.add((Writable) new Text(pyOuts.getStrValue(name)));
|
case DICT:
|
||||||
|
case LIST:
|
||||||
|
schemaBuilder.addColumnString(name);
|
||||||
break;
|
break;
|
||||||
case NDARRAY:
|
case NDARRAY:
|
||||||
out.add((Writable) new NDArrayWritable(pyOuts.getNDArrayValue(name).getNd4jArray()));
|
NumpyArray arr = pyOuts.getNDArrayValue(name);
|
||||||
|
schemaBuilder.addColumnNDArray(name, arr.getShape());
|
||||||
break;
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalStateException("Unable to support type " + pyType.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.outputSchema = schemaBuilder.build();
|
||||||
|
|
||||||
|
|
||||||
|
for (int i = 0; i < varNames.length; i++) {
|
||||||
|
String name = varNames[i];
|
||||||
|
PythonVariables.Type pyType = pyOuts.getType(name);
|
||||||
|
|
||||||
|
switch (pyType){
|
||||||
|
case INT:
|
||||||
|
out.add(new LongWritable(pyOuts.getIntValue(name)));
|
||||||
|
break;
|
||||||
|
case FLOAT:
|
||||||
|
out.add(new DoubleWritable(pyOuts.getFloatValue(name)));
|
||||||
|
break;
|
||||||
|
case STR:
|
||||||
|
out.add(new Text(pyOuts.getStrValue(name)));
|
||||||
|
break;
|
||||||
|
case NDARRAY:
|
||||||
|
NumpyArray arr = pyOuts.getNDArrayValue(name);
|
||||||
|
out.add(new NDArrayWritable(arr.getNd4jArray()));
|
||||||
|
break;
|
||||||
|
case DICT:
|
||||||
|
Map<?, ?> dictValue = pyOuts.getDictValue(name);
|
||||||
|
Map noNullValues = new java.util.HashMap<>();
|
||||||
|
for(Map.Entry entry : dictValue.entrySet()) {
|
||||||
|
if(entry.getValue() != org.json.JSONObject.NULL) {
|
||||||
|
noNullValues.put(entry.getKey(), entry.getValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(noNullValues)));
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new IllegalStateException("Unable to serialize dictionary " + name + " to json!");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case LIST:
|
||||||
|
Object[] listValue = pyOuts.getListValue(name);
|
||||||
|
try {
|
||||||
|
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue)));
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalStateException("Unable to support type " + pyType.name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public PythonTransform(String code) throws Exception{
|
|
||||||
this.code = code;
|
|
||||||
this.name = UUID.randomUUID().toString();
|
|
||||||
}
|
|
||||||
private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{
|
|
||||||
PythonVariables pyVars = new PythonVariables();
|
|
||||||
int numCols = schema.numColumns();
|
|
||||||
for (int i=0; i<numCols; i++){
|
|
||||||
String colName = schema.getName(i);
|
|
||||||
ColumnType colType = schema.getType(i);
|
|
||||||
switch (colType){
|
|
||||||
case Long:
|
|
||||||
case Integer:
|
|
||||||
pyVars.addInt(colName);
|
|
||||||
break;
|
|
||||||
case Double:
|
|
||||||
case Float:
|
|
||||||
pyVars.addFloat(colName);
|
|
||||||
break;
|
|
||||||
case String:
|
|
||||||
pyVars.addStr(colName);
|
|
||||||
break;
|
|
||||||
case NDArray:
|
|
||||||
pyVars.addNDArray(colName);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new Exception("Unsupported python input type: " + colType.toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return pyVars;
|
|
||||||
}
|
|
||||||
|
|
||||||
public PythonTransform(String code, Schema outputSchema) throws Exception{
|
|
||||||
this.code = code;
|
|
||||||
this.name = UUID.randomUUID().toString();
|
|
||||||
this.outputSchema = outputSchema;
|
|
||||||
this.pyOutputs = schemaToPythonVariables(outputSchema);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
public String getName() {
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getCode(){
|
|
||||||
return code;
|
|
||||||
}
|
|
||||||
|
|
||||||
public PythonVariables getInputs() {
|
|
||||||
return pyInputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
public PythonVariables getOutputs() {
|
|
||||||
return pyOutputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,306 @@
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import org.datavec.api.transform.ColumnType;
|
||||||
|
import org.datavec.api.transform.metadata.BooleanMetaData;
|
||||||
|
import org.datavec.api.transform.schema.Schema;
|
||||||
|
import org.json.JSONArray;
|
||||||
|
import org.json.JSONObject;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List of utilities for executing python transforms.
|
||||||
|
*
|
||||||
|
* @author Adam Gibson
|
||||||
|
*/
|
||||||
|
public class PythonUtils {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link Schema}
|
||||||
|
* from {@link PythonVariables}.
|
||||||
|
* Types are mapped to types of the same name.
|
||||||
|
* @param input the input {@link PythonVariables}
|
||||||
|
* @return the output {@link Schema}
|
||||||
|
*/
|
||||||
|
public static Schema fromPythonVariables(PythonVariables input) {
|
||||||
|
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||||
|
Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0,"Input must have variables. Found none.");
|
||||||
|
for(Map.Entry<String,PythonVariables.Type> entry : input.getVars().entrySet()) {
|
||||||
|
switch(entry.getValue()) {
|
||||||
|
case INT:
|
||||||
|
schemaBuilder.addColumnInteger(entry.getKey());
|
||||||
|
break;
|
||||||
|
case STR:
|
||||||
|
schemaBuilder.addColumnString(entry.getKey());
|
||||||
|
break;
|
||||||
|
case FLOAT:
|
||||||
|
schemaBuilder.addColumnFloat(entry.getKey());
|
||||||
|
break;
|
||||||
|
case NDARRAY:
|
||||||
|
schemaBuilder.addColumnNDArray(entry.getKey(),null);
|
||||||
|
break;
|
||||||
|
case BOOL:
|
||||||
|
schemaBuilder.addColumn(new BooleanMetaData(entry.getKey()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return schemaBuilder.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link Schema} from an input
|
||||||
|
* {@link PythonVariables}
|
||||||
|
* Types are mapped to types of the same name
|
||||||
|
* @param input the input schema
|
||||||
|
* @return the output python variables.
|
||||||
|
*/
|
||||||
|
public static PythonVariables fromSchema(Schema input) {
|
||||||
|
PythonVariables ret = new PythonVariables();
|
||||||
|
for(int i = 0; i < input.numColumns(); i++) {
|
||||||
|
String currColumnName = input.getName(i);
|
||||||
|
ColumnType columnType = input.getType(i);
|
||||||
|
switch(columnType) {
|
||||||
|
case NDArray:
|
||||||
|
ret.add(currColumnName, PythonVariables.Type.NDARRAY);
|
||||||
|
break;
|
||||||
|
case Boolean:
|
||||||
|
ret.add(currColumnName, PythonVariables.Type.BOOL);
|
||||||
|
break;
|
||||||
|
case Categorical:
|
||||||
|
case String:
|
||||||
|
ret.add(currColumnName, PythonVariables.Type.STR);
|
||||||
|
break;
|
||||||
|
case Double:
|
||||||
|
case Float:
|
||||||
|
ret.add(currColumnName, PythonVariables.Type.FLOAT);
|
||||||
|
break;
|
||||||
|
case Integer:
|
||||||
|
case Long:
|
||||||
|
ret.add(currColumnName, PythonVariables.Type.INT);
|
||||||
|
break;
|
||||||
|
case Bytes:
|
||||||
|
break;
|
||||||
|
case Time:
|
||||||
|
throw new UnsupportedOperationException("Unable to process dates with python yet.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Convert a {@link Schema}
|
||||||
|
* to {@link PythonVariables}
|
||||||
|
* @param schema the input schema
|
||||||
|
* @return the output {@link PythonVariables} where each
|
||||||
|
* name in the map is associated with a column name in the schema.
|
||||||
|
* A proper type is also chosen based on the schema
|
||||||
|
* @throws Exception
|
||||||
|
*/
|
||||||
|
public static PythonVariables schemaToPythonVariables(Schema schema) throws Exception {
|
||||||
|
PythonVariables pyVars = new PythonVariables();
|
||||||
|
int numCols = schema.numColumns();
|
||||||
|
for (int i = 0; i < numCols; i++) {
|
||||||
|
String colName = schema.getName(i);
|
||||||
|
ColumnType colType = schema.getType(i);
|
||||||
|
switch (colType){
|
||||||
|
case Long:
|
||||||
|
case Integer:
|
||||||
|
pyVars.addInt(colName);
|
||||||
|
break;
|
||||||
|
case Double:
|
||||||
|
case Float:
|
||||||
|
pyVars.addFloat(colName);
|
||||||
|
break;
|
||||||
|
case String:
|
||||||
|
pyVars.addStr(colName);
|
||||||
|
break;
|
||||||
|
case NDArray:
|
||||||
|
pyVars.addNDArray(colName);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new Exception("Unsupported python input type: " + colType.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pyVars;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static NumpyArray mapToNumpyArray(Map map){
|
||||||
|
String dtypeName = (String)map.get("dtype");
|
||||||
|
DataType dtype;
|
||||||
|
if (dtypeName.equals("float64")){
|
||||||
|
dtype = DataType.DOUBLE;
|
||||||
|
}
|
||||||
|
else if (dtypeName.equals("float32")){
|
||||||
|
dtype = DataType.FLOAT;
|
||||||
|
}
|
||||||
|
else if (dtypeName.equals("int16")){
|
||||||
|
dtype = DataType.SHORT;
|
||||||
|
}
|
||||||
|
else if (dtypeName.equals("int32")){
|
||||||
|
dtype = DataType.INT;
|
||||||
|
}
|
||||||
|
else if (dtypeName.equals("int64")){
|
||||||
|
dtype = DataType.LONG;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||||
|
}
|
||||||
|
List shapeList = (List)map.get("shape");
|
||||||
|
long[] shape = new long[shapeList.size()];
|
||||||
|
for (int i = 0; i < shape.length; i++) {
|
||||||
|
shape[i] = (Long)shapeList.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
List strideList = (List)map.get("shape");
|
||||||
|
long[] stride = new long[strideList.size()];
|
||||||
|
for (int i = 0; i < stride.length; i++) {
|
||||||
|
stride[i] = (Long)strideList.get(i);
|
||||||
|
}
|
||||||
|
long address = (Long)map.get("address");
|
||||||
|
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
|
||||||
|
return numpyArray;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonVariables expandInnerDict(PythonVariables pyvars, String key){
|
||||||
|
Map dict = pyvars.getDictValue(key);
|
||||||
|
String[] keys = (String[])dict.keySet().toArray(new String[dict.keySet().size()]);
|
||||||
|
PythonVariables pyvars2 = new PythonVariables();
|
||||||
|
for (String subkey: keys){
|
||||||
|
Object value = dict.get(subkey);
|
||||||
|
if (value instanceof Map){
|
||||||
|
Map map = (Map)value;
|
||||||
|
if (map.containsKey("_is_numpy_array")){
|
||||||
|
pyvars2.addNDArray(subkey, mapToNumpyArray(map));
|
||||||
|
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
pyvars2.addDict(subkey, (Map)value);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else if (value instanceof List){
|
||||||
|
pyvars2.addList(subkey, ((List) value).toArray());
|
||||||
|
}
|
||||||
|
else if (value instanceof String){
|
||||||
|
System.out.println((String)value);
|
||||||
|
pyvars2.addStr(subkey, (String) value);
|
||||||
|
}
|
||||||
|
else if (value instanceof Integer || value instanceof Long) {
|
||||||
|
Number number = (Number) value;
|
||||||
|
pyvars2.addInt(subkey, number.intValue());
|
||||||
|
}
|
||||||
|
else if (value instanceof Float || value instanceof Double) {
|
||||||
|
Number number = (Number) value;
|
||||||
|
pyvars2.addFloat(subkey, number.doubleValue());
|
||||||
|
}
|
||||||
|
else if (value instanceof NumpyArray){
|
||||||
|
pyvars2.addNDArray(subkey, (NumpyArray)value);
|
||||||
|
}
|
||||||
|
else if (value == null){
|
||||||
|
pyvars2.addStr(subkey, "None"); // FixMe
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
throw new RuntimeException("Unsupported type!" + value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pyvars2;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static long[] jsonArrayToLongArray(JSONArray jsonArray){
|
||||||
|
long[] longs = new long[jsonArray.length()];
|
||||||
|
for (int i=0; i<longs.length; i++){
|
||||||
|
|
||||||
|
longs[i] = jsonArray.getLong(i);
|
||||||
|
}
|
||||||
|
return longs;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<String, Object> toMap(JSONObject jsonobj) {
|
||||||
|
Map<String, Object> map = new HashMap<>();
|
||||||
|
String[] keys = (String[])jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]);
|
||||||
|
for (String key: keys){
|
||||||
|
Object value = jsonobj.get(key);
|
||||||
|
if (value instanceof JSONArray) {
|
||||||
|
value = toList((JSONArray) value);
|
||||||
|
} else if (value instanceof JSONObject) {
|
||||||
|
JSONObject jsonobj2 = (JSONObject)value;
|
||||||
|
if (jsonobj2.has("_is_numpy_array")){
|
||||||
|
value = jsonToNumpyArray(jsonobj2);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
value = toMap(jsonobj2);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
map.put(key, value);
|
||||||
|
} return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static List<Object> toList(JSONArray array) {
|
||||||
|
List<Object> list = new ArrayList<>();
|
||||||
|
for (int i = 0; i < array.length(); i++) {
|
||||||
|
Object value = array.get(i);
|
||||||
|
if (value instanceof JSONArray) {
|
||||||
|
value = toList((JSONArray) value);
|
||||||
|
} else if (value instanceof JSONObject) {
|
||||||
|
JSONObject jsonobj2 = (JSONObject) value;
|
||||||
|
if (jsonobj2.has("_is_numpy_array")) {
|
||||||
|
value = jsonToNumpyArray(jsonobj2);
|
||||||
|
} else {
|
||||||
|
value = toMap(jsonobj2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
list.add(value);
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private static NumpyArray jsonToNumpyArray(JSONObject map){
|
||||||
|
String dtypeName = (String)map.get("dtype");
|
||||||
|
DataType dtype;
|
||||||
|
if (dtypeName.equals("float64")){
|
||||||
|
dtype = DataType.DOUBLE;
|
||||||
|
}
|
||||||
|
else if (dtypeName.equals("float32")){
|
||||||
|
dtype = DataType.FLOAT;
|
||||||
|
}
|
||||||
|
else if (dtypeName.equals("int16")){
|
||||||
|
dtype = DataType.SHORT;
|
||||||
|
}
|
||||||
|
else if (dtypeName.equals("int32")){
|
||||||
|
dtype = DataType.INT;
|
||||||
|
}
|
||||||
|
else if (dtypeName.equals("int64")){
|
||||||
|
dtype = DataType.LONG;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||||
|
}
|
||||||
|
List shapeList = (List)map.get("shape");
|
||||||
|
long[] shape = new long[shapeList.size()];
|
||||||
|
for (int i = 0; i < shape.length; i++) {
|
||||||
|
shape[i] = (Long)shapeList.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
List strideList = (List)map.get("shape");
|
||||||
|
long[] stride = new long[strideList.size()];
|
||||||
|
for (int i = 0; i < stride.length; i++) {
|
||||||
|
stride[i] = (Long)strideList.get(i);
|
||||||
|
}
|
||||||
|
long address = (Long)map.get("address");
|
||||||
|
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
|
||||||
|
return numpyArray;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -17,8 +17,8 @@
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.json.simple.JSONArray;
|
import org.json.JSONObject;
|
||||||
import org.json.simple.JSONObject;
|
import org.json.JSONArray;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
@ -31,8 +31,8 @@ import java.util.*;
|
||||||
* @author Fariz Rahman
|
* @author Fariz Rahman
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@Data
|
@lombok.Data
|
||||||
public class PythonVariables implements Serializable{
|
public class PythonVariables implements java.io.Serializable {
|
||||||
|
|
||||||
public enum Type{
|
public enum Type{
|
||||||
BOOL,
|
BOOL,
|
||||||
|
@ -41,23 +41,29 @@ public class PythonVariables implements Serializable{
|
||||||
FLOAT,
|
FLOAT,
|
||||||
NDARRAY,
|
NDARRAY,
|
||||||
LIST,
|
LIST,
|
||||||
FILE
|
FILE,
|
||||||
|
DICT
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<String, String> strVars = new HashMap<String, String>();
|
private java.util.Map<String, String> strVariables = new java.util.LinkedHashMap<>();
|
||||||
private Map<String, Long> intVars = new HashMap<String, Long>();
|
private java.util.Map<String, Long> intVariables = new java.util.LinkedHashMap<>();
|
||||||
private Map<String, Double> floatVars = new HashMap<String, Double>();
|
private java.util.Map<String, Double> floatVariables = new java.util.LinkedHashMap<>();
|
||||||
private Map<String, Boolean> boolVars = new HashMap<String, Boolean>();
|
private java.util.Map<String, Boolean> boolVariables = new java.util.LinkedHashMap<>();
|
||||||
private Map<String, NumpyArray> ndVars = new HashMap<String, NumpyArray>();
|
private java.util.Map<String, NumpyArray> ndVars = new java.util.LinkedHashMap<>();
|
||||||
private Map<String, Object[]> listVars = new HashMap<String, Object[]>();
|
private java.util.Map<String, Object[]> listVariables = new java.util.LinkedHashMap<>();
|
||||||
private Map<String, String> fileVars = new HashMap<String, String>();
|
private java.util.Map<String, String> fileVariables = new java.util.LinkedHashMap<>();
|
||||||
|
private java.util.Map<String, java.util.Map<?,?>> dictVariables = new java.util.LinkedHashMap<>();
|
||||||
private Map<String, Type> vars = new HashMap<String, Type>();
|
private java.util.Map<String, Type> vars = new java.util.LinkedHashMap<>();
|
||||||
|
private java.util.Map<Type, java.util.Map> maps = new java.util.LinkedHashMap<>();
|
||||||
private Map<Type, Map> maps = new HashMap<Type, Map>();
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a copy of the variable
|
||||||
|
* schema in this array without the values
|
||||||
|
* @return an empty variables clone
|
||||||
|
* with no values
|
||||||
|
*/
|
||||||
public PythonVariables copySchema(){
|
public PythonVariables copySchema(){
|
||||||
PythonVariables ret = new PythonVariables();
|
PythonVariables ret = new PythonVariables();
|
||||||
for (String varName: getVariables()){
|
for (String varName: getVariables()){
|
||||||
|
@ -66,15 +72,30 @@ public class PythonVariables implements Serializable{
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
public PythonVariables(){
|
|
||||||
maps.put(Type.BOOL, boolVars);
|
|
||||||
maps.put(Type.STR, strVars);
|
|
||||||
maps.put(Type.INT, intVars);
|
|
||||||
maps.put(Type.FLOAT, floatVars);
|
|
||||||
maps.put(Type.NDARRAY, ndVars);
|
|
||||||
maps.put(Type.LIST, listVars);
|
|
||||||
maps.put(Type.FILE, fileVars);
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public PythonVariables() {
|
||||||
|
maps.put(PythonVariables.Type.BOOL, boolVariables);
|
||||||
|
maps.put(PythonVariables.Type.STR, strVariables);
|
||||||
|
maps.put(PythonVariables.Type.INT, intVariables);
|
||||||
|
maps.put(PythonVariables.Type.FLOAT, floatVariables);
|
||||||
|
maps.put(PythonVariables.Type.NDARRAY, ndVars);
|
||||||
|
maps.put(PythonVariables.Type.LIST, listVariables);
|
||||||
|
maps.put(PythonVariables.Type.FILE, fileVariables);
|
||||||
|
maps.put(PythonVariables.Type.DICT, dictVariables);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return true if there are no variables.
|
||||||
|
*/
|
||||||
|
public boolean isEmpty() {
|
||||||
|
return getVariables().length < 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,6 +126,9 @@ public class PythonVariables implements Serializable{
|
||||||
break;
|
break;
|
||||||
case FILE:
|
case FILE:
|
||||||
addFile(name);
|
addFile(name);
|
||||||
|
break;
|
||||||
|
case DICT:
|
||||||
|
addDict(name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,252 +137,463 @@ public class PythonVariables implements Serializable{
|
||||||
* @param name name of the variable
|
* @param name name of the variable
|
||||||
* @param type type of the variable
|
* @param type type of the variable
|
||||||
* @param value value of the variable (must be instance of expected type)
|
* @param value value of the variable (must be instance of expected type)
|
||||||
* @throws Exception
|
|
||||||
*/
|
*/
|
||||||
public void add (String name, Type type, Object value) throws Exception{
|
public void add(String name, Type type, Object value) {
|
||||||
add(name, type);
|
add(name, type);
|
||||||
setValue(name, value);
|
setValue(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
*/
|
||||||
|
public void addDict(String name) {
|
||||||
|
vars.put(name, PythonVariables.Type.DICT);
|
||||||
|
dictVariables.put(name,null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
*/
|
||||||
public void addBool(String name){
|
public void addBool(String name){
|
||||||
vars.put(name, Type.BOOL);
|
vars.put(name, PythonVariables.Type.BOOL);
|
||||||
boolVars.put(name, null);
|
boolVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
*/
|
||||||
public void addStr(String name){
|
public void addStr(String name){
|
||||||
vars.put(name, Type.STR);
|
vars.put(name, PythonVariables.Type.STR);
|
||||||
strVars.put(name, null);
|
strVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
*/
|
||||||
public void addInt(String name){
|
public void addInt(String name){
|
||||||
vars.put(name, Type.INT);
|
vars.put(name, PythonVariables.Type.INT);
|
||||||
intVars.put(name, null);
|
intVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
*/
|
||||||
public void addFloat(String name){
|
public void addFloat(String name){
|
||||||
vars.put(name, Type.FLOAT);
|
vars.put(name, PythonVariables.Type.FLOAT);
|
||||||
floatVars.put(name, null);
|
floatVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
*/
|
||||||
public void addNDArray(String name){
|
public void addNDArray(String name){
|
||||||
vars.put(name, Type.NDARRAY);
|
vars.put(name, PythonVariables.Type.NDARRAY);
|
||||||
ndVars.put(name, null);
|
ndVars.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
*/
|
||||||
public void addList(String name){
|
public void addList(String name){
|
||||||
vars.put(name, Type.LIST);
|
vars.put(name, PythonVariables.Type.LIST);
|
||||||
listVars.put(name, null);
|
listVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
*/
|
||||||
public void addFile(String name){
|
public void addFile(String name){
|
||||||
vars.put(name, Type.FILE);
|
vars.put(name, PythonVariables.Type.FILE);
|
||||||
fileVars.put(name, null);
|
fileVariables.put(name, null);
|
||||||
}
|
|
||||||
public void addBool(String name, boolean value){
|
|
||||||
vars.put(name, Type.BOOL);
|
|
||||||
boolVars.put(name, value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addStr(String name, String value){
|
/**
|
||||||
vars.put(name, Type.STR);
|
* Add a boolean variable to
|
||||||
strVars.put(name, value);
|
* the set of variables
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addBool(String name, boolean value) {
|
||||||
|
vars.put(name, PythonVariables.Type.BOOL);
|
||||||
|
boolVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addInt(String name, int value){
|
/**
|
||||||
vars.put(name, Type.INT);
|
* Add a string variable to
|
||||||
intVars.put(name, (long)value);
|
* the set of variables
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addStr(String name, String value) {
|
||||||
|
vars.put(name, PythonVariables.Type.STR);
|
||||||
|
strVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addInt(String name, long value){
|
/**
|
||||||
vars.put(name, Type.INT);
|
* Add an int variable to
|
||||||
intVars.put(name, value);
|
* the set of variables
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addInt(String name, int value) {
|
||||||
|
vars.put(name, PythonVariables.Type.INT);
|
||||||
|
intVariables.put(name, (long)value);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addFloat(String name, double value){
|
/**
|
||||||
vars.put(name, Type.FLOAT);
|
* Add a long variable to
|
||||||
floatVars.put(name, value);
|
* the set of variables
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addInt(String name, long value) {
|
||||||
|
vars.put(name, PythonVariables.Type.INT);
|
||||||
|
intVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addFloat(String name, float value){
|
/**
|
||||||
vars.put(name, Type.FLOAT);
|
* Add a double variable to
|
||||||
floatVars.put(name, (double)value);
|
* the set of variables
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addFloat(String name, double value) {
|
||||||
|
vars.put(name, PythonVariables.Type.FLOAT);
|
||||||
|
floatVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addNDArray(String name, NumpyArray value){
|
/**
|
||||||
vars.put(name, Type.NDARRAY);
|
* Add a float variable to
|
||||||
|
* the set of variables
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addFloat(String name, float value) {
|
||||||
|
vars.put(name, PythonVariables.Type.FLOAT);
|
||||||
|
floatVariables.put(name, (double)value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addNDArray(String name, NumpyArray value) {
|
||||||
|
vars.put(name, PythonVariables.Type.NDARRAY);
|
||||||
ndVars.put(name, value);
|
ndVars.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addNDArray(String name, INDArray value){
|
/**
|
||||||
vars.put(name, Type.NDARRAY);
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
|
||||||
|
vars.put(name, PythonVariables.Type.NDARRAY);
|
||||||
ndVars.put(name, new NumpyArray(value));
|
ndVars.put(name, new NumpyArray(value));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addList(String name, Object[] value){
|
/**
|
||||||
vars.put(name, Type.LIST);
|
* Add a null variable to
|
||||||
listVars.put(name, value);
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addList(String name, Object[] value) {
|
||||||
|
vars.put(name, PythonVariables.Type.LIST);
|
||||||
|
listVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addFile(String name, String value){
|
/**
|
||||||
vars.put(name, Type.FILE);
|
* Add a null variable to
|
||||||
fileVars.put(name, value);
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addFile(String name, String value) {
|
||||||
|
vars.put(name, PythonVariables.Type.FILE);
|
||||||
|
fileVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addDict(String name, java.util.Map value) {
|
||||||
|
vars.put(name, PythonVariables.Type.DICT);
|
||||||
|
dictVariables.put(name, value);
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param name name of the variable
|
* @param name name of the variable
|
||||||
* @param value new value for the variable
|
* @param value new value for the variable
|
||||||
* @throws Exception
|
|
||||||
*/
|
*/
|
||||||
public void setValue(String name, Object value) {
|
public void setValue(String name, Object value) {
|
||||||
Type type = vars.get(name);
|
Type type = vars.get(name);
|
||||||
if (type == Type.BOOL){
|
if (type == PythonVariables.Type.BOOL){
|
||||||
boolVars.put(name, (Boolean)value);
|
boolVariables.put(name, (Boolean)value);
|
||||||
}
|
}
|
||||||
else if (type == Type.INT){
|
else if (type == PythonVariables.Type.INT){
|
||||||
if (value instanceof Long){
|
Number number = (Number) value;
|
||||||
intVars.put(name, ((Long)value));
|
intVariables.put(name, number.longValue());
|
||||||
}
|
|
||||||
else if (value instanceof Integer){
|
|
||||||
intVars.put(name, ((Integer)value).longValue());
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if (type == Type.FLOAT){
|
else if (type == PythonVariables.Type.FLOAT){
|
||||||
floatVars.put(name, (Double)value);
|
Number number = (Number) value;
|
||||||
|
floatVariables.put(name, number.doubleValue());
|
||||||
}
|
}
|
||||||
else if (type == Type.NDARRAY){
|
else if (type == PythonVariables.Type.NDARRAY){
|
||||||
if (value instanceof NumpyArray){
|
if (value instanceof NumpyArray){
|
||||||
ndVars.put(name, (NumpyArray)value);
|
ndVars.put(name, (NumpyArray)value);
|
||||||
}
|
}
|
||||||
else if (value instanceof INDArray){
|
else if (value instanceof org.nd4j.linalg.api.ndarray.INDArray) {
|
||||||
ndVars.put(name, new NumpyArray((INDArray) value));
|
ndVars.put(name, new NumpyArray((org.nd4j.linalg.api.ndarray.INDArray) value));
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
throw new RuntimeException("Unsupported type: " + value.getClass().toString());
|
throw new RuntimeException("Unsupported type: " + value.getClass().toString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (type == Type.LIST){
|
else if (type == PythonVariables.Type.LIST) {
|
||||||
listVars.put(name, (Object[]) value);
|
if (value instanceof java.util.List) {
|
||||||
|
value = ((java.util.List) value).toArray();
|
||||||
|
listVariables.put(name, (Object[]) value);
|
||||||
|
}
|
||||||
|
else if(value instanceof org.json.JSONArray) {
|
||||||
|
org.json.JSONArray jsonArray = (org.json.JSONArray) value;
|
||||||
|
Object[] copyArr = new Object[jsonArray.length()];
|
||||||
|
for(int i = 0; i < copyArr.length; i++) {
|
||||||
|
copyArr[i] = jsonArray.get(i);
|
||||||
|
}
|
||||||
|
listVariables.put(name, copyArr);
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
listVariables.put(name, (Object[]) value);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else if (type == Type.FILE){
|
else if(type == PythonVariables.Type.DICT) {
|
||||||
fileVars.put(name, (String)value);
|
dictVariables.put(name,(java.util.Map<?,?>) value);
|
||||||
|
}
|
||||||
|
else if (type == PythonVariables.Type.FILE){
|
||||||
|
fileVariables.put(name, (String)value);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
strVars.put(name, (String)value);
|
strVariables.put(name, (String)value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Object getValue(String name){
|
/**
|
||||||
|
* Do a general object lookup.
|
||||||
|
* The look up will happen relative to the {@link Type}
|
||||||
|
* of variable is described in the
|
||||||
|
* @param name the name of the variable to get
|
||||||
|
* @return teh value for the variable with the given name
|
||||||
|
*/
|
||||||
|
public Object getValue(String name) {
|
||||||
Type type = vars.get(name);
|
Type type = vars.get(name);
|
||||||
Map map = maps.get(type);
|
java.util.Map map = maps.get(type);
|
||||||
return map.get(name);
|
return map.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a boolean variable with the given name.
|
||||||
|
* @param name the variable name to get the value for
|
||||||
|
* @return the retrieved boolean value
|
||||||
|
*/
|
||||||
|
public boolean getBooleanValue(String name) {
|
||||||
|
return boolVariables.get(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param name the variable name
|
||||||
|
* @return the dictionary value
|
||||||
|
*/
|
||||||
|
public java.util.Map<?,?> getDictValue(String name) {
|
||||||
|
return dictVariables.get(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param name the variable name
|
||||||
|
* @return the string value
|
||||||
|
*/
|
||||||
public String getStrValue(String name){
|
public String getStrValue(String name){
|
||||||
return strVars.get(name);
|
return strVariables.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
public long getIntValue(String name){
|
/**
|
||||||
return intVars.get(name);
|
*
|
||||||
|
* @param name the variable name
|
||||||
|
* @return the long value
|
||||||
|
*/
|
||||||
|
public Long getIntValue(String name){
|
||||||
|
return intVariables.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getFloatValue(String name){
|
/**
|
||||||
return floatVars.get(name);
|
*
|
||||||
|
* @param name the variable name
|
||||||
|
* @return the float value
|
||||||
|
*/
|
||||||
|
public Double getFloatValue(String name){
|
||||||
|
return floatVariables.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param name the variable name
|
||||||
|
* @return the numpy array value
|
||||||
|
*/
|
||||||
public NumpyArray getNDArrayValue(String name){
|
public NumpyArray getNDArrayValue(String name){
|
||||||
return ndVars.get(name);
|
return ndVars.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param name the variable name
|
||||||
|
* @return the list value as an object array
|
||||||
|
*/
|
||||||
public Object[] getListValue(String name){
|
public Object[] getListValue(String name){
|
||||||
return listVars.get(name);
|
return listVariables.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param name the variable name
|
||||||
|
* @return the value of the given file name
|
||||||
|
*/
|
||||||
public String getFileValue(String name){
|
public String getFileValue(String name){
|
||||||
return fileVars.get(name);
|
return fileVariables.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the type for the given variable name
|
||||||
|
* @param name the name of the variable to get the type for
|
||||||
|
* @return the type for the given variable
|
||||||
|
*/
|
||||||
public Type getType(String name){
|
public Type getType(String name){
|
||||||
return vars.get(name);
|
return vars.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all the variables present as a string array
|
||||||
|
* @return the variable names for this variable sset
|
||||||
|
*/
|
||||||
public String[] getVariables() {
|
public String[] getVariables() {
|
||||||
String[] strArr = new String[vars.size()];
|
String[] strArr = new String[vars.size()];
|
||||||
return vars.keySet().toArray(strArr);
|
return vars.keySet().toArray(strArr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public Map<String, Boolean> getBoolVariables(){
|
/**
|
||||||
return boolVars;
|
* This variables set as its json representation (an array of json objects)
|
||||||
}
|
* @return the json array output
|
||||||
public Map<String, String> getStrVariables(){
|
*/
|
||||||
return strVars;
|
public org.json.JSONArray toJSON(){
|
||||||
}
|
org.json.JSONArray arr = new org.json.JSONArray();
|
||||||
|
|
||||||
public Map<String, Long> getIntVariables(){
|
|
||||||
return intVars;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<String, Double> getFloatVariables(){
|
|
||||||
return floatVars;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<String, NumpyArray> getNDArrayVariables(){
|
|
||||||
return ndVars;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<String, Object[]> getListVariables(){
|
|
||||||
return listVars;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<String, String> getFileVariables(){
|
|
||||||
return fileVars;
|
|
||||||
}
|
|
||||||
|
|
||||||
public JSONArray toJSON(){
|
|
||||||
JSONArray arr = new JSONArray();
|
|
||||||
for (String varName: getVariables()){
|
for (String varName: getVariables()){
|
||||||
JSONObject var = new JSONObject();
|
org.json.JSONObject var = new org.json.JSONObject();
|
||||||
var.put("name", varName);
|
var.put("name", varName);
|
||||||
String varType = getType(varName).toString();
|
String varType = getType(varName).toString();
|
||||||
var.put("type", varType);
|
var.put("type", varType);
|
||||||
arr.add(var);
|
arr.put(var);
|
||||||
}
|
}
|
||||||
return arr;
|
return arr;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static PythonVariables fromJSON(JSONArray jsonArray){
|
/**
|
||||||
|
* Create a schema from a map.
|
||||||
|
* This is an empty PythonVariables
|
||||||
|
* that just contains names and types with no values
|
||||||
|
* @param inputTypes the input types to convert
|
||||||
|
* @return the schema from the given map
|
||||||
|
*/
|
||||||
|
public static PythonVariables schemaFromMap(java.util.Map<String,String> inputTypes) {
|
||||||
|
PythonVariables ret = new PythonVariables();
|
||||||
|
for(java.util.Map.Entry<String,String> entry : inputTypes.entrySet()) {
|
||||||
|
ret.add(entry.getKey(), PythonVariables.Type.valueOf(entry.getValue()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the python variable state relative to the
|
||||||
|
* input json array
|
||||||
|
* @param jsonArray the input json array
|
||||||
|
* @return the python variables based on the input json array
|
||||||
|
*/
|
||||||
|
public static PythonVariables fromJSON(org.json.JSONArray jsonArray){
|
||||||
PythonVariables pyvars = new PythonVariables();
|
PythonVariables pyvars = new PythonVariables();
|
||||||
for (int i=0; i<jsonArray.size(); i++){
|
for (int i = 0; i < jsonArray.length(); i++) {
|
||||||
JSONObject input = (JSONObject) jsonArray.get(i);
|
org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i);
|
||||||
String varName = (String)input.get("name");
|
String varName = (String)input.get("name");
|
||||||
String varType = (String)input.get("type");
|
String varType = (String)input.get("type");
|
||||||
if (varType.equals("BOOL")){
|
if (varType.equals("BOOL")) {
|
||||||
pyvars.addBool(varName);
|
pyvars.addBool(varName);
|
||||||
}
|
}
|
||||||
else if (varType.equals("INT")){
|
else if (varType.equals("INT")) {
|
||||||
pyvars.addInt(varName);
|
pyvars.addInt(varName);
|
||||||
}
|
}
|
||||||
else if (varType.equals("FlOAT")){
|
else if (varType.equals("FlOAT")){
|
||||||
pyvars.addFloat(varName);
|
pyvars.addFloat(varName);
|
||||||
}
|
}
|
||||||
else if (varType.equals("STR")){
|
else if (varType.equals("STR")) {
|
||||||
pyvars.addStr(varName);
|
pyvars.addStr(varName);
|
||||||
}
|
}
|
||||||
else if (varType.equals("LIST")){
|
else if (varType.equals("LIST")) {
|
||||||
pyvars.addList(varName);
|
pyvars.addList(varName);
|
||||||
}
|
}
|
||||||
else if (varType.equals("FILE")){
|
else if (varType.equals("FILE")){
|
||||||
pyvars.addFile(varName);
|
pyvars.addFile(varName);
|
||||||
}
|
}
|
||||||
else if (varType.equals("NDARRAY")){
|
else if (varType.equals("NDARRAY")) {
|
||||||
pyvars.addNDArray(varName);
|
pyvars.addNDArray(varName);
|
||||||
}
|
}
|
||||||
|
else if(varType.equals("DICT")) {
|
||||||
|
pyvars.addDict(varName);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return pyvars;
|
return pyvars;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,5 @@
|
||||||
|
#See: https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script
|
||||||
|
import sys
|
||||||
|
this = sys.modules[__name__]
|
||||||
|
for n in dir():
|
||||||
|
if n[0]!='_': delattr(this, n)
|
|
@ -0,0 +1 @@
|
||||||
|
loc = {}
|
|
@ -0,0 +1,20 @@
|
||||||
|
|
||||||
|
def __is_numpy_array(x):
|
||||||
|
return str(type(x))== "<class 'numpy.ndarray'>"
|
||||||
|
|
||||||
|
def maybe_serialize_ndarray_metadata(x):
|
||||||
|
return serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_ndarray_metadata(x):
|
||||||
|
return {"address": x.__array_interface__['data'][0],
|
||||||
|
"shape": x.shape,
|
||||||
|
"strides": x.strides,
|
||||||
|
"dtype": str(x.dtype),
|
||||||
|
"_is_numpy_array": True} if __is_numpy_array(x) else x
|
||||||
|
|
||||||
|
|
||||||
|
def is_json_ready(key, value):
|
||||||
|
return key is not 'f2' and not inspect.ismodule(value) \
|
||||||
|
and not hasattr(value, '__call__')
|
||||||
|
|
|
@ -0,0 +1,202 @@
|
||||||
|
#patch
|
||||||
|
|
||||||
|
"""Implementation of __array_function__ overrides from NEP-18."""
|
||||||
|
import collections
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
|
||||||
|
from numpy.core._multiarray_umath import (
|
||||||
|
add_docstring, implement_array_function, _get_implementing_args)
|
||||||
|
from numpy.compat._inspect import getargspec
|
||||||
|
|
||||||
|
|
||||||
|
ENABLE_ARRAY_FUNCTION = bool(
|
||||||
|
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
|
||||||
|
|
||||||
|
|
||||||
|
ARRAY_FUNCTION_ENABLED = ENABLE_ARRAY_FUNCTION # backward compat
|
||||||
|
|
||||||
|
|
||||||
|
_add_docstring = add_docstring
|
||||||
|
|
||||||
|
|
||||||
|
def add_docstring(*args):
|
||||||
|
try:
|
||||||
|
_add_docstring(*args)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
add_docstring(
|
||||||
|
implement_array_function,
|
||||||
|
"""
|
||||||
|
Implement a function with checks for __array_function__ overrides.
|
||||||
|
|
||||||
|
All arguments are required, and can only be passed by position.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
implementation : function
|
||||||
|
Function that implements the operation on NumPy array without
|
||||||
|
overrides when called like ``implementation(*args, **kwargs)``.
|
||||||
|
public_api : function
|
||||||
|
Function exposed by NumPy's public API originally called like
|
||||||
|
``public_api(*args, **kwargs)`` on which arguments are now being
|
||||||
|
checked.
|
||||||
|
relevant_args : iterable
|
||||||
|
Iterable of arguments to check for __array_function__ methods.
|
||||||
|
args : tuple
|
||||||
|
Arbitrary positional arguments originally passed into ``public_api``.
|
||||||
|
kwargs : dict
|
||||||
|
Arbitrary keyword arguments originally passed into ``public_api``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Result from calling ``implementation()`` or an ``__array_function__``
|
||||||
|
method, as appropriate.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
TypeError : if no implementation is found.
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
# exposed for testing purposes; used internally by implement_array_function
|
||||||
|
add_docstring(
|
||||||
|
_get_implementing_args,
|
||||||
|
"""
|
||||||
|
Collect arguments on which to call __array_function__.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
relevant_args : iterable of array-like
|
||||||
|
Iterable of possibly array-like arguments to check for
|
||||||
|
__array_function__ methods.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Sequence of arguments with __array_function__ methods, in the order in
|
||||||
|
which they should be called.
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
|
||||||
|
|
||||||
|
|
||||||
|
def verify_matching_signatures(implementation, dispatcher):
|
||||||
|
"""Verify that a dispatcher function has the right signature."""
|
||||||
|
implementation_spec = ArgSpec(*getargspec(implementation))
|
||||||
|
dispatcher_spec = ArgSpec(*getargspec(dispatcher))
|
||||||
|
|
||||||
|
if (implementation_spec.args != dispatcher_spec.args or
|
||||||
|
implementation_spec.varargs != dispatcher_spec.varargs or
|
||||||
|
implementation_spec.keywords != dispatcher_spec.keywords or
|
||||||
|
(bool(implementation_spec.defaults) !=
|
||||||
|
bool(dispatcher_spec.defaults)) or
|
||||||
|
(implementation_spec.defaults is not None and
|
||||||
|
len(implementation_spec.defaults) !=
|
||||||
|
len(dispatcher_spec.defaults))):
|
||||||
|
raise RuntimeError('implementation and dispatcher for %s have '
|
||||||
|
'different function signatures' % implementation)
|
||||||
|
|
||||||
|
if implementation_spec.defaults is not None:
|
||||||
|
if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
|
||||||
|
raise RuntimeError('dispatcher functions can only use None for '
|
||||||
|
'default argument values')
|
||||||
|
|
||||||
|
|
||||||
|
def set_module(module):
|
||||||
|
"""Decorator for overriding __module__ on a function or class.
|
||||||
|
|
||||||
|
Example usage::
|
||||||
|
|
||||||
|
@set_module('numpy')
|
||||||
|
def example():
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert example.__module__ == 'numpy'
|
||||||
|
"""
|
||||||
|
def decorator(func):
|
||||||
|
if module is not None:
|
||||||
|
func.__module__ = module
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def array_function_dispatch(dispatcher, module=None, verify=True,
|
||||||
|
docs_from_dispatcher=False):
|
||||||
|
"""Decorator for adding dispatch with the __array_function__ protocol.
|
||||||
|
|
||||||
|
See NEP-18 for example usage.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dispatcher : callable
|
||||||
|
Function that when called like ``dispatcher(*args, **kwargs)`` with
|
||||||
|
arguments from the NumPy function call returns an iterable of
|
||||||
|
array-like arguments to check for ``__array_function__``.
|
||||||
|
module : str, optional
|
||||||
|
__module__ attribute to set on new function, e.g., ``module='numpy'``.
|
||||||
|
By default, module is copied from the decorated function.
|
||||||
|
verify : bool, optional
|
||||||
|
If True, verify the that the signature of the dispatcher and decorated
|
||||||
|
function signatures match exactly: all required and optional arguments
|
||||||
|
should appear in order with the same names, but the default values for
|
||||||
|
all optional arguments should be ``None``. Only disable verification
|
||||||
|
if the dispatcher's signature needs to deviate for some particular
|
||||||
|
reason, e.g., because the function has a signature like
|
||||||
|
``func(*args, **kwargs)``.
|
||||||
|
docs_from_dispatcher : bool, optional
|
||||||
|
If True, copy docs from the dispatcher function onto the dispatched
|
||||||
|
function, rather than from the implementation. This is useful for
|
||||||
|
functions defined in C, which otherwise don't have docstrings.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Function suitable for decorating the implementation of a NumPy function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not ENABLE_ARRAY_FUNCTION:
|
||||||
|
# __array_function__ requires an explicit opt-in for now
|
||||||
|
def decorator(implementation):
|
||||||
|
if module is not None:
|
||||||
|
implementation.__module__ = module
|
||||||
|
if docs_from_dispatcher:
|
||||||
|
add_docstring(implementation, dispatcher.__doc__)
|
||||||
|
return implementation
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def decorator(implementation):
|
||||||
|
if verify:
|
||||||
|
verify_matching_signatures(implementation, dispatcher)
|
||||||
|
|
||||||
|
if docs_from_dispatcher:
|
||||||
|
add_docstring(implementation, dispatcher.__doc__)
|
||||||
|
|
||||||
|
@functools.wraps(implementation)
|
||||||
|
def public_api(*args, **kwargs):
|
||||||
|
relevant_args = dispatcher(*args, **kwargs)
|
||||||
|
return implement_array_function(
|
||||||
|
implementation, public_api, relevant_args, args, kwargs)
|
||||||
|
|
||||||
|
if module is not None:
|
||||||
|
public_api.__module__ = module
|
||||||
|
|
||||||
|
# TODO: remove this when we drop Python 2 support (functools.wraps)
|
||||||
|
# adds __wrapped__ automatically in later versions)
|
||||||
|
public_api.__wrapped__ = implementation
|
||||||
|
|
||||||
|
return public_api
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def array_function_from_dispatcher(
|
||||||
|
implementation, module=None, verify=True, docs_from_dispatcher=True):
|
||||||
|
"""Like array_function_dispatcher, but with function arguments flipped."""
|
||||||
|
|
||||||
|
def decorator(dispatcher):
|
||||||
|
return array_function_dispatch(
|
||||||
|
dispatcher, module, verify=verify,
|
||||||
|
docs_from_dispatcher=docs_from_dispatcher)(implementation)
|
||||||
|
return decorator
|
|
@ -0,0 +1,172 @@
|
||||||
|
#patch 1
|
||||||
|
|
||||||
|
"""
|
||||||
|
========================
|
||||||
|
Random Number Generation
|
||||||
|
========================
|
||||||
|
|
||||||
|
==================== =========================================================
|
||||||
|
Utility functions
|
||||||
|
==============================================================================
|
||||||
|
random_sample Uniformly distributed floats over ``[0, 1)``.
|
||||||
|
random Alias for `random_sample`.
|
||||||
|
bytes Uniformly distributed random bytes.
|
||||||
|
random_integers Uniformly distributed integers in a given range.
|
||||||
|
permutation Randomly permute a sequence / generate a random sequence.
|
||||||
|
shuffle Randomly permute a sequence in place.
|
||||||
|
seed Seed the random number generator.
|
||||||
|
choice Random sample from 1-D array.
|
||||||
|
|
||||||
|
==================== =========================================================
|
||||||
|
|
||||||
|
==================== =========================================================
|
||||||
|
Compatibility functions
|
||||||
|
==============================================================================
|
||||||
|
rand Uniformly distributed values.
|
||||||
|
randn Normally distributed values.
|
||||||
|
ranf Uniformly distributed floating point numbers.
|
||||||
|
randint Uniformly distributed integers in a given range.
|
||||||
|
==================== =========================================================
|
||||||
|
|
||||||
|
==================== =========================================================
|
||||||
|
Univariate distributions
|
||||||
|
==============================================================================
|
||||||
|
beta Beta distribution over ``[0, 1]``.
|
||||||
|
binomial Binomial distribution.
|
||||||
|
chisquare :math:`\\chi^2` distribution.
|
||||||
|
exponential Exponential distribution.
|
||||||
|
f F (Fisher-Snedecor) distribution.
|
||||||
|
gamma Gamma distribution.
|
||||||
|
geometric Geometric distribution.
|
||||||
|
gumbel Gumbel distribution.
|
||||||
|
hypergeometric Hypergeometric distribution.
|
||||||
|
laplace Laplace distribution.
|
||||||
|
logistic Logistic distribution.
|
||||||
|
lognormal Log-normal distribution.
|
||||||
|
logseries Logarithmic series distribution.
|
||||||
|
negative_binomial Negative binomial distribution.
|
||||||
|
noncentral_chisquare Non-central chi-square distribution.
|
||||||
|
noncentral_f Non-central F distribution.
|
||||||
|
normal Normal / Gaussian distribution.
|
||||||
|
pareto Pareto distribution.
|
||||||
|
poisson Poisson distribution.
|
||||||
|
power Power distribution.
|
||||||
|
rayleigh Rayleigh distribution.
|
||||||
|
triangular Triangular distribution.
|
||||||
|
uniform Uniform distribution.
|
||||||
|
vonmises Von Mises circular distribution.
|
||||||
|
wald Wald (inverse Gaussian) distribution.
|
||||||
|
weibull Weibull distribution.
|
||||||
|
zipf Zipf's distribution over ranked data.
|
||||||
|
==================== =========================================================
|
||||||
|
|
||||||
|
==================== =========================================================
|
||||||
|
Multivariate distributions
|
||||||
|
==============================================================================
|
||||||
|
dirichlet Multivariate generalization of Beta distribution.
|
||||||
|
multinomial Multivariate generalization of the binomial distribution.
|
||||||
|
multivariate_normal Multivariate generalization of the normal distribution.
|
||||||
|
==================== =========================================================
|
||||||
|
|
||||||
|
==================== =========================================================
|
||||||
|
Standard distributions
|
||||||
|
==============================================================================
|
||||||
|
standard_cauchy Standard Cauchy-Lorentz distribution.
|
||||||
|
standard_exponential Standard exponential distribution.
|
||||||
|
standard_gamma Standard Gamma distribution.
|
||||||
|
standard_normal Standard normal distribution.
|
||||||
|
standard_t Standard Student's t-distribution.
|
||||||
|
==================== =========================================================
|
||||||
|
|
||||||
|
==================== =========================================================
|
||||||
|
Internal functions
|
||||||
|
==============================================================================
|
||||||
|
get_state Get tuple representing internal state of generator.
|
||||||
|
set_state Set state of generator.
|
||||||
|
==================== =========================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import division, absolute_import, print_function
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'beta',
|
||||||
|
'binomial',
|
||||||
|
'bytes',
|
||||||
|
'chisquare',
|
||||||
|
'choice',
|
||||||
|
'dirichlet',
|
||||||
|
'exponential',
|
||||||
|
'f',
|
||||||
|
'gamma',
|
||||||
|
'geometric',
|
||||||
|
'get_state',
|
||||||
|
'gumbel',
|
||||||
|
'hypergeometric',
|
||||||
|
'laplace',
|
||||||
|
'logistic',
|
||||||
|
'lognormal',
|
||||||
|
'logseries',
|
||||||
|
'multinomial',
|
||||||
|
'multivariate_normal',
|
||||||
|
'negative_binomial',
|
||||||
|
'noncentral_chisquare',
|
||||||
|
'noncentral_f',
|
||||||
|
'normal',
|
||||||
|
'pareto',
|
||||||
|
'permutation',
|
||||||
|
'poisson',
|
||||||
|
'power',
|
||||||
|
'rand',
|
||||||
|
'randint',
|
||||||
|
'randn',
|
||||||
|
'random_integers',
|
||||||
|
'random_sample',
|
||||||
|
'rayleigh',
|
||||||
|
'seed',
|
||||||
|
'set_state',
|
||||||
|
'shuffle',
|
||||||
|
'standard_cauchy',
|
||||||
|
'standard_exponential',
|
||||||
|
'standard_gamma',
|
||||||
|
'standard_normal',
|
||||||
|
'standard_t',
|
||||||
|
'triangular',
|
||||||
|
'uniform',
|
||||||
|
'vonmises',
|
||||||
|
'wald',
|
||||||
|
'weibull',
|
||||||
|
'zipf'
|
||||||
|
]
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
|
||||||
|
try:
|
||||||
|
from .mtrand import *
|
||||||
|
# Some aliases:
|
||||||
|
ranf = random = sample = random_sample
|
||||||
|
__all__.extend(['ranf', 'random', 'sample'])
|
||||||
|
except:
|
||||||
|
warnings.warn("numpy.random is not available when using multiple interpreters!")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def __RandomState_ctor():
|
||||||
|
"""Return a RandomState instance.
|
||||||
|
|
||||||
|
This function exists solely to assist (un)pickling.
|
||||||
|
|
||||||
|
Note that the state of the RandomState returned here is irrelevant, as this function's
|
||||||
|
entire purpose is to return a newly allocated RandomState whose state pickle can set.
|
||||||
|
Consequently the RandomState returned by this function is a freshly allocated copy
|
||||||
|
with a seed=0.
|
||||||
|
|
||||||
|
See https://github.com/numpy/numpy/issues/4763 for a detailed discussion
|
||||||
|
|
||||||
|
"""
|
||||||
|
return RandomState(seed=0)
|
||||||
|
|
||||||
|
from numpy._pytesttester import PytestTester
|
||||||
|
test = PytestTester(__name__)
|
||||||
|
del PytestTester
|
|
@ -0,0 +1,20 @@
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import json
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
pass
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
except Exception as ex:
|
||||||
|
try:
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
finally:
|
||||||
|
print(ex)
|
||||||
|
traceback.print_exception(*exc_info)
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
def __is_numpy_array(x):
|
||||||
|
return str(type(x))== "<class 'numpy.ndarray'>"
|
||||||
|
|
||||||
|
def __maybe_serialize_ndarray_metadata(x):
|
||||||
|
return __serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
|
||||||
|
|
||||||
|
|
||||||
|
def __serialize_ndarray_metadata(x):
|
||||||
|
return {"address": x.__array_interface__['data'][0],
|
||||||
|
"shape": x.shape,
|
||||||
|
"strides": x.strides,
|
||||||
|
"dtype": str(x.dtype),
|
||||||
|
"_is_numpy_array": True} if __is_numpy_array(x) else x
|
||||||
|
|
||||||
|
|
||||||
|
def __serialize_list(x):
|
||||||
|
import json
|
||||||
|
return json.dumps(__recursive_serialize_list(x))
|
||||||
|
|
||||||
|
|
||||||
|
def __serialize_dict(x):
|
||||||
|
import json
|
||||||
|
return json.dumps(__recursive_serialize_dict(x))
|
||||||
|
|
||||||
|
def __recursive_serialize_list(x):
|
||||||
|
out = []
|
||||||
|
for i in x:
|
||||||
|
if __is_numpy_array(i):
|
||||||
|
out.append(__serialize_ndarray_metadata(i))
|
||||||
|
elif isinstance(i, (list, tuple)):
|
||||||
|
out.append(__recursive_serialize_list(i))
|
||||||
|
elif isinstance(i, dict):
|
||||||
|
out.append(__recursive_serialize_dict(i))
|
||||||
|
else:
|
||||||
|
out.append(i)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __recursive_serialize_dict(x):
|
||||||
|
out = {}
|
||||||
|
for k in x:
|
||||||
|
v = x[k]
|
||||||
|
if __is_numpy_array(v):
|
||||||
|
out[k] = __serialize_ndarray_metadata(v)
|
||||||
|
elif isinstance(v, (list, tuple)):
|
||||||
|
out[k] = __recursive_serialize_list(v)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
out[k] = __recursive_serialize_dict(v)
|
||||||
|
else:
|
||||||
|
out[k] = v
|
||||||
|
return out
|
|
@ -0,0 +1,75 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
|
public class TestPythonExecutionSandbox {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInt(){
|
||||||
|
PythonExecutioner.setInterpreter("interp1");
|
||||||
|
PythonExecutioner.exec("a = 1");
|
||||||
|
PythonExecutioner.setInterpreter("interp2");
|
||||||
|
PythonExecutioner.exec("a = 2");
|
||||||
|
PythonExecutioner.setInterpreter("interp3");
|
||||||
|
PythonExecutioner.exec("a = 3");
|
||||||
|
|
||||||
|
|
||||||
|
PythonExecutioner.setInterpreter("interp1");
|
||||||
|
Assert.assertEquals(1, PythonExecutioner.evalInteger("a"));
|
||||||
|
|
||||||
|
PythonExecutioner.setInterpreter("interp2");
|
||||||
|
Assert.assertEquals(2, PythonExecutioner.evalInteger("a"));
|
||||||
|
|
||||||
|
PythonExecutioner.setInterpreter("interp3");
|
||||||
|
Assert.assertEquals(3, PythonExecutioner.evalInteger("a"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNDArray(){
|
||||||
|
PythonExecutioner.setInterpreter("main");
|
||||||
|
PythonExecutioner.exec("import numpy as np");
|
||||||
|
PythonExecutioner.exec("a = np.zeros(5)");
|
||||||
|
|
||||||
|
PythonExecutioner.setInterpreter("main");
|
||||||
|
//PythonExecutioner.exec("import numpy as np");
|
||||||
|
PythonExecutioner.exec("a = np.zeros(5)");
|
||||||
|
|
||||||
|
PythonExecutioner.setInterpreter("main");
|
||||||
|
PythonExecutioner.exec("a += 2");
|
||||||
|
|
||||||
|
PythonExecutioner.setInterpreter("main");
|
||||||
|
PythonExecutioner.exec("a += 3");
|
||||||
|
|
||||||
|
PythonExecutioner.setInterpreter("main");
|
||||||
|
//PythonExecutioner.exec("import numpy as np");
|
||||||
|
// PythonExecutioner.exec("a = np.zeros(5)");
|
||||||
|
|
||||||
|
PythonExecutioner.setInterpreter("main");
|
||||||
|
Assert.assertEquals(25, PythonExecutioner.evalNdArray("a").getNd4jArray().sum().getDouble(), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNumpyRandom(){
|
||||||
|
PythonExecutioner.setInterpreter("main");
|
||||||
|
PythonExecutioner.exec("import numpy as np; print(np.random.randint(5))");
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,17 +15,25 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
import org.junit.Ignore;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
|
|
||||||
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
public class TestPythonExecutioner {
|
public class TestPythonExecutioner {
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
|
||||||
|
@org.junit.Test
|
||||||
|
public void testPythonSysVersion() {
|
||||||
|
PythonExecutioner.exec("import sys; print(sys.version)");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
public void testStr() throws Exception{
|
public void testStr() throws Exception{
|
||||||
|
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
@ -47,7 +55,7 @@ public class TestPythonExecutioner {
|
||||||
assertEquals("Hello World", z);
|
assertEquals("Hello World", z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test
|
||||||
public void testInt()throws Exception{
|
public void testInt()throws Exception{
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
@ -55,7 +63,7 @@ public class TestPythonExecutioner {
|
||||||
pyInputs.addInt("x", 10);
|
pyInputs.addInt("x", 10);
|
||||||
pyInputs.addInt("y", 20);
|
pyInputs.addInt("y", 20);
|
||||||
|
|
||||||
String code = "z = x + y";
|
String code = "z = x + y";
|
||||||
|
|
||||||
pyOutputs.addInt("z");
|
pyOutputs.addInt("z");
|
||||||
|
|
||||||
|
@ -64,11 +72,11 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
long z = pyOutputs.getIntValue("z");
|
long z = pyOutputs.getIntValue("z");
|
||||||
|
|
||||||
assertEquals(30, z);
|
Assert.assertEquals(30, z);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test
|
||||||
public void testList() throws Exception{
|
public void testList() throws Exception{
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
@ -88,18 +96,35 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
Object[] z = pyOutputs.getListValue("z");
|
Object[] z = pyOutputs.getListValue("z");
|
||||||
|
|
||||||
assertEquals(z.length, x.length + y.length);
|
Assert.assertEquals(z.length, x.length + y.length);
|
||||||
|
|
||||||
|
for (int i = 0; i < x.length; i++) {
|
||||||
|
if(x[i] instanceof Number) {
|
||||||
|
Number xNum = (Number) x[i];
|
||||||
|
Number zNum = (Number) z[i];
|
||||||
|
Assert.assertEquals(xNum.intValue(), zNum.intValue());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
Assert.assertEquals(x[i], z[i]);
|
||||||
|
}
|
||||||
|
|
||||||
for (int i=0; i < x.length; i++){
|
|
||||||
assertEquals(x[i], z[i]);
|
|
||||||
}
|
}
|
||||||
for (int i=0; i<y.length; i++){
|
for (int i = 0; i < y.length; i++){
|
||||||
assertEquals(y[i], z[x.length + i]);
|
if(y[i] instanceof Number) {
|
||||||
|
Number yNum = (Number) y[i];
|
||||||
|
Number zNum = (Number) z[x.length + i];
|
||||||
|
Assert.assertEquals(yNum.intValue(), zNum.intValue());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
Assert.assertEquals(y[i], z[x.length + i]);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test
|
||||||
public void testNDArrayFloat()throws Exception{
|
public void testNDArrayFloat()throws Exception{
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
@ -113,12 +138,17 @@ public class TestPythonExecutioner {
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
||||||
|
|
||||||
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test
|
||||||
|
public void testTensorflowCustomAnaconda() {
|
||||||
|
PythonExecutioner.exec("import tensorflow as tf");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
public void testNDArrayDouble()throws Exception {
|
public void testNDArrayDouble()throws Exception {
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
@ -132,10 +162,10 @@ public class TestPythonExecutioner {
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
||||||
|
|
||||||
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test
|
||||||
public void testNDArrayShort()throws Exception{
|
public void testNDArrayShort()throws Exception{
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
@ -149,11 +179,11 @@ public class TestPythonExecutioner {
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
||||||
|
|
||||||
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test
|
||||||
public void testNDArrayInt()throws Exception{
|
public void testNDArrayInt()throws Exception{
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
@ -167,11 +197,11 @@ public class TestPythonExecutioner {
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
||||||
|
|
||||||
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test
|
||||||
public void testNDArrayLong()throws Exception{
|
public void testNDArrayLong()throws Exception{
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
@ -185,7 +215,7 @@ public class TestPythonExecutioner {
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
||||||
|
|
||||||
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
|
public class TestPythonSetupAndRun {
|
||||||
|
@Test
|
||||||
|
public void testPythonWithSetupAndRun() throws Exception{
|
||||||
|
String code = "def setup():" +
|
||||||
|
"global counter;counter=0\n" +
|
||||||
|
"def run(step):" +
|
||||||
|
"global counter;" +
|
||||||
|
"counter+=step;" +
|
||||||
|
"return {\"counter\":counter}";
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
pyInputs.addInt("step", 2);
|
||||||
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
pyOutputs.addInt("counter");
|
||||||
|
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
|
||||||
|
assertEquals((long)pyOutputs.getIntValue("counter"), 2L);
|
||||||
|
pyInputs.addInt("step", 3);
|
||||||
|
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
|
||||||
|
assertEquals((long)pyOutputs.getIntValue("counter"), 5L);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,102 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* * ******************************************************************************
|
||||||
|
* * * Copyright (c) 2015-2019 Skymind Inc.
|
||||||
|
* * * Copyright (c) 2019 Konduit AI.
|
||||||
|
* * *
|
||||||
|
* * * This program and the accompanying materials are made available under the
|
||||||
|
* * * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* * *
|
||||||
|
* * * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * * License for the specific language governing permissions and limitations
|
||||||
|
* * * under the License.
|
||||||
|
* * *
|
||||||
|
* * * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* * *****************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
|
||||||
|
import static junit.framework.TestCase.assertNotNull;
|
||||||
|
import static junit.framework.TestCase.assertNull;
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
public class TestPythonVariables {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testImportNumpy(){
|
||||||
|
Nd4j.scalar(1.0);
|
||||||
|
System.out.println(System.getProperty("org.bytedeco.openblas.load"));
|
||||||
|
PythonExecutioner.exec("import numpy as np");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDataAssociations() {
|
||||||
|
PythonVariables pythonVariables = new PythonVariables();
|
||||||
|
PythonVariables.Type[] types = {
|
||||||
|
PythonVariables.Type.INT,
|
||||||
|
PythonVariables.Type.FLOAT,
|
||||||
|
PythonVariables.Type.STR,
|
||||||
|
PythonVariables.Type.BOOL,
|
||||||
|
PythonVariables.Type.DICT,
|
||||||
|
PythonVariables.Type.LIST,
|
||||||
|
PythonVariables.Type.LIST,
|
||||||
|
PythonVariables.Type.FILE,
|
||||||
|
PythonVariables.Type.NDARRAY
|
||||||
|
};
|
||||||
|
|
||||||
|
NumpyArray npArr = new NumpyArray(Nd4j.scalar(1.0));
|
||||||
|
Object[] values = {
|
||||||
|
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
||||||
|
new Object[]{1}, Arrays.asList(1),"type", npArr
|
||||||
|
};
|
||||||
|
|
||||||
|
Object[] expectedValues = {
|
||||||
|
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
||||||
|
new Object[]{1}, new Object[]{1},"type", npArr
|
||||||
|
};
|
||||||
|
|
||||||
|
for(int i = 0; i < types.length; i++) {
|
||||||
|
testInsertGet(pythonVariables,types[i].name() + i,values[i],types[i],expectedValues[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(types.length,pythonVariables.getVariables().length);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonVariables.Type type,Object expectedValue) {
|
||||||
|
pythonVariables.add(key, type);
|
||||||
|
assertNull(pythonVariables.getValue(key));
|
||||||
|
pythonVariables.setValue(key,value);
|
||||||
|
assertNotNull(pythonVariables.getValue(key));
|
||||||
|
Object actualValue = pythonVariables.getValue(key);
|
||||||
|
if (expectedValue instanceof Object[]){
|
||||||
|
assertTrue(actualValue instanceof Object[]);
|
||||||
|
Object[] actualArr = (Object[])actualValue;
|
||||||
|
Object[] expectedArr = (Object[])expectedValue;
|
||||||
|
assertArrayEquals(expectedArr, actualArr);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
assertEquals(expectedValue,pythonVariables.getValue(key));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -29,7 +29,7 @@ public class TestSerde {
|
||||||
public static JsonSerializer j = new JsonSerializer();
|
public static JsonSerializer j = new JsonSerializer();
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test(timeout = 60000L)
|
||||||
public void testBasicSerde() throws Exception{
|
public void testBasicSerde(){
|
||||||
Schema schema = new Schema.Builder()
|
Schema schema = new Schema.Builder()
|
||||||
.addColumnInteger("col1")
|
.addColumnInteger("col1")
|
||||||
.addColumnFloat("col2")
|
.addColumnFloat("col2")
|
||||||
|
@ -37,10 +37,9 @@ public class TestSerde {
|
||||||
.addColumnDouble("col4")
|
.addColumnDouble("col4")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
Transform t = new PythonTransform(
|
Transform t = PythonTransform.builder().code(
|
||||||
"col1+=3\ncol2+=2\ncol3+='a'\ncol4+=2.0",
|
"col1+=3\ncol2+=2\ncol3+='a'\ncol4+=2.0"
|
||||||
schema
|
).inputSchema(schema).outputSchema(schema).build();
|
||||||
);
|
|
||||||
|
|
||||||
String yaml = y.serialize(t);
|
String yaml = y.serialize(t);
|
||||||
String json = j.serialize(t);
|
String json = j.serialize(t);
|
||||||
|
|
|
@ -58,7 +58,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -52,7 +52,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -171,7 +171,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -38,7 +38,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -138,7 +138,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -247,10 +247,9 @@ public class ExecutionTest extends BaseSparkTest {
|
||||||
.addColumnInteger("col1").addColumnDouble("col2").build();
|
.addColumnInteger("col1").addColumnDouble("col2").build();
|
||||||
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
|
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
|
||||||
TransformProcess tp = new TransformProcess.Builder(schema).transform(
|
TransformProcess tp = new TransformProcess.Builder(schema).transform(
|
||||||
new PythonTransform(
|
PythonTransform.builder().code(
|
||||||
pythonCode,
|
"first = np.sin(first)\nsecond = np.cos(second)")
|
||||||
finalSchema
|
.outputSchema(finalSchema).build()
|
||||||
)
|
|
||||||
).build();
|
).build();
|
||||||
List<List<Writable>> inputData = new ArrayList<>();
|
List<List<Writable>> inputData = new ArrayList<>();
|
||||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||||
|
@ -288,10 +287,9 @@ public class ExecutionTest extends BaseSparkTest {
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
String pythonCode = "col3 = col1 + col2";
|
||||||
TransformProcess tp = new TransformProcess.Builder(schema).transform(
|
TransformProcess tp = new TransformProcess.Builder(schema).transform(
|
||||||
new PythonTransform(
|
PythonTransform.builder().code(
|
||||||
pythonCode,
|
"first = np.sin(first)\nsecond = np.cos(second)")
|
||||||
finalSchema
|
.outputSchema(schema).build()
|
||||||
)
|
|
||||||
).build();
|
).build();
|
||||||
|
|
||||||
INDArray zeros = Nd4j.zeros(shape);
|
INDArray zeros = Nd4j.zeros(shape);
|
||||||
|
|
|
@ -112,7 +112,7 @@
|
||||||
<skip>${skipTestResourceEnforcement}</skip>
|
<skip>${skipTestResourceEnforcement}</skip>
|
||||||
<rules>
|
<rules>
|
||||||
<requireActiveProfile>
|
<requireActiveProfile>
|
||||||
<profiles>test-nd4j-native,test-nd4j-cuda-10.1</profiles>
|
<profiles>test-nd4j-native,test-nd4j-cuda-10.2</profiles>
|
||||||
<all>false</all>
|
<all>false</all>
|
||||||
</requireActiveProfile>
|
</requireActiveProfile>
|
||||||
</rules>
|
</rules>
|
||||||
|
@ -365,11 +365,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
<artifactId>nd4j-cuda-10.2</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -40,7 +40,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -163,11 +163,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
<artifactId>nd4j-cuda-10.2</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -27,6 +27,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.util.Convolution1DUtils;
|
||||||
|
import org.deeplearning4j.util.ConvolutionUtils;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -442,4 +444,76 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnn1Causal() {
|
||||||
|
int convNIn = 2;
|
||||||
|
int convNOut1 = 3;
|
||||||
|
int convNOut2 = 4;
|
||||||
|
int finalNOut = 3;
|
||||||
|
|
||||||
|
int[] lengths = {11, 12, 13, 9, 10, 11};
|
||||||
|
int[] kernels = {2, 3, 2, 4, 2, 3};
|
||||||
|
int[] dilations = {1, 1, 2, 1, 2, 1};
|
||||||
|
int[] strides = {1, 2, 1, 2, 1, 1};
|
||||||
|
boolean[] masks = {false, true, false, true, false, true};
|
||||||
|
boolean[] hasB = {true, false, true, false, true, true};
|
||||||
|
|
||||||
|
for (int i = 0; i < lengths.length; i++) {
|
||||||
|
int length = lengths[i];
|
||||||
|
int k = kernels[i];
|
||||||
|
int d = dilations[i];
|
||||||
|
int st = strides[i];
|
||||||
|
boolean mask = masks[i];
|
||||||
|
boolean hasBias = hasB[i];
|
||||||
|
//TODO has bias
|
||||||
|
String s = "k=" + k + ", s=" + st + "d=" + d + ", seqLen=" + length;
|
||||||
|
log.info("Starting test: " + s);
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.weightInit(new NormalDistribution(0, 1))
|
||||||
|
.seed(12345)
|
||||||
|
.list()
|
||||||
|
.layer(new Convolution1DLayer.Builder().kernelSize(k)
|
||||||
|
.dilation(d)
|
||||||
|
.hasBias(hasBias)
|
||||||
|
.convolutionMode(ConvolutionMode.Causal)
|
||||||
|
.stride(st).nIn(convNIn).nOut(convNOut1)
|
||||||
|
.build())
|
||||||
|
.layer(new Convolution1DLayer.Builder().kernelSize(k)
|
||||||
|
.dilation(d)
|
||||||
|
.convolutionMode(ConvolutionMode.Causal)
|
||||||
|
.stride(st).nIn(convNOut1).nOut(convNOut2)
|
||||||
|
.build())
|
||||||
|
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
|
||||||
|
.setInputType(InputType.recurrent(convNIn, length)).build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length);
|
||||||
|
INDArray fm = null;
|
||||||
|
if (mask) {
|
||||||
|
fm = Nd4j.create(2, length);
|
||||||
|
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
|
||||||
|
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length-2)).assign(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d);
|
||||||
|
long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d);
|
||||||
|
|
||||||
|
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2);
|
||||||
|
|
||||||
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null);
|
||||||
|
|
||||||
|
assertTrue(s, gradOK);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,12 +21,14 @@ import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
|
@ -289,4 +291,66 @@ public class RnnGradientChecks extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTimeDistributedDense() {
|
||||||
|
int nIn = 3;
|
||||||
|
int nOut = 5;
|
||||||
|
int tsLength = 4;
|
||||||
|
int layerSize = 8;
|
||||||
|
|
||||||
|
Random r = new Random(12345);
|
||||||
|
for (int mb : new int[]{1, 3}) {
|
||||||
|
for (boolean inputMask : new boolean[]{false, true}) {
|
||||||
|
|
||||||
|
|
||||||
|
INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength});
|
||||||
|
INDArray labels = TestUtils.randomOneHotTimeSeries(mb, nOut, tsLength);
|
||||||
|
String maskType = (inputMask ? "inputMask" : "none");
|
||||||
|
|
||||||
|
INDArray inMask = null;
|
||||||
|
if (inputMask) {
|
||||||
|
inMask = Nd4j.ones(mb, tsLength);
|
||||||
|
for (int i = 0; i < mb; i++) {
|
||||||
|
int firstMaskedStep = tsLength - 1 - i;
|
||||||
|
if (firstMaskedStep == 0) {
|
||||||
|
firstMaskedStep = tsLength;
|
||||||
|
}
|
||||||
|
for (int j = firstMaskedStep; j < tsLength; j++) {
|
||||||
|
inMask.putScalar(i, j, 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
String name = "testLastTimeStepLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType;
|
||||||
|
if (PRINT_RESULTS) {
|
||||||
|
System.out.println("Starting test: " + name);
|
||||||
|
}
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.list()
|
||||||
|
.layer(new LSTM.Builder().nOut(layerSize).build())
|
||||||
|
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build(), 2))
|
||||||
|
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
.setInputType(InputType.recurrent(nIn))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 16);
|
||||||
|
assertTrue(name, gradOK);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -28,6 +29,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
|
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer;
|
import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -485,4 +487,32 @@ public class TestPreProcessors extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(15 * 15 * 10, ((FeedForwardLayer) conf.getConf(1).getLayer()).getNIn());
|
assertEquals(15 * 15 * 10, ((FeedForwardLayer) conf.getConf(1).getLayer()).getNIn());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPreprocessorVertex(){
|
||||||
|
for(boolean withMinibatchDim : new boolean[]{true, false}){
|
||||||
|
long[] inShape = withMinibatchDim ? new long[]{-1, 32} : new long[]{32};
|
||||||
|
long[] targetShape = withMinibatchDim ? new long[]{-1, 2, 4, 4} : new long[]{2, 4, 4};
|
||||||
|
|
||||||
|
for( long minibatch : new long[]{1, 3}) {
|
||||||
|
long[] inArrayShape = new long[]{minibatch, 32};
|
||||||
|
long[] targetArrayShape = new long[]{minibatch, 2, 4, 4};
|
||||||
|
long length = minibatch * 32;
|
||||||
|
|
||||||
|
INDArray in = Nd4j.linspace(1, length, length).reshape('c', inArrayShape);
|
||||||
|
|
||||||
|
ReshapePreprocessor pp = new ReshapePreprocessor(inShape, targetShape, withMinibatchDim);
|
||||||
|
|
||||||
|
for( int i=0; i<3; i++ ) {
|
||||||
|
INDArray out = pp.preProcess(in, (int) minibatch, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
INDArray expOut = in.reshape(targetArrayShape);
|
||||||
|
assertEquals(expOut, out);
|
||||||
|
|
||||||
|
INDArray backprop = pp.backprop(expOut, (int)minibatch, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
assertEquals(in, backprop);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.dtypes;
|
package org.deeplearning4j.nn.dtypes;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -811,7 +812,8 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
.layer(new DenseLayer.Builder().nOut(5).build())
|
.layer(new DenseLayer.Builder().nOut(5).build())
|
||||||
.layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
.layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
||||||
.layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()))
|
.layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()))
|
||||||
.layer(new SimpleRnn.Builder().nIn(10).nOut(5).build())
|
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build(), 2))
|
||||||
|
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build())
|
||||||
.layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
|
.layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
|
||||||
.layer(secondLast)
|
.layer(secondLast)
|
||||||
.layer(ol)
|
.layer(ol)
|
||||||
|
|
|
@ -712,4 +712,73 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
||||||
assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"));
|
assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv1dCausalAllowed(){
|
||||||
|
new Convolution1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||||
|
new Subsampling1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv2dNoCausalAllowed(){
|
||||||
|
|
||||||
|
try{
|
||||||
|
new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
new Deconvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
new DepthwiseConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
new SeparableConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv3dNoCausalAllowed(){
|
||||||
|
try{
|
||||||
|
new Convolution3D.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
new Subsampling3DLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
|
||||||
|
fail("Expected exception");
|
||||||
|
} catch (Throwable t){
|
||||||
|
String m = t.getMessage().toLowerCase();
|
||||||
|
assertTrue(m, m.contains("causal") && m.contains("1d"));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
package org.deeplearning4j.nn.layers.recurrent;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.TestUtils;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class TestTimeDistributed extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTimeDistributed(){
|
||||||
|
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
||||||
|
.trainingWorkspaceMode(wsm)
|
||||||
|
.inferenceWorkspaceMode(wsm)
|
||||||
|
.seed(12345)
|
||||||
|
.updater(new Adam(0.1))
|
||||||
|
.list()
|
||||||
|
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
|
||||||
|
.layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build())
|
||||||
|
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
.setInputType(InputType.recurrent(3))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||||
|
.trainingWorkspaceMode(wsm)
|
||||||
|
.inferenceWorkspaceMode(wsm)
|
||||||
|
.seed(12345)
|
||||||
|
.updater(new Adam(0.1))
|
||||||
|
.list()
|
||||||
|
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
|
||||||
|
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), 2))
|
||||||
|
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
.setInputType(InputType.recurrent(3))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
||||||
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
|
net1.init();
|
||||||
|
net2.init();
|
||||||
|
|
||||||
|
for( int mb : new int[]{1, 5}) {
|
||||||
|
for(char inLabelOrder : new char[]{'c', 'f'}) {
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3, 5).dup(inLabelOrder);
|
||||||
|
|
||||||
|
INDArray out1 = net1.output(in);
|
||||||
|
INDArray out2 = net2.output(in);
|
||||||
|
|
||||||
|
assertEquals(out1, out2);
|
||||||
|
|
||||||
|
INDArray labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
|
||||||
|
|
||||||
|
DataSet ds = new DataSet(in, labels);
|
||||||
|
net1.fit(ds);
|
||||||
|
net2.fit(ds);
|
||||||
|
|
||||||
|
assertEquals(net1.params(), net2.params());
|
||||||
|
|
||||||
|
MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2);
|
||||||
|
out2 = net2.output(in);
|
||||||
|
INDArray out3 = net3.output(in);
|
||||||
|
|
||||||
|
assertEquals(out2, out3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<artifactId>deeplearning4j-cuda-10.1</artifactId>
|
<artifactId>deeplearning4j-cuda-10.2</artifactId>
|
||||||
<name>deeplearning4j-cuda</name>
|
<name>deeplearning4j-cuda</name>
|
||||||
<parent>
|
<parent>
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
|
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
|
||||||
<cuda.version>10.1</cuda.version>
|
<cuda.version>10.2</cuda.version>
|
||||||
<cudnn.version>7.6</cudnn.version>
|
<cudnn.version>7.6</cudnn.version>
|
||||||
<javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version>
|
<javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
@ -106,7 +106,7 @@
|
||||||
</build>
|
</build>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -48,7 +48,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -38,7 +38,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -116,11 +116,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
<artifactId>nd4j-cuda-10.2</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -58,7 +58,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -41,7 +41,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -302,11 +302,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
<artifactId>nd4j-cuda-10.2</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -115,11 +115,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
<artifactId>nd4j-cuda-10.2</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -356,6 +356,10 @@ public class KerasLayer {
|
||||||
return this.layer;
|
return this.layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void setLayer(Layer layer){
|
||||||
|
this.layer = layer;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Whether this Keras layer maps to a DL4J Vertex.
|
* Whether this Keras layer maps to a DL4J Vertex.
|
||||||
*
|
*
|
||||||
|
|
|
@ -233,6 +233,7 @@ public class KerasLayerConfiguration {
|
||||||
private final String LAYER_BORDER_MODE_SAME = "same";
|
private final String LAYER_BORDER_MODE_SAME = "same";
|
||||||
private final String LAYER_BORDER_MODE_VALID = "valid";
|
private final String LAYER_BORDER_MODE_VALID = "valid";
|
||||||
private final String LAYER_BORDER_MODE_FULL = "full";
|
private final String LAYER_BORDER_MODE_FULL = "full";
|
||||||
|
private final String LAYER_BORDER_MODE_CAUSAL = "causal";
|
||||||
|
|
||||||
/* Noise layers */
|
/* Noise layers */
|
||||||
private final String LAYER_FIELD_RATE = "rate";
|
private final String LAYER_FIELD_RATE = "rate";
|
||||||
|
|
|
@ -124,7 +124,26 @@ public class KerasInput extends KerasLayer {
|
||||||
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
|
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
|
if(this.dimOrder != null) {
|
||||||
|
switch (this.dimOrder) {
|
||||||
|
case TENSORFLOW: //NWC == channels_last
|
||||||
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
|
||||||
|
break;
|
||||||
|
case THEANO: //NCW == channels_first
|
||||||
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||||
|
break;
|
||||||
|
case NONE:
|
||||||
|
//Assume RNN in [mb, seqLen, size] format
|
||||||
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
//Assume RNN in [mb, seqLen, size] format
|
||||||
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||||
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
switch (this.dimOrder) {
|
switch (this.dimOrder) {
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -96,13 +97,13 @@ public class KerasLoss extends KerasLayer {
|
||||||
*/
|
*/
|
||||||
public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException {
|
public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException {
|
||||||
if (type instanceof InputType.InputTypeFeedForward) {
|
if (type instanceof InputType.InputTypeFeedForward) {
|
||||||
this.layer = new LossLayer.Builder(loss).name(this.layerName).build();
|
this.layer = new LossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
|
||||||
}
|
}
|
||||||
else if (type instanceof InputType.InputTypeRecurrent) {
|
else if (type instanceof InputType.InputTypeRecurrent) {
|
||||||
this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).build();
|
this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
|
||||||
}
|
}
|
||||||
else if (type instanceof InputType.InputTypeConvolutional) {
|
else if (type instanceof InputType.InputTypeConvolutional) {
|
||||||
this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).build();
|
this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
|
||||||
} else {
|
} else {
|
||||||
throw new UnsupportedKerasConfigurationException("Unsupported output layer type"
|
throw new UnsupportedKerasConfigurationException("Unsupported output layer type"
|
||||||
+ "got : " + type.toString());
|
+ "got : " + type.toString());
|
||||||
|
|
|
@ -79,7 +79,6 @@ abstract public class KerasConvolution extends KerasLayer {
|
||||||
public KerasConvolution(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
|
public KerasConvolution(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
|
||||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||||
super(layerConfig, enforceTrainingConfig);
|
super(layerConfig, enforceTrainingConfig);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -185,18 +185,11 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case THEANO:
|
case THEANO:
|
||||||
paramValue = kerasParamValue.permute(2, 1, 0);
|
//Convert from keras [k,nIn,nOut] to DL4J conv2d [nOut, nIn, k, 1]
|
||||||
paramValue = paramValue.reshape(
|
long k = kerasParamValue.size(0);
|
||||||
paramValue.size(0), paramValue.size(1),
|
long nIn = kerasParamValue.size(1);
|
||||||
paramValue.size(2), 1).dup();
|
long nOut = kerasParamValue.size(2);
|
||||||
for (int i = 0; i < paramValue.tensorsAlongDimension(2, 3); i++) {
|
paramValue = kerasParamValue.permute(2, 1, 0).dup('c').reshape(nOut, nIn, k, 1);
|
||||||
INDArray copyFilter = paramValue.tensorAlongDimension(i, 2, 3).dup();
|
|
||||||
double[] flattenedFilter = copyFilter.ravel().data().asDouble();
|
|
||||||
ArrayUtils.reverse(flattenedFilter);
|
|
||||||
INDArray newFilter = Nd4j.create(flattenedFilter, copyFilter.shape());
|
|
||||||
INDArray inPlaceFilter = paramValue.tensorAlongDimension(i, 2, 3);
|
|
||||||
inPlaceFilter.muli(0).addi(newFilter.castTo(inPlaceFilter.dataType()));
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder());
|
throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder());
|
||||||
|
|
|
@ -264,7 +264,8 @@ public class KerasConvolutionUtils {
|
||||||
} else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) ||
|
} else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) ||
|
||||||
borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) {
|
borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) {
|
||||||
convolutionMode = ConvolutionMode.Truncate;
|
convolutionMode = ConvolutionMode.Truncate;
|
||||||
|
} else if(borderMode.equals(conf.getLAYER_BORDER_MODE_CAUSAL())) {
|
||||||
|
convolutionMode = ConvolutionMode.Causal;
|
||||||
} else {
|
} else {
|
||||||
throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode);
|
throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode);
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,7 +111,7 @@ public class KerasFlatten extends KerasLayer {
|
||||||
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
|
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
|
||||||
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
||||||
val inputShape = new long[]{it.getSize()};
|
val inputShape = new long[]{it.getSize()};
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, inputShape);
|
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
|
||||||
}
|
}
|
||||||
return preprocessor;
|
return preprocessor;
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,11 +111,11 @@ public class KerasReshape extends KerasLayer {
|
||||||
} else {
|
} else {
|
||||||
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
|
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
|
||||||
}
|
}
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||||
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
|
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
|
||||||
if (inputShape[0] != targetShape[0])
|
if (inputShape[0] != targetShape[0])
|
||||||
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
|
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
|
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
|
||||||
|
@ -128,23 +128,23 @@ public class KerasReshape extends KerasLayer {
|
||||||
} else {
|
} else {
|
||||||
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
|
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
|
||||||
}
|
}
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||||
} else {
|
} else {
|
||||||
if (inputShape[0] != targetShape[0])
|
if (inputShape[0] != targetShape[0])
|
||||||
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
|
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||||
}
|
}
|
||||||
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
|
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
|
||||||
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
|
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
|
||||||
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
|
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
|
||||||
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
|
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
|
||||||
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
||||||
val inputShape = new long[]{it.getSize()};
|
val inputShape = new long[]{it.getSize()};
|
||||||
if (targetShape.length == 3) {
|
if (targetShape.length == 3) {
|
||||||
targetShape = targetShapeForDimOrder(inputShape, targetShape);
|
targetShape = targetShapeForDimOrder(inputShape, targetShape);
|
||||||
}
|
}
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
|
||||||
}
|
}
|
||||||
return preprocessor;
|
return preprocessor;
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,11 +23,13 @@ import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -186,6 +188,9 @@ public class KerasLSTM extends KerasLayer {
|
||||||
.biasInit(0.0) // TODO: this is incorrect
|
.biasInit(0.0) // TODO: this is incorrect
|
||||||
.l1(this.weightL1Regularization)
|
.l1(this.weightL1Regularization)
|
||||||
.l2(this.weightL2Regularization);
|
.l2(this.weightL2Regularization);
|
||||||
|
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
||||||
|
if(nIn != null)
|
||||||
|
builder.setNIn(nIn);
|
||||||
if (biasConstraint != null)
|
if (biasConstraint != null)
|
||||||
builder.constrainBias(biasConstraint);
|
builder.constrainBias(biasConstraint);
|
||||||
if (weightConstraint != null)
|
if (weightConstraint != null)
|
||||||
|
@ -436,6 +441,20 @@ public class KerasLSTM extends KerasLayer {
|
||||||
log.warn("Attemping to set weights for unknown parameters: "
|
log.warn("Attemping to set weights for unknown parameters: "
|
||||||
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
|
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
FeedForwardLayer ffl;
|
||||||
|
if(this.layer instanceof BaseWrapperLayer){
|
||||||
|
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
|
||||||
|
ffl = (FeedForwardLayer)bwl.getUnderlying();
|
||||||
|
} else {
|
||||||
|
ffl = (FeedForwardLayer) this.layer;
|
||||||
|
}
|
||||||
|
if(ffl.getNIn() != wRows){
|
||||||
|
//Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config)
|
||||||
|
//We can reliably infer nIn from the shape of the weights array however
|
||||||
|
ffl.setNIn(wRows);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -22,11 +22,13 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -154,6 +156,9 @@ public class KerasSimpleRnn extends KerasLayer {
|
||||||
.biasInit(0.0)
|
.biasInit(0.0)
|
||||||
.l1(this.weightL1Regularization)
|
.l1(this.weightL1Regularization)
|
||||||
.l2(this.weightL2Regularization);
|
.l2(this.weightL2Regularization);
|
||||||
|
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
||||||
|
if(nIn != null)
|
||||||
|
builder.setNIn(nIn);
|
||||||
if (biasConstraint != null)
|
if (biasConstraint != null)
|
||||||
builder.constrainBias(biasConstraint);
|
builder.constrainBias(biasConstraint);
|
||||||
if (weightConstraint != null)
|
if (weightConstraint != null)
|
||||||
|
@ -282,6 +287,19 @@ public class KerasSimpleRnn extends KerasLayer {
|
||||||
log.warn("Attemping to set weights for unknown parameters: "
|
log.warn("Attemping to set weights for unknown parameters: "
|
||||||
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
|
+ unknownParamNames.substring(1, unknownParamNames.length() - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FeedForwardLayer ffl;
|
||||||
|
if(this.layer instanceof BaseWrapperLayer){
|
||||||
|
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
|
||||||
|
ffl = (FeedForwardLayer)bwl.getUnderlying();
|
||||||
|
} else {
|
||||||
|
ffl = (FeedForwardLayer) this.layer;
|
||||||
|
}
|
||||||
|
if(ffl.getNIn() != W.rows()){
|
||||||
|
//Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config)
|
||||||
|
//We can reliably infer nIn from the shape of the weights array however
|
||||||
|
ffl.setNIn(W.rows());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -229,8 +229,8 @@ public class KerasBidirectional extends KerasLayer {
|
||||||
@Override
|
@Override
|
||||||
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
|
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
|
||||||
|
|
||||||
Map<String, INDArray> forwardWeights = getUnderlyingWeights(weights, "forward");
|
Map<String, INDArray> forwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getFwd(), weights, "forward");
|
||||||
Map<String, INDArray> backwardWeights = getUnderlyingWeights(weights, "backward");
|
Map<String, INDArray> backwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getBwd(), weights, "backward");
|
||||||
|
|
||||||
this.weights = new HashMap<>();
|
this.weights = new HashMap<>();
|
||||||
|
|
||||||
|
@ -241,7 +241,7 @@ public class KerasBidirectional extends KerasLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private Map<String, INDArray> getUnderlyingWeights(Map<String, INDArray> weights, String direction)
|
private Map<String, INDArray> getUnderlyingWeights(Layer l, Map<String, INDArray> weights, String direction)
|
||||||
throws InvalidKerasConfigurationException {
|
throws InvalidKerasConfigurationException {
|
||||||
int keras1SubstringLength;
|
int keras1SubstringLength;
|
||||||
if (kerasRnnlayer instanceof KerasLSTM)
|
if (kerasRnnlayer instanceof KerasLSTM)
|
||||||
|
@ -270,8 +270,12 @@ public class KerasBidirectional extends KerasLayer {
|
||||||
weights = newWeights;
|
weights = newWeights;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Layer layerBefore = kerasRnnlayer.getLayer();
|
||||||
|
kerasRnnlayer.setLayer(l);
|
||||||
kerasRnnlayer.setWeights(weights);
|
kerasRnnlayer.setWeights(weights);
|
||||||
return kerasRnnlayer.getWeights();
|
Map<String,INDArray> ret = kerasRnnlayer.getWeights();
|
||||||
|
kerasRnnlayer.setLayer(layerBefore);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -20,7 +21,6 @@ import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
||||||
|
@ -36,73 +36,72 @@ import java.util.Arrays;
|
||||||
import static org.nd4j.linalg.util.ArrayUtil.prodLong;
|
import static org.nd4j.linalg.util.ArrayUtil.prodLong;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generic reshape preprocessor
|
* Generic reshape preprocessor.
|
||||||
|
* Note that shapes may be specified with or without the leading minibatch dimension, as long as hasMiniBatchDimension
|
||||||
|
* is set appropriately in {@link #ReshapePreprocessor(long[], long[], boolean)}<br>
|
||||||
|
* For example, to reshape from [minibatch, 32] to [minibatch, 2, 4, 4] you could use:<br>
|
||||||
|
* hasMiniBatchDimension = true with inputShape = [-1, 32] and targetShape = [-1, 2, 4, 4] OR<br>
|
||||||
|
* hasMiniBatchDimension = false with inputShape = [32] and targetShape = [2, 4, 4]
|
||||||
*
|
*
|
||||||
* @author Max Pumperla
|
* @author Max Pumperla
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"})
|
@JsonIgnoreProperties({"miniBatchSize", "staticTargetShape"})
|
||||||
public class ReshapePreprocessor extends BaseInputPreProcessor {
|
public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
|
|
||||||
private long[] inputShape;
|
private final long[] inputShape;
|
||||||
private long[] targetShape;
|
private final long[] targetShape;
|
||||||
private boolean hasMiniBatchDimension = false;
|
private boolean hasMiniBatchDimension;
|
||||||
private int miniBatchSize;
|
|
||||||
private long[] staticTargetShape;
|
|
||||||
|
|
||||||
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) {
|
/**
|
||||||
this.inputShape = inputShape;
|
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
|
||||||
this.targetShape = targetShape;
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
|
||||||
|
this(inputShape, targetShape, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int prod(int[] array) {
|
/**
|
||||||
int prod = 1;
|
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||||
for (int i : array) {
|
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||||
prod *= i;
|
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
|
||||||
|
*/
|
||||||
|
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
|
||||||
|
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
|
||||||
|
this.inputShape = inputShape;
|
||||||
|
this.targetShape = targetShape;
|
||||||
|
this.hasMiniBatchDimension = hasMiniBatchDimension;
|
||||||
|
}
|
||||||
|
|
||||||
|
private long[] getShape(long[] originalShape, long minibatch) {
|
||||||
|
long[] newShape = (hasMiniBatchDimension ? originalShape : prependMiniBatchSize(originalShape, minibatch));
|
||||||
|
if (newShape[0] != minibatch) {
|
||||||
|
newShape = newShape.clone();
|
||||||
|
newShape[0] = minibatch;
|
||||||
}
|
}
|
||||||
return prod;
|
return newShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) {
|
private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) {
|
||||||
int shapeLength = shape.length;
|
int shapeLength = shape.length;
|
||||||
val miniBatchShape = new long[shapeLength + 1];
|
val miniBatchShape = new long[shapeLength + 1];
|
||||||
for (int i = 0; i < miniBatchShape.length; i++) {
|
miniBatchShape[0] = miniBatchSize;
|
||||||
if (i == 0)
|
for (int i = 1; i < miniBatchShape.length; i++) {
|
||||||
miniBatchShape[i] = miniBatchSize;
|
miniBatchShape[i] = shape[i - 1];
|
||||||
else
|
|
||||||
miniBatchShape[i] = shape[i - 1];
|
|
||||||
}
|
}
|
||||||
return miniBatchShape;
|
return miniBatchShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||||
// the target shape read from a keras config does not have mini-batch size
|
// the target shape read from a keras config does not have mini-batch size included. We prepend it here dynamically.
|
||||||
// included. We prepend it here dynamically.
|
long[] targetShape = getShape(this.targetShape, miniBatchSize);
|
||||||
|
long[] inputShape = getShape(this.inputShape, miniBatchSize);
|
||||||
|
|
||||||
long[] targetShape;
|
|
||||||
if (staticTargetShape != null){
|
|
||||||
targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize);
|
|
||||||
hasMiniBatchDimension = true;
|
|
||||||
this.miniBatchSize = miniBatchSize;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
targetShape = this.targetShape;
|
|
||||||
}
|
|
||||||
if (!this.hasMiniBatchDimension) {
|
|
||||||
targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
|
|
||||||
inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
|
|
||||||
this.miniBatchSize = miniBatchSize;
|
|
||||||
}
|
|
||||||
if (this.miniBatchSize != miniBatchSize) {
|
|
||||||
targetShape = prependMiniBatchSize(ArrayUtils.subarray(targetShape, 1, targetShape.length), miniBatchSize);
|
|
||||||
inputShape = prependMiniBatchSize(ArrayUtils.subarray(inputShape, 1, targetShape.length), miniBatchSize);
|
|
||||||
this.miniBatchSize = miniBatchSize;
|
|
||||||
}
|
|
||||||
if (prodLong(input.shape()) == prodLong((targetShape))) {
|
if (prodLong(input.shape()) == prodLong((targetShape))) {
|
||||||
if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){
|
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) {
|
||||||
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
||||||
}
|
}
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
|
||||||
|
@ -114,15 +113,18 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
long[] targetShape = getShape(this.targetShape, miniBatchSize);
|
||||||
|
long[] inputShape = getShape(this.inputShape, miniBatchSize);
|
||||||
|
|
||||||
if (!Arrays.equals(targetShape, output.shape())) {
|
if (!Arrays.equals(targetShape, output.shape())) {
|
||||||
throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape())
|
throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape())
|
||||||
+ " (expected to be " + Arrays.toString(targetShape) + ")");
|
+ " (expected to be " + Arrays.toString(targetShape) + ")");
|
||||||
}
|
}
|
||||||
if (prodLong(output.shape()) == prodLong((targetShape))) {
|
if (prodLong(output.shape()) == prodLong((targetShape))) {
|
||||||
if(output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)){
|
if (output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)) {
|
||||||
output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c');
|
output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c');
|
||||||
}
|
}
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(this.inputShape));
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(inputShape));
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalStateException("Output shape" + Arrays.toString(output.shape())
|
throw new IllegalStateException("Output shape" + Arrays.toString(output.shape())
|
||||||
+ " and input shape" + Arrays.toString(targetShape) + " do not match");
|
+ " and input shape" + Arrays.toString(targetShape) + " do not match");
|
||||||
|
@ -131,7 +133,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
|
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
|
||||||
val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0);
|
long[] shape = getShape(this.targetShape, 0);
|
||||||
InputType ret;
|
InputType ret;
|
||||||
switch (shape.length) {
|
switch (shape.length) {
|
||||||
case 2:
|
case 2:
|
||||||
|
@ -141,18 +143,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
ret = InputType.recurrent(shape[2], shape[1]);
|
ret = InputType.recurrent(shape[2], shape[1]);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){
|
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
|
||||||
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
|
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
|
||||||
}else {
|
} else {
|
||||||
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
|
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"Cannot infer input type for reshape array " + Arrays.toString(shape));
|
"Cannot infer input type for reshape array " + Arrays.toString(shape));
|
||||||
|
|
||||||
}
|
}
|
||||||
this.staticTargetShape = ret.getShape();
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -505,6 +505,17 @@ public class KerasLayerUtils {
|
||||||
return nOut;
|
return nOut;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static Integer getNInFromInputDim(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
|
||||||
|
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||||
|
if(innerConfig.containsKey(conf.getLAYER_FIELD_INPUT_DIM())){
|
||||||
|
Object id = innerConfig.get(conf.getLAYER_FIELD_INPUT_DIM());
|
||||||
|
if(id instanceof Number){
|
||||||
|
return ((Number)id).intValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get dropout from Keras layer configuration.
|
* Get dropout from Keras layer configuration.
|
||||||
*
|
*
|
||||||
|
|
|
@ -257,12 +257,15 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void ReshapeEmbeddingConcatTest() throws Exception{
|
public void ReshapeEmbeddingConcatTest() throws Exception{
|
||||||
|
//TODO AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441
|
||||||
|
|
||||||
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
|
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
|
||||||
ComputationGraphConfiguration config =
|
ComputationGraphConfiguration config =
|
||||||
new KerasModel().modelBuilder().modelJsonInputStream(is)
|
new KerasModel().modelBuilder().modelJsonInputStream(is)
|
||||||
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
|
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
|
||||||
ComputationGraph model = new ComputationGraph(config);
|
ComputationGraph model = new ComputationGraph(config);
|
||||||
model.init();
|
model.init();
|
||||||
|
// System.out.println(model.summary());
|
||||||
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
|
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,8 @@ import org.deeplearning4j.eval.ROCMultiClass;
|
||||||
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
|
@ -47,6 +49,8 @@ import org.nd4j.linalg.activations.impl.*;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.function.BiFunction;
|
||||||
|
import org.nd4j.linalg.function.Function;
|
||||||
import org.nd4j.linalg.learning.config.NoOp;
|
import org.nd4j.linalg.learning.config.NoOp;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
|
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
|
||||||
|
@ -58,10 +62,7 @@ import java.io.InputStream;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.StandardCopyOption;
|
import java.nio.file.StandardCopyOption;
|
||||||
import java.util.HashMap;
|
import java.util.*;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
@ -86,7 +87,16 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
@Rule
|
@Rule
|
||||||
public final TemporaryFolder testDir = new TemporaryFolder();
|
public final TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
@Test(expected = IllegalStateException.class)
|
public static final BiFunction<String,INDArray,INDArray> nwc2ncwExpected = new BiFunction<String, INDArray, INDArray>() {
|
||||||
|
@Override
|
||||||
|
public INDArray apply(String s, INDArray array) {
|
||||||
|
if(array.rank() == 3)
|
||||||
|
return array.permute(0, 2, 1); //NWC to NCW
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
@Test(expected = IllegalStateException.class)
|
||||||
public void fileNotFoundEndToEnd() throws Exception {
|
public void fileNotFoundEndToEnd() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
||||||
importEndModelTest(modelPath, null, true, true, false, false);
|
importEndModelTest(modelPath, null, true, true, false, false);
|
||||||
|
@ -154,28 +164,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
public void importImdbLstmTfKeras1() throws Exception {
|
public void importImdbLstmTfKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importImdbLstmThKeras1() throws Exception {
|
public void importImdbLstmThKeras1() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importImdbLstmTfKeras2() throws Exception {
|
public void importImdbLstmTfKeras2() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void importImdbLstmThKeras2() throws Exception {
|
public void importImdbLstmThKeras2() throws Exception {
|
||||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
|
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -247,7 +257,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
|
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
|
||||||
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
|
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
|
||||||
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
|
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
|
||||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false);
|
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -598,6 +608,122 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
model.summary();
|
model.summary();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCausalCon1D() throws Exception {
|
||||||
|
String[] names = new String[]{
|
||||||
|
"causal_conv1d_k2_s1_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k2_s1_d2_cl_model.h5",
|
||||||
|
"causal_conv1d_k2_s2_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k2_s3_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k3_s1_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k3_s1_d2_cl_model.h5",
|
||||||
|
"causal_conv1d_k3_s2_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k3_s3_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k4_s1_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k4_s1_d2_cl_model.h5",
|
||||||
|
"causal_conv1d_k4_s2_d1_cl_model.h5",
|
||||||
|
"causal_conv1d_k4_s3_d1_cl_model.h5"
|
||||||
|
};
|
||||||
|
|
||||||
|
for(String name : names ){
|
||||||
|
System.out.println("Starting test: " + name);
|
||||||
|
String modelPath = "modelimport/keras/examples/causal_conv1d/" + name;
|
||||||
|
String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
|
||||||
|
Function<INDArray,INDArray> f = new Function<INDArray, INDArray>() {
|
||||||
|
@Override
|
||||||
|
public INDArray apply(INDArray i) {
|
||||||
|
//NWC to NCW
|
||||||
|
return i.permute(0, 2, 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
|
||||||
|
true, true, false, f, nwc2ncwExpected);
|
||||||
|
Layer l = net.getLayer(0);
|
||||||
|
Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig();
|
||||||
|
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCon1D() throws Exception {
|
||||||
|
String[] names = new String[]{
|
||||||
|
"conv1d_k2_s1_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k2_s1_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k2_s1_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k2_s1_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k2_s1_d2_cf_same_model.h5",
|
||||||
|
"conv1d_k2_s1_d2_cf_valid_model.h5",
|
||||||
|
"conv1d_k2_s1_d2_cl_same_model.h5",
|
||||||
|
"conv1d_k2_s1_d2_cl_valid_model.h5",
|
||||||
|
"conv1d_k2_s2_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k2_s2_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k2_s2_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k2_s2_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k2_s3_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k2_s3_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k2_s3_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k2_s3_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k3_s1_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k3_s1_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k3_s1_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k3_s1_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k3_s1_d2_cf_same_model.h5",
|
||||||
|
"conv1d_k3_s1_d2_cf_valid_model.h5",
|
||||||
|
"conv1d_k3_s1_d2_cl_same_model.h5",
|
||||||
|
"conv1d_k3_s1_d2_cl_valid_model.h5",
|
||||||
|
"conv1d_k3_s2_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k3_s2_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k3_s2_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k3_s2_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k3_s3_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k3_s3_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k3_s3_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k3_s3_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k4_s1_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k4_s1_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k4_s1_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k4_s1_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k4_s1_d2_cf_same_model.h5",
|
||||||
|
"conv1d_k4_s1_d2_cf_valid_model.h5",
|
||||||
|
"conv1d_k4_s1_d2_cl_same_model.h5",
|
||||||
|
"conv1d_k4_s1_d2_cl_valid_model.h5",
|
||||||
|
"conv1d_k4_s2_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k4_s2_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k4_s2_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k4_s2_d1_cl_valid_model.h5",
|
||||||
|
"conv1d_k4_s3_d1_cf_same_model.h5",
|
||||||
|
"conv1d_k4_s3_d1_cf_valid_model.h5",
|
||||||
|
"conv1d_k4_s3_d1_cl_same_model.h5",
|
||||||
|
"conv1d_k4_s3_d1_cl_valid_model.h5",
|
||||||
|
};
|
||||||
|
|
||||||
|
for(String name : names ){
|
||||||
|
System.out.println("Starting test: " + name);
|
||||||
|
String modelPath = "modelimport/keras/examples/conv1d/" + name;
|
||||||
|
String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
|
||||||
|
Function<INDArray,INDArray> f = name.contains("_cf_") ? null : new Function<INDArray, INDArray>() {
|
||||||
|
@Override
|
||||||
|
public INDArray apply(INDArray i) {
|
||||||
|
//NWC to NCW
|
||||||
|
return i.permute(0, 2, 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
BiFunction<String,INDArray,INDArray> f2 = name.contains("_cf_") ? null : new BiFunction<String, INDArray, INDArray>() {
|
||||||
|
@Override
|
||||||
|
public INDArray apply(String s, INDArray array) {
|
||||||
|
// if("conv".equals(s)){
|
||||||
|
return array.permute(0, 2, 1);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
importEndModelTest(modelPath, inputsOutputPath, true, true,
|
||||||
|
true, true, false, f, f2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception {
|
private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception {
|
||||||
return importFunctionalModelH5Test(modelPath, null, false);
|
return importFunctionalModelH5Test(modelPath, null, false);
|
||||||
}
|
}
|
||||||
|
@ -640,6 +766,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
|
|
||||||
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
|
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
|
||||||
boolean checkGradients, boolean enforceTrainingConfig) throws Exception {
|
boolean checkGradients, boolean enforceTrainingConfig) throws Exception {
|
||||||
|
return importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, checkGradients, true, enforceTrainingConfig, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
|
||||||
|
boolean checkGradients, boolean enforceTrainingConfig, boolean checkAuc, Function<INDArray,INDArray> inputPreProc,
|
||||||
|
BiFunction<String,INDArray,INDArray> expectedPreProc) throws Exception {
|
||||||
MultiLayerNetwork model;
|
MultiLayerNetwork model;
|
||||||
try(InputStream is = Resources.asStream(modelPath)) {
|
try(InputStream is = Resources.asStream(modelPath)) {
|
||||||
File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
|
File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
|
||||||
|
@ -658,20 +790,25 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
|
|
||||||
if (checkPredictions) {
|
if (checkPredictions) {
|
||||||
INDArray input = getInputs(outputsArchive, tfOrdering)[0];
|
INDArray input = getInputs(outputsArchive, tfOrdering)[0];
|
||||||
|
if(inputPreProc != null)
|
||||||
|
input = inputPreProc.apply(input);
|
||||||
|
|
||||||
Map<String, INDArray> activationsKeras = getActivations(outputsArchive, tfOrdering);
|
Map<String, INDArray> activationsKeras = getActivations(outputsArchive, tfOrdering);
|
||||||
for (int i = 0; i < model.getLayers().length; i++) {
|
for (int i = 0; i < model.getLayers().length; i++) {
|
||||||
String layerName = model.getLayerNames().get(i);
|
String layerName = model.getLayerNames().get(i);
|
||||||
if (activationsKeras.containsKey(layerName)) {
|
if (activationsKeras.containsKey(layerName)) {
|
||||||
INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1);
|
INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1);
|
||||||
if (activationsDl4j.shape().length == 3)
|
INDArray exp = activationsKeras.get(layerName);
|
||||||
activationsDl4j = activationsDl4j.permute(0, 2, 1);
|
if(expectedPreProc != null)
|
||||||
compareINDArrays(layerName, activationsKeras.get(layerName), activationsDl4j, EPS);
|
exp = expectedPreProc.apply(layerName, exp);
|
||||||
|
compareINDArrays(layerName, exp, activationsDl4j, EPS);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0];
|
INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0];
|
||||||
INDArray predictionsDl4j = model.output(input, false);
|
INDArray predictionsDl4j = model.output(input, false);
|
||||||
|
if(expectedPreProc != null)
|
||||||
|
predictionsKeras = expectedPreProc.apply("output", predictionsKeras);
|
||||||
compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS);
|
compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS);
|
||||||
INDArray outputs = getOutputs(outputsArchive, true)[0];
|
INDArray outputs = getOutputs(outputsArchive, true)[0];
|
||||||
|
|
||||||
|
@ -680,7 +817,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
val nOut = (int) outputs.size(-1);
|
val nOut = (int) outputs.size(-1);
|
||||||
|
|
||||||
compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS);
|
if(checkAuc)
|
||||||
|
compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (checkGradients && ! SKIP_GRAD_CHECKS) {
|
if (checkGradients && ! SKIP_GRAD_CHECKS) {
|
||||||
|
@ -760,20 +898,23 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
return predictions;
|
return predictions;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void compareINDArrays(String label, INDArray a, INDArray b, double eps) {
|
private static void compareINDArrays(String label, INDArray expected, INDArray actual, double eps) {
|
||||||
INDArray diff = a.sub(b.castTo(a.dataType()));
|
if(!expected.equalShapes(actual)){
|
||||||
|
throw new IllegalStateException("Shapes do not match for \"" + label + "\": got " + Arrays.toString(expected.shape()) + " vs " + Arrays.toString(actual.shape()));
|
||||||
|
}
|
||||||
|
INDArray diff = expected.sub(actual.castTo(expected.dataType()));
|
||||||
double min = diff.minNumber().doubleValue();
|
double min = diff.minNumber().doubleValue();
|
||||||
double max = diff.maxNumber().doubleValue();
|
double max = diff.maxNumber().doubleValue();
|
||||||
log.info(label + ": " + a.equalsWithEps(b, eps) + ", " + min + ", " + max);
|
log.info(label + ": " + expected.equalsWithEps(actual, eps) + ", " + min + ", " + max);
|
||||||
double threshold = 1e-7;
|
double threshold = 1e-7;
|
||||||
double aAbsMax = Math.max(Math.abs(a.minNumber().doubleValue()), Math.abs(a.maxNumber().doubleValue()));
|
double aAbsMax = Math.max(Math.abs(expected.minNumber().doubleValue()), Math.abs(expected.maxNumber().doubleValue()));
|
||||||
double bAbsMax = Math.max(Math.abs(b.minNumber().doubleValue()), Math.abs(b.maxNumber().doubleValue()));
|
double bAbsMax = Math.max(Math.abs(actual.minNumber().doubleValue()), Math.abs(actual.maxNumber().doubleValue()));
|
||||||
|
|
||||||
// skip too small absolute inputs
|
// skip too small absolute inputs
|
||||||
if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
|
if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
|
||||||
assertTrue(a.equalsWithEps(b.castTo(a.dataType()), eps));
|
boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps);
|
||||||
|
assertTrue("Output differs: " + label, eq);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses,
|
private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
~ Copyright (c) 2019 Konduit K.K.
|
||||||
~
|
~
|
||||||
~ This program and the accompanying materials are made available under the
|
~ This program and the accompanying materials are made available under the
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
~ terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -23,16 +24,11 @@
|
||||||
</parent>
|
</parent>
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<artifactId>deeplearning4j-nearestneighbor-server_2.11</artifactId>
|
<artifactId>deeplearning4j-nearestneighbor-server</artifactId>
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
<name>deeplearning4j-nearestneighbor-server</name>
|
<name>deeplearning4j-nearestneighbor-server</name>
|
||||||
|
|
||||||
<properties>
|
|
||||||
<!-- Default scala versions, may be overwritten by build profiles -->
|
|
||||||
<scala.version>2.11.12</scala.version>
|
|
||||||
<scala.binary.version>2.11</scala.binary.version>
|
|
||||||
</properties>
|
|
||||||
<build>
|
<build>
|
||||||
<pluginManagement>
|
<pluginManagement>
|
||||||
<plugins>
|
<plugins>
|
||||||
|
@ -73,29 +69,17 @@
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.typesafe.play</groupId>
|
<groupId>io.vertx</groupId>
|
||||||
<artifactId>play-java_2.11</artifactId>
|
<artifactId>vertx-core</artifactId>
|
||||||
<version>${playframework.version}</version>
|
<version>${vertx.version}</version>
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>com.google.code.findbugs</groupId>
|
|
||||||
<artifactId>jsr305</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.apache.tomcat</groupId>
|
|
||||||
<artifactId>tomcat-servlet-api</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>net.jodah</groupId>
|
|
||||||
<artifactId>typetools</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>net.jodah</groupId>
|
<groupId>io.vertx</groupId>
|
||||||
<artifactId>typetools</artifactId>
|
<artifactId>vertx-web</artifactId>
|
||||||
<version>${jodah.typetools.version}</version>
|
<version>${vertx.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.mashape.unirest</groupId>
|
<groupId>com.mashape.unirest</groupId>
|
||||||
<artifactId>unirest-java</artifactId>
|
<artifactId>unirest-java</artifactId>
|
||||||
|
@ -108,25 +92,16 @@
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.play</groupId>
|
|
||||||
<artifactId>play-json_2.11</artifactId>
|
|
||||||
<version>${playframework.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.play</groupId>
|
|
||||||
<artifactId>play-server_2.11</artifactId>
|
|
||||||
<version>${playframework.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.beust</groupId>
|
<groupId>com.beust</groupId>
|
||||||
<artifactId>jcommander</artifactId>
|
<artifactId>jcommander</artifactId>
|
||||||
<version>${jcommander.version}</version>
|
<version>${jcommander.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.typesafe.play</groupId>
|
<groupId>ch.qos.logback</groupId>
|
||||||
<artifactId>play-netty-server_2.11</artifactId>
|
<artifactId>logback-classic</artifactId>
|
||||||
<version>${playframework.version}</version>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
@ -144,11 +119,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
<artifactId>nd4j-cuda-10.2</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -19,6 +20,11 @@ package org.deeplearning4j.nearestneighbor.server;
|
||||||
import com.beust.jcommander.JCommander;
|
import com.beust.jcommander.JCommander;
|
||||||
import com.beust.jcommander.Parameter;
|
import com.beust.jcommander.Parameter;
|
||||||
import com.beust.jcommander.ParameterException;
|
import com.beust.jcommander.ParameterException;
|
||||||
|
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||||
|
import io.vertx.core.AbstractVerticle;
|
||||||
|
import io.vertx.core.Vertx;
|
||||||
|
import io.vertx.ext.web.Router;
|
||||||
|
import io.vertx.ext.web.handler.BodyHandler;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
import org.deeplearning4j.clustering.sptree.DataPoint;
|
||||||
|
@ -26,6 +32,7 @@ import org.deeplearning4j.clustering.vptree.VPTree;
|
||||||
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
|
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.nearestneighbor.model.*;
|
import org.deeplearning4j.nearestneighbor.model.*;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
@ -33,19 +40,10 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
import org.nd4j.serde.base64.Nd4jBase64;
|
||||||
import org.nd4j.serde.binary.BinarySerde;
|
import org.nd4j.serde.binary.BinarySerde;
|
||||||
import play.BuiltInComponents;
|
|
||||||
import play.Mode;
|
|
||||||
import play.libs.Json;
|
|
||||||
import play.routing.Router;
|
|
||||||
import play.routing.RoutingDsl;
|
|
||||||
import play.server.Server;
|
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static play.mvc.Controller.request;
|
|
||||||
import static play.mvc.Results.*;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A rest server for using an
|
* A rest server for using an
|
||||||
* {@link VPTree} based on loading an ndarray containing
|
* {@link VPTree} based on loading an ndarray containing
|
||||||
|
@ -57,22 +55,33 @@ import static play.mvc.Results.*;
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class NearestNeighborsServer {
|
public class NearestNeighborsServer extends AbstractVerticle {
|
||||||
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
|
|
||||||
private String ndarrayPath = null;
|
|
||||||
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
|
|
||||||
private String labelsPath = null;
|
|
||||||
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
|
|
||||||
private int port = 9000;
|
|
||||||
@Parameter(names = {"--similarityFunction"}, arity = 1)
|
|
||||||
private String similarityFunction = "euclidean";
|
|
||||||
@Parameter(names = {"--invert"}, arity = 1)
|
|
||||||
private boolean invert = false;
|
|
||||||
|
|
||||||
private Server server;
|
private static class RunArgs {
|
||||||
|
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
|
||||||
|
private String ndarrayPath = null;
|
||||||
|
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
|
||||||
|
private String labelsPath = null;
|
||||||
|
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
|
||||||
|
private int port = 9000;
|
||||||
|
@Parameter(names = {"--similarityFunction"}, arity = 1)
|
||||||
|
private String similarityFunction = "euclidean";
|
||||||
|
@Parameter(names = {"--invert"}, arity = 1)
|
||||||
|
private boolean invert = false;
|
||||||
|
}
|
||||||
|
|
||||||
public void runMain(String... args) throws Exception {
|
private static RunArgs instanceArgs;
|
||||||
JCommander jcmdr = new JCommander(this);
|
private static NearestNeighborsServer instance;
|
||||||
|
|
||||||
|
public NearestNeighborsServer(){ }
|
||||||
|
|
||||||
|
public static NearestNeighborsServer getInstance(){
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void runMain(String... args) {
|
||||||
|
RunArgs r = new RunArgs();
|
||||||
|
JCommander jcmdr = new JCommander(r);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
jcmdr.parse(args);
|
jcmdr.parse(args);
|
||||||
|
@ -84,7 +93,7 @@ public class NearestNeighborsServer {
|
||||||
|
|
||||||
//User provides invalid input -> print the usage info
|
//User provides invalid input -> print the usage info
|
||||||
jcmdr.usage();
|
jcmdr.usage();
|
||||||
if (ndarrayPath == null)
|
if (r.ndarrayPath == null)
|
||||||
log.error("Json path parameter is missing (null)");
|
log.error("Json path parameter is missing (null)");
|
||||||
try {
|
try {
|
||||||
Thread.sleep(500);
|
Thread.sleep(500);
|
||||||
|
@ -93,16 +102,20 @@ public class NearestNeighborsServer {
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
instanceArgs = r;
|
||||||
try {
|
try {
|
||||||
runHelper();
|
Vertx vertx = Vertx.vertx();
|
||||||
|
vertx.deployVerticle(NearestNeighborsServer.class.getName());
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
log.error("Error in NearestNeighboursServer run method",t);
|
log.error("Error in NearestNeighboursServer run method",t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void runHelper() throws Exception {
|
@Override
|
||||||
|
public void start() throws Exception {
|
||||||
|
instance = this;
|
||||||
|
|
||||||
String[] pathArr = ndarrayPath.split(",");
|
String[] pathArr = instanceArgs.ndarrayPath.split(",");
|
||||||
//INDArray[] pointsArr = new INDArray[pathArr.length];
|
//INDArray[] pointsArr = new INDArray[pathArr.length];
|
||||||
// first of all we reading shapes of saved eariler files
|
// first of all we reading shapes of saved eariler files
|
||||||
int rows = 0;
|
int rows = 0;
|
||||||
|
@ -111,7 +124,7 @@ public class NearestNeighborsServer {
|
||||||
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
|
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
|
||||||
|
|
||||||
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
|
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
|
||||||
Shape.size(shape, 1));
|
Shape.size(shape, 1));
|
||||||
|
|
||||||
if (Shape.rank(shape) != 2)
|
if (Shape.rank(shape) != 2)
|
||||||
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
|
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
|
||||||
|
@ -122,12 +135,12 @@ public class NearestNeighborsServer {
|
||||||
cols = Shape.size(shape, 1);
|
cols = Shape.size(shape, 1);
|
||||||
else if (cols != Shape.size(shape, 1))
|
else if (cols != Shape.size(shape, 1))
|
||||||
throw new DL4JInvalidInputException(
|
throw new DL4JInvalidInputException(
|
||||||
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
|
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<String> labels = new ArrayList<>();
|
final List<String> labels = new ArrayList<>();
|
||||||
if (labelsPath != null) {
|
if (instanceArgs.labelsPath != null) {
|
||||||
String[] labelsPathArr = labelsPath.split(",");
|
String[] labelsPathArr = instanceArgs.labelsPath.split(",");
|
||||||
for (int i = 0; i < labelsPathArr.length; i++) {
|
for (int i = 0; i < labelsPathArr.length; i++) {
|
||||||
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
|
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
|
||||||
}
|
}
|
||||||
|
@ -149,7 +162,7 @@ public class NearestNeighborsServer {
|
||||||
System.gc();
|
System.gc();
|
||||||
}
|
}
|
||||||
|
|
||||||
VPTree tree = new VPTree(points, similarityFunction, invert);
|
VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
|
||||||
|
|
||||||
//Set play secret key, if required
|
//Set play secret key, if required
|
||||||
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
||||||
|
@ -163,40 +176,57 @@ public class NearestNeighborsServer {
|
||||||
System.setProperty("play.crypto.secret", base64);
|
System.setProperty("play.crypto.secret", base64);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Router r = Router.router(vertx);
|
||||||
|
r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
|
||||||
|
createRoutes(r, labels, tree, points);
|
||||||
|
|
||||||
server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b));
|
vertx.createHttpServer()
|
||||||
|
.requestHandler(r)
|
||||||
|
.listen(instanceArgs.port);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Router createRouter(VPTree tree, List<String> labels, INDArray points, BuiltInComponents builtInComponents){
|
private void createRoutes(Router r, List<String> labels, VPTree tree, INDArray points){
|
||||||
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
|
|
||||||
//return the host information for a given id
|
r.post("/knn").handler(rc -> {
|
||||||
routingDsl.POST("/knn").routingTo(request -> {
|
|
||||||
try {
|
try {
|
||||||
NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class);
|
String json = rc.getBodyAsJson().encode();
|
||||||
|
NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
|
||||||
|
|
||||||
NearestNeighbor nearestNeighbor =
|
NearestNeighbor nearestNeighbor =
|
||||||
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
|
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
|
||||||
|
|
||||||
if (record == null)
|
if (record == null) {
|
||||||
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
|
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||||
|
.putHeader("content-type", "application/json")
|
||||||
|
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
NearestNeighborsResults results =
|
NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
|
||||||
NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
|
|
||||||
|
|
||||||
|
|
||||||
return ok(Json.toJson(results));
|
|
||||||
|
|
||||||
|
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||||
|
.putHeader("content-type", "application/json")
|
||||||
|
.end(JsonMappers.getMapper().writeValueAsString(results));
|
||||||
|
return;
|
||||||
} catch (Throwable e) {
|
} catch (Throwable e) {
|
||||||
log.error("Error in POST /knn",e);
|
log.error("Error in POST /knn",e);
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError(e.getMessage());
|
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||||
|
.end("Error parsing request - " + e.getMessage());
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
routingDsl.POST("/knnnew").routingTo(request -> {
|
r.post("/knnnew").handler(rc -> {
|
||||||
try {
|
try {
|
||||||
Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class);
|
String json = rc.getBodyAsJson().encode();
|
||||||
if (record == null)
|
Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
|
||||||
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
|
if (record == null) {
|
||||||
|
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||||
|
.putHeader("content-type", "application/json")
|
||||||
|
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
|
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
|
||||||
List<DataPoint> results;
|
List<DataPoint> results;
|
||||||
|
@ -214,9 +244,10 @@ public class NearestNeighborsServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (results.size() != distances.size()) {
|
if (results.size() != distances.size()) {
|
||||||
return internalServerError(
|
rc.response()
|
||||||
String.format("results.size == %d != %d == distances.size",
|
.setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||||
results.size(), distances.size()));
|
.end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
List<NearestNeighborsResult> nnResult = new ArrayList<>();
|
List<NearestNeighborsResult> nnResult = new ArrayList<>();
|
||||||
|
@ -228,30 +259,29 @@ public class NearestNeighborsServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
|
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
|
||||||
return ok(Json.toJson(results2));
|
String j = JsonMappers.getMapper().writeValueAsString(results2);
|
||||||
|
rc.response()
|
||||||
|
.putHeader("content-type", "application/json")
|
||||||
|
.end(j);
|
||||||
} catch (Throwable e) {
|
} catch (Throwable e) {
|
||||||
log.error("Error in POST /knnnew",e);
|
log.error("Error in POST /knnnew",e);
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError(e.getMessage());
|
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||||
|
.end("Error parsing request - " + e.getMessage());
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return routingDsl.build();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stop the server
|
* Stop the server
|
||||||
*/
|
*/
|
||||||
public void stop() {
|
public void stop() throws Exception {
|
||||||
if (server != null) {
|
super.stop();
|
||||||
log.info("Attempting to stop server");
|
|
||||||
server.stop();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
new NearestNeighborsServer().runMain(args);
|
runMain(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -50,7 +51,6 @@ public class NearestNeighborTest extends BaseDL4JTest {
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
//@Ignore("AB 2019/05/21 - Failing - Issue #7657")
|
|
||||||
public void testNearestNeighbor() {
|
public void testNearestNeighbor() {
|
||||||
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
|
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
|
||||||
INDArray arr = Nd4j.create(data);
|
INDArray arr = Nd4j.create(data);
|
||||||
|
@ -119,14 +119,15 @@ public class NearestNeighborTest extends BaseDL4JTest {
|
||||||
File writeToTmp = testDir.newFile();
|
File writeToTmp = testDir.newFile();
|
||||||
writeToTmp.deleteOnExit();
|
writeToTmp.deleteOnExit();
|
||||||
BinarySerde.writeArrayToDisk(rand, writeToTmp);
|
BinarySerde.writeArrayToDisk(rand, writeToTmp);
|
||||||
NearestNeighborsServer server = new NearestNeighborsServer();
|
NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
|
||||||
server.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
|
|
||||||
String.valueOf(localPort));
|
String.valueOf(localPort));
|
||||||
|
|
||||||
|
Thread.sleep(3000);
|
||||||
|
|
||||||
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
|
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
|
||||||
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
|
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
|
||||||
assertEquals(5, result.getResults().size());
|
assertEquals(5, result.getResults().size());
|
||||||
server.stop();
|
NearestNeighborsServer.getInstance().stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
~ Copyright (c) 2019 Konduit K.K.
|
||||||
|
~
|
||||||
|
~ This program and the accompanying materials are made available under the
|
||||||
|
~ terms of the Apache License, Version 2.0 which is available at
|
||||||
|
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
~
|
||||||
|
~ Unless required by applicable law or agreed to in writing, software
|
||||||
|
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
~ License for the specific language governing permissions and limitations
|
||||||
|
~ under the License.
|
||||||
|
~
|
||||||
|
~ SPDX-License-Identifier: Apache-2.0
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||||
|
|
||||||
|
<configuration>
|
||||||
|
|
||||||
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
|
<file>logs/application.log</file>
|
||||||
|
<encoder>
|
||||||
|
<pattern>%date - [%level] - from %logger in %thread
|
||||||
|
%n%message%n%xException%n</pattern>
|
||||||
|
</encoder>
|
||||||
|
</appender>
|
||||||
|
|
||||||
|
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||||
|
<encoder>
|
||||||
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
|
</pattern>
|
||||||
|
</encoder>
|
||||||
|
</appender>
|
||||||
|
|
||||||
|
<logger name="org.deeplearning4j" level="INFO" />
|
||||||
|
<logger name="org.datavec" level="INFO" />
|
||||||
|
<logger name="org.nd4j" level="INFO" />
|
||||||
|
|
||||||
|
<root level="ERROR">
|
||||||
|
<appender-ref ref="STDOUT" />
|
||||||
|
<appender-ref ref="FILE" />
|
||||||
|
</root>
|
||||||
|
</configuration>
|
|
@ -54,7 +54,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -53,7 +53,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -83,11 +83,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
<artifactId>nd4j-cuda-10.2</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
~ Copyright (c) 2019 Konduit K.K.
|
||||||
~
|
~
|
||||||
~ This program and the accompanying materials are made available under the
|
~ This program and the accompanying materials are made available under the
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
~ terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -43,7 +44,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -66,7 +66,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -68,7 +68,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -84,7 +84,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.1</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
import org.nd4j.linalg.primitives.Triple;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -85,10 +86,20 @@ import java.util.Map;
|
||||||
* <pre>
|
* <pre>
|
||||||
* {@code
|
* {@code
|
||||||
* BertIterator b;
|
* BertIterator b;
|
||||||
|
* Pair<INDArray[],INDArray[]> featuresAndMask;
|
||||||
|
* INDArray[] features;
|
||||||
|
* INDArray[] featureMasks;
|
||||||
|
*
|
||||||
|
* //With sentences
|
||||||
* List<String> forInference;
|
* List<String> forInference;
|
||||||
* Pair<INDArray[],INDArray[]> featuresAndMask = b.featurizeSentences(forInference);
|
* featuresAndMask = b.featurizeSentences(forInference);
|
||||||
* INDArray[] features = featuresAndMask.getFirst();
|
*
|
||||||
* INDArray[] featureMasks = featuresAndMask.getSecond();
|
* //OR with sentence pairs
|
||||||
|
* List<Pair<String, String>> forInferencePair};
|
||||||
|
* featuresAndMask = b.featurizeSentencePairs(forInference);
|
||||||
|
*
|
||||||
|
* features = featuresAndMask.getFirst();
|
||||||
|
* featureMasks = featuresAndMask.getSecond();
|
||||||
* }
|
* }
|
||||||
* </pre>
|
* </pre>
|
||||||
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br>
|
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br>
|
||||||
|
@ -135,6 +146,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
@Setter
|
@Setter
|
||||||
protected MultiDataSetPreProcessor preProcessor;
|
protected MultiDataSetPreProcessor preProcessor;
|
||||||
protected LabeledSentenceProvider sentenceProvider = null;
|
protected LabeledSentenceProvider sentenceProvider = null;
|
||||||
|
protected LabeledPairSentenceProvider sentencePairProvider = null;
|
||||||
protected LengthHandling lengthHandling;
|
protected LengthHandling lengthHandling;
|
||||||
protected FeatureArrays featureArrays;
|
protected FeatureArrays featureArrays;
|
||||||
protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
|
protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
|
||||||
|
@ -142,6 +154,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
|
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
|
||||||
protected String maskToken;
|
protected String maskToken;
|
||||||
protected String prependToken;
|
protected String prependToken;
|
||||||
|
protected String appendToken;
|
||||||
|
|
||||||
|
|
||||||
protected List<String> vocabKeysAsList;
|
protected List<String> vocabKeysAsList;
|
||||||
|
@ -154,6 +167,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
this.padMinibatches = b.padMinibatches;
|
this.padMinibatches = b.padMinibatches;
|
||||||
this.preProcessor = b.preProcessor;
|
this.preProcessor = b.preProcessor;
|
||||||
this.sentenceProvider = b.sentenceProvider;
|
this.sentenceProvider = b.sentenceProvider;
|
||||||
|
this.sentencePairProvider = b.sentencePairProvider;
|
||||||
this.lengthHandling = b.lengthHandling;
|
this.lengthHandling = b.lengthHandling;
|
||||||
this.featureArrays = b.featureArrays;
|
this.featureArrays = b.featureArrays;
|
||||||
this.vocabMap = b.vocabMap;
|
this.vocabMap = b.vocabMap;
|
||||||
|
@ -161,11 +175,14 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
|
this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
|
||||||
this.maskToken = b.maskToken;
|
this.maskToken = b.maskToken;
|
||||||
this.prependToken = b.prependToken;
|
this.prependToken = b.prependToken;
|
||||||
|
this.appendToken = b.appendToken;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean hasNext() {
|
public boolean hasNext() {
|
||||||
return sentenceProvider.hasNext();
|
if (sentenceProvider != null)
|
||||||
|
return sentenceProvider.hasNext();
|
||||||
|
return sentencePairProvider.hasNext();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -181,29 +198,38 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
@Override
|
@Override
|
||||||
public MultiDataSet next(int num) {
|
public MultiDataSet next(int num) {
|
||||||
Preconditions.checkState(hasNext(), "No next element available");
|
Preconditions.checkState(hasNext(), "No next element available");
|
||||||
|
List<Pair<List<String>, String>> tokensAndLabelList;
|
||||||
List<Pair<String, String>> list = new ArrayList<>(num);
|
|
||||||
int mbSize = 0;
|
int mbSize = 0;
|
||||||
|
int outLength;
|
||||||
|
long[] segIdOnesFrom = null;
|
||||||
if (sentenceProvider != null) {
|
if (sentenceProvider != null) {
|
||||||
|
List<Pair<String, String>> list = new ArrayList<>(num);
|
||||||
while (sentenceProvider.hasNext() && mbSize++ < num) {
|
while (sentenceProvider.hasNext() && mbSize++ < num) {
|
||||||
list.add(sentenceProvider.nextSentence());
|
list.add(sentenceProvider.nextSentence());
|
||||||
}
|
}
|
||||||
|
SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(list);
|
||||||
|
tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
|
||||||
|
outLength = sentenceListProcessed.getMaxL();
|
||||||
|
} else if (sentencePairProvider != null) {
|
||||||
|
List<Triple<String, String, String>> listPairs = new ArrayList<>(num);
|
||||||
|
while (sentencePairProvider.hasNext() && mbSize++ < num) {
|
||||||
|
listPairs.add(sentencePairProvider.nextSentencePair());
|
||||||
|
}
|
||||||
|
SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(listPairs);
|
||||||
|
tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
|
||||||
|
outLength = sentencePairListProcessed.getMaxL();
|
||||||
|
segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
|
||||||
} else {
|
} else {
|
||||||
//TODO - other types of iterators...
|
//TODO - other types of iterators...
|
||||||
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
|
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
|
||||||
Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list);
|
|
||||||
List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
|
|
||||||
int outLength = outLTokenizedSentencesPair.getLeft();
|
|
||||||
|
|
||||||
Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength);
|
|
||||||
INDArray[] featureArray = featuresAndMaskArraysPair.getFirst();
|
INDArray[] featureArray = featuresAndMaskArraysPair.getFirst();
|
||||||
INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond();
|
INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond();
|
||||||
|
|
||||||
|
|
||||||
Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength);
|
Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokensAndLabelList, featureArray, outLength);
|
||||||
INDArray[] labelArray = labelsAndMaskArraysPair.getFirst();
|
INDArray[] labelArray = labelsAndMaskArraysPair.getFirst();
|
||||||
INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond();
|
INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond();
|
||||||
|
|
||||||
|
@ -224,32 +250,59 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) {
|
public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) {
|
||||||
|
|
||||||
List<Pair<String, String>> sentencesWithNullLabel = addDummyLabel(listOnlySentences);
|
List<Pair<String, String>> sentencesWithNullLabel = addDummyLabel(listOnlySentences);
|
||||||
|
SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(sentencesWithNullLabel);
|
||||||
|
List<Pair<List<String>, String>> tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
|
||||||
|
int outLength = sentenceListProcessed.getMaxL();
|
||||||
|
|
||||||
Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel);
|
|
||||||
List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
|
|
||||||
int outLength = outLTokenizedSentencesPair.getLeft();
|
|
||||||
|
|
||||||
Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength);
|
|
||||||
if (preProcessor != null) {
|
if (preProcessor != null) {
|
||||||
|
Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
|
||||||
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null);
|
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null);
|
||||||
preProcessor.preProcess(dummyMDS);
|
preProcessor.preProcess(dummyMDS);
|
||||||
return new Pair<INDArray[],INDArray[]>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
|
return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
|
||||||
}
|
}
|
||||||
return convertMiniBatchFeatures(tokenizedSentences, outLength);
|
return convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokenizedSentences, int outLength) {
|
/**
|
||||||
int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
|
* For use during inference. Will convert a given pair of a list of sentences to features and feature masks as appropriate.
|
||||||
|
*
|
||||||
|
* @param listOnlySentencePairs
|
||||||
|
* @return Pair of INDArrays[], first element is feature arrays and the second is the masks array
|
||||||
|
*/
|
||||||
|
public Pair<INDArray[], INDArray[]> featurizeSentencePairs(List<Pair<String, String>> listOnlySentencePairs) {
|
||||||
|
Preconditions.checkState(sentencePairProvider != null, "The featurizeSentencePairs method is meant for inference with sentence pairs. Use only when the sentence pair provider is set (i.e not null).");
|
||||||
|
|
||||||
|
List<Triple<String, String, String>> sentencePairsWithNullLabel = addDummyLabelForPairs(listOnlySentencePairs);
|
||||||
|
SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(sentencePairsWithNullLabel);
|
||||||
|
List<Pair<List<String>, String>> tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
|
||||||
|
int outLength = sentencePairListProcessed.getMaxL();
|
||||||
|
long[] segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
|
||||||
|
if (preProcessor != null) {
|
||||||
|
Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
|
||||||
|
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featuresAndMaskArraysPair.getFirst(), null, featuresAndMaskArraysPair.getSecond(), null);
|
||||||
|
preProcessor.preProcess(dummyMDS);
|
||||||
|
return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
|
||||||
|
}
|
||||||
|
return convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokensAndLabelList, int outLength, long[] segIdOnesFrom) {
|
||||||
|
int mbPadded = padMinibatches ? minibatchSize : tokensAndLabelList.size();
|
||||||
int[][] outIdxs = new int[mbPadded][outLength];
|
int[][] outIdxs = new int[mbPadded][outLength];
|
||||||
int[][] outMask = new int[mbPadded][outLength];
|
int[][] outMask = new int[mbPadded][outLength];
|
||||||
for (int i = 0; i < tokenizedSentences.size(); i++) {
|
int[][] outSegmentId = null;
|
||||||
Pair<List<String>, String> p = tokenizedSentences.get(i);
|
if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
|
outSegmentId = new int[mbPadded][outLength];
|
||||||
|
for (int i = 0; i < tokensAndLabelList.size(); i++) {
|
||||||
|
Pair<List<String>, String> p = tokensAndLabelList.get(i);
|
||||||
List<String> t = p.getFirst();
|
List<String> t = p.getFirst();
|
||||||
for (int j = 0; j < outLength && j < t.size(); j++) {
|
for (int j = 0; j < outLength && j < t.size(); j++) {
|
||||||
Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j));
|
Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j));
|
||||||
int idx = vocabMap.get(t.get(j));
|
int idx = vocabMap.get(t.get(j));
|
||||||
outIdxs[i][j] = idx;
|
outIdxs[i][j] = idx;
|
||||||
outMask[i][j] = 1;
|
outMask[i][j] = 1;
|
||||||
|
if (segIdOnesFrom != null && j >= segIdOnesFrom[i])
|
||||||
|
outSegmentId[i][j] = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,8 +313,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
INDArray[] f;
|
INDArray[] f;
|
||||||
INDArray[] fm;
|
INDArray[] fm;
|
||||||
if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
|
if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
|
||||||
//For now: always segment index 0 (only single s sequence input supported)
|
outSegmentIdArr = Nd4j.createFromArray(outSegmentId);
|
||||||
outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength);
|
|
||||||
f = new INDArray[]{outIdxsArr, outSegmentIdArr};
|
f = new INDArray[]{outIdxsArr, outSegmentIdArr};
|
||||||
fm = new INDArray[]{outMaskArr, null};
|
fm = new INDArray[]{outMaskArr, null};
|
||||||
} else {
|
} else {
|
||||||
|
@ -271,16 +323,15 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
return new Pair<>(f, fm);
|
return new Pair<>(f, fm);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Pair<Integer, List<Pair<List<String>, String>>> tokenizeMiniBatch(List<Pair<String, String>> list) {
|
private SentenceListProcessed tokenizeMiniBatch(List<Pair<String, String>> list) {
|
||||||
//Get and tokenize the sentences for this minibatch
|
//Get and tokenize the sentences for this minibatch
|
||||||
List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(list.size());
|
SentenceListProcessed sentenceListProcessed = new SentenceListProcessed(list.size());
|
||||||
int longestSeq = -1;
|
int longestSeq = -1;
|
||||||
for (Pair<String, String> p : list) {
|
for (Pair<String, String> p : list) {
|
||||||
List<String> tokens = tokenizeSentence(p.getFirst());
|
List<String> tokens = tokenizeSentence(p.getFirst());
|
||||||
tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
|
sentenceListProcessed.addProcessedToList(new Pair<>(tokens, p.getSecond()));
|
||||||
longestSeq = Math.max(longestSeq, tokens.size());
|
longestSeq = Math.max(longestSeq, tokens.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Determine output array length...
|
//Determine output array length...
|
||||||
int outLength;
|
int outLength;
|
||||||
switch (lengthHandling) {
|
switch (lengthHandling) {
|
||||||
|
@ -296,7 +347,52 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
|
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
|
||||||
}
|
}
|
||||||
return new Pair<>(outLength, tokenizedSentences);
|
sentenceListProcessed.setMaxL(outLength);
|
||||||
|
return sentenceListProcessed;
|
||||||
|
}
|
||||||
|
|
||||||
|
private SentencePairListProcessed tokenizePairsMiniBatch(List<Triple<String, String, String>> listPairs) {
|
||||||
|
SentencePairListProcessed sentencePairListProcessed = new SentencePairListProcessed(listPairs.size());
|
||||||
|
for (Triple<String, String, String> t : listPairs) {
|
||||||
|
List<String> tokensL = tokenizeSentence(t.getFirst(), true);
|
||||||
|
List<String> tokensR = tokenizeSentence(t.getSecond(), true);
|
||||||
|
List<String> tokens = new ArrayList<>(maxTokens);
|
||||||
|
int maxLength = maxTokens;
|
||||||
|
if (prependToken != null)
|
||||||
|
maxLength--;
|
||||||
|
if (appendToken != null)
|
||||||
|
maxLength -= 2;
|
||||||
|
if (tokensL.size() + tokensR.size() > maxLength) {
|
||||||
|
boolean shortOnL = tokensL.size() < tokensR.size();
|
||||||
|
int shortSize = Math.min(tokensL.size(), tokensR.size());
|
||||||
|
if (shortSize > maxLength / 2) {
|
||||||
|
//both lists need to be sliced
|
||||||
|
tokensL.subList(maxLength / 2, tokensL.size()).clear(); //if maxsize/2 is odd pop extra on L side to match implementation in TF
|
||||||
|
tokensR.subList(maxLength - maxLength / 2, tokensR.size()).clear();
|
||||||
|
} else {
|
||||||
|
//slice longer list
|
||||||
|
if (shortOnL) {
|
||||||
|
//longer on R - slice R
|
||||||
|
tokensR.subList(maxLength - tokensL.size(), tokensR.size()).clear();
|
||||||
|
} else {
|
||||||
|
//longer on L - slice L
|
||||||
|
tokensL.subList(maxLength - tokensR.size(), tokensL.size()).clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (prependToken != null)
|
||||||
|
tokens.add(prependToken);
|
||||||
|
tokens.addAll(tokensL);
|
||||||
|
if (appendToken != null)
|
||||||
|
tokens.add(appendToken);
|
||||||
|
int segIdOnesFrom = tokens.size();
|
||||||
|
tokens.addAll(tokensR);
|
||||||
|
if (appendToken != null)
|
||||||
|
tokens.add(appendToken);
|
||||||
|
sentencePairListProcessed.addProcessedToList(segIdOnesFrom, new Pair<>(tokens, t.getThird()));
|
||||||
|
}
|
||||||
|
sentencePairListProcessed.setMaxL(maxTokens);
|
||||||
|
return sentencePairListProcessed;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
|
private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
|
||||||
|
@ -316,6 +412,14 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
classLabels[i] = labels.indexOf(lbl);
|
classLabels[i] = labels.indexOf(lbl);
|
||||||
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
|
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
|
||||||
}
|
}
|
||||||
|
} else if (sentencePairProvider != null) {
|
||||||
|
numClasses = sentencePairProvider.numLabelClasses();
|
||||||
|
List<String> labels = sentencePairProvider.allLabels();
|
||||||
|
for (int i = 0; i < mbSize; i++) {
|
||||||
|
String lbl = tokenizedSentences.get(i).getRight();
|
||||||
|
classLabels[i] = labels.indexOf(lbl);
|
||||||
|
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
|
@ -392,16 +496,22 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<String> tokenizeSentence(String sentence) {
|
private List<String> tokenizeSentence(String sentence) {
|
||||||
|
return tokenizeSentence(sentence, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> tokenizeSentence(String sentence, boolean ignorePrependAppend) {
|
||||||
Tokenizer t = tokenizerFactory.create(sentence);
|
Tokenizer t = tokenizerFactory.create(sentence);
|
||||||
|
|
||||||
List<String> tokens = new ArrayList<>();
|
List<String> tokens = new ArrayList<>();
|
||||||
if (prependToken != null)
|
if (prependToken != null && !ignorePrependAppend)
|
||||||
tokens.add(prependToken);
|
tokens.add(prependToken);
|
||||||
|
|
||||||
while (t.hasMoreTokens()) {
|
while (t.hasMoreTokens()) {
|
||||||
String token = t.nextToken();
|
String token = t.nextToken();
|
||||||
tokens.add(token);
|
tokens.add(token);
|
||||||
}
|
}
|
||||||
|
if (appendToken != null && !ignorePrependAppend)
|
||||||
|
tokens.add(appendToken);
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -414,6 +524,13 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<Triple<String, String, String>> addDummyLabelForPairs(List<Pair<String, String>> listOnlySentencePairs) {
|
||||||
|
List<Triple<String, String, String>> list = new ArrayList<>(listOnlySentencePairs.size());
|
||||||
|
for (Pair<String, String> p : listOnlySentencePairs) {
|
||||||
|
list.add(new Triple<String, String, String>(p.getFirst(), p.getSecond(), null));
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean resetSupported() {
|
public boolean resetSupported() {
|
||||||
|
@ -446,12 +563,14 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
protected boolean padMinibatches = false;
|
protected boolean padMinibatches = false;
|
||||||
protected MultiDataSetPreProcessor preProcessor;
|
protected MultiDataSetPreProcessor preProcessor;
|
||||||
protected LabeledSentenceProvider sentenceProvider = null;
|
protected LabeledSentenceProvider sentenceProvider = null;
|
||||||
|
protected LabeledPairSentenceProvider sentencePairProvider = null;
|
||||||
protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
|
protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
|
||||||
protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
|
protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
|
||||||
protected BertSequenceMasker masker = new BertMaskedLMMasker();
|
protected BertSequenceMasker masker = new BertMaskedLMMasker();
|
||||||
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
|
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
|
||||||
protected String maskToken;
|
protected String maskToken;
|
||||||
protected String prependToken;
|
protected String prependToken;
|
||||||
|
protected String appendToken;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
|
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
|
||||||
|
@ -519,14 +638,21 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised
|
* Specify the source of the data for classification.
|
||||||
* use case, the labels will be ignored.
|
|
||||||
*/
|
*/
|
||||||
public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
|
public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
|
||||||
this.sentenceProvider = sentenceProvider;
|
this.sentenceProvider = sentenceProvider;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specify the source of the data for classification on sentence pairs.
|
||||||
|
*/
|
||||||
|
public Builder sentencePairProvider(LabeledPairSentenceProvider sentencePairProvider) {
|
||||||
|
this.sentencePairProvider = sentencePairProvider;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specify what arrays should be returned. See {@link BertIterator} for more details.
|
* Specify what arrays should be returned. See {@link BertIterator} for more details.
|
||||||
*/
|
*/
|
||||||
|
@ -591,6 +717,19 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Append the specified token to the sequences, when doing training on sentence pairs.<br>
|
||||||
|
* Generally "[SEP]" is used
|
||||||
|
* No token in appended by default.
|
||||||
|
*
|
||||||
|
* @param appendToken Token at end of each sentence for pairs of sentences (null: no token will be appended)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Builder appendToken(String appendToken) {
|
||||||
|
this.appendToken = appendToken;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public BertIterator build() {
|
public BertIterator build() {
|
||||||
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
|
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
|
||||||
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
|
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
|
||||||
|
@ -598,9 +737,69 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
|
Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
|
||||||
Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
|
Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
|
||||||
Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
|
Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
|
||||||
|
if (sentencePairProvider != null) {
|
||||||
|
Preconditions.checkState(task == Task.SEQ_CLASSIFICATION, "Currently only supervised sequence classification is set up with sentence pairs. \".task(BertIterator.Task.SEQ_CLASSIFICATION)\" is required with a sentence pair provider");
|
||||||
|
Preconditions.checkState(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID, "Currently only supervised sequence classification is set up with sentence pairs. \".featureArrays(FeatureArrays.INDICES_MASK_SEGMENTID)\" is required with a sentence pair provider");
|
||||||
|
Preconditions.checkState(lengthHandling == LengthHandling.FIXED_LENGTH, "Currently only fixed length is supported for sentence pairs. \".lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxLength)\" is required with a sentence pair provider");
|
||||||
|
Preconditions.checkState(sentencePairProvider != null, "Provide either a sentence provider or a sentence pair provider. Both cannot be non null");
|
||||||
|
}
|
||||||
|
if (appendToken != null) {
|
||||||
|
Preconditions.checkState(sentencePairProvider != null, "Tokens are only appended with sentence pairs. Sentence pair provider is not set. Set sentence pair provider.");
|
||||||
|
}
|
||||||
return new BertIterator(this);
|
return new BertIterator(this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class SentencePairListProcessed {
|
||||||
|
private int listLength = 0;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private long[] segIdOnesFrom;
|
||||||
|
private int cursor = 0;
|
||||||
|
private SentenceListProcessed sentenceListProcessed;
|
||||||
|
|
||||||
|
private SentencePairListProcessed(int listLength) {
|
||||||
|
this.listLength = listLength;
|
||||||
|
segIdOnesFrom = new long[listLength];
|
||||||
|
sentenceListProcessed = new SentenceListProcessed(listLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addProcessedToList(long segIdIdx, Pair<List<String>, String> tokenizedSentencePairAndLabel) {
|
||||||
|
segIdOnesFrom[cursor] = segIdIdx;
|
||||||
|
sentenceListProcessed.addProcessedToList(tokenizedSentencePairAndLabel);
|
||||||
|
cursor++;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setMaxL(int maxL) {
|
||||||
|
sentenceListProcessed.setMaxL(maxL);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getMaxL() {
|
||||||
|
return sentenceListProcessed.getMaxL();
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<Pair<List<String>, String>> getTokensAndLabelList() {
|
||||||
|
return sentenceListProcessed.getTokensAndLabelList();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class SentenceListProcessed {
|
||||||
|
private int listLength;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
private int maxL;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private List<Pair<List<String>, String>> tokensAndLabelList;
|
||||||
|
|
||||||
|
private SentenceListProcessed(int listLength) {
|
||||||
|
this.listLength = listLength;
|
||||||
|
tokensAndLabelList = new ArrayList<>(listLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addProcessedToList(Pair<List<String>, String> tokenizedSentenceAndLabel) {
|
||||||
|
tokensAndLabelList.add(tokenizedSentenceAndLabel);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue