Merge pull request #8495 from KonduitAI/master

Update master
master
Alex Black 2019-12-05 11:05:44 +11:00 committed by GitHub
commit 3275fe35a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
565 changed files with 17579 additions and 5224 deletions

View File

@ -91,7 +91,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -70,7 +70,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -305,7 +305,7 @@ public class TestGraphLocalExecution {
@Test @Test
public void testLocalExecutionEarlyStopping() throws Exception { public void testLocalExecutionEarlyStopping() throws Exception {
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>() EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(6)) .epochTerminationConditions(new MaxEpochsTerminationCondition(4))
.scoreCalculator(new ScoreProvider()) .scoreCalculator(new ScoreProvider())
.modelSaver(new InMemoryModelSaver()).build(); .modelSaver(new InMemoryModelSaver()).build();
Map<String, Object> commands = new HashMap<>(); Map<String, Object> commands = new HashMap<>();
@ -348,7 +348,7 @@ public class TestGraphLocalExecution {
.dataProvider(dataProvider) .dataProvider(dataProvider)
.scoreFunction(ScoreFunctions.testSetF1()) .scoreFunction(ScoreFunctions.testSetF1())
.modelSaver(new FileModelSaver(modelSavePath)) .modelSaver(new FileModelSaver(modelSavePath))
.terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS), .terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS),
new MaxCandidatesCondition(10)) new MaxCandidatesCondition(10))
.build(); .build();

View File

@ -32,7 +32,7 @@ public class TestDataFactoryProviderMnist implements DataSetIteratorFactory {
private int terminationIter; private int terminationIter;
public TestDataFactoryProviderMnist(){ public TestDataFactoryProviderMnist(){
this(16, 10); this(16, 4);
} }
@Override @Override

View File

@ -56,7 +56,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -37,7 +37,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -151,7 +151,7 @@
<skip>${skipTestResourceEnforcement}</skip> <skip>${skipTestResourceEnforcement}</skip>
<rules> <rules>
<requireActiveProfile> <requireActiveProfile>
<profiles>test-nd4j-native,test-nd4j-cuda-10.1</profiles> <profiles>test-nd4j-native,test-nd4j-cuda-10.2</profiles>
<all>false</all> <all>false</all>
</requireActiveProfile> </requireActiveProfile>
</rules> </rules>
@ -333,11 +333,11 @@
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId> <artifactId>nd4j-cuda-10.2</artifactId>
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@ -20,7 +20,7 @@
set -e set -e
VALID_VERSIONS=( 9.2 10.0 10.1 ) VALID_VERSIONS=( 9.2 10.0 10.1 10.2 )
usage() { usage() {
echo "Usage: $(basename $0) [-h|--help] <cuda version to be used> echo "Usage: $(basename $0) [-h|--help] <cuda version to be used>
@ -47,6 +47,10 @@ check_cuda_version() {
check_cuda_version "$VERSION" check_cuda_version "$VERSION"
case $VERSION in case $VERSION in
10.2)
VERSION2="7.6"
VERSION3="1.5.2"
;;
10.1) 10.1)
VERSION2="7.6" VERSION2="7.6"
VERSION3="1.5.2" VERSION3="1.5.2"

View File

@ -117,7 +117,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -56,7 +56,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -110,7 +110,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -72,7 +72,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -59,7 +59,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -126,7 +126,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -67,7 +67,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -58,7 +58,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -58,7 +58,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -49,7 +49,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -67,7 +67,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -65,7 +65,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -88,7 +88,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -256,11 +256,9 @@ public class ExecutionTest {
TransformProcess transformProcess = new TransformProcess.Builder(schema) TransformProcess transformProcess = new TransformProcess.Builder(schema)
.transform( .transform(
new PythonTransform( PythonTransform.builder().code(
"first = np.sin(first)\nsecond = np.cos(second)", "first = np.sin(first)\nsecond = np.cos(second)")
schema .outputSchema(schema).build())
)
)
.build(); .build();
List<List<Writable>> functions = new ArrayList<>(); List<List<Writable>> functions = new ArrayList<>();

View File

@ -14,35 +14,40 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.datavec.python; package org.datavec.local.transforms.transform;
import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.condition.Condition;
import org.datavec.api.transform.filter.ConditionFilter; import org.datavec.api.transform.filter.ConditionFilter;
import org.datavec.api.transform.filter.Filter; import org.datavec.api.transform.filter.Filter;
import org.datavec.api.writable.*;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.junit.Ignore; import org.datavec.local.transforms.LocalTransformExecutor;
import org.datavec.api.writable.*;
import org.datavec.python.PythonCondition;
import org.datavec.python.PythonTransform;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") import static junit.framework.TestCase.assertTrue;
import static org.datavec.api.transform.schema.Schema.Builder;
import static org.junit.Assert.*;
@NotThreadSafe
public class TestPythonTransformProcess { public class TestPythonTransformProcess {
@Test(timeout = 60000L)
@Test()
public void testStringConcat() throws Exception{ public void testStringConcat() throws Exception{
Schema.Builder schemaBuilder = new Schema.Builder(); Builder schemaBuilder = new Builder();
schemaBuilder schemaBuilder
.addColumnString("col1") .addColumnString("col1")
.addColumnString("col2"); .addColumnString("col2");
@ -54,10 +59,12 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2"; String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform(pythonCode, finalSchema) PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build(); ).build();
List<Writable> inputs = Arrays.asList((Writable) new Text("Hello "), new Text("World!")); List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
List<Writable> outputs = tp.execute(inputs); List<Writable> outputs = tp.execute(inputs);
assertEquals((outputs.get(0)).toString(), "Hello "); assertEquals((outputs.get(0)).toString(), "Hello ");
@ -68,7 +75,7 @@ public class TestPythonTransformProcess {
@Test(timeout = 60000L) @Test(timeout = 60000L)
public void testMixedTypes() throws Exception{ public void testMixedTypes() throws Exception{
Schema.Builder schemaBuilder = new Schema.Builder(); Builder schemaBuilder = new Builder();
schemaBuilder schemaBuilder
.addColumnInteger("col1") .addColumnInteger("col1")
.addColumnFloat("col2") .addColumnFloat("col2")
@ -83,11 +90,12 @@ public class TestPythonTransformProcess {
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)"; String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform(pythonCode, finalSchema) PythonTransform.builder().code(pythonCode)
).build(); .outputSchema(finalSchema)
.inputSchema(initialSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList((Writable) List<Writable> inputs = Arrays.asList((Writable)new IntWritable(10),
new IntWritable(10),
new FloatWritable(3.5f), new FloatWritable(3.5f),
new Text("5"), new Text("5"),
new DoubleWritable(2.0) new DoubleWritable(2.0)
@ -105,7 +113,7 @@ public class TestPythonTransformProcess {
INDArray expectedOutput = arr1.add(arr2); INDArray expectedOutput = arr1.add(arr2);
Schema.Builder schemaBuilder = new Schema.Builder(); Builder schemaBuilder = new Builder();
schemaBuilder schemaBuilder
.addColumnNDArray("col1", shape) .addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape); .addColumnNDArray("col2", shape);
@ -116,11 +124,13 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2"; String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform(pythonCode, finalSchema) PythonTransform.builder().code(pythonCode)
).build(); .outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList( List<Writable> inputs = Arrays.asList(
(Writable) new NDArrayWritable(arr1), (Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2) new NDArrayWritable(arr2)
); );
@ -139,7 +149,7 @@ public class TestPythonTransformProcess {
INDArray expectedOutput = arr1.add(arr2); INDArray expectedOutput = arr1.add(arr2);
Schema.Builder schemaBuilder = new Schema.Builder(); Builder schemaBuilder = new Builder();
schemaBuilder schemaBuilder
.addColumnNDArray("col1", shape) .addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape); .addColumnNDArray("col2", shape);
@ -150,11 +160,13 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2"; String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform(pythonCode, finalSchema) PythonTransform.builder().code(pythonCode)
).build(); .outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList( List<Writable> inputs = Arrays.asList(
(Writable) new NDArrayWritable(arr1), (Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2) new NDArrayWritable(arr2)
); );
@ -172,7 +184,7 @@ public class TestPythonTransformProcess {
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape); INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE)); INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
Schema.Builder schemaBuilder = new Schema.Builder(); Builder schemaBuilder = new Builder();
schemaBuilder schemaBuilder
.addColumnNDArray("col1", shape) .addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape); .addColumnNDArray("col2", shape);
@ -183,11 +195,14 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2"; String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform(pythonCode, finalSchema) PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build(); ).build();
List<Writable> inputs = Arrays.asList( List<Writable> inputs = Arrays.asList(
(Writable) new NDArrayWritable(arr1), (Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2) new NDArrayWritable(arr2)
); );
@ -199,8 +214,8 @@ public class TestPythonTransformProcess {
} }
@Test(timeout = 60000L) @Test(timeout = 60000L)
public void testPythonFilter(){ public void testPythonFilter() {
Schema schema = new Schema.Builder().addColumnInteger("column").build(); Schema schema = new Builder().addColumnInteger("column").build();
Condition condition = new PythonCondition( Condition condition = new PythonCondition(
"f = lambda: column < 0" "f = lambda: column < 0"
@ -210,17 +225,17 @@ public class TestPythonTransformProcess {
Filter filter = new ConditionFilter(condition); Filter filter = new ConditionFilter(condition);
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(10)))); assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(1)))); assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(0)))); assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-1)))); assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-10)))); assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
} }
@Test(timeout = 60000L) @Test(timeout = 60000L)
public void testPythonFilterAndTransform() throws Exception{ public void testPythonFilterAndTransform() throws Exception{
Schema.Builder schemaBuilder = new Schema.Builder(); Builder schemaBuilder = new Builder();
schemaBuilder schemaBuilder
.addColumnInteger("col1") .addColumnInteger("col1")
.addColumnFloat("col2") .addColumnFloat("col2")
@ -241,33 +256,85 @@ public class TestPythonTransformProcess {
String pythonCode = "col6 = str(col1 + col2)"; String pythonCode = "col6 = str(col1 + col2)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform( PythonTransform.builder().code(pythonCode)
pythonCode, .outputSchema(finalSchema)
finalSchema .build()
)
).filter( ).filter(
filter filter
).build(); ).build();
List<List<Writable>> inputs = new ArrayList<>(); List<List<Writable>> inputs = new ArrayList<>();
inputs.add( inputs.add(
Arrays.asList((Writable) new IntWritable(5), Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(3.0f), new FloatWritable(3.0f),
new Text("abcd"), new Text("abcd"),
new DoubleWritable(2.1)) new DoubleWritable(2.1))
); );
inputs.add( inputs.add(
Arrays.asList((Writable) new IntWritable(-3), Arrays.asList(
(Writable)
new IntWritable(-3),
new FloatWritable(3.0f), new FloatWritable(3.0f),
new Text("abcd"), new Text("abcd"),
new DoubleWritable(2.1)) new DoubleWritable(2.1))
); );
inputs.add( inputs.add(
Arrays.asList((Writable) new IntWritable(5), Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(11.2f), new FloatWritable(11.2f),
new Text("abcd"), new Text("abcd"),
new DoubleWritable(2.1)) new DoubleWritable(2.1))
); );
LocalTransformExecutor.execute(inputs,tp);
} }
@Test
public void testPythonTransformNoOutputSpecified() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
Schema inputSchema = new Builder()
.addColumnInteger("a")
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertEquals(3,execute.get(0).get(0).toInt());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testNumpyTransform() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertEquals("hello world",execute.get(0).get(0).toString());
}
} }

View File

@ -59,7 +59,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -28,15 +28,21 @@
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>com.googlecode.json-simple</groupId> <groupId>org.json</groupId>
<artifactId>json-simple</artifactId> <artifactId>json</artifactId>
<version>1.1</version> <version>20190722</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.bytedeco</groupId> <groupId>org.bytedeco</groupId>
<artifactId>cpython-platform</artifactId> <artifactId>cpython-platform</artifactId>
<version>${cpython-platform.version}</version> <version>${cpython-platform.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>numpy-platform</artifactId>
<version>${numpy.javacpp.version}</version>
</dependency>
<dependency> <dependency>
<groupId>com.google.code.findbugs</groupId> <groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId> <artifactId>jsr305</artifactId>
@ -65,7 +71,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -16,10 +16,13 @@
package org.datavec.python; package org.datavec.python;
import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
@ -33,19 +36,27 @@ import org.nd4j.linalg.api.buffer.DataType;
* @author Fariz Rahman * @author Fariz Rahman
*/ */
@Getter @Getter
@NoArgsConstructor
public class NumpyArray { public class NumpyArray {
private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); private static NativeOps nativeOps;
private long address; private long address;
private long[] shape; private long[] shape;
private long[] strides; private long[] strides;
private DataType dtype = DataType.FLOAT; private DataType dtype;
private INDArray nd4jArray; private INDArray nd4jArray;
static {
//initialize
Nd4j.scalar(1.0);
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
}
public NumpyArray(long address, long[] shape, long strides[], boolean copy){ @Builder
public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) {
this.address = address; this.address = address;
this.shape = shape; this.shape = shape;
this.strides = strides; this.strides = strides;
this.dtype = dtype;
setND4JArray(); setND4JArray();
if (copy){ if (copy){
nd4jArray = nd4jArray.dup(); nd4jArray = nd4jArray.dup();
@ -57,8 +68,9 @@ public class NumpyArray {
public NumpyArray copy(){ public NumpyArray copy(){
return new NumpyArray(nd4jArray.dup()); return new NumpyArray(nd4jArray.dup());
} }
public NumpyArray(long address, long[] shape, long strides[]){ public NumpyArray(long address, long[] shape, long strides[]){
this(address, shape, strides, false); this(address, shape, strides, false,DataType.FLOAT);
} }
public NumpyArray(long address, long[] shape, long strides[], DataType dtype){ public NumpyArray(long address, long[] shape, long strides[], DataType dtype){
@ -77,9 +89,9 @@ public class NumpyArray {
} }
} }
private void setND4JArray(){ private void setND4JArray() {
long size = 1; long size = 1;
for(long d: shape){ for(long d: shape) {
size *= d; size *= d;
} }
Pointer ptr = nativeOps.pointerForAddress(address); Pointer ptr = nativeOps.pointerForAddress(address);
@ -88,10 +100,11 @@ public class NumpyArray {
DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype); DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype);
int elemSize = buff.getElementSize(); int elemSize = buff.getElementSize();
long[] nd4jStrides = new long[strides.length]; long[] nd4jStrides = new long[strides.length];
for (int i=0; i<strides.length; i++){ for (int i = 0; i < strides.length; i++) {
nd4jStrides[i] = strides[i] / elemSize; nd4jStrides[i] = strides[i] / elemSize;
} }
this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, 'c', dtype);
this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype);
} }

View File

@ -23,6 +23,8 @@ import org.datavec.api.writable.*;
import java.util.List; import java.util.List;
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
/** /**
* Lets a condition be defined as a python method f that takes no arguments * Lets a condition be defined as a python method f that takes no arguments
* and returns a boolean indicating whether or not to filter a row. * and returns a boolean indicating whether or not to filter a row.
@ -38,81 +40,28 @@ public class PythonCondition implements Condition {
private String code; private String code;
public PythonCondition(String pythonCode){ public PythonCondition(String pythonCode) {
org.nd4j.base.Preconditions.checkNotNull("Python code must not be null!",pythonCode);
org.nd4j.base.Preconditions.checkState(pythonCode.length() >= 1,"Python code must not be empty!");
code = pythonCode; code = pythonCode;
} }
private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{
PythonVariables pyVars = new PythonVariables();
int numCols = schema.numColumns();
for (int i=0; i<numCols; i++){
String colName = schema.getName(i);
ColumnType colType = schema.getType(i);
switch (colType){
case Long:
case Integer:
pyVars.addInt(colName);
break;
case Double:
case Float:
pyVars.addFloat(colName);
break;
case String:
pyVars.addStr(colName);
break;
case NDArray:
pyVars.addNDArray(colName);
break;
default:
throw new Exception("Unsupported python input type: " + colType.toString());
}
}
return pyVars;
}
private PythonVariables getPyInputsFromWritables(List<Writable> writables){
PythonVariables ret = new PythonVariables();
for (String name: pyInputs.getVariables()){
int colIdx = inputSchema.getIndexOfColumn(name);
Writable w = writables.get(colIdx);
PythonVariables.Type pyType = pyInputs.getType(name);
switch (pyType){
case INT:
if (w instanceof LongWritable){
ret.addInt(name, ((LongWritable)w).get());
}
else{
ret.addInt(name, ((IntWritable)w).get());
}
break;
case FLOAT:
ret.addFloat(name, ((DoubleWritable)w).get());
break;
case STR:
ret.addStr(name, ((Text)w).toString());
break;
case NDARRAY:
ret.addNDArray(name,((NDArrayWritable)w).get());
break;
}
}
return ret;
}
@Override @Override
public void setInputSchema(Schema inputSchema){ public void setInputSchema(Schema inputSchema) {
this.inputSchema = inputSchema; this.inputSchema = inputSchema;
try{ try{
pyInputs = schemaToPythonVariables(inputSchema); pyInputs = schemaToPythonVariables(inputSchema);
PythonVariables pyOuts = new PythonVariables(); PythonVariables pyOuts = new PythonVariables();
pyOuts.addInt("out"); pyOuts.addInt("out");
pythonTransform = new PythonTransform( pythonTransform = PythonTransform.builder()
code + "\n\nout=f()\nout=0 if out is None else int(out)", // TODO: remove int conversion after boolean support is covered .code(code + "\n\nout=f()\nout=0 if out is None else int(out)")
pyInputs, .inputs(pyInputs)
pyOuts .outputs(pyOuts)
); .build();
} }
catch (Exception e){ catch (Exception e){
throw new RuntimeException(e); throw new RuntimeException(e);
@ -127,41 +76,47 @@ public class PythonCondition implements Condition {
return inputSchema; return inputSchema;
} }
public String[] outputColumnNames(){ @Override
public String[] outputColumnNames() {
String[] columnNames = new String[inputSchema.numColumns()]; String[] columnNames = new String[inputSchema.numColumns()];
inputSchema.getColumnNames().toArray(columnNames); inputSchema.getColumnNames().toArray(columnNames);
return columnNames; return columnNames;
} }
@Override
public String outputColumnName(){ public String outputColumnName(){
return outputColumnNames()[0]; return outputColumnNames()[0];
} }
@Override
public String[] columnNames(){ public String[] columnNames(){
return outputColumnNames(); return outputColumnNames();
} }
@Override
public String columnName(){ public String columnName(){
return outputColumnName(); return outputColumnName();
} }
@Override
public Schema transform(Schema inputSchema){ public Schema transform(Schema inputSchema){
return inputSchema; return inputSchema;
} }
public boolean condition(List<Writable> list){ @Override
public boolean condition(List<Writable> list) {
PythonVariables inputs = getPyInputsFromWritables(list); PythonVariables inputs = getPyInputsFromWritables(list);
try{ try{
PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs()); PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs());
boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0; boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0;
return ret; return ret;
} }
catch (Exception e){ catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
@Override
public boolean condition(Object input){ public boolean condition(Object input){
return condition(input); return condition(input);
} }
@ -177,5 +132,37 @@ public class PythonCondition implements Condition {
throw new UnsupportedOperationException("not supported"); throw new UnsupportedOperationException("not supported");
} }
private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
PythonVariables ret = new PythonVariables();
for (int i = 0; i < inputSchema.numColumns(); i++){
String name = inputSchema.getName(i);
Writable w = writables.get(i);
PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i));
switch (pyType){
case INT:
if (w instanceof LongWritable) {
ret.addInt(name, ((LongWritable)w).get());
}
else {
ret.addInt(name, ((IntWritable)w).get());
}
break;
case FLOAT:
ret.addFloat(name, ((DoubleWritable)w).get());
break;
case STR:
ret.addStr(name, w.toString());
break;
case NDARRAY:
ret.addNDArray(name,((NDArrayWritable)w).get());
break;
}
}
return ret;
}
} }

View File

@ -16,16 +16,29 @@
package org.datavec.python; package org.datavec.python;
import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.apache.commons.io.IOUtils;
import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.Transform; import org.datavec.api.transform.Transform;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.nd4j.base.Preconditions;
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.UUID; import java.util.UUID;
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
/** /**
* Row-wise Transform that applies arbitrary python code on each row * Row-wise Transform that applies arbitrary python code on each row
* *
@ -34,31 +47,87 @@ import java.util.UUID;
@NoArgsConstructor @NoArgsConstructor
@Data @Data
public class PythonTransform implements Transform{ public class PythonTransform implements Transform {
private String code; private String code;
private PythonVariables pyInputs; private PythonVariables inputs;
private PythonVariables pyOutputs; private PythonVariables outputs;
private String name; private String name = UUID.randomUUID().toString();
private Schema inputSchema; private Schema inputSchema;
private Schema outputSchema; private Schema outputSchema;
private String outputDict;
private boolean returnAllVariables;
private boolean setupAndRun = false;
public PythonTransform(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception{ @Builder
public PythonTransform(String code,
PythonVariables inputs,
PythonVariables outputs,
String name,
Schema inputSchema,
Schema outputSchema,
String outputDict,
boolean returnAllInputs,
boolean setupAndRun) {
Preconditions.checkNotNull(code,"No code found to run!");
this.code = code; this.code = code;
this.pyInputs = pyInputs; this.returnAllVariables = returnAllInputs;
this.pyOutputs = pyOutputs; this.setupAndRun = setupAndRun;
this.name = UUID.randomUUID().toString(); if(inputs != null)
this.inputs = inputs;
if(outputs != null)
this.outputs = outputs;
if(name != null)
this.name = name;
if (outputDict != null) {
this.outputDict = outputDict;
this.outputs = new PythonVariables();
this.outputs.addDict(outputDict);
String helpers;
try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) {
helpers = IOUtils.toString(is, Charset.defaultCharset());
}catch (IOException e){
throw new RuntimeException("Error reading python code");
}
this.code += "\n\n" + helpers;
this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")";
} }
try {
if(inputSchema != null) {
this.inputSchema = inputSchema;
if(inputs == null || inputs.isEmpty()) {
this.inputs = schemaToPythonVariables(inputSchema);
}
}
if(outputSchema != null) {
this.outputSchema = outputSchema;
if(outputs == null || outputs.isEmpty()) {
this.outputs = schemaToPythonVariables(outputSchema);
}
}
}catch(Exception e) {
throw new IllegalStateException(e);
}
}
@Override @Override
public void setInputSchema(Schema inputSchema){ public void setInputSchema(Schema inputSchema) {
Preconditions.checkNotNull(inputSchema,"No input schema found!");
this.inputSchema = inputSchema; this.inputSchema = inputSchema;
try{ try{
pyInputs = schemaToPythonVariables(inputSchema); inputs = schemaToPythonVariables(inputSchema);
}catch (Exception e){ }catch (Exception e){
throw new RuntimeException(e); throw new RuntimeException(e);
} }
if (outputSchema == null){ if (outputSchema == null && outputDict == null){
outputSchema = inputSchema; outputSchema = inputSchema;
} }
@ -88,12 +157,42 @@ public class PythonTransform implements Transform{
throw new UnsupportedOperationException("Not yet implemented"); throw new UnsupportedOperationException("Not yet implemented");
} }
@Override @Override
public List<Writable> map(List<Writable> writables){ public List<Writable> map(List<Writable> writables) {
PythonVariables pyInputs = getPyInputsFromWritables(writables); PythonVariables pyInputs = getPyInputsFromWritables(writables);
Preconditions.checkNotNull(pyInputs,"Inputs must not be null!");
try{ try{
PythonExecutioner.exec(code, pyInputs, pyOutputs); if (returnAllVariables) {
return getWritablesFromPyOutputs(pyOutputs); if (setupAndRun){
return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs));
}
return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs));
}
if (outputDict != null) {
if (setupAndRun) {
PythonExecutioner.execWithSetupAndRun(this, pyInputs);
}else{
PythonExecutioner.exec(this, pyInputs);
}
PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict);
return getWritablesFromPyOutputs(out);
}
else {
if (setupAndRun) {
PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs);
}else{
PythonExecutioner.exec(code, pyInputs, outputs);
}
return getWritablesFromPyOutputs(outputs);
}
} }
catch (Exception e){ catch (Exception e){
throw new RuntimeException(e); throw new RuntimeException(e);
@ -102,7 +201,7 @@ public class PythonTransform implements Transform{
@Override @Override
public String[] outputColumnNames(){ public String[] outputColumnNames(){
return pyOutputs.getVariables(); return outputs.getVariables();
} }
@Override @Override
@ -111,7 +210,7 @@ public class PythonTransform implements Transform{
} }
@Override @Override
public String[] columnNames(){ public String[] columnNames(){
return pyOutputs.getVariables(); return outputs.getVariables();
} }
@Override @Override
@ -124,14 +223,13 @@ public class PythonTransform implements Transform{
} }
private PythonVariables getPyInputsFromWritables(List<Writable> writables){ private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
PythonVariables ret = new PythonVariables(); PythonVariables ret = new PythonVariables();
for (String name: pyInputs.getVariables()){ for (String name: inputs.getVariables()) {
int colIdx = inputSchema.getIndexOfColumn(name); int colIdx = inputSchema.getIndexOfColumn(name);
Writable w = writables.get(colIdx); Writable w = writables.get(colIdx);
PythonVariables.Type pyType = pyInputs.getType(name); PythonVariables.Type pyType = inputs.getType(name);
switch (pyType){ switch (pyType){
case INT: case INT:
if (w instanceof LongWritable){ if (w instanceof LongWritable){
@ -143,7 +241,7 @@ public class PythonTransform implements Transform{
break; break;
case FLOAT: case FLOAT:
if (w instanceof DoubleWritable){ if (w instanceof DoubleWritable) {
ret.addFloat(name, ((DoubleWritable)w).get()); ret.addFloat(name, ((DoubleWritable)w).get());
} }
else{ else{
@ -151,96 +249,99 @@ public class PythonTransform implements Transform{
} }
break; break;
case STR: case STR:
ret.addStr(name, ((Text)w).toString()); ret.addStr(name, w.toString());
break; break;
case NDARRAY: case NDARRAY:
ret.addNDArray(name,((NDArrayWritable)w).get()); ret.addNDArray(name,((NDArrayWritable)w).get());
break; break;
default:
throw new RuntimeException("Unsupported input type:" + pyType);
} }
} }
return ret; return ret;
} }
private List<Writable> getWritablesFromPyOutputs(PythonVariables pyOuts){ private List<Writable> getWritablesFromPyOutputs(PythonVariables pyOuts) {
List<Writable> out = new ArrayList<>(); List<Writable> out = new ArrayList<>();
for (int i=0; i<outputSchema.numColumns(); i++){ String[] varNames;
String name = outputSchema.getName(i); varNames = pyOuts.getVariables();
PythonVariables.Type pyType = pyOutputs.getType(name); Schema.Builder schemaBuilder = new Schema.Builder();
for (int i = 0; i < varNames.length; i++) {
String name = varNames[i];
PythonVariables.Type pyType = pyOuts.getType(name);
switch (pyType){ switch (pyType){
case INT: case INT:
out.add((Writable) new LongWritable(pyOuts.getIntValue(name))); schemaBuilder.addColumnLong(name);
break; break;
case FLOAT: case FLOAT:
out.add((Writable) new DoubleWritable(pyOuts.getFloatValue(name))); schemaBuilder.addColumnDouble(name);
break; break;
case STR: case STR:
out.add((Writable) new Text(pyOuts.getStrValue(name))); case DICT:
case LIST:
schemaBuilder.addColumnString(name);
break; break;
case NDARRAY: case NDARRAY:
out.add((Writable) new NDArrayWritable(pyOuts.getNDArrayValue(name).getNd4jArray())); NumpyArray arr = pyOuts.getNDArrayValue(name);
schemaBuilder.addColumnNDArray(name, arr.getShape());
break; break;
default:
throw new IllegalStateException("Unable to support type " + pyType.name());
}
}
this.outputSchema = schemaBuilder.build();
for (int i = 0; i < varNames.length; i++) {
String name = varNames[i];
PythonVariables.Type pyType = pyOuts.getType(name);
switch (pyType){
case INT:
out.add(new LongWritable(pyOuts.getIntValue(name)));
break;
case FLOAT:
out.add(new DoubleWritable(pyOuts.getFloatValue(name)));
break;
case STR:
out.add(new Text(pyOuts.getStrValue(name)));
break;
case NDARRAY:
NumpyArray arr = pyOuts.getNDArrayValue(name);
out.add(new NDArrayWritable(arr.getNd4jArray()));
break;
case DICT:
Map<?, ?> dictValue = pyOuts.getDictValue(name);
Map noNullValues = new java.util.HashMap<>();
for(Map.Entry entry : dictValue.entrySet()) {
if(entry.getValue() != org.json.JSONObject.NULL) {
noNullValues.put(entry.getKey(), entry.getValue());
}
}
try {
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(noNullValues)));
} catch (JsonProcessingException e) {
throw new IllegalStateException("Unable to serialize dictionary " + name + " to json!");
}
break;
case LIST:
Object[] listValue = pyOuts.getListValue(name);
try {
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue)));
} catch (JsonProcessingException e) {
throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
}
break;
default:
throw new IllegalStateException("Unable to support type " + pyType.name());
} }
} }
return out; return out;
} }
public PythonTransform(String code) throws Exception{
this.code = code;
this.name = UUID.randomUUID().toString();
}
private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{
PythonVariables pyVars = new PythonVariables();
int numCols = schema.numColumns();
for (int i=0; i<numCols; i++){
String colName = schema.getName(i);
ColumnType colType = schema.getType(i);
switch (colType){
case Long:
case Integer:
pyVars.addInt(colName);
break;
case Double:
case Float:
pyVars.addFloat(colName);
break;
case String:
pyVars.addStr(colName);
break;
case NDArray:
pyVars.addNDArray(colName);
break;
default:
throw new Exception("Unsupported python input type: " + colType.toString());
}
}
return pyVars;
}
public PythonTransform(String code, Schema outputSchema) throws Exception{
this.code = code;
this.name = UUID.randomUUID().toString();
this.outputSchema = outputSchema;
this.pyOutputs = schemaToPythonVariables(outputSchema);
}
public String getName() {
return name;
}
public String getCode(){
return code;
}
public PythonVariables getInputs() {
return pyInputs;
}
public PythonVariables getOutputs() {
return pyOutputs;
}
} }

View File

@ -0,0 +1,306 @@
package org.datavec.python;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.BooleanMetaData;
import org.datavec.api.transform.schema.Schema;
import org.json.JSONArray;
import org.json.JSONObject;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* List of utilities for executing python transforms.
*
* @author Adam Gibson
*/
public class PythonUtils {
/**
* Create a {@link Schema}
* from {@link PythonVariables}.
* Types are mapped to types of the same name.
* @param input the input {@link PythonVariables}
* @return the output {@link Schema}
*/
public static Schema fromPythonVariables(PythonVariables input) {
Schema.Builder schemaBuilder = new Schema.Builder();
Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0,"Input must have variables. Found none.");
for(Map.Entry<String,PythonVariables.Type> entry : input.getVars().entrySet()) {
switch(entry.getValue()) {
case INT:
schemaBuilder.addColumnInteger(entry.getKey());
break;
case STR:
schemaBuilder.addColumnString(entry.getKey());
break;
case FLOAT:
schemaBuilder.addColumnFloat(entry.getKey());
break;
case NDARRAY:
schemaBuilder.addColumnNDArray(entry.getKey(),null);
break;
case BOOL:
schemaBuilder.addColumn(new BooleanMetaData(entry.getKey()));
}
}
return schemaBuilder.build();
}
/**
* Create a {@link Schema} from an input
* {@link PythonVariables}
* Types are mapped to types of the same name
* @param input the input schema
* @return the output python variables.
*/
public static PythonVariables fromSchema(Schema input) {
PythonVariables ret = new PythonVariables();
for(int i = 0; i < input.numColumns(); i++) {
String currColumnName = input.getName(i);
ColumnType columnType = input.getType(i);
switch(columnType) {
case NDArray:
ret.add(currColumnName, PythonVariables.Type.NDARRAY);
break;
case Boolean:
ret.add(currColumnName, PythonVariables.Type.BOOL);
break;
case Categorical:
case String:
ret.add(currColumnName, PythonVariables.Type.STR);
break;
case Double:
case Float:
ret.add(currColumnName, PythonVariables.Type.FLOAT);
break;
case Integer:
case Long:
ret.add(currColumnName, PythonVariables.Type.INT);
break;
case Bytes:
break;
case Time:
throw new UnsupportedOperationException("Unable to process dates with python yet.");
}
}
return ret;
}
/**
* Convert a {@link Schema}
* to {@link PythonVariables}
* @param schema the input schema
* @return the output {@link PythonVariables} where each
* name in the map is associated with a column name in the schema.
* A proper type is also chosen based on the schema
* @throws Exception
*/
public static PythonVariables schemaToPythonVariables(Schema schema) throws Exception {
PythonVariables pyVars = new PythonVariables();
int numCols = schema.numColumns();
for (int i = 0; i < numCols; i++) {
String colName = schema.getName(i);
ColumnType colType = schema.getType(i);
switch (colType){
case Long:
case Integer:
pyVars.addInt(colName);
break;
case Double:
case Float:
pyVars.addFloat(colName);
break;
case String:
pyVars.addStr(colName);
break;
case NDArray:
pyVars.addNDArray(colName);
break;
default:
throw new Exception("Unsupported python input type: " + colType.toString());
}
}
return pyVars;
}
public static NumpyArray mapToNumpyArray(Map map){
String dtypeName = (String)map.get("dtype");
DataType dtype;
if (dtypeName.equals("float64")){
dtype = DataType.DOUBLE;
}
else if (dtypeName.equals("float32")){
dtype = DataType.FLOAT;
}
else if (dtypeName.equals("int16")){
dtype = DataType.SHORT;
}
else if (dtypeName.equals("int32")){
dtype = DataType.INT;
}
else if (dtypeName.equals("int64")){
dtype = DataType.LONG;
}
else{
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
}
List shapeList = (List)map.get("shape");
long[] shape = new long[shapeList.size()];
for (int i = 0; i < shape.length; i++) {
shape[i] = (Long)shapeList.get(i);
}
List strideList = (List)map.get("shape");
long[] stride = new long[strideList.size()];
for (int i = 0; i < stride.length; i++) {
stride[i] = (Long)strideList.get(i);
}
long address = (Long)map.get("address");
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
return numpyArray;
}
public static PythonVariables expandInnerDict(PythonVariables pyvars, String key){
Map dict = pyvars.getDictValue(key);
String[] keys = (String[])dict.keySet().toArray(new String[dict.keySet().size()]);
PythonVariables pyvars2 = new PythonVariables();
for (String subkey: keys){
Object value = dict.get(subkey);
if (value instanceof Map){
Map map = (Map)value;
if (map.containsKey("_is_numpy_array")){
pyvars2.addNDArray(subkey, mapToNumpyArray(map));
}
else{
pyvars2.addDict(subkey, (Map)value);
}
}
else if (value instanceof List){
pyvars2.addList(subkey, ((List) value).toArray());
}
else if (value instanceof String){
System.out.println((String)value);
pyvars2.addStr(subkey, (String) value);
}
else if (value instanceof Integer || value instanceof Long) {
Number number = (Number) value;
pyvars2.addInt(subkey, number.intValue());
}
else if (value instanceof Float || value instanceof Double) {
Number number = (Number) value;
pyvars2.addFloat(subkey, number.doubleValue());
}
else if (value instanceof NumpyArray){
pyvars2.addNDArray(subkey, (NumpyArray)value);
}
else if (value == null){
pyvars2.addStr(subkey, "None"); // FixMe
}
else{
throw new RuntimeException("Unsupported type!" + value);
}
}
return pyvars2;
}
public static long[] jsonArrayToLongArray(JSONArray jsonArray){
long[] longs = new long[jsonArray.length()];
for (int i=0; i<longs.length; i++){
longs[i] = jsonArray.getLong(i);
}
return longs;
}
public static Map<String, Object> toMap(JSONObject jsonobj) {
Map<String, Object> map = new HashMap<>();
String[] keys = (String[])jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]);
for (String key: keys){
Object value = jsonobj.get(key);
if (value instanceof JSONArray) {
value = toList((JSONArray) value);
} else if (value instanceof JSONObject) {
JSONObject jsonobj2 = (JSONObject)value;
if (jsonobj2.has("_is_numpy_array")){
value = jsonToNumpyArray(jsonobj2);
}
else{
value = toMap(jsonobj2);
}
}
map.put(key, value);
} return map;
}
public static List<Object> toList(JSONArray array) {
List<Object> list = new ArrayList<>();
for (int i = 0; i < array.length(); i++) {
Object value = array.get(i);
if (value instanceof JSONArray) {
value = toList((JSONArray) value);
} else if (value instanceof JSONObject) {
JSONObject jsonobj2 = (JSONObject) value;
if (jsonobj2.has("_is_numpy_array")) {
value = jsonToNumpyArray(jsonobj2);
} else {
value = toMap(jsonobj2);
}
}
list.add(value);
}
return list;
}
private static NumpyArray jsonToNumpyArray(JSONObject map){
String dtypeName = (String)map.get("dtype");
DataType dtype;
if (dtypeName.equals("float64")){
dtype = DataType.DOUBLE;
}
else if (dtypeName.equals("float32")){
dtype = DataType.FLOAT;
}
else if (dtypeName.equals("int16")){
dtype = DataType.SHORT;
}
else if (dtypeName.equals("int32")){
dtype = DataType.INT;
}
else if (dtypeName.equals("int64")){
dtype = DataType.LONG;
}
else{
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
}
List shapeList = (List)map.get("shape");
long[] shape = new long[shapeList.size()];
for (int i = 0; i < shape.length; i++) {
shape[i] = (Long)shapeList.get(i);
}
List strideList = (List)map.get("shape");
long[] stride = new long[strideList.size()];
for (int i = 0; i < stride.length; i++) {
stride[i] = (Long)strideList.get(i);
}
long address = (Long)map.get("address");
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
return numpyArray;
}
}

View File

@ -17,8 +17,8 @@
package org.datavec.python; package org.datavec.python;
import lombok.Data; import lombok.Data;
import org.json.simple.JSONArray; import org.json.JSONObject;
import org.json.simple.JSONObject; import org.json.JSONArray;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.Serializable; import java.io.Serializable;
@ -31,8 +31,8 @@ import java.util.*;
* @author Fariz Rahman * @author Fariz Rahman
*/ */
@Data @lombok.Data
public class PythonVariables implements Serializable{ public class PythonVariables implements java.io.Serializable {
public enum Type{ public enum Type{
BOOL, BOOL,
@ -41,23 +41,29 @@ public class PythonVariables implements Serializable{
FLOAT, FLOAT,
NDARRAY, NDARRAY,
LIST, LIST,
FILE FILE,
DICT
} }
private Map<String, String> strVars = new HashMap<String, String>(); private java.util.Map<String, String> strVariables = new java.util.LinkedHashMap<>();
private Map<String, Long> intVars = new HashMap<String, Long>(); private java.util.Map<String, Long> intVariables = new java.util.LinkedHashMap<>();
private Map<String, Double> floatVars = new HashMap<String, Double>(); private java.util.Map<String, Double> floatVariables = new java.util.LinkedHashMap<>();
private Map<String, Boolean> boolVars = new HashMap<String, Boolean>(); private java.util.Map<String, Boolean> boolVariables = new java.util.LinkedHashMap<>();
private Map<String, NumpyArray> ndVars = new HashMap<String, NumpyArray>(); private java.util.Map<String, NumpyArray> ndVars = new java.util.LinkedHashMap<>();
private Map<String, Object[]> listVars = new HashMap<String, Object[]>(); private java.util.Map<String, Object[]> listVariables = new java.util.LinkedHashMap<>();
private Map<String, String> fileVars = new HashMap<String, String>(); private java.util.Map<String, String> fileVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, java.util.Map<?,?>> dictVariables = new java.util.LinkedHashMap<>();
private Map<String, Type> vars = new HashMap<String, Type>(); private java.util.Map<String, Type> vars = new java.util.LinkedHashMap<>();
private java.util.Map<Type, java.util.Map> maps = new java.util.LinkedHashMap<>();
private Map<Type, Map> maps = new HashMap<Type, Map>();
/**
* Returns a copy of the variable
* schema in this array without the values
* @return an empty variables clone
* with no values
*/
public PythonVariables copySchema(){ public PythonVariables copySchema(){
PythonVariables ret = new PythonVariables(); PythonVariables ret = new PythonVariables();
for (String varName: getVariables()){ for (String varName: getVariables()){
@ -66,15 +72,30 @@ public class PythonVariables implements Serializable{
} }
return ret; return ret;
} }
public PythonVariables(){
maps.put(Type.BOOL, boolVars);
maps.put(Type.STR, strVars);
maps.put(Type.INT, intVars);
maps.put(Type.FLOAT, floatVars);
maps.put(Type.NDARRAY, ndVars);
maps.put(Type.LIST, listVars);
maps.put(Type.FILE, fileVars);
/**
*
*/
public PythonVariables() {
maps.put(PythonVariables.Type.BOOL, boolVariables);
maps.put(PythonVariables.Type.STR, strVariables);
maps.put(PythonVariables.Type.INT, intVariables);
maps.put(PythonVariables.Type.FLOAT, floatVariables);
maps.put(PythonVariables.Type.NDARRAY, ndVars);
maps.put(PythonVariables.Type.LIST, listVariables);
maps.put(PythonVariables.Type.FILE, fileVariables);
maps.put(PythonVariables.Type.DICT, dictVariables);
}
/**
*
* @return true if there are no variables.
*/
public boolean isEmpty() {
return getVariables().length < 1;
} }
@ -105,6 +126,9 @@ public class PythonVariables implements Serializable{
break; break;
case FILE: case FILE:
addFile(name); addFile(name);
break;
case DICT:
addDict(name);
} }
} }
@ -113,248 +137,459 @@ public class PythonVariables implements Serializable{
* @param name name of the variable * @param name name of the variable
* @param type type of the variable * @param type type of the variable
* @param value value of the variable (must be instance of expected type) * @param value value of the variable (must be instance of expected type)
* @throws Exception
*/ */
public void add (String name, Type type, Object value) throws Exception{ public void add(String name, Type type, Object value) {
add(name, type); add(name, type);
setValue(name, value); setValue(name, value);
} }
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addDict(String name) {
vars.put(name, PythonVariables.Type.DICT);
dictVariables.put(name,null);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addBool(String name){ public void addBool(String name){
vars.put(name, Type.BOOL); vars.put(name, PythonVariables.Type.BOOL);
boolVars.put(name, null); boolVariables.put(name, null);
} }
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addStr(String name){ public void addStr(String name){
vars.put(name, Type.STR); vars.put(name, PythonVariables.Type.STR);
strVars.put(name, null); strVariables.put(name, null);
} }
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addInt(String name){ public void addInt(String name){
vars.put(name, Type.INT); vars.put(name, PythonVariables.Type.INT);
intVars.put(name, null); intVariables.put(name, null);
} }
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addFloat(String name){ public void addFloat(String name){
vars.put(name, Type.FLOAT); vars.put(name, PythonVariables.Type.FLOAT);
floatVars.put(name, null); floatVariables.put(name, null);
} }
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addNDArray(String name){ public void addNDArray(String name){
vars.put(name, Type.NDARRAY); vars.put(name, PythonVariables.Type.NDARRAY);
ndVars.put(name, null); ndVars.put(name, null);
} }
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addList(String name){ public void addList(String name){
vars.put(name, Type.LIST); vars.put(name, PythonVariables.Type.LIST);
listVars.put(name, null); listVariables.put(name, null);
} }
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addFile(String name){ public void addFile(String name){
vars.put(name, Type.FILE); vars.put(name, PythonVariables.Type.FILE);
fileVars.put(name, null); fileVariables.put(name, null);
}
public void addBool(String name, boolean value){
vars.put(name, Type.BOOL);
boolVars.put(name, value);
} }
public void addStr(String name, String value){ /**
vars.put(name, Type.STR); * Add a boolean variable to
strVars.put(name, value); * the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addBool(String name, boolean value) {
vars.put(name, PythonVariables.Type.BOOL);
boolVariables.put(name, value);
} }
public void addInt(String name, int value){ /**
vars.put(name, Type.INT); * Add a string variable to
intVars.put(name, (long)value); * the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addStr(String name, String value) {
vars.put(name, PythonVariables.Type.STR);
strVariables.put(name, value);
} }
public void addInt(String name, long value){ /**
vars.put(name, Type.INT); * Add an int variable to
intVars.put(name, value); * the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addInt(String name, int value) {
vars.put(name, PythonVariables.Type.INT);
intVariables.put(name, (long)value);
} }
public void addFloat(String name, double value){ /**
vars.put(name, Type.FLOAT); * Add a long variable to
floatVars.put(name, value); * the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addInt(String name, long value) {
vars.put(name, PythonVariables.Type.INT);
intVariables.put(name, value);
} }
public void addFloat(String name, float value){ /**
vars.put(name, Type.FLOAT); * Add a double variable to
floatVars.put(name, (double)value); * the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addFloat(String name, double value) {
vars.put(name, PythonVariables.Type.FLOAT);
floatVariables.put(name, value);
} }
public void addNDArray(String name, NumpyArray value){ /**
vars.put(name, Type.NDARRAY); * Add a float variable to
* the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addFloat(String name, float value) {
vars.put(name, PythonVariables.Type.FLOAT);
floatVariables.put(name, (double)value);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add
*/
public void addNDArray(String name, NumpyArray value) {
vars.put(name, PythonVariables.Type.NDARRAY);
ndVars.put(name, value); ndVars.put(name, value);
} }
public void addNDArray(String name, INDArray value){ /**
vars.put(name, Type.NDARRAY); * Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add
*/
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
vars.put(name, PythonVariables.Type.NDARRAY);
ndVars.put(name, new NumpyArray(value)); ndVars.put(name, new NumpyArray(value));
} }
public void addList(String name, Object[] value){ /**
vars.put(name, Type.LIST); * Add a null variable to
listVars.put(name, value); * the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add
*/
public void addList(String name, Object[] value) {
vars.put(name, PythonVariables.Type.LIST);
listVariables.put(name, value);
} }
public void addFile(String name, String value){ /**
vars.put(name, Type.FILE); * Add a null variable to
fileVars.put(name, value); * the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add
*/
public void addFile(String name, String value) {
vars.put(name, PythonVariables.Type.FILE);
fileVariables.put(name, value);
} }
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add
*/
public void addDict(String name, java.util.Map value) {
vars.put(name, PythonVariables.Type.DICT);
dictVariables.put(name, value);
}
/** /**
* *
* @param name name of the variable * @param name name of the variable
* @param value new value for the variable * @param value new value for the variable
* @throws Exception
*/ */
public void setValue(String name, Object value) { public void setValue(String name, Object value) {
Type type = vars.get(name); Type type = vars.get(name);
if (type == Type.BOOL){ if (type == PythonVariables.Type.BOOL){
boolVars.put(name, (Boolean)value); boolVariables.put(name, (Boolean)value);
} }
else if (type == Type.INT){ else if (type == PythonVariables.Type.INT){
if (value instanceof Long){ Number number = (Number) value;
intVars.put(name, ((Long)value)); intVariables.put(name, number.longValue());
} }
else if (value instanceof Integer){ else if (type == PythonVariables.Type.FLOAT){
intVars.put(name, ((Integer)value).longValue()); Number number = (Number) value;
floatVariables.put(name, number.doubleValue());
} }
} else if (type == PythonVariables.Type.NDARRAY){
else if (type == Type.FLOAT){
floatVars.put(name, (Double)value);
}
else if (type == Type.NDARRAY){
if (value instanceof NumpyArray){ if (value instanceof NumpyArray){
ndVars.put(name, (NumpyArray)value); ndVars.put(name, (NumpyArray)value);
} }
else if (value instanceof INDArray){ else if (value instanceof org.nd4j.linalg.api.ndarray.INDArray) {
ndVars.put(name, new NumpyArray((INDArray) value)); ndVars.put(name, new NumpyArray((org.nd4j.linalg.api.ndarray.INDArray) value));
} }
else{ else{
throw new RuntimeException("Unsupported type: " + value.getClass().toString()); throw new RuntimeException("Unsupported type: " + value.getClass().toString());
} }
} }
else if (type == Type.LIST){ else if (type == PythonVariables.Type.LIST) {
listVars.put(name, (Object[]) value); if (value instanceof java.util.List) {
value = ((java.util.List) value).toArray();
listVariables.put(name, (Object[]) value);
} }
else if (type == Type.FILE){ else if(value instanceof org.json.JSONArray) {
fileVars.put(name, (String)value); org.json.JSONArray jsonArray = (org.json.JSONArray) value;
Object[] copyArr = new Object[jsonArray.length()];
for(int i = 0; i < copyArr.length; i++) {
copyArr[i] = jsonArray.get(i);
}
listVariables.put(name, copyArr);
}
else {
listVariables.put(name, (Object[]) value);
}
}
else if(type == PythonVariables.Type.DICT) {
dictVariables.put(name,(java.util.Map<?,?>) value);
}
else if (type == PythonVariables.Type.FILE){
fileVariables.put(name, (String)value);
} }
else{ else{
strVars.put(name, (String)value); strVariables.put(name, (String)value);
} }
} }
public Object getValue(String name){ /**
* Do a general object lookup.
* The look up will happen relative to the {@link Type}
* of variable is described in the
* @param name the name of the variable to get
* @return teh value for the variable with the given name
*/
public Object getValue(String name) {
Type type = vars.get(name); Type type = vars.get(name);
Map map = maps.get(type); java.util.Map map = maps.get(type);
return map.get(name); return map.get(name);
} }
/**
* Returns a boolean variable with the given name.
* @param name the variable name to get the value for
* @return the retrieved boolean value
*/
public boolean getBooleanValue(String name) {
return boolVariables.get(name);
}
/**
*
* @param name the variable name
* @return the dictionary value
*/
public java.util.Map<?,?> getDictValue(String name) {
return dictVariables.get(name);
}
/**
/**
*
* @param name the variable name
* @return the string value
*/
public String getStrValue(String name){ public String getStrValue(String name){
return strVars.get(name); return strVariables.get(name);
} }
public long getIntValue(String name){ /**
return intVars.get(name); *
* @param name the variable name
* @return the long value
*/
public Long getIntValue(String name){
return intVariables.get(name);
} }
public double getFloatValue(String name){ /**
return floatVars.get(name); *
* @param name the variable name
* @return the float value
*/
public Double getFloatValue(String name){
return floatVariables.get(name);
} }
/**
*
* @param name the variable name
* @return the numpy array value
*/
public NumpyArray getNDArrayValue(String name){ public NumpyArray getNDArrayValue(String name){
return ndVars.get(name); return ndVars.get(name);
} }
/**
*
* @param name the variable name
* @return the list value as an object array
*/
public Object[] getListValue(String name){ public Object[] getListValue(String name){
return listVars.get(name); return listVariables.get(name);
} }
/**
*
* @param name the variable name
* @return the value of the given file name
*/
public String getFileValue(String name){ public String getFileValue(String name){
return fileVars.get(name); return fileVariables.get(name);
} }
/**
* Returns the type for the given variable name
* @param name the name of the variable to get the type for
* @return the type for the given variable
*/
public Type getType(String name){ public Type getType(String name){
return vars.get(name); return vars.get(name);
} }
/**
* Get all the variables present as a string array
* @return the variable names for this variable sset
*/
public String[] getVariables() { public String[] getVariables() {
String[] strArr = new String[vars.size()]; String[] strArr = new String[vars.size()];
return vars.keySet().toArray(strArr); return vars.keySet().toArray(strArr);
} }
public Map<String, Boolean> getBoolVariables(){ /**
return boolVars; * This variables set as its json representation (an array of json objects)
} * @return the json array output
public Map<String, String> getStrVariables(){ */
return strVars; public org.json.JSONArray toJSON(){
} org.json.JSONArray arr = new org.json.JSONArray();
public Map<String, Long> getIntVariables(){
return intVars;
}
public Map<String, Double> getFloatVariables(){
return floatVars;
}
public Map<String, NumpyArray> getNDArrayVariables(){
return ndVars;
}
public Map<String, Object[]> getListVariables(){
return listVars;
}
public Map<String, String> getFileVariables(){
return fileVars;
}
public JSONArray toJSON(){
JSONArray arr = new JSONArray();
for (String varName: getVariables()){ for (String varName: getVariables()){
JSONObject var = new JSONObject(); org.json.JSONObject var = new org.json.JSONObject();
var.put("name", varName); var.put("name", varName);
String varType = getType(varName).toString(); String varType = getType(varName).toString();
var.put("type", varType); var.put("type", varType);
arr.add(var); arr.put(var);
} }
return arr; return arr;
} }
public static PythonVariables fromJSON(JSONArray jsonArray){ /**
* Create a schema from a map.
* This is an empty PythonVariables
* that just contains names and types with no values
* @param inputTypes the input types to convert
* @return the schema from the given map
*/
public static PythonVariables schemaFromMap(java.util.Map<String,String> inputTypes) {
PythonVariables ret = new PythonVariables();
for(java.util.Map.Entry<String,String> entry : inputTypes.entrySet()) {
ret.add(entry.getKey(), PythonVariables.Type.valueOf(entry.getValue()));
}
return ret;
}
/**
* Get the python variable state relative to the
* input json array
* @param jsonArray the input json array
* @return the python variables based on the input json array
*/
public static PythonVariables fromJSON(org.json.JSONArray jsonArray){
PythonVariables pyvars = new PythonVariables(); PythonVariables pyvars = new PythonVariables();
for (int i=0; i<jsonArray.size(); i++){ for (int i = 0; i < jsonArray.length(); i++) {
JSONObject input = (JSONObject) jsonArray.get(i); org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i);
String varName = (String)input.get("name"); String varName = (String)input.get("name");
String varType = (String)input.get("type"); String varType = (String)input.get("type");
if (varType.equals("BOOL")){ if (varType.equals("BOOL")) {
pyvars.addBool(varName); pyvars.addBool(varName);
} }
else if (varType.equals("INT")){ else if (varType.equals("INT")) {
pyvars.addInt(varName); pyvars.addInt(varName);
} }
else if (varType.equals("FlOAT")){ else if (varType.equals("FlOAT")){
pyvars.addFloat(varName); pyvars.addFloat(varName);
} }
else if (varType.equals("STR")){ else if (varType.equals("STR")) {
pyvars.addStr(varName); pyvars.addStr(varName);
} }
else if (varType.equals("LIST")){ else if (varType.equals("LIST")) {
pyvars.addList(varName); pyvars.addList(varName);
} }
else if (varType.equals("FILE")){ else if (varType.equals("FILE")){
pyvars.addFile(varName); pyvars.addFile(varName);
} }
else if (varType.equals("NDARRAY")){ else if (varType.equals("NDARRAY")) {
pyvars.addNDArray(varName); pyvars.addNDArray(varName);
} }
else if(varType.equals("DICT")) {
pyvars.addDict(varName);
}
} }
return pyvars; return pyvars;

View File

@ -0,0 +1,5 @@
#See: https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script
import sys
this = sys.modules[__name__]
for n in dir():
if n[0]!='_': delattr(this, n)

View File

@ -0,0 +1 @@
loc = {}

View File

@ -0,0 +1,20 @@
def __is_numpy_array(x):
return str(type(x))== "<class 'numpy.ndarray'>"
def maybe_serialize_ndarray_metadata(x):
return serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
def serialize_ndarray_metadata(x):
return {"address": x.__array_interface__['data'][0],
"shape": x.shape,
"strides": x.strides,
"dtype": str(x.dtype),
"_is_numpy_array": True} if __is_numpy_array(x) else x
def is_json_ready(key, value):
return key is not 'f2' and not inspect.ismodule(value) \
and not hasattr(value, '__call__')

View File

@ -0,0 +1,202 @@
#patch
"""Implementation of __array_function__ overrides from NEP-18."""
import collections
import functools
import os
from numpy.core._multiarray_umath import (
add_docstring, implement_array_function, _get_implementing_args)
from numpy.compat._inspect import getargspec
ENABLE_ARRAY_FUNCTION = bool(
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
ARRAY_FUNCTION_ENABLED = ENABLE_ARRAY_FUNCTION # backward compat
_add_docstring = add_docstring
def add_docstring(*args):
try:
_add_docstring(*args)
except:
pass
add_docstring(
implement_array_function,
"""
Implement a function with checks for __array_function__ overrides.
All arguments are required, and can only be passed by position.
Arguments
---------
implementation : function
Function that implements the operation on NumPy array without
overrides when called like ``implementation(*args, **kwargs)``.
public_api : function
Function exposed by NumPy's public API originally called like
``public_api(*args, **kwargs)`` on which arguments are now being
checked.
relevant_args : iterable
Iterable of arguments to check for __array_function__ methods.
args : tuple
Arbitrary positional arguments originally passed into ``public_api``.
kwargs : dict
Arbitrary keyword arguments originally passed into ``public_api``.
Returns
-------
Result from calling ``implementation()`` or an ``__array_function__``
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
""")
# exposed for testing purposes; used internally by implement_array_function
add_docstring(
_get_implementing_args,
"""
Collect arguments on which to call __array_function__.
Parameters
----------
relevant_args : iterable of array-like
Iterable of possibly array-like arguments to check for
__array_function__ methods.
Returns
-------
Sequence of arguments with __array_function__ methods, in the order in
which they should be called.
""")
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
def verify_matching_signatures(implementation, dispatcher):
"""Verify that a dispatcher function has the right signature."""
implementation_spec = ArgSpec(*getargspec(implementation))
dispatcher_spec = ArgSpec(*getargspec(dispatcher))
if (implementation_spec.args != dispatcher_spec.args or
implementation_spec.varargs != dispatcher_spec.varargs or
implementation_spec.keywords != dispatcher_spec.keywords or
(bool(implementation_spec.defaults) !=
bool(dispatcher_spec.defaults)) or
(implementation_spec.defaults is not None and
len(implementation_spec.defaults) !=
len(dispatcher_spec.defaults))):
raise RuntimeError('implementation and dispatcher for %s have '
'different function signatures' % implementation)
if implementation_spec.defaults is not None:
if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
raise RuntimeError('dispatcher functions can only use None for '
'default argument values')
def set_module(module):
"""Decorator for overriding __module__ on a function or class.
Example usage::
@set_module('numpy')
def example():
pass
assert example.__module__ == 'numpy'
"""
def decorator(func):
if module is not None:
func.__module__ = module
return func
return decorator
def array_function_dispatch(dispatcher, module=None, verify=True,
docs_from_dispatcher=False):
"""Decorator for adding dispatch with the __array_function__ protocol.
See NEP-18 for example usage.
Parameters
----------
dispatcher : callable
Function that when called like ``dispatcher(*args, **kwargs)`` with
arguments from the NumPy function call returns an iterable of
array-like arguments to check for ``__array_function__``.
module : str, optional
__module__ attribute to set on new function, e.g., ``module='numpy'``.
By default, module is copied from the decorated function.
verify : bool, optional
If True, verify the that the signature of the dispatcher and decorated
function signatures match exactly: all required and optional arguments
should appear in order with the same names, but the default values for
all optional arguments should be ``None``. Only disable verification
if the dispatcher's signature needs to deviate for some particular
reason, e.g., because the function has a signature like
``func(*args, **kwargs)``.
docs_from_dispatcher : bool, optional
If True, copy docs from the dispatcher function onto the dispatched
function, rather than from the implementation. This is useful for
functions defined in C, which otherwise don't have docstrings.
Returns
-------
Function suitable for decorating the implementation of a NumPy function.
"""
if not ENABLE_ARRAY_FUNCTION:
# __array_function__ requires an explicit opt-in for now
def decorator(implementation):
if module is not None:
implementation.__module__ = module
if docs_from_dispatcher:
add_docstring(implementation, dispatcher.__doc__)
return implementation
return decorator
def decorator(implementation):
if verify:
verify_matching_signatures(implementation, dispatcher)
if docs_from_dispatcher:
add_docstring(implementation, dispatcher.__doc__)
@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
return implement_array_function(
implementation, public_api, relevant_args, args, kwargs)
if module is not None:
public_api.__module__ = module
# TODO: remove this when we drop Python 2 support (functools.wraps)
# adds __wrapped__ automatically in later versions)
public_api.__wrapped__ = implementation
return public_api
return decorator
def array_function_from_dispatcher(
implementation, module=None, verify=True, docs_from_dispatcher=True):
"""Like array_function_dispatcher, but with function arguments flipped."""
def decorator(dispatcher):
return array_function_dispatch(
dispatcher, module, verify=verify,
docs_from_dispatcher=docs_from_dispatcher)(implementation)
return decorator

View File

@ -0,0 +1,172 @@
#patch 1
"""
========================
Random Number Generation
========================
==================== =========================================================
Utility functions
==============================================================================
random_sample Uniformly distributed floats over ``[0, 1)``.
random Alias for `random_sample`.
bytes Uniformly distributed random bytes.
random_integers Uniformly distributed integers in a given range.
permutation Randomly permute a sequence / generate a random sequence.
shuffle Randomly permute a sequence in place.
seed Seed the random number generator.
choice Random sample from 1-D array.
==================== =========================================================
==================== =========================================================
Compatibility functions
==============================================================================
rand Uniformly distributed values.
randn Normally distributed values.
ranf Uniformly distributed floating point numbers.
randint Uniformly distributed integers in a given range.
==================== =========================================================
==================== =========================================================
Univariate distributions
==============================================================================
beta Beta distribution over ``[0, 1]``.
binomial Binomial distribution.
chisquare :math:`\\chi^2` distribution.
exponential Exponential distribution.
f F (Fisher-Snedecor) distribution.
gamma Gamma distribution.
geometric Geometric distribution.
gumbel Gumbel distribution.
hypergeometric Hypergeometric distribution.
laplace Laplace distribution.
logistic Logistic distribution.
lognormal Log-normal distribution.
logseries Logarithmic series distribution.
negative_binomial Negative binomial distribution.
noncentral_chisquare Non-central chi-square distribution.
noncentral_f Non-central F distribution.
normal Normal / Gaussian distribution.
pareto Pareto distribution.
poisson Poisson distribution.
power Power distribution.
rayleigh Rayleigh distribution.
triangular Triangular distribution.
uniform Uniform distribution.
vonmises Von Mises circular distribution.
wald Wald (inverse Gaussian) distribution.
weibull Weibull distribution.
zipf Zipf's distribution over ranked data.
==================== =========================================================
==================== =========================================================
Multivariate distributions
==============================================================================
dirichlet Multivariate generalization of Beta distribution.
multinomial Multivariate generalization of the binomial distribution.
multivariate_normal Multivariate generalization of the normal distribution.
==================== =========================================================
==================== =========================================================
Standard distributions
==============================================================================
standard_cauchy Standard Cauchy-Lorentz distribution.
standard_exponential Standard exponential distribution.
standard_gamma Standard Gamma distribution.
standard_normal Standard normal distribution.
standard_t Standard Student's t-distribution.
==================== =========================================================
==================== =========================================================
Internal functions
==============================================================================
get_state Get tuple representing internal state of generator.
set_state Set state of generator.
==================== =========================================================
"""
from __future__ import division, absolute_import, print_function
import warnings
__all__ = [
'beta',
'binomial',
'bytes',
'chisquare',
'choice',
'dirichlet',
'exponential',
'f',
'gamma',
'geometric',
'get_state',
'gumbel',
'hypergeometric',
'laplace',
'logistic',
'lognormal',
'logseries',
'multinomial',
'multivariate_normal',
'negative_binomial',
'noncentral_chisquare',
'noncentral_f',
'normal',
'pareto',
'permutation',
'poisson',
'power',
'rand',
'randint',
'randn',
'random_integers',
'random_sample',
'rayleigh',
'seed',
'set_state',
'shuffle',
'standard_cauchy',
'standard_exponential',
'standard_gamma',
'standard_normal',
'standard_t',
'triangular',
'uniform',
'vonmises',
'wald',
'weibull',
'zipf'
]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
try:
from .mtrand import *
# Some aliases:
ranf = random = sample = random_sample
__all__.extend(['ranf', 'random', 'sample'])
except:
warnings.warn("numpy.random is not available when using multiple interpreters!")
def __RandomState_ctor():
"""Return a RandomState instance.
This function exists solely to assist (un)pickling.
Note that the state of the RandomState returned here is irrelevant, as this function's
entire purpose is to return a newly allocated RandomState whose state pickle can set.
Consequently the RandomState returned by this function is a freshly allocated copy
with a seed=0.
See https://github.com/numpy/numpy/issues/4763 for a detailed discussion
"""
return RandomState(seed=0)
from numpy._pytesttester import PytestTester
test = PytestTester(__name__)
del PytestTester

View File

@ -0,0 +1,20 @@
import sys
import traceback
import json
import inspect
try:
pass
sys.stdout.flush()
sys.stderr.flush()
except Exception as ex:
try:
exc_info = sys.exc_info()
finally:
print(ex)
traceback.print_exception(*exc_info)
sys.stdout.flush()
sys.stderr.flush()

View File

@ -0,0 +1,50 @@
def __is_numpy_array(x):
return str(type(x))== "<class 'numpy.ndarray'>"
def __maybe_serialize_ndarray_metadata(x):
return __serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
def __serialize_ndarray_metadata(x):
return {"address": x.__array_interface__['data'][0],
"shape": x.shape,
"strides": x.strides,
"dtype": str(x.dtype),
"_is_numpy_array": True} if __is_numpy_array(x) else x
def __serialize_list(x):
import json
return json.dumps(__recursive_serialize_list(x))
def __serialize_dict(x):
import json
return json.dumps(__recursive_serialize_dict(x))
def __recursive_serialize_list(x):
out = []
for i in x:
if __is_numpy_array(i):
out.append(__serialize_ndarray_metadata(i))
elif isinstance(i, (list, tuple)):
out.append(__recursive_serialize_list(i))
elif isinstance(i, dict):
out.append(__recursive_serialize_dict(i))
else:
out.append(i)
return out
def __recursive_serialize_dict(x):
out = {}
for k in x:
v = x[k]
if __is_numpy_array(v):
out[k] = __serialize_ndarray_metadata(v)
elif isinstance(v, (list, tuple)):
out[k] = __recursive_serialize_list(v)
elif isinstance(v, dict):
out[k] = __recursive_serialize_dict(v)
else:
out[k] = v
return out

View File

@ -0,0 +1,75 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import org.junit.Assert;
import org.junit.Test;
@javax.annotation.concurrent.NotThreadSafe
public class TestPythonExecutionSandbox {
@Test
public void testInt(){
PythonExecutioner.setInterpreter("interp1");
PythonExecutioner.exec("a = 1");
PythonExecutioner.setInterpreter("interp2");
PythonExecutioner.exec("a = 2");
PythonExecutioner.setInterpreter("interp3");
PythonExecutioner.exec("a = 3");
PythonExecutioner.setInterpreter("interp1");
Assert.assertEquals(1, PythonExecutioner.evalInteger("a"));
PythonExecutioner.setInterpreter("interp2");
Assert.assertEquals(2, PythonExecutioner.evalInteger("a"));
PythonExecutioner.setInterpreter("interp3");
Assert.assertEquals(3, PythonExecutioner.evalInteger("a"));
}
@Test
public void testNDArray(){
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("import numpy as np");
PythonExecutioner.exec("a = np.zeros(5)");
PythonExecutioner.setInterpreter("main");
//PythonExecutioner.exec("import numpy as np");
PythonExecutioner.exec("a = np.zeros(5)");
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("a += 2");
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("a += 3");
PythonExecutioner.setInterpreter("main");
//PythonExecutioner.exec("import numpy as np");
// PythonExecutioner.exec("a = np.zeros(5)");
PythonExecutioner.setInterpreter("main");
Assert.assertEquals(25, PythonExecutioner.evalNdArray("a").getNd4jArray().sum().getDouble(), 1e-5);
}
@Test
public void testNumpyRandom(){
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("import numpy as np; print(np.random.randint(5))");
}
}

View File

@ -15,17 +15,25 @@
******************************************************************************/ ******************************************************************************/
package org.datavec.python; package org.datavec.python;
import org.junit.Ignore; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
@javax.annotation.concurrent.NotThreadSafe
public class TestPythonExecutioner { public class TestPythonExecutioner {
@Test(timeout = 60000L)
@org.junit.Test
public void testPythonSysVersion() {
PythonExecutioner.exec("import sys; print(sys.version)");
}
@Test
public void testStr() throws Exception{ public void testStr() throws Exception{
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
@ -47,7 +55,7 @@ public class TestPythonExecutioner {
assertEquals("Hello World", z); assertEquals("Hello World", z);
} }
@Test(timeout = 60000L) @Test
public void testInt()throws Exception{ public void testInt()throws Exception{
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -64,11 +72,11 @@ public class TestPythonExecutioner {
long z = pyOutputs.getIntValue("z"); long z = pyOutputs.getIntValue("z");
assertEquals(30, z); Assert.assertEquals(30, z);
} }
@Test(timeout = 60000L) @Test
public void testList() throws Exception{ public void testList() throws Exception{
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -88,18 +96,35 @@ public class TestPythonExecutioner {
Object[] z = pyOutputs.getListValue("z"); Object[] z = pyOutputs.getListValue("z");
assertEquals(z.length, x.length + y.length); Assert.assertEquals(z.length, x.length + y.length);
for (int i=0; i < x.length; i++){ for (int i = 0; i < x.length; i++) {
assertEquals(x[i], z[i]); if(x[i] instanceof Number) {
Number xNum = (Number) x[i];
Number zNum = (Number) z[i];
Assert.assertEquals(xNum.intValue(), zNum.intValue());
} }
for (int i=0; i<y.length; i++){ else {
assertEquals(y[i], z[x.length + i]); Assert.assertEquals(x[i], z[i]);
}
}
for (int i = 0; i < y.length; i++){
if(y[i] instanceof Number) {
Number yNum = (Number) y[i];
Number zNum = (Number) z[x.length + i];
Assert.assertEquals(yNum.intValue(), zNum.intValue());
}
else {
Assert.assertEquals(y[i], z[x.length + i]);
}
} }
} }
@Test(timeout = 60000L) @Test
public void testNDArrayFloat()throws Exception{ public void testNDArrayFloat()throws Exception{
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -113,12 +138,17 @@ public class TestPythonExecutioner {
PythonExecutioner.exec(code, pyInputs, pyOutputs); PythonExecutioner.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
assertEquals(6.0, z.sum().getDouble(0), 1e-5); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
} }
@Test(timeout = 60000L) @Test
public void testTensorflowCustomAnaconda() {
PythonExecutioner.exec("import tensorflow as tf");
}
@Test
public void testNDArrayDouble()throws Exception { public void testNDArrayDouble()throws Exception {
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -132,10 +162,10 @@ public class TestPythonExecutioner {
PythonExecutioner.exec(code, pyInputs, pyOutputs); PythonExecutioner.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
assertEquals(6.0, z.sum().getDouble(0), 1e-5); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
} }
@Test(timeout = 60000L) @Test
public void testNDArrayShort()throws Exception{ public void testNDArrayShort()throws Exception{
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -149,11 +179,11 @@ public class TestPythonExecutioner {
PythonExecutioner.exec(code, pyInputs, pyOutputs); PythonExecutioner.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
assertEquals(6.0, z.sum().getDouble(0), 1e-5); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
} }
@Test(timeout = 60000L) @Test
public void testNDArrayInt()throws Exception{ public void testNDArrayInt()throws Exception{
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -167,11 +197,11 @@ public class TestPythonExecutioner {
PythonExecutioner.exec(code, pyInputs, pyOutputs); PythonExecutioner.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
assertEquals(6.0, z.sum().getDouble(0), 1e-5); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
} }
@Test(timeout = 60000L) @Test
public void testNDArrayLong()throws Exception{ public void testNDArrayLong()throws Exception{
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -185,7 +215,7 @@ public class TestPythonExecutioner {
PythonExecutioner.exec(code, pyInputs, pyOutputs); PythonExecutioner.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
assertEquals(6.0, z.sum().getDouble(0), 1e-5); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
} }

View File

@ -0,0 +1,27 @@
package org.datavec.python;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@javax.annotation.concurrent.NotThreadSafe
public class TestPythonSetupAndRun {
@Test
public void testPythonWithSetupAndRun() throws Exception{
String code = "def setup():" +
"global counter;counter=0\n" +
"def run(step):" +
"global counter;" +
"counter+=step;" +
"return {\"counter\":counter}";
PythonVariables pyInputs = new PythonVariables();
pyInputs.addInt("step", 2);
PythonVariables pyOutputs = new PythonVariables();
pyOutputs.addInt("counter");
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
assertEquals((long)pyOutputs.getIntValue("counter"), 2L);
pyInputs.addInt("step", 3);
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
assertEquals((long)pyOutputs.getIntValue("counter"), 5L);
}
}

View File

@ -0,0 +1,102 @@
/*
*
* * ******************************************************************************
* * * Copyright (c) 2015-2019 Skymind Inc.
* * * Copyright (c) 2019 Konduit AI.
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.datavec.python;
import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
import java.util.Collections;
import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.assertNull;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class TestPythonVariables {
@Test
public void testImportNumpy(){
Nd4j.scalar(1.0);
System.out.println(System.getProperty("org.bytedeco.openblas.load"));
PythonExecutioner.exec("import numpy as np");
}
@Test
public void testDataAssociations() {
PythonVariables pythonVariables = new PythonVariables();
PythonVariables.Type[] types = {
PythonVariables.Type.INT,
PythonVariables.Type.FLOAT,
PythonVariables.Type.STR,
PythonVariables.Type.BOOL,
PythonVariables.Type.DICT,
PythonVariables.Type.LIST,
PythonVariables.Type.LIST,
PythonVariables.Type.FILE,
PythonVariables.Type.NDARRAY
};
NumpyArray npArr = new NumpyArray(Nd4j.scalar(1.0));
Object[] values = {
1L,1.0,"1",true, Collections.singletonMap("1",1),
new Object[]{1}, Arrays.asList(1),"type", npArr
};
Object[] expectedValues = {
1L,1.0,"1",true, Collections.singletonMap("1",1),
new Object[]{1}, new Object[]{1},"type", npArr
};
for(int i = 0; i < types.length; i++) {
testInsertGet(pythonVariables,types[i].name() + i,values[i],types[i],expectedValues[i]);
}
assertEquals(types.length,pythonVariables.getVariables().length);
}
private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonVariables.Type type,Object expectedValue) {
pythonVariables.add(key, type);
assertNull(pythonVariables.getValue(key));
pythonVariables.setValue(key,value);
assertNotNull(pythonVariables.getValue(key));
Object actualValue = pythonVariables.getValue(key);
if (expectedValue instanceof Object[]){
assertTrue(actualValue instanceof Object[]);
Object[] actualArr = (Object[])actualValue;
Object[] expectedArr = (Object[])expectedValue;
assertArrayEquals(expectedArr, actualArr);
}
else{
assertEquals(expectedValue,pythonVariables.getValue(key));
}
}
}

View File

@ -29,7 +29,7 @@ public class TestSerde {
public static JsonSerializer j = new JsonSerializer(); public static JsonSerializer j = new JsonSerializer();
@Test(timeout = 60000L) @Test(timeout = 60000L)
public void testBasicSerde() throws Exception{ public void testBasicSerde(){
Schema schema = new Schema.Builder() Schema schema = new Schema.Builder()
.addColumnInteger("col1") .addColumnInteger("col1")
.addColumnFloat("col2") .addColumnFloat("col2")
@ -37,10 +37,9 @@ public class TestSerde {
.addColumnDouble("col4") .addColumnDouble("col4")
.build(); .build();
Transform t = new PythonTransform( Transform t = PythonTransform.builder().code(
"col1+=3\ncol2+=2\ncol3+='a'\ncol4+=2.0", "col1+=3\ncol2+=2\ncol3+='a'\ncol4+=2.0"
schema ).inputSchema(schema).outputSchema(schema).build();
);
String yaml = y.serialize(t); String yaml = y.serialize(t);
String json = j.serialize(t); String json = j.serialize(t);

View File

@ -58,7 +58,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -52,7 +52,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -171,7 +171,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -38,7 +38,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -138,7 +138,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -247,10 +247,9 @@ public class ExecutionTest extends BaseSparkTest {
.addColumnInteger("col1").addColumnDouble("col2").build(); .addColumnInteger("col1").addColumnDouble("col2").build();
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0"; String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
TransformProcess tp = new TransformProcess.Builder(schema).transform( TransformProcess tp = new TransformProcess.Builder(schema).transform(
new PythonTransform( PythonTransform.builder().code(
pythonCode, "first = np.sin(first)\nsecond = np.cos(second)")
finalSchema .outputSchema(finalSchema).build()
)
).build(); ).build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
@ -288,10 +287,9 @@ public class ExecutionTest extends BaseSparkTest {
String pythonCode = "col3 = col1 + col2"; String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(schema).transform( TransformProcess tp = new TransformProcess.Builder(schema).transform(
new PythonTransform( PythonTransform.builder().code(
pythonCode, "first = np.sin(first)\nsecond = np.cos(second)")
finalSchema .outputSchema(schema).build()
)
).build(); ).build();
INDArray zeros = Nd4j.zeros(shape); INDArray zeros = Nd4j.zeros(shape);

View File

@ -112,7 +112,7 @@
<skip>${skipTestResourceEnforcement}</skip> <skip>${skipTestResourceEnforcement}</skip>
<rules> <rules>
<requireActiveProfile> <requireActiveProfile>
<profiles>test-nd4j-native,test-nd4j-cuda-10.1</profiles> <profiles>test-nd4j-native,test-nd4j-cuda-10.2</profiles>
<all>false</all> <all>false</all>
</requireActiveProfile> </requireActiveProfile>
</rules> </rules>
@ -365,11 +365,11 @@
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId> <artifactId>nd4j-cuda-10.2</artifactId>
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@ -40,7 +40,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -163,11 +163,11 @@
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId> <artifactId>nd4j-cuda-10.2</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@ -27,6 +27,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.Convolution1DUtils;
import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -442,4 +444,76 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
} }
} }
} }
@Test
public void testCnn1Causal() {
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 3;
int[] lengths = {11, 12, 13, 9, 10, 11};
int[] kernels = {2, 3, 2, 4, 2, 3};
int[] dilations = {1, 1, 2, 1, 2, 1};
int[] strides = {1, 2, 1, 2, 1, 1};
boolean[] masks = {false, true, false, true, false, true};
boolean[] hasB = {true, false, true, false, true, true};
for (int i = 0; i < lengths.length; i++) {
int length = lengths[i];
int k = kernels[i];
int d = dilations[i];
int st = strides[i];
boolean mask = masks[i];
boolean hasBias = hasB[i];
//TODO has bias
String s = "k=" + k + ", s=" + st + "d=" + d + ", seqLen=" + length;
log.info("Starting test: " + s);
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.activation(Activation.TANH)
.weightInit(new NormalDistribution(0, 1))
.seed(12345)
.list()
.layer(new Convolution1DLayer.Builder().kernelSize(k)
.dilation(d)
.hasBias(hasBias)
.convolutionMode(ConvolutionMode.Causal)
.stride(st).nIn(convNIn).nOut(convNOut1)
.build())
.layer(new Convolution1DLayer.Builder().kernelSize(k)
.dilation(d)
.convolutionMode(ConvolutionMode.Causal)
.stride(st).nIn(convNOut1).nOut(convNOut2)
.build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length);
INDArray fm = null;
if (mask) {
fm = Nd4j.create(2, length);
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length-2)).assign(1);
}
long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d);
long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d);
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null);
assertTrue(s, gradOK);
TestUtils.testModelSerialization(net);
}
}
} }

View File

@ -21,12 +21,14 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Ignore; import org.junit.Ignore;
@ -289,4 +291,66 @@ public class RnnGradientChecks extends BaseDL4JTest {
} }
} }
} }
@Test
public void testTimeDistributedDense() {
int nIn = 3;
int nOut = 5;
int tsLength = 4;
int layerSize = 8;
Random r = new Random(12345);
for (int mb : new int[]{1, 3}) {
for (boolean inputMask : new boolean[]{false, true}) {
INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength});
INDArray labels = TestUtils.randomOneHotTimeSeries(mb, nOut, tsLength);
String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null;
if (inputMask) {
inMask = Nd4j.ones(mb, tsLength);
for (int i = 0; i < mb; i++) {
int firstMaskedStep = tsLength - 1 - i;
if (firstMaskedStep == 0) {
firstMaskedStep = tsLength;
}
for (int j = firstMaskedStep; j < tsLength; j++) {
inMask.putScalar(i, j, 0.0);
}
}
}
String name = "testLastTimeStepLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType;
if (PRINT_RESULTS) {
System.out.println("Starting test: " + name);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build(), 2))
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 16);
assertTrue(name, gradOK);
TestUtils.testModelSerialization(net);
}
}
}
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -28,6 +29,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer; import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer; import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -485,4 +487,32 @@ public class TestPreProcessors extends BaseDL4JTest {
assertEquals(15 * 15 * 10, ((FeedForwardLayer) conf.getConf(1).getLayer()).getNIn()); assertEquals(15 * 15 * 10, ((FeedForwardLayer) conf.getConf(1).getLayer()).getNIn());
} }
@Test
public void testPreprocessorVertex(){
for(boolean withMinibatchDim : new boolean[]{true, false}){
long[] inShape = withMinibatchDim ? new long[]{-1, 32} : new long[]{32};
long[] targetShape = withMinibatchDim ? new long[]{-1, 2, 4, 4} : new long[]{2, 4, 4};
for( long minibatch : new long[]{1, 3}) {
long[] inArrayShape = new long[]{minibatch, 32};
long[] targetArrayShape = new long[]{minibatch, 2, 4, 4};
long length = minibatch * 32;
INDArray in = Nd4j.linspace(1, length, length).reshape('c', inArrayShape);
ReshapePreprocessor pp = new ReshapePreprocessor(inShape, targetShape, withMinibatchDim);
for( int i=0; i<3; i++ ) {
INDArray out = pp.preProcess(in, (int) minibatch, LayerWorkspaceMgr.noWorkspaces());
INDArray expOut = in.reshape(targetArrayShape);
assertEquals(expOut, out);
INDArray backprop = pp.backprop(expOut, (int)minibatch, LayerWorkspaceMgr.noWorkspaces());
assertEquals(in, backprop);
}
}
}
}
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.dtypes; package org.deeplearning4j.nn.dtypes;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.nd4j.shade.guava.collect.ImmutableSet; import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath; import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -811,7 +812,8 @@ public class DTypeTests extends BaseDL4JTest {
.layer(new DenseLayer.Builder().nOut(5).build()) .layer(new DenseLayer.Builder().nOut(5).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()) .layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
.layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())) .layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()))
.layer(new SimpleRnn.Builder().nIn(10).nOut(5).build()) .layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build(), 2))
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build())
.layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build()) .layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
.layer(secondLast) .layer(secondLast)
.layer(ol) .layer(ol)

View File

@ -712,4 +712,73 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels")); assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"));
} }
} }
@Test
public void testConv1dCausalAllowed(){
new Convolution1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
new Subsampling1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
}
@Test
public void testConv2dNoCausalAllowed(){
try{
new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
try{
new Deconvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
try{
new DepthwiseConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
try{
new SeparableConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
try{
new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
}
@Test
public void testConv3dNoCausalAllowed(){
try{
new Convolution3D.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
try{
new Subsampling3DLayer.Builder().convolutionMode(ConvolutionMode.Causal).build();
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage().toLowerCase();
assertTrue(m, m.contains("causal") && m.contains("1d"));
}
}
} }

View File

@ -0,0 +1,88 @@
package org.deeplearning4j.nn.layers.recurrent;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.Assert.assertEquals;
public class TestTimeDistributed extends BaseDL4JTest {
@Test
public void testTimeDistributed(){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(wsm)
.seed(12345)
.updater(new Adam(0.1))
.list()
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
.layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build())
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(3))
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(wsm)
.seed(12345)
.updater(new Adam(0.1))
.list()
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), 2))
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(3))
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net1.init();
net2.init();
for( int mb : new int[]{1, 5}) {
for(char inLabelOrder : new char[]{'c', 'f'}) {
INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3, 5).dup(inLabelOrder);
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
assertEquals(out1, out2);
INDArray labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
DataSet ds = new DataSet(in, labels);
net1.fit(ds);
net2.fit(ds);
assertEquals(net1.params(), net2.params());
MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2);
out2 = net2.output(in);
INDArray out3 = net3.output(in);
assertEquals(out2, out3);
}
}
}
}
}

View File

@ -16,7 +16,7 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<artifactId>deeplearning4j-cuda-10.1</artifactId> <artifactId>deeplearning4j-cuda-10.2</artifactId>
<name>deeplearning4j-cuda</name> <name>deeplearning4j-cuda</name>
<parent> <parent>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
@ -26,7 +26,7 @@
<properties> <properties>
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml --> <!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
<cuda.version>10.1</cuda.version> <cuda.version>10.2</cuda.version>
<cudnn.version>7.6</cudnn.version> <cudnn.version>7.6</cudnn.version>
<javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version> <javacpp-presets.cuda.version>1.5.2</javacpp-presets.cuda.version>
</properties> </properties>
@ -106,7 +106,7 @@
</build> </build>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -51,7 +51,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -46,7 +46,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -48,7 +48,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -38,7 +38,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -116,11 +116,11 @@
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId> <artifactId>nd4j-cuda-10.2</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@ -58,7 +58,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -62,7 +62,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -41,7 +41,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -302,11 +302,11 @@
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId> <artifactId>nd4j-cuda-10.2</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@ -115,11 +115,11 @@
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId> <artifactId>nd4j-cuda-10.2</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@ -356,6 +356,10 @@ public class KerasLayer {
return this.layer; return this.layer;
} }
public void setLayer(Layer layer){
this.layer = layer;
}
/** /**
* Whether this Keras layer maps to a DL4J Vertex. * Whether this Keras layer maps to a DL4J Vertex.
* *

View File

@ -233,6 +233,7 @@ public class KerasLayerConfiguration {
private final String LAYER_BORDER_MODE_SAME = "same"; private final String LAYER_BORDER_MODE_SAME = "same";
private final String LAYER_BORDER_MODE_VALID = "valid"; private final String LAYER_BORDER_MODE_VALID = "valid";
private final String LAYER_BORDER_MODE_FULL = "full"; private final String LAYER_BORDER_MODE_FULL = "full";
private final String LAYER_BORDER_MODE_CAUSAL = "causal";
/* Noise layers */ /* Noise layers */
private final String LAYER_FIELD_RATE = "rate"; private final String LAYER_FIELD_RATE = "rate";

View File

@ -124,7 +124,26 @@ public class KerasInput extends KerasLayer {
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]); myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
break; break;
case 2: case 2:
if(this.dimOrder != null) {
switch (this.dimOrder) {
case TENSORFLOW: //NWC == channels_last
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]); myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
break;
case THEANO: //NCW == channels_first
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
break;
case NONE:
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
break;
default:
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
}
} else {
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
}
break; break;
case 3: case 3:
switch (this.dimOrder) { switch (this.dimOrder) {

View File

@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList; import java.util.ArrayList;
@ -96,13 +97,13 @@ public class KerasLoss extends KerasLayer {
*/ */
public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException { public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException {
if (type instanceof InputType.InputTypeFeedForward) { if (type instanceof InputType.InputTypeFeedForward) {
this.layer = new LossLayer.Builder(loss).name(this.layerName).build(); this.layer = new LossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
} }
else if (type instanceof InputType.InputTypeRecurrent) { else if (type instanceof InputType.InputTypeRecurrent) {
this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).build(); this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
} }
else if (type instanceof InputType.InputTypeConvolutional) { else if (type instanceof InputType.InputTypeConvolutional) {
this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).build(); this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build();
} else { } else {
throw new UnsupportedKerasConfigurationException("Unsupported output layer type" throw new UnsupportedKerasConfigurationException("Unsupported output layer type"
+ "got : " + type.toString()); + "got : " + type.toString());

View File

@ -79,7 +79,6 @@ abstract public class KerasConvolution extends KerasLayer {
public KerasConvolution(Map<String, Object> layerConfig, boolean enforceTrainingConfig) public KerasConvolution(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig); super(layerConfig, enforceTrainingConfig);
} }
/** /**

View File

@ -185,18 +185,11 @@ public class KerasConvolution1D extends KerasConvolution {
break; break;
case THEANO: case THEANO:
paramValue = kerasParamValue.permute(2, 1, 0); //Convert from keras [k,nIn,nOut] to DL4J conv2d [nOut, nIn, k, 1]
paramValue = paramValue.reshape( long k = kerasParamValue.size(0);
paramValue.size(0), paramValue.size(1), long nIn = kerasParamValue.size(1);
paramValue.size(2), 1).dup(); long nOut = kerasParamValue.size(2);
for (int i = 0; i < paramValue.tensorsAlongDimension(2, 3); i++) { paramValue = kerasParamValue.permute(2, 1, 0).dup('c').reshape(nOut, nIn, k, 1);
INDArray copyFilter = paramValue.tensorAlongDimension(i, 2, 3).dup();
double[] flattenedFilter = copyFilter.ravel().data().asDouble();
ArrayUtils.reverse(flattenedFilter);
INDArray newFilter = Nd4j.create(flattenedFilter, copyFilter.shape());
INDArray inPlaceFilter = paramValue.tensorAlongDimension(i, 2, 3);
inPlaceFilter.muli(0).addi(newFilter.castTo(inPlaceFilter.dataType()));
}
break; break;
default: default:
throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder()); throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder());

View File

@ -264,7 +264,8 @@ public class KerasConvolutionUtils {
} else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) || } else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) ||
borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) { borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) {
convolutionMode = ConvolutionMode.Truncate; convolutionMode = ConvolutionMode.Truncate;
} else if(borderMode.equals(conf.getLAYER_BORDER_MODE_CAUSAL())) {
convolutionMode = ConvolutionMode.Causal;
} else { } else {
throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode); throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode);
} }

View File

@ -111,7 +111,7 @@ public class KerasFlatten extends KerasLayer {
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten). // to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()}; val inputShape = new long[]{it.getSize()};
preprocessor = new ReshapePreprocessor(inputShape, inputShape); preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
} }
return preprocessor; return preprocessor;
} }

View File

@ -111,11 +111,11 @@ public class KerasReshape extends KerasLayer {
} else { } else {
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]}; targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
} }
preprocessor = new ReshapePreprocessor(inputShape, targetShape); preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2) } else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
if (inputShape[0] != targetShape[0]) if (inputShape[0] != targetShape[0])
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]}; targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
preprocessor = new ReshapePreprocessor(inputShape, targetShape); preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
} }
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) { } else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
@ -128,23 +128,23 @@ public class KerasReshape extends KerasLayer {
} else { } else {
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] }; targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
} }
preprocessor = new ReshapePreprocessor(inputShape, targetShape); preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
} else { } else {
if (inputShape[0] != targetShape[0]) if (inputShape[0] != targetShape[0])
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] }; targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
preprocessor = new ReshapePreprocessor(inputShape, targetShape); preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
} }
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) { } else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0]; InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()}; val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape); preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) { } else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()}; val inputShape = new long[]{it.getSize()};
if (targetShape.length == 3) { if (targetShape.length == 3) {
targetShape = targetShapeForDimOrder(inputShape, targetShape); targetShape = targetShapeForDimOrder(inputShape, targetShape);
} }
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape); preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
} }
return preprocessor; return preprocessor;
} }

View File

@ -23,11 +23,13 @@ import lombok.val;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
@ -186,6 +188,9 @@ public class KerasLSTM extends KerasLayer {
.biasInit(0.0) // TODO: this is incorrect .biasInit(0.0) // TODO: this is incorrect
.l1(this.weightL1Regularization) .l1(this.weightL1Regularization)
.l2(this.weightL2Regularization); .l2(this.weightL2Regularization);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);
if (biasConstraint != null) if (biasConstraint != null)
builder.constrainBias(biasConstraint); builder.constrainBias(biasConstraint);
if (weightConstraint != null) if (weightConstraint != null)
@ -436,6 +441,20 @@ public class KerasLSTM extends KerasLayer {
log.warn("Attemping to set weights for unknown parameters: " log.warn("Attemping to set weights for unknown parameters: "
+ unknownParamNames.substring(1, unknownParamNames.length() - 1)); + unknownParamNames.substring(1, unknownParamNames.length() - 1));
} }
FeedForwardLayer ffl;
if(this.layer instanceof BaseWrapperLayer){
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
ffl = (FeedForwardLayer)bwl.getUnderlying();
} else {
ffl = (FeedForwardLayer) this.layer;
}
if(ffl.getNIn() != wRows){
//Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config)
//We can reliably infer nIn from the shape of the weights array however
ffl.setNIn(wRows);
}
} }
/** /**

View File

@ -22,11 +22,13 @@ import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
@ -154,6 +156,9 @@ public class KerasSimpleRnn extends KerasLayer {
.biasInit(0.0) .biasInit(0.0)
.l1(this.weightL1Regularization) .l1(this.weightL1Regularization)
.l2(this.weightL2Regularization); .l2(this.weightL2Regularization);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);
if (biasConstraint != null) if (biasConstraint != null)
builder.constrainBias(biasConstraint); builder.constrainBias(biasConstraint);
if (weightConstraint != null) if (weightConstraint != null)
@ -282,6 +287,19 @@ public class KerasSimpleRnn extends KerasLayer {
log.warn("Attemping to set weights for unknown parameters: " log.warn("Attemping to set weights for unknown parameters: "
+ unknownParamNames.substring(1, unknownParamNames.length() - 1)); + unknownParamNames.substring(1, unknownParamNames.length() - 1));
} }
FeedForwardLayer ffl;
if(this.layer instanceof BaseWrapperLayer){
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
ffl = (FeedForwardLayer)bwl.getUnderlying();
} else {
ffl = (FeedForwardLayer) this.layer;
}
if(ffl.getNIn() != W.rows()){
//Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config)
//We can reliably infer nIn from the shape of the weights array however
ffl.setNIn(W.rows());
}
} }
} }

View File

@ -229,8 +229,8 @@ public class KerasBidirectional extends KerasLayer {
@Override @Override
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException { public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
Map<String, INDArray> forwardWeights = getUnderlyingWeights(weights, "forward"); Map<String, INDArray> forwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getFwd(), weights, "forward");
Map<String, INDArray> backwardWeights = getUnderlyingWeights(weights, "backward"); Map<String, INDArray> backwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getBwd(), weights, "backward");
this.weights = new HashMap<>(); this.weights = new HashMap<>();
@ -241,7 +241,7 @@ public class KerasBidirectional extends KerasLayer {
} }
private Map<String, INDArray> getUnderlyingWeights(Map<String, INDArray> weights, String direction) private Map<String, INDArray> getUnderlyingWeights(Layer l, Map<String, INDArray> weights, String direction)
throws InvalidKerasConfigurationException { throws InvalidKerasConfigurationException {
int keras1SubstringLength; int keras1SubstringLength;
if (kerasRnnlayer instanceof KerasLSTM) if (kerasRnnlayer instanceof KerasLSTM)
@ -270,8 +270,12 @@ public class KerasBidirectional extends KerasLayer {
weights = newWeights; weights = newWeights;
} }
Layer layerBefore = kerasRnnlayer.getLayer();
kerasRnnlayer.setLayer(l);
kerasRnnlayer.setWeights(weights); kerasRnnlayer.setWeights(weights);
return kerasRnnlayer.getWeights(); Map<String,INDArray> ret = kerasRnnlayer.getWeights();
kerasRnnlayer.setLayer(layerBefore);
return ret;
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -20,7 +21,6 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
@ -36,42 +36,59 @@ import java.util.Arrays;
import static org.nd4j.linalg.util.ArrayUtil.prodLong; import static org.nd4j.linalg.util.ArrayUtil.prodLong;
/** /**
* Generic reshape preprocessor * Generic reshape preprocessor.
* Note that shapes may be specified with or without the leading minibatch dimension, as long as hasMiniBatchDimension
* is set appropriately in {@link #ReshapePreprocessor(long[], long[], boolean)}<br>
* For example, to reshape from [minibatch, 32] to [minibatch, 2, 4, 4] you could use:<br>
* hasMiniBatchDimension = true with inputShape = [-1, 32] and targetShape = [-1, 2, 4, 4] OR<br>
* hasMiniBatchDimension = false with inputShape = [32] and targetShape = [2, 4, 4]
* *
* @author Max Pumperla * @author Max Pumperla
*/ */
@Data @Data
@Slf4j @Slf4j
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"}) @JsonIgnoreProperties({"miniBatchSize", "staticTargetShape"})
public class ReshapePreprocessor extends BaseInputPreProcessor { public class ReshapePreprocessor extends BaseInputPreProcessor {
private long[] inputShape; private final long[] inputShape;
private long[] targetShape; private final long[] targetShape;
private boolean hasMiniBatchDimension = false; private boolean hasMiniBatchDimension;
private int miniBatchSize;
private long[] staticTargetShape;
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) { /**
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
*/
@Deprecated
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
this(inputShape, targetShape, false);
}
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
*/
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
this.inputShape = inputShape; this.inputShape = inputShape;
this.targetShape = targetShape; this.targetShape = targetShape;
this.hasMiniBatchDimension = hasMiniBatchDimension;
} }
private static int prod(int[] array) { private long[] getShape(long[] originalShape, long minibatch) {
int prod = 1; long[] newShape = (hasMiniBatchDimension ? originalShape : prependMiniBatchSize(originalShape, minibatch));
for (int i : array) { if (newShape[0] != minibatch) {
prod *= i; newShape = newShape.clone();
newShape[0] = minibatch;
} }
return prod; return newShape;
} }
private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) { private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) {
int shapeLength = shape.length; int shapeLength = shape.length;
val miniBatchShape = new long[shapeLength + 1]; val miniBatchShape = new long[shapeLength + 1];
for (int i = 0; i < miniBatchShape.length; i++) { miniBatchShape[0] = miniBatchSize;
if (i == 0) for (int i = 1; i < miniBatchShape.length; i++) {
miniBatchShape[i] = miniBatchSize;
else
miniBatchShape[i] = shape[i - 1]; miniBatchShape[i] = shape[i - 1];
} }
return miniBatchShape; return miniBatchShape;
@ -79,30 +96,12 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
@Override @Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
// the target shape read from a keras config does not have mini-batch size // the target shape read from a keras config does not have mini-batch size included. We prepend it here dynamically.
// included. We prepend it here dynamically. long[] targetShape = getShape(this.targetShape, miniBatchSize);
long[] inputShape = getShape(this.inputShape, miniBatchSize);
long[] targetShape;
if (staticTargetShape != null){
targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize);
hasMiniBatchDimension = true;
this.miniBatchSize = miniBatchSize;
}
else{
targetShape = this.targetShape;
}
if (!this.hasMiniBatchDimension) {
targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
this.miniBatchSize = miniBatchSize;
}
if (this.miniBatchSize != miniBatchSize) {
targetShape = prependMiniBatchSize(ArrayUtils.subarray(targetShape, 1, targetShape.length), miniBatchSize);
inputShape = prependMiniBatchSize(ArrayUtils.subarray(inputShape, 1, targetShape.length), miniBatchSize);
this.miniBatchSize = miniBatchSize;
}
if (prodLong(input.shape()) == prodLong((targetShape))) { if (prodLong(input.shape()) == prodLong((targetShape))) {
if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){ if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) {
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
} }
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape)); return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
@ -114,15 +113,18 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
@Override @Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
long[] targetShape = getShape(this.targetShape, miniBatchSize);
long[] inputShape = getShape(this.inputShape, miniBatchSize);
if (!Arrays.equals(targetShape, output.shape())) { if (!Arrays.equals(targetShape, output.shape())) {
throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape()) throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape())
+ " (expected to be " + Arrays.toString(targetShape) + ")"); + " (expected to be " + Arrays.toString(targetShape) + ")");
} }
if (prodLong(output.shape()) == prodLong((targetShape))) { if (prodLong(output.shape()) == prodLong((targetShape))) {
if(output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)){ if (output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)) {
output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c'); output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c');
} }
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(this.inputShape)); return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(inputShape));
} else { } else {
throw new IllegalStateException("Output shape" + Arrays.toString(output.shape()) throw new IllegalStateException("Output shape" + Arrays.toString(output.shape())
+ " and input shape" + Arrays.toString(targetShape) + " do not match"); + " and input shape" + Arrays.toString(targetShape) + " do not match");
@ -131,7 +133,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
@Override @Override
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException { public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0); long[] shape = getShape(this.targetShape, 0);
InputType ret; InputType ret;
switch (shape.length) { switch (shape.length) {
case 2: case 2:
@ -141,18 +143,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
ret = InputType.recurrent(shape[2], shape[1]); ret = InputType.recurrent(shape[2], shape[1]);
break; break;
case 4: case 4:
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){ if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
ret = InputType.convolutional(shape[1], shape[2], shape[3]); ret = InputType.convolutional(shape[1], shape[2], shape[3]);
}else { } else {
ret = InputType.convolutional(shape[2], shape[3], shape[1]); ret = InputType.convolutional(shape[2], shape[3], shape[1]);
} }
break; break;
default: default:
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Cannot infer input type for reshape array " + Arrays.toString(shape)); "Cannot infer input type for reshape array " + Arrays.toString(shape));
} }
this.staticTargetShape = ret.getShape();
return ret; return ret;
} }
} }

View File

@ -505,6 +505,17 @@ public class KerasLayerUtils {
return nOut; return nOut;
} }
public static Integer getNInFromInputDim(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
if(innerConfig.containsKey(conf.getLAYER_FIELD_INPUT_DIM())){
Object id = innerConfig.get(conf.getLAYER_FIELD_INPUT_DIM());
if(id instanceof Number){
return ((Number)id).intValue();
}
}
return null;
}
/** /**
* Get dropout from Keras layer configuration. * Get dropout from Keras layer configuration.
* *

View File

@ -257,12 +257,15 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
@Test @Test
public void ReshapeEmbeddingConcatTest() throws Exception{ public void ReshapeEmbeddingConcatTest() throws Exception{
//TODO AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) { try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
ComputationGraphConfiguration config = ComputationGraphConfiguration config =
new KerasModel().modelBuilder().modelJsonInputStream(is) new KerasModel().modelBuilder().modelJsonInputStream(is)
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration(); .enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
ComputationGraph model = new ComputationGraph(config); ComputationGraph model = new ComputationGraph(config);
model.init(); model.init();
// System.out.println(model.summary());
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1)); model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
} }
} }

View File

@ -24,6 +24,8 @@ import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.gradientcheck.GradientCheckUtil; import org.deeplearning4j.gradientcheck.GradientCheckUtil;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
@ -47,6 +49,8 @@ import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.function.BiFunction;
import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT; import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
@ -58,10 +62,7 @@ import java.io.InputStream;
import java.net.URL; import java.net.URL;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.StandardCopyOption; import java.nio.file.StandardCopyOption;
import java.util.HashMap; import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.Random;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -86,6 +87,15 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Rule @Rule
public final TemporaryFolder testDir = new TemporaryFolder(); public final TemporaryFolder testDir = new TemporaryFolder();
public static final BiFunction<String,INDArray,INDArray> nwc2ncwExpected = new BiFunction<String, INDArray, INDArray>() {
@Override
public INDArray apply(String s, INDArray array) {
if(array.rank() == 3)
return array.permute(0, 2, 1); //NWC to NCW
return array;
}
};
@Test(expected = IllegalStateException.class) @Test(expected = IllegalStateException.class)
public void fileNotFoundEndToEnd() throws Exception { public void fileNotFoundEndToEnd() throws Exception {
String modelPath = "modelimport/keras/examples/foo/bar.h5"; String modelPath = "modelimport/keras/examples/foo/bar.h5";
@ -154,28 +164,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importImdbLstmTfKeras1() throws Exception { public void importImdbLstmTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
} }
@Test @Test
public void importImdbLstmThKeras1() throws Exception { public void importImdbLstmThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
} }
@Test @Test
public void importImdbLstmTfKeras2() throws Exception { public void importImdbLstmTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
} }
@Test @Test
public void importImdbLstmThKeras2() throws Exception { public void importImdbLstmThKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false); importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected);
} }
/** /**
@ -247,7 +257,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
} }
/** /**
@ -598,6 +608,122 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
model.summary(); model.summary();
} }
@Test
public void testCausalCon1D() throws Exception {
String[] names = new String[]{
"causal_conv1d_k2_s1_d1_cl_model.h5",
"causal_conv1d_k2_s1_d2_cl_model.h5",
"causal_conv1d_k2_s2_d1_cl_model.h5",
"causal_conv1d_k2_s3_d1_cl_model.h5",
"causal_conv1d_k3_s1_d1_cl_model.h5",
"causal_conv1d_k3_s1_d2_cl_model.h5",
"causal_conv1d_k3_s2_d1_cl_model.h5",
"causal_conv1d_k3_s3_d1_cl_model.h5",
"causal_conv1d_k4_s1_d1_cl_model.h5",
"causal_conv1d_k4_s1_d2_cl_model.h5",
"causal_conv1d_k4_s2_d1_cl_model.h5",
"causal_conv1d_k4_s3_d1_cl_model.h5"
};
for(String name : names ){
System.out.println("Starting test: " + name);
String modelPath = "modelimport/keras/examples/causal_conv1d/" + name;
String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
Function<INDArray,INDArray> f = new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray i) {
//NWC to NCW
return i.permute(0, 2, 1);
}
};
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, f, nwc2ncwExpected);
Layer l = net.getLayer(0);
Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig();
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
}
}
@Test
public void testCon1D() throws Exception {
String[] names = new String[]{
"conv1d_k2_s1_d1_cf_same_model.h5",
"conv1d_k2_s1_d1_cf_valid_model.h5",
"conv1d_k2_s1_d1_cl_same_model.h5",
"conv1d_k2_s1_d1_cl_valid_model.h5",
"conv1d_k2_s1_d2_cf_same_model.h5",
"conv1d_k2_s1_d2_cf_valid_model.h5",
"conv1d_k2_s1_d2_cl_same_model.h5",
"conv1d_k2_s1_d2_cl_valid_model.h5",
"conv1d_k2_s2_d1_cf_same_model.h5",
"conv1d_k2_s2_d1_cf_valid_model.h5",
"conv1d_k2_s2_d1_cl_same_model.h5",
"conv1d_k2_s2_d1_cl_valid_model.h5",
"conv1d_k2_s3_d1_cf_same_model.h5",
"conv1d_k2_s3_d1_cf_valid_model.h5",
"conv1d_k2_s3_d1_cl_same_model.h5",
"conv1d_k2_s3_d1_cl_valid_model.h5",
"conv1d_k3_s1_d1_cf_same_model.h5",
"conv1d_k3_s1_d1_cf_valid_model.h5",
"conv1d_k3_s1_d1_cl_same_model.h5",
"conv1d_k3_s1_d1_cl_valid_model.h5",
"conv1d_k3_s1_d2_cf_same_model.h5",
"conv1d_k3_s1_d2_cf_valid_model.h5",
"conv1d_k3_s1_d2_cl_same_model.h5",
"conv1d_k3_s1_d2_cl_valid_model.h5",
"conv1d_k3_s2_d1_cf_same_model.h5",
"conv1d_k3_s2_d1_cf_valid_model.h5",
"conv1d_k3_s2_d1_cl_same_model.h5",
"conv1d_k3_s2_d1_cl_valid_model.h5",
"conv1d_k3_s3_d1_cf_same_model.h5",
"conv1d_k3_s3_d1_cf_valid_model.h5",
"conv1d_k3_s3_d1_cl_same_model.h5",
"conv1d_k3_s3_d1_cl_valid_model.h5",
"conv1d_k4_s1_d1_cf_same_model.h5",
"conv1d_k4_s1_d1_cf_valid_model.h5",
"conv1d_k4_s1_d1_cl_same_model.h5",
"conv1d_k4_s1_d1_cl_valid_model.h5",
"conv1d_k4_s1_d2_cf_same_model.h5",
"conv1d_k4_s1_d2_cf_valid_model.h5",
"conv1d_k4_s1_d2_cl_same_model.h5",
"conv1d_k4_s1_d2_cl_valid_model.h5",
"conv1d_k4_s2_d1_cf_same_model.h5",
"conv1d_k4_s2_d1_cf_valid_model.h5",
"conv1d_k4_s2_d1_cl_same_model.h5",
"conv1d_k4_s2_d1_cl_valid_model.h5",
"conv1d_k4_s3_d1_cf_same_model.h5",
"conv1d_k4_s3_d1_cf_valid_model.h5",
"conv1d_k4_s3_d1_cl_same_model.h5",
"conv1d_k4_s3_d1_cl_valid_model.h5",
};
for(String name : names ){
System.out.println("Starting test: " + name);
String modelPath = "modelimport/keras/examples/conv1d/" + name;
String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
Function<INDArray,INDArray> f = name.contains("_cf_") ? null : new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray i) {
//NWC to NCW
return i.permute(0, 2, 1);
}
};
BiFunction<String,INDArray,INDArray> f2 = name.contains("_cf_") ? null : new BiFunction<String, INDArray, INDArray>() {
@Override
public INDArray apply(String s, INDArray array) {
// if("conv".equals(s)){
return array.permute(0, 2, 1);
// }
}
};
importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, f, f2);
}
}
private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception { private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception {
return importFunctionalModelH5Test(modelPath, null, false); return importFunctionalModelH5Test(modelPath, null, false);
} }
@ -640,6 +766,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
boolean checkGradients, boolean enforceTrainingConfig) throws Exception { boolean checkGradients, boolean enforceTrainingConfig) throws Exception {
return importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, checkGradients, true, enforceTrainingConfig, null, null);
}
public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions,
boolean checkGradients, boolean enforceTrainingConfig, boolean checkAuc, Function<INDArray,INDArray> inputPreProc,
BiFunction<String,INDArray,INDArray> expectedPreProc) throws Exception {
MultiLayerNetwork model; MultiLayerNetwork model;
try(InputStream is = Resources.asStream(modelPath)) { try(InputStream is = Resources.asStream(modelPath)) {
File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
@ -658,20 +790,25 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
if (checkPredictions) { if (checkPredictions) {
INDArray input = getInputs(outputsArchive, tfOrdering)[0]; INDArray input = getInputs(outputsArchive, tfOrdering)[0];
if(inputPreProc != null)
input = inputPreProc.apply(input);
Map<String, INDArray> activationsKeras = getActivations(outputsArchive, tfOrdering); Map<String, INDArray> activationsKeras = getActivations(outputsArchive, tfOrdering);
for (int i = 0; i < model.getLayers().length; i++) { for (int i = 0; i < model.getLayers().length; i++) {
String layerName = model.getLayerNames().get(i); String layerName = model.getLayerNames().get(i);
if (activationsKeras.containsKey(layerName)) { if (activationsKeras.containsKey(layerName)) {
INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1); INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1);
if (activationsDl4j.shape().length == 3) INDArray exp = activationsKeras.get(layerName);
activationsDl4j = activationsDl4j.permute(0, 2, 1); if(expectedPreProc != null)
compareINDArrays(layerName, activationsKeras.get(layerName), activationsDl4j, EPS); exp = expectedPreProc.apply(layerName, exp);
compareINDArrays(layerName, exp, activationsDl4j, EPS);
} }
} }
INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0]; INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0];
INDArray predictionsDl4j = model.output(input, false); INDArray predictionsDl4j = model.output(input, false);
if(expectedPreProc != null)
predictionsKeras = expectedPreProc.apply("output", predictionsKeras);
compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS); compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS);
INDArray outputs = getOutputs(outputsArchive, true)[0]; INDArray outputs = getOutputs(outputsArchive, true)[0];
@ -680,6 +817,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
} }
val nOut = (int) outputs.size(-1); val nOut = (int) outputs.size(-1);
if(checkAuc)
compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS); compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS);
} }
@ -760,20 +898,23 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
return predictions; return predictions;
} }
private static void compareINDArrays(String label, INDArray a, INDArray b, double eps) { private static void compareINDArrays(String label, INDArray expected, INDArray actual, double eps) {
INDArray diff = a.sub(b.castTo(a.dataType())); if(!expected.equalShapes(actual)){
throw new IllegalStateException("Shapes do not match for \"" + label + "\": got " + Arrays.toString(expected.shape()) + " vs " + Arrays.toString(actual.shape()));
}
INDArray diff = expected.sub(actual.castTo(expected.dataType()));
double min = diff.minNumber().doubleValue(); double min = diff.minNumber().doubleValue();
double max = diff.maxNumber().doubleValue(); double max = diff.maxNumber().doubleValue();
log.info(label + ": " + a.equalsWithEps(b, eps) + ", " + min + ", " + max); log.info(label + ": " + expected.equalsWithEps(actual, eps) + ", " + min + ", " + max);
double threshold = 1e-7; double threshold = 1e-7;
double aAbsMax = Math.max(Math.abs(a.minNumber().doubleValue()), Math.abs(a.maxNumber().doubleValue())); double aAbsMax = Math.max(Math.abs(expected.minNumber().doubleValue()), Math.abs(expected.maxNumber().doubleValue()));
double bAbsMax = Math.max(Math.abs(b.minNumber().doubleValue()), Math.abs(b.maxNumber().doubleValue())); double bAbsMax = Math.max(Math.abs(actual.minNumber().doubleValue()), Math.abs(actual.maxNumber().doubleValue()));
// skip too small absolute inputs // skip too small absolute inputs
if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) { if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
assertTrue(a.equalsWithEps(b.castTo(a.dataType()), eps)); boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps);
assertTrue("Output differs: " + label, eq);
} }
} }
private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses, private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses,

View File

@ -1,5 +1,6 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc. ~ Copyright (c) 2015-2018 Skymind, Inc.
~ Copyright (c) 2019 Konduit K.K.
~ ~
~ This program and the accompanying materials are made available under the ~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at ~ terms of the Apache License, Version 2.0 which is available at
@ -23,16 +24,11 @@
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<artifactId>deeplearning4j-nearestneighbor-server_2.11</artifactId> <artifactId>deeplearning4j-nearestneighbor-server</artifactId>
<packaging>jar</packaging> <packaging>jar</packaging>
<name>deeplearning4j-nearestneighbor-server</name> <name>deeplearning4j-nearestneighbor-server</name>
<properties>
<!-- Default scala versions, may be overwritten by build profiles -->
<scala.version>2.11.12</scala.version>
<scala.binary.version>2.11</scala.binary.version>
</properties>
<build> <build>
<pluginManagement> <pluginManagement>
<plugins> <plugins>
@ -73,29 +69,17 @@
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>io.vertx</groupId>
<artifactId>play-java_2.11</artifactId> <artifactId>vertx-core</artifactId>
<version>${playframework.version}</version> <version>${vertx.version}</version>
<exclusions>
<exclusion>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.tomcat</groupId>
<artifactId>tomcat-servlet-api</artifactId>
</exclusion>
<exclusion>
<groupId>net.jodah</groupId>
<artifactId>typetools</artifactId>
</exclusion>
</exclusions>
</dependency> </dependency>
<dependency> <dependency>
<groupId>net.jodah</groupId> <groupId>io.vertx</groupId>
<artifactId>typetools</artifactId> <artifactId>vertx-web</artifactId>
<version>${jodah.typetools.version}</version> <version>${vertx.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.mashape.unirest</groupId> <groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId> <artifactId>unirest-java</artifactId>
@ -108,25 +92,16 @@
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-json_2.11</artifactId>
<version>${playframework.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-server_2.11</artifactId>
<version>${playframework.version}</version>
</dependency>
<dependency> <dependency>
<groupId>com.beust</groupId> <groupId>com.beust</groupId>
<artifactId>jcommander</artifactId> <artifactId>jcommander</artifactId>
<version>${jcommander.version}</version> <version>${jcommander.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>play-netty-server_2.11</artifactId> <artifactId>logback-classic</artifactId>
<version>${playframework.version}</version> <scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>
@ -144,11 +119,11 @@
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId> <artifactId>nd4j-cuda-10.2</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -19,6 +20,11 @@ package org.deeplearning4j.nearestneighbor.server;
import com.beust.jcommander.JCommander; import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter; import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException; import com.beust.jcommander.ParameterException;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.Vertx;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.handler.BodyHandler;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.deeplearning4j.clustering.sptree.DataPoint; import org.deeplearning4j.clustering.sptree.DataPoint;
@ -26,6 +32,7 @@ import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch; import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nearestneighbor.model.*; import org.deeplearning4j.nearestneighbor.model.*;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
@ -33,19 +40,10 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.serde.base64.Nd4jBase64; import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.serde.binary.BinarySerde; import org.nd4j.serde.binary.BinarySerde;
import play.BuiltInComponents;
import play.Mode;
import play.libs.Json;
import play.routing.Router;
import play.routing.RoutingDsl;
import play.server.Server;
import java.io.File; import java.io.File;
import java.util.*; import java.util.*;
import static play.mvc.Controller.request;
import static play.mvc.Results.*;
/** /**
* A rest server for using an * A rest server for using an
* {@link VPTree} based on loading an ndarray containing * {@link VPTree} based on loading an ndarray containing
@ -57,7 +55,9 @@ import static play.mvc.Results.*;
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j @Slf4j
public class NearestNeighborsServer { public class NearestNeighborsServer extends AbstractVerticle {
private static class RunArgs {
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true) @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
private String ndarrayPath = null; private String ndarrayPath = null;
@Parameter(names = {"--labelsPath"}, arity = 1, required = false) @Parameter(names = {"--labelsPath"}, arity = 1, required = false)
@ -68,11 +68,20 @@ public class NearestNeighborsServer {
private String similarityFunction = "euclidean"; private String similarityFunction = "euclidean";
@Parameter(names = {"--invert"}, arity = 1) @Parameter(names = {"--invert"}, arity = 1)
private boolean invert = false; private boolean invert = false;
}
private Server server; private static RunArgs instanceArgs;
private static NearestNeighborsServer instance;
public void runMain(String... args) throws Exception { public NearestNeighborsServer(){ }
JCommander jcmdr = new JCommander(this);
public static NearestNeighborsServer getInstance(){
return instance;
}
public static void runMain(String... args) {
RunArgs r = new RunArgs();
JCommander jcmdr = new JCommander(r);
try { try {
jcmdr.parse(args); jcmdr.parse(args);
@ -84,7 +93,7 @@ public class NearestNeighborsServer {
//User provides invalid input -> print the usage info //User provides invalid input -> print the usage info
jcmdr.usage(); jcmdr.usage();
if (ndarrayPath == null) if (r.ndarrayPath == null)
log.error("Json path parameter is missing (null)"); log.error("Json path parameter is missing (null)");
try { try {
Thread.sleep(500); Thread.sleep(500);
@ -93,16 +102,20 @@ public class NearestNeighborsServer {
System.exit(1); System.exit(1);
} }
instanceArgs = r;
try { try {
runHelper(); Vertx vertx = Vertx.vertx();
vertx.deployVerticle(NearestNeighborsServer.class.getName());
} catch (Throwable t){ } catch (Throwable t){
log.error("Error in NearestNeighboursServer run method",t); log.error("Error in NearestNeighboursServer run method",t);
} }
} }
protected void runHelper() throws Exception { @Override
public void start() throws Exception {
instance = this;
String[] pathArr = ndarrayPath.split(","); String[] pathArr = instanceArgs.ndarrayPath.split(",");
//INDArray[] pointsArr = new INDArray[pathArr.length]; //INDArray[] pointsArr = new INDArray[pathArr.length];
// first of all we reading shapes of saved eariler files // first of all we reading shapes of saved eariler files
int rows = 0; int rows = 0;
@ -126,8 +139,8 @@ public class NearestNeighborsServer {
} }
final List<String> labels = new ArrayList<>(); final List<String> labels = new ArrayList<>();
if (labelsPath != null) { if (instanceArgs.labelsPath != null) {
String[] labelsPathArr = labelsPath.split(","); String[] labelsPathArr = instanceArgs.labelsPath.split(",");
for (int i = 0; i < labelsPathArr.length; i++) { for (int i = 0; i < labelsPathArr.length; i++) {
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8")); labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
} }
@ -149,7 +162,7 @@ public class NearestNeighborsServer {
System.gc(); System.gc();
} }
VPTree tree = new VPTree(points, similarityFunction, invert); VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
//Set play secret key, if required //Set play secret key, if required
//http://www.playframework.com/documentation/latest/ApplicationSecret //http://www.playframework.com/documentation/latest/ApplicationSecret
@ -163,40 +176,57 @@ public class NearestNeighborsServer {
System.setProperty("play.crypto.secret", base64); System.setProperty("play.crypto.secret", base64);
} }
Router r = Router.router(vertx);
r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
createRoutes(r, labels, tree, points);
server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b)); vertx.createHttpServer()
.requestHandler(r)
.listen(instanceArgs.port);
} }
protected Router createRouter(VPTree tree, List<String> labels, INDArray points, BuiltInComponents builtInComponents){ private void createRoutes(Router r, List<String> labels, VPTree tree, INDArray points){
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
//return the host information for a given id r.post("/knn").handler(rc -> {
routingDsl.POST("/knn").routingTo(request -> {
try { try {
NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class); String json = rc.getBodyAsJson().encode();
NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
NearestNeighbor nearestNeighbor = NearestNeighbor nearestNeighbor =
NearestNeighbor.builder().points(points).record(record).tree(tree).build(); NearestNeighbor.builder().points(points).record(record).tree(tree).build();
if (record == null) if (record == null) {
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
return;
}
NearestNeighborsResults results = NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
return ok(Json.toJson(results));
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(results));
return;
} catch (Throwable e) { } catch (Throwable e) {
log.error("Error in POST /knn",e); log.error("Error in POST /knn",e);
e.printStackTrace(); e.printStackTrace();
return internalServerError(e.getMessage()); rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage());
return;
} }
}); });
routingDsl.POST("/knnnew").routingTo(request -> { r.post("/knnnew").handler(rc -> {
try { try {
Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class); String json = rc.getBodyAsJson().encode();
if (record == null) Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); if (record == null) {
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
return;
}
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray()); INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
List<DataPoint> results; List<DataPoint> results;
@ -214,9 +244,10 @@ public class NearestNeighborsServer {
} }
if (results.size() != distances.size()) { if (results.size() != distances.size()) {
return internalServerError( rc.response()
String.format("results.size == %d != %d == distances.size", .setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
results.size(), distances.size())); .end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
return;
} }
List<NearestNeighborsResult> nnResult = new ArrayList<>(); List<NearestNeighborsResult> nnResult = new ArrayList<>();
@ -228,30 +259,29 @@ public class NearestNeighborsServer {
} }
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build(); NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
return ok(Json.toJson(results2)); String j = JsonMappers.getMapper().writeValueAsString(results2);
rc.response()
.putHeader("content-type", "application/json")
.end(j);
} catch (Throwable e) { } catch (Throwable e) {
log.error("Error in POST /knnnew",e); log.error("Error in POST /knnnew",e);
e.printStackTrace(); e.printStackTrace();
return internalServerError(e.getMessage()); rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage());
return;
} }
}); });
return routingDsl.build();
} }
/** /**
* Stop the server * Stop the server
*/ */
public void stop() { public void stop() throws Exception {
if (server != null) { super.stop();
log.info("Attempting to stop server");
server.stop();
}
} }
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
new NearestNeighborsServer().runMain(args); runMain(args);
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -50,7 +51,6 @@ public class NearestNeighborTest extends BaseDL4JTest {
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
//@Ignore("AB 2019/05/21 - Failing - Issue #7657")
public void testNearestNeighbor() { public void testNearestNeighbor() {
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}}; double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
INDArray arr = Nd4j.create(data); INDArray arr = Nd4j.create(data);
@ -119,14 +119,15 @@ public class NearestNeighborTest extends BaseDL4JTest {
File writeToTmp = testDir.newFile(); File writeToTmp = testDir.newFile();
writeToTmp.deleteOnExit(); writeToTmp.deleteOnExit();
BinarySerde.writeArrayToDisk(rand, writeToTmp); BinarySerde.writeArrayToDisk(rand, writeToTmp);
NearestNeighborsServer server = new NearestNeighborsServer(); NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
server.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
String.valueOf(localPort)); String.valueOf(localPort));
Thread.sleep(3000);
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort); NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0)); NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
assertEquals(5, result.getResults().size()); assertEquals(5, result.getResults().size());
server.stop(); NearestNeighborsServer.getInstance().stop();
} }

View File

@ -0,0 +1,42 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2019 Konduit K.K.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file>
<encoder>
<pattern>%date - [%level] - from %logger in %thread
%n%message%n%xException%n</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<logger name="org.deeplearning4j" level="INFO" />
<logger name="org.datavec" level="INFO" />
<logger name="org.nd4j" level="INFO" />
<root level="ERROR">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

View File

@ -54,7 +54,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -53,7 +53,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -83,11 +83,11 @@
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId> <artifactId>nd4j-cuda-10.2</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@ -1,5 +1,6 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc. ~ Copyright (c) 2015-2018 Skymind, Inc.
~ Copyright (c) 2019 Konduit K.K.
~ ~
~ This program and the accompanying materials are made available under the ~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at ~ terms of the Apache License, Version 2.0 which is available at
@ -43,7 +44,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -66,7 +66,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -68,7 +68,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -61,7 +61,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -79,7 +79,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -84,7 +84,7 @@
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.2</id>
</profile> </profile>
</profiles> </profiles>

View File

@ -34,6 +34,7 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -85,10 +86,20 @@ import java.util.Map;
* <pre> * <pre>
* {@code * {@code
* BertIterator b; * BertIterator b;
* Pair<INDArray[],INDArray[]> featuresAndMask;
* INDArray[] features;
* INDArray[] featureMasks;
*
* //With sentences
* List<String> forInference; * List<String> forInference;
* Pair<INDArray[],INDArray[]> featuresAndMask = b.featurizeSentences(forInference); * featuresAndMask = b.featurizeSentences(forInference);
* INDArray[] features = featuresAndMask.getFirst(); *
* INDArray[] featureMasks = featuresAndMask.getSecond(); * //OR with sentence pairs
* List<Pair<String, String>> forInferencePair};
* featuresAndMask = b.featurizeSentencePairs(forInference);
*
* features = featuresAndMask.getFirst();
* featureMasks = featuresAndMask.getSecond();
* } * }
* </pre> * </pre>
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br> * This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.<br>
@ -135,6 +146,7 @@ public class BertIterator implements MultiDataSetIterator {
@Setter @Setter
protected MultiDataSetPreProcessor preProcessor; protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null; protected LabeledSentenceProvider sentenceProvider = null;
protected LabeledPairSentenceProvider sentencePairProvider = null;
protected LengthHandling lengthHandling; protected LengthHandling lengthHandling;
protected FeatureArrays featureArrays; protected FeatureArrays featureArrays;
protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
@ -142,6 +154,7 @@ public class BertIterator implements MultiDataSetIterator {
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null; protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
protected String maskToken; protected String maskToken;
protected String prependToken; protected String prependToken;
protected String appendToken;
protected List<String> vocabKeysAsList; protected List<String> vocabKeysAsList;
@ -154,6 +167,7 @@ public class BertIterator implements MultiDataSetIterator {
this.padMinibatches = b.padMinibatches; this.padMinibatches = b.padMinibatches;
this.preProcessor = b.preProcessor; this.preProcessor = b.preProcessor;
this.sentenceProvider = b.sentenceProvider; this.sentenceProvider = b.sentenceProvider;
this.sentencePairProvider = b.sentencePairProvider;
this.lengthHandling = b.lengthHandling; this.lengthHandling = b.lengthHandling;
this.featureArrays = b.featureArrays; this.featureArrays = b.featureArrays;
this.vocabMap = b.vocabMap; this.vocabMap = b.vocabMap;
@ -161,11 +175,14 @@ public class BertIterator implements MultiDataSetIterator {
this.unsupervisedLabelFormat = b.unsupervisedLabelFormat; this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
this.maskToken = b.maskToken; this.maskToken = b.maskToken;
this.prependToken = b.prependToken; this.prependToken = b.prependToken;
this.appendToken = b.appendToken;
} }
@Override @Override
public boolean hasNext() { public boolean hasNext() {
if (sentenceProvider != null)
return sentenceProvider.hasNext(); return sentenceProvider.hasNext();
return sentencePairProvider.hasNext();
} }
@Override @Override
@ -181,29 +198,38 @@ public class BertIterator implements MultiDataSetIterator {
@Override @Override
public MultiDataSet next(int num) { public MultiDataSet next(int num) {
Preconditions.checkState(hasNext(), "No next element available"); Preconditions.checkState(hasNext(), "No next element available");
List<Pair<List<String>, String>> tokensAndLabelList;
List<Pair<String, String>> list = new ArrayList<>(num);
int mbSize = 0; int mbSize = 0;
int outLength;
long[] segIdOnesFrom = null;
if (sentenceProvider != null) { if (sentenceProvider != null) {
List<Pair<String, String>> list = new ArrayList<>(num);
while (sentenceProvider.hasNext() && mbSize++ < num) { while (sentenceProvider.hasNext() && mbSize++ < num) {
list.add(sentenceProvider.nextSentence()); list.add(sentenceProvider.nextSentence());
} }
SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(list);
tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
outLength = sentenceListProcessed.getMaxL();
} else if (sentencePairProvider != null) {
List<Triple<String, String, String>> listPairs = new ArrayList<>(num);
while (sentencePairProvider.hasNext() && mbSize++ < num) {
listPairs.add(sentencePairProvider.nextSentencePair());
}
SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(listPairs);
tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
outLength = sentencePairListProcessed.getMaxL();
segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
} else { } else {
//TODO - other types of iterators... //TODO - other types of iterators...
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented"); throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
} }
Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list);
List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
int outLength = outLTokenizedSentencesPair.getLeft();
Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength);
INDArray[] featureArray = featuresAndMaskArraysPair.getFirst(); INDArray[] featureArray = featuresAndMaskArraysPair.getFirst();
INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond(); INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond();
Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength); Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = convertMiniBatchLabels(tokensAndLabelList, featureArray, outLength);
INDArray[] labelArray = labelsAndMaskArraysPair.getFirst(); INDArray[] labelArray = labelsAndMaskArraysPair.getFirst();
INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond(); INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond();
@ -224,32 +250,59 @@ public class BertIterator implements MultiDataSetIterator {
public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) { public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) {
List<Pair<String, String>> sentencesWithNullLabel = addDummyLabel(listOnlySentences); List<Pair<String, String>> sentencesWithNullLabel = addDummyLabel(listOnlySentences);
SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(sentencesWithNullLabel);
List<Pair<List<String>, String>> tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
int outLength = sentenceListProcessed.getMaxL();
Pair<Integer, List<Pair<List<String>, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel);
List<Pair<List<String>, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight();
int outLength = outLTokenizedSentencesPair.getLeft();
Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength);
if (preProcessor != null) { if (preProcessor != null) {
Pair<INDArray[], INDArray[]> featureFeatureMasks = convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null); MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null);
preProcessor.preProcess(dummyMDS); preProcessor.preProcess(dummyMDS);
return new Pair<INDArray[],INDArray[]>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
} }
return convertMiniBatchFeatures(tokenizedSentences, outLength); return convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
} }
private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokenizedSentences, int outLength) { /**
int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size(); * For use during inference. Will convert a given pair of a list of sentences to features and feature masks as appropriate.
*
* @param listOnlySentencePairs
* @return Pair of INDArrays[], first element is feature arrays and the second is the masks array
*/
public Pair<INDArray[], INDArray[]> featurizeSentencePairs(List<Pair<String, String>> listOnlySentencePairs) {
Preconditions.checkState(sentencePairProvider != null, "The featurizeSentencePairs method is meant for inference with sentence pairs. Use only when the sentence pair provider is set (i.e not null).");
List<Triple<String, String, String>> sentencePairsWithNullLabel = addDummyLabelForPairs(listOnlySentencePairs);
SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(sentencePairsWithNullLabel);
List<Pair<List<String>, String>> tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
int outLength = sentencePairListProcessed.getMaxL();
long[] segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
if (preProcessor != null) {
Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featuresAndMaskArraysPair.getFirst(), null, featuresAndMaskArraysPair.getSecond(), null);
preProcessor.preProcess(dummyMDS);
return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
}
return convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
}
private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokensAndLabelList, int outLength, long[] segIdOnesFrom) {
int mbPadded = padMinibatches ? minibatchSize : tokensAndLabelList.size();
int[][] outIdxs = new int[mbPadded][outLength]; int[][] outIdxs = new int[mbPadded][outLength];
int[][] outMask = new int[mbPadded][outLength]; int[][] outMask = new int[mbPadded][outLength];
for (int i = 0; i < tokenizedSentences.size(); i++) { int[][] outSegmentId = null;
Pair<List<String>, String> p = tokenizedSentences.get(i); if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID)
outSegmentId = new int[mbPadded][outLength];
for (int i = 0; i < tokensAndLabelList.size(); i++) {
Pair<List<String>, String> p = tokensAndLabelList.get(i);
List<String> t = p.getFirst(); List<String> t = p.getFirst();
for (int j = 0; j < outLength && j < t.size(); j++) { for (int j = 0; j < outLength && j < t.size(); j++) {
Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j)); Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j));
int idx = vocabMap.get(t.get(j)); int idx = vocabMap.get(t.get(j));
outIdxs[i][j] = idx; outIdxs[i][j] = idx;
outMask[i][j] = 1; outMask[i][j] = 1;
if (segIdOnesFrom != null && j >= segIdOnesFrom[i])
outSegmentId[i][j] = 1;
} }
} }
@ -260,8 +313,7 @@ public class BertIterator implements MultiDataSetIterator {
INDArray[] f; INDArray[] f;
INDArray[] fm; INDArray[] fm;
if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) { if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
//For now: always segment index 0 (only single s sequence input supported) outSegmentIdArr = Nd4j.createFromArray(outSegmentId);
outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength);
f = new INDArray[]{outIdxsArr, outSegmentIdArr}; f = new INDArray[]{outIdxsArr, outSegmentIdArr};
fm = new INDArray[]{outMaskArr, null}; fm = new INDArray[]{outMaskArr, null};
} else { } else {
@ -271,16 +323,15 @@ public class BertIterator implements MultiDataSetIterator {
return new Pair<>(f, fm); return new Pair<>(f, fm);
} }
private Pair<Integer, List<Pair<List<String>, String>>> tokenizeMiniBatch(List<Pair<String, String>> list) { private SentenceListProcessed tokenizeMiniBatch(List<Pair<String, String>> list) {
//Get and tokenize the sentences for this minibatch //Get and tokenize the sentences for this minibatch
List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(list.size()); SentenceListProcessed sentenceListProcessed = new SentenceListProcessed(list.size());
int longestSeq = -1; int longestSeq = -1;
for (Pair<String, String> p : list) { for (Pair<String, String> p : list) {
List<String> tokens = tokenizeSentence(p.getFirst()); List<String> tokens = tokenizeSentence(p.getFirst());
tokenizedSentences.add(new Pair<>(tokens, p.getSecond())); sentenceListProcessed.addProcessedToList(new Pair<>(tokens, p.getSecond()));
longestSeq = Math.max(longestSeq, tokens.size()); longestSeq = Math.max(longestSeq, tokens.size());
} }
//Determine output array length... //Determine output array length...
int outLength; int outLength;
switch (lengthHandling) { switch (lengthHandling) {
@ -296,7 +347,52 @@ public class BertIterator implements MultiDataSetIterator {
default: default:
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling); throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
} }
return new Pair<>(outLength, tokenizedSentences); sentenceListProcessed.setMaxL(outLength);
return sentenceListProcessed;
}
private SentencePairListProcessed tokenizePairsMiniBatch(List<Triple<String, String, String>> listPairs) {
SentencePairListProcessed sentencePairListProcessed = new SentencePairListProcessed(listPairs.size());
for (Triple<String, String, String> t : listPairs) {
List<String> tokensL = tokenizeSentence(t.getFirst(), true);
List<String> tokensR = tokenizeSentence(t.getSecond(), true);
List<String> tokens = new ArrayList<>(maxTokens);
int maxLength = maxTokens;
if (prependToken != null)
maxLength--;
if (appendToken != null)
maxLength -= 2;
if (tokensL.size() + tokensR.size() > maxLength) {
boolean shortOnL = tokensL.size() < tokensR.size();
int shortSize = Math.min(tokensL.size(), tokensR.size());
if (shortSize > maxLength / 2) {
//both lists need to be sliced
tokensL.subList(maxLength / 2, tokensL.size()).clear(); //if maxsize/2 is odd pop extra on L side to match implementation in TF
tokensR.subList(maxLength - maxLength / 2, tokensR.size()).clear();
} else {
//slice longer list
if (shortOnL) {
//longer on R - slice R
tokensR.subList(maxLength - tokensL.size(), tokensR.size()).clear();
} else {
//longer on L - slice L
tokensL.subList(maxLength - tokensR.size(), tokensL.size()).clear();
}
}
}
if (prependToken != null)
tokens.add(prependToken);
tokens.addAll(tokensL);
if (appendToken != null)
tokens.add(appendToken);
int segIdOnesFrom = tokens.size();
tokens.addAll(tokensR);
if (appendToken != null)
tokens.add(appendToken);
sentencePairListProcessed.addProcessedToList(segIdOnesFrom, new Pair<>(tokens, t.getThird()));
}
sentencePairListProcessed.setMaxL(maxTokens);
return sentencePairListProcessed;
} }
private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) { private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
@ -316,6 +412,14 @@ public class BertIterator implements MultiDataSetIterator {
classLabels[i] = labels.indexOf(lbl); classLabels[i] = labels.indexOf(lbl);
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
} }
} else if (sentencePairProvider != null) {
numClasses = sentencePairProvider.numLabelClasses();
List<String> labels = sentencePairProvider.allLabels();
for (int i = 0; i < mbSize; i++) {
String lbl = tokenizedSentences.get(i).getRight();
classLabels[i] = labels.indexOf(lbl);
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
}
} else { } else {
throw new RuntimeException(); throw new RuntimeException();
} }
@ -392,16 +496,22 @@ public class BertIterator implements MultiDataSetIterator {
} }
private List<String> tokenizeSentence(String sentence) { private List<String> tokenizeSentence(String sentence) {
return tokenizeSentence(sentence, false);
}
private List<String> tokenizeSentence(String sentence, boolean ignorePrependAppend) {
Tokenizer t = tokenizerFactory.create(sentence); Tokenizer t = tokenizerFactory.create(sentence);
List<String> tokens = new ArrayList<>(); List<String> tokens = new ArrayList<>();
if (prependToken != null) if (prependToken != null && !ignorePrependAppend)
tokens.add(prependToken); tokens.add(prependToken);
while (t.hasMoreTokens()) { while (t.hasMoreTokens()) {
String token = t.nextToken(); String token = t.nextToken();
tokens.add(token); tokens.add(token);
} }
if (appendToken != null && !ignorePrependAppend)
tokens.add(appendToken);
return tokens; return tokens;
} }
@ -414,6 +524,13 @@ public class BertIterator implements MultiDataSetIterator {
return list; return list;
} }
private List<Triple<String, String, String>> addDummyLabelForPairs(List<Pair<String, String>> listOnlySentencePairs) {
List<Triple<String, String, String>> list = new ArrayList<>(listOnlySentencePairs.size());
for (Pair<String, String> p : listOnlySentencePairs) {
list.add(new Triple<String, String, String>(p.getFirst(), p.getSecond(), null));
}
return list;
}
@Override @Override
public boolean resetSupported() { public boolean resetSupported() {
@ -446,12 +563,14 @@ public class BertIterator implements MultiDataSetIterator {
protected boolean padMinibatches = false; protected boolean padMinibatches = false;
protected MultiDataSetPreProcessor preProcessor; protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null; protected LabeledSentenceProvider sentenceProvider = null;
protected LabeledPairSentenceProvider sentencePairProvider = null;
protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID; protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? protected Map<String, Integer> vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
protected BertSequenceMasker masker = new BertMaskedLMMasker(); protected BertSequenceMasker masker = new BertMaskedLMMasker();
protected UnsupervisedLabelFormat unsupervisedLabelFormat; protected UnsupervisedLabelFormat unsupervisedLabelFormat;
protected String maskToken; protected String maskToken;
protected String prependToken; protected String prependToken;
protected String appendToken;
/** /**
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details. * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
@ -519,14 +638,21 @@ public class BertIterator implements MultiDataSetIterator {
} }
/** /**
* Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised * Specify the source of the data for classification.
* use case, the labels will be ignored.
*/ */
public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) { public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
this.sentenceProvider = sentenceProvider; this.sentenceProvider = sentenceProvider;
return this; return this;
} }
/**
* Specify the source of the data for classification on sentence pairs.
*/
public Builder sentencePairProvider(LabeledPairSentenceProvider sentencePairProvider) {
this.sentencePairProvider = sentencePairProvider;
return this;
}
/** /**
* Specify what arrays should be returned. See {@link BertIterator} for more details. * Specify what arrays should be returned. See {@link BertIterator} for more details.
*/ */
@ -591,6 +717,19 @@ public class BertIterator implements MultiDataSetIterator {
return this; return this;
} }
/**
* Append the specified token to the sequences, when doing training on sentence pairs.<br>
* Generally "[SEP]" is used
* No token in appended by default.
*
* @param appendToken Token at end of each sentence for pairs of sentences (null: no token will be appended)
* @return
*/
public Builder appendToken(String appendToken) {
this.appendToken = appendToken;
return this;
}
public BertIterator build() { public BertIterator build() {
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed"); Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required"); Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
@ -598,9 +737,69 @@ public class BertIterator implements MultiDataSetIterator {
Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method"); Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method"); Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified"); Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
if (sentencePairProvider != null) {
Preconditions.checkState(task == Task.SEQ_CLASSIFICATION, "Currently only supervised sequence classification is set up with sentence pairs. \".task(BertIterator.Task.SEQ_CLASSIFICATION)\" is required with a sentence pair provider");
Preconditions.checkState(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID, "Currently only supervised sequence classification is set up with sentence pairs. \".featureArrays(FeatureArrays.INDICES_MASK_SEGMENTID)\" is required with a sentence pair provider");
Preconditions.checkState(lengthHandling == LengthHandling.FIXED_LENGTH, "Currently only fixed length is supported for sentence pairs. \".lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxLength)\" is required with a sentence pair provider");
Preconditions.checkState(sentencePairProvider != null, "Provide either a sentence provider or a sentence pair provider. Both cannot be non null");
}
if (appendToken != null) {
Preconditions.checkState(sentencePairProvider != null, "Tokens are only appended with sentence pairs. Sentence pair provider is not set. Set sentence pair provider.");
}
return new BertIterator(this); return new BertIterator(this);
} }
} }
private static class SentencePairListProcessed {
private int listLength = 0;
@Getter
private long[] segIdOnesFrom;
private int cursor = 0;
private SentenceListProcessed sentenceListProcessed;
private SentencePairListProcessed(int listLength) {
this.listLength = listLength;
segIdOnesFrom = new long[listLength];
sentenceListProcessed = new SentenceListProcessed(listLength);
}
private void addProcessedToList(long segIdIdx, Pair<List<String>, String> tokenizedSentencePairAndLabel) {
segIdOnesFrom[cursor] = segIdIdx;
sentenceListProcessed.addProcessedToList(tokenizedSentencePairAndLabel);
cursor++;
}
private void setMaxL(int maxL) {
sentenceListProcessed.setMaxL(maxL);
}
private int getMaxL() {
return sentenceListProcessed.getMaxL();
}
private List<Pair<List<String>, String>> getTokensAndLabelList() {
return sentenceListProcessed.getTokensAndLabelList();
}
}
private static class SentenceListProcessed {
private int listLength;
@Getter
@Setter
private int maxL;
@Getter
private List<Pair<List<String>, String>> tokensAndLabelList;
private SentenceListProcessed(int listLength) {
this.listLength = listLength;
tokensAndLabelList = new ArrayList<>(listLength);
}
private void addProcessedToList(Pair<List<String>, String> tokenizedSentenceAndLabel) {
tokensAndLabelList.add(tokenizedSentenceAndLabel);
}
}
} }

Some files were not shown because too many files have changed in this diff Show More