Python updates (#86)
* python updates * fix cyclic deps * konduit updates * konduit updates * fix list * fixes * sync pyvars test * setuprun comments * Version fix, other module test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * bug fix using advanced hacking skillzzmaster
parent
8123d9fa9b
commit
1adc25919c
|
@ -256,11 +256,9 @@ public class ExecutionTest {
|
|||
|
||||
TransformProcess transformProcess = new TransformProcess.Builder(schema)
|
||||
.transform(
|
||||
new PythonTransform(
|
||||
"first = np.sin(first)\nsecond = np.cos(second)",
|
||||
schema
|
||||
)
|
||||
)
|
||||
PythonTransform.builder().code(
|
||||
"first = np.sin(first)\nsecond = np.cos(second)")
|
||||
.outputSchema(schema).build())
|
||||
.build();
|
||||
|
||||
List<List<Writable>> functions = new ArrayList<>();
|
||||
|
|
|
@ -14,35 +14,40 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.python;
|
||||
package org.datavec.local.transforms.transform;
|
||||
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.api.transform.condition.Condition;
|
||||
import org.datavec.api.transform.filter.ConditionFilter;
|
||||
import org.datavec.api.transform.filter.Filter;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.junit.Ignore;
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
|
||||
import org.datavec.api.writable.*;
|
||||
import org.datavec.python.PythonCondition;
|
||||
import org.datavec.python.PythonTransform;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
|
||||
import static junit.framework.TestCase.assertTrue;
|
||||
import static org.datavec.api.transform.schema.Schema.Builder;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@NotThreadSafe
|
||||
public class TestPythonTransformProcess {
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
|
||||
@Test()
|
||||
public void testStringConcat() throws Exception{
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnString("col1")
|
||||
.addColumnString("col2");
|
||||
|
@ -54,7 +59,9 @@ public class TestPythonTransformProcess {
|
|||
String pythonCode = "col3 = col1 + col2";
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
new PythonTransform(pythonCode, finalSchema)
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build()
|
||||
).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
|
||||
|
@ -68,7 +75,7 @@ public class TestPythonTransformProcess {
|
|||
|
||||
@Test(timeout = 60000L)
|
||||
public void testMixedTypes() throws Exception{
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnInteger("col1")
|
||||
.addColumnFloat("col2")
|
||||
|
@ -83,11 +90,12 @@ public class TestPythonTransformProcess {
|
|||
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
new PythonTransform(pythonCode, finalSchema)
|
||||
).build();
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.inputSchema(initialSchema)
|
||||
.build() ).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList((Writable)
|
||||
new IntWritable(10),
|
||||
List<Writable> inputs = Arrays.asList((Writable)new IntWritable(10),
|
||||
new FloatWritable(3.5f),
|
||||
new Text("5"),
|
||||
new DoubleWritable(2.0)
|
||||
|
@ -105,7 +113,7 @@ public class TestPythonTransformProcess {
|
|||
|
||||
INDArray expectedOutput = arr1.add(arr2);
|
||||
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnNDArray("col1", shape)
|
||||
.addColumnNDArray("col2", shape);
|
||||
|
@ -116,11 +124,13 @@ public class TestPythonTransformProcess {
|
|||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
new PythonTransform(pythonCode, finalSchema)
|
||||
).build();
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build() ).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList(
|
||||
(Writable) new NDArrayWritable(arr1),
|
||||
(Writable)
|
||||
new NDArrayWritable(arr1),
|
||||
new NDArrayWritable(arr2)
|
||||
);
|
||||
|
||||
|
@ -139,7 +149,7 @@ public class TestPythonTransformProcess {
|
|||
|
||||
INDArray expectedOutput = arr1.add(arr2);
|
||||
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnNDArray("col1", shape)
|
||||
.addColumnNDArray("col2", shape);
|
||||
|
@ -150,11 +160,13 @@ public class TestPythonTransformProcess {
|
|||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
new PythonTransform(pythonCode, finalSchema)
|
||||
).build();
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build() ).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList(
|
||||
(Writable) new NDArrayWritable(arr1),
|
||||
(Writable)
|
||||
new NDArrayWritable(arr1),
|
||||
new NDArrayWritable(arr2)
|
||||
);
|
||||
|
||||
|
@ -172,7 +184,7 @@ public class TestPythonTransformProcess {
|
|||
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
|
||||
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
|
||||
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnNDArray("col1", shape)
|
||||
.addColumnNDArray("col2", shape);
|
||||
|
@ -183,11 +195,14 @@ public class TestPythonTransformProcess {
|
|||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
new PythonTransform(pythonCode, finalSchema)
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build()
|
||||
).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList(
|
||||
(Writable) new NDArrayWritable(arr1),
|
||||
(Writable)
|
||||
new NDArrayWritable(arr1),
|
||||
new NDArrayWritable(arr2)
|
||||
);
|
||||
|
||||
|
@ -200,7 +215,7 @@ public class TestPythonTransformProcess {
|
|||
|
||||
@Test(timeout = 60000L)
|
||||
public void testPythonFilter() {
|
||||
Schema schema = new Schema.Builder().addColumnInteger("column").build();
|
||||
Schema schema = new Builder().addColumnInteger("column").build();
|
||||
|
||||
Condition condition = new PythonCondition(
|
||||
"f = lambda: column < 0"
|
||||
|
@ -210,17 +225,17 @@ public class TestPythonTransformProcess {
|
|||
|
||||
Filter filter = new ConditionFilter(condition);
|
||||
|
||||
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(10))));
|
||||
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(1))));
|
||||
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(0))));
|
||||
assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-1))));
|
||||
assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-10))));
|
||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
|
||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
|
||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
|
||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
|
||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
|
||||
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
public void testPythonFilterAndTransform() throws Exception{
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnInteger("col1")
|
||||
.addColumnFloat("col2")
|
||||
|
@ -241,33 +256,85 @@ public class TestPythonTransformProcess {
|
|||
|
||||
String pythonCode = "col6 = str(col1 + col2)";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
new PythonTransform(
|
||||
pythonCode,
|
||||
finalSchema
|
||||
)
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build()
|
||||
).filter(
|
||||
filter
|
||||
).build();
|
||||
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(
|
||||
Arrays.asList((Writable) new IntWritable(5),
|
||||
Arrays.asList(
|
||||
(Writable)
|
||||
new IntWritable(5),
|
||||
new FloatWritable(3.0f),
|
||||
new Text("abcd"),
|
||||
new DoubleWritable(2.1))
|
||||
);
|
||||
inputs.add(
|
||||
Arrays.asList((Writable) new IntWritable(-3),
|
||||
Arrays.asList(
|
||||
(Writable)
|
||||
new IntWritable(-3),
|
||||
new FloatWritable(3.0f),
|
||||
new Text("abcd"),
|
||||
new DoubleWritable(2.1))
|
||||
);
|
||||
inputs.add(
|
||||
Arrays.asList((Writable) new IntWritable(5),
|
||||
Arrays.asList(
|
||||
(Writable)
|
||||
new IntWritable(5),
|
||||
new FloatWritable(11.2f),
|
||||
new Text("abcd"),
|
||||
new DoubleWritable(2.1))
|
||||
);
|
||||
|
||||
LocalTransformExecutor.execute(inputs,tp);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testPythonTransformNoOutputSpecified() throws Exception {
|
||||
PythonTransform pythonTransform = PythonTransform.builder()
|
||||
.code("a += 2; b = 'hello world'")
|
||||
.returnAllInputs(true)
|
||||
.build();
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
|
||||
Schema inputSchema = new Builder()
|
||||
.addColumnInteger("a")
|
||||
.build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||
.transform(pythonTransform)
|
||||
.build();
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||
assertEquals(3,execute.get(0).get(0).toInt());
|
||||
assertEquals("hello world",execute.get(0).get(1).toString());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNumpyTransform() throws Exception {
|
||||
PythonTransform pythonTransform = PythonTransform.builder()
|
||||
.code("a += 2; b = 'hello world'")
|
||||
.returnAllInputs(true)
|
||||
.build();
|
||||
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
|
||||
Schema inputSchema = new Builder()
|
||||
.addColumnNDArray("a",new long[]{1,1})
|
||||
.build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||
.transform(pythonTransform)
|
||||
.build();
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||
assertFalse(execute.isEmpty());
|
||||
assertNotNull(execute.get(0));
|
||||
assertNotNull(execute.get(0).get(0));
|
||||
assertEquals("hello world",execute.get(0).get(0).toString());
|
||||
}
|
||||
|
||||
}
|
|
@ -28,15 +28,21 @@
|
|||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.googlecode.json-simple</groupId>
|
||||
<artifactId>json-simple</artifactId>
|
||||
<version>1.1</version>
|
||||
<groupId>org.json</groupId>
|
||||
<artifactId>json</artifactId>
|
||||
<version>20190722</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>cpython-platform</artifactId>
|
||||
<version>${cpython-platform.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>numpy-platform</artifactId>
|
||||
<version>${numpy.javacpp.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.code.findbugs</groupId>
|
||||
<artifactId>jsr305</artifactId>
|
||||
|
|
|
@ -16,10 +16,13 @@
|
|||
|
||||
package org.datavec.python;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.nativeblas.NativeOps;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
@ -33,19 +36,27 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
* @author Fariz Rahman
|
||||
*/
|
||||
@Getter
|
||||
@NoArgsConstructor
|
||||
public class NumpyArray {
|
||||
|
||||
private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||
private static NativeOps nativeOps;
|
||||
private long address;
|
||||
private long[] shape;
|
||||
private long[] strides;
|
||||
private DataType dtype = DataType.FLOAT;
|
||||
private DataType dtype;
|
||||
private INDArray nd4jArray;
|
||||
static {
|
||||
//initialize
|
||||
Nd4j.scalar(1.0);
|
||||
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||
}
|
||||
|
||||
public NumpyArray(long address, long[] shape, long strides[], boolean copy){
|
||||
@Builder
|
||||
public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) {
|
||||
this.address = address;
|
||||
this.shape = shape;
|
||||
this.strides = strides;
|
||||
this.dtype = dtype;
|
||||
setND4JArray();
|
||||
if (copy){
|
||||
nd4jArray = nd4jArray.dup();
|
||||
|
@ -57,8 +68,9 @@ public class NumpyArray {
|
|||
public NumpyArray copy(){
|
||||
return new NumpyArray(nd4jArray.dup());
|
||||
}
|
||||
|
||||
public NumpyArray(long address, long[] shape, long strides[]){
|
||||
this(address, shape, strides, false);
|
||||
this(address, shape, strides, false,DataType.FLOAT);
|
||||
}
|
||||
|
||||
public NumpyArray(long address, long[] shape, long strides[], DataType dtype){
|
||||
|
@ -91,7 +103,8 @@ public class NumpyArray {
|
|||
for (int i = 0; i < strides.length; i++) {
|
||||
nd4jStrides[i] = strides[i] / elemSize;
|
||||
}
|
||||
this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, 'c', dtype);
|
||||
|
||||
this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ import org.datavec.api.writable.*;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
|
||||
|
||||
/**
|
||||
* Lets a condition be defined as a python method f that takes no arguments
|
||||
* and returns a boolean indicating whether or not to filter a row.
|
||||
|
@ -39,68 +41,14 @@ public class PythonCondition implements Condition {
|
|||
|
||||
|
||||
public PythonCondition(String pythonCode) {
|
||||
org.nd4j.base.Preconditions.checkNotNull("Python code must not be null!",pythonCode);
|
||||
org.nd4j.base.Preconditions.checkState(pythonCode.length() >= 1,"Python code must not be empty!");
|
||||
code = pythonCode;
|
||||
}
|
||||
|
||||
private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{
|
||||
PythonVariables pyVars = new PythonVariables();
|
||||
int numCols = schema.numColumns();
|
||||
for (int i=0; i<numCols; i++){
|
||||
String colName = schema.getName(i);
|
||||
ColumnType colType = schema.getType(i);
|
||||
switch (colType){
|
||||
case Long:
|
||||
case Integer:
|
||||
pyVars.addInt(colName);
|
||||
break;
|
||||
case Double:
|
||||
case Float:
|
||||
pyVars.addFloat(colName);
|
||||
break;
|
||||
case String:
|
||||
pyVars.addStr(colName);
|
||||
break;
|
||||
case NDArray:
|
||||
pyVars.addNDArray(colName);
|
||||
break;
|
||||
default:
|
||||
throw new Exception("Unsupported python input type: " + colType.toString());
|
||||
}
|
||||
}
|
||||
return pyVars;
|
||||
}
|
||||
private PythonVariables getPyInputsFromWritables(List<Writable> writables){
|
||||
|
||||
PythonVariables ret = new PythonVariables();
|
||||
|
||||
for (String name: pyInputs.getVariables()){
|
||||
int colIdx = inputSchema.getIndexOfColumn(name);
|
||||
Writable w = writables.get(colIdx);
|
||||
PythonVariables.Type pyType = pyInputs.getType(name);
|
||||
switch (pyType){
|
||||
case INT:
|
||||
if (w instanceof LongWritable){
|
||||
ret.addInt(name, ((LongWritable)w).get());
|
||||
}
|
||||
else{
|
||||
ret.addInt(name, ((IntWritable)w).get());
|
||||
}
|
||||
|
||||
break;
|
||||
case FLOAT:
|
||||
ret.addFloat(name, ((DoubleWritable)w).get());
|
||||
break;
|
||||
case STR:
|
||||
ret.addStr(name, ((Text)w).toString());
|
||||
break;
|
||||
case NDARRAY:
|
||||
ret.addNDArray(name,((NDArrayWritable)w).get());
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@Override
|
||||
public void setInputSchema(Schema inputSchema) {
|
||||
this.inputSchema = inputSchema;
|
||||
|
@ -108,11 +56,12 @@ public class PythonCondition implements Condition {
|
|||
pyInputs = schemaToPythonVariables(inputSchema);
|
||||
PythonVariables pyOuts = new PythonVariables();
|
||||
pyOuts.addInt("out");
|
||||
pythonTransform = new PythonTransform(
|
||||
code + "\n\nout=f()\nout=0 if out is None else int(out)", // TODO: remove int conversion after boolean support is covered
|
||||
pyInputs,
|
||||
pyOuts
|
||||
);
|
||||
pythonTransform = PythonTransform.builder()
|
||||
.code(code + "\n\nout=f()\nout=0 if out is None else int(out)")
|
||||
.inputs(pyInputs)
|
||||
.outputs(pyOuts)
|
||||
.build();
|
||||
|
||||
}
|
||||
catch (Exception e){
|
||||
throw new RuntimeException(e);
|
||||
|
@ -127,28 +76,34 @@ public class PythonCondition implements Condition {
|
|||
return inputSchema;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] outputColumnNames() {
|
||||
String[] columnNames = new String[inputSchema.numColumns()];
|
||||
inputSchema.getColumnNames().toArray(columnNames);
|
||||
return columnNames;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String outputColumnName(){
|
||||
return outputColumnNames()[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] columnNames(){
|
||||
return outputColumnNames();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String columnName(){
|
||||
return outputColumnName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Schema transform(Schema inputSchema){
|
||||
return inputSchema;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean condition(List<Writable> list) {
|
||||
PythonVariables inputs = getPyInputsFromWritables(list);
|
||||
try{
|
||||
|
@ -159,9 +114,9 @@ public class PythonCondition implements Condition {
|
|||
catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean condition(Object input){
|
||||
return condition(input);
|
||||
}
|
||||
|
@ -177,5 +132,37 @@ public class PythonCondition implements Condition {
|
|||
throw new UnsupportedOperationException("not supported");
|
||||
}
|
||||
|
||||
private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
|
||||
PythonVariables ret = new PythonVariables();
|
||||
|
||||
for (int i = 0; i < inputSchema.numColumns(); i++){
|
||||
String name = inputSchema.getName(i);
|
||||
Writable w = writables.get(i);
|
||||
PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i));
|
||||
switch (pyType){
|
||||
case INT:
|
||||
if (w instanceof LongWritable) {
|
||||
ret.addInt(name, ((LongWritable)w).get());
|
||||
}
|
||||
else {
|
||||
ret.addInt(name, ((IntWritable)w).get());
|
||||
}
|
||||
|
||||
break;
|
||||
case FLOAT:
|
||||
ret.addFloat(name, ((DoubleWritable)w).get());
|
||||
break;
|
||||
case STR:
|
||||
ret.addStr(name, w.toString());
|
||||
break;
|
||||
case NDARRAY:
|
||||
ret.addNDArray(name,((NDArrayWritable)w).get());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -16,16 +16,29 @@
|
|||
|
||||
package org.datavec.python;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.datavec.api.transform.ColumnType;
|
||||
import org.datavec.api.transform.Transform;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
|
||||
|
||||
/**
|
||||
* Row-wise Transform that applies arbitrary python code on each row
|
||||
*
|
||||
|
@ -35,30 +48,86 @@ import java.util.UUID;
|
|||
@NoArgsConstructor
|
||||
@Data
|
||||
public class PythonTransform implements Transform {
|
||||
|
||||
private String code;
|
||||
private PythonVariables pyInputs;
|
||||
private PythonVariables pyOutputs;
|
||||
private String name;
|
||||
private PythonVariables inputs;
|
||||
private PythonVariables outputs;
|
||||
private String name = UUID.randomUUID().toString();
|
||||
private Schema inputSchema;
|
||||
private Schema outputSchema;
|
||||
private String outputDict;
|
||||
private boolean returnAllVariables;
|
||||
private boolean setupAndRun = false;
|
||||
|
||||
|
||||
public PythonTransform(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception{
|
||||
@Builder
|
||||
public PythonTransform(String code,
|
||||
PythonVariables inputs,
|
||||
PythonVariables outputs,
|
||||
String name,
|
||||
Schema inputSchema,
|
||||
Schema outputSchema,
|
||||
String outputDict,
|
||||
boolean returnAllInputs,
|
||||
boolean setupAndRun) {
|
||||
Preconditions.checkNotNull(code,"No code found to run!");
|
||||
this.code = code;
|
||||
this.pyInputs = pyInputs;
|
||||
this.pyOutputs = pyOutputs;
|
||||
this.name = UUID.randomUUID().toString();
|
||||
this.returnAllVariables = returnAllInputs;
|
||||
this.setupAndRun = setupAndRun;
|
||||
if(inputs != null)
|
||||
this.inputs = inputs;
|
||||
if(outputs != null)
|
||||
this.outputs = outputs;
|
||||
|
||||
if(name != null)
|
||||
this.name = name;
|
||||
if (outputDict != null) {
|
||||
this.outputDict = outputDict;
|
||||
this.outputs = new PythonVariables();
|
||||
this.outputs.addDict(outputDict);
|
||||
|
||||
String helpers;
|
||||
try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) {
|
||||
helpers = IOUtils.toString(is, Charset.defaultCharset());
|
||||
|
||||
}catch (IOException e){
|
||||
throw new RuntimeException("Error reading python code");
|
||||
}
|
||||
this.code += "\n\n" + helpers;
|
||||
this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")";
|
||||
}
|
||||
|
||||
try {
|
||||
if(inputSchema != null) {
|
||||
this.inputSchema = inputSchema;
|
||||
if(inputs == null || inputs.isEmpty()) {
|
||||
this.inputs = schemaToPythonVariables(inputSchema);
|
||||
}
|
||||
}
|
||||
|
||||
if(outputSchema != null) {
|
||||
this.outputSchema = outputSchema;
|
||||
if(outputs == null || outputs.isEmpty()) {
|
||||
this.outputs = schemaToPythonVariables(outputSchema);
|
||||
}
|
||||
}
|
||||
}catch(Exception e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void setInputSchema(Schema inputSchema) {
|
||||
Preconditions.checkNotNull(inputSchema,"No input schema found!");
|
||||
this.inputSchema = inputSchema;
|
||||
try{
|
||||
pyInputs = schemaToPythonVariables(inputSchema);
|
||||
inputs = schemaToPythonVariables(inputSchema);
|
||||
}catch (Exception e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
if (outputSchema == null){
|
||||
if (outputSchema == null && outputDict == null){
|
||||
outputSchema = inputSchema;
|
||||
}
|
||||
|
||||
|
@ -88,12 +157,42 @@ public class PythonTransform implements Transform{
|
|||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public List<Writable> map(List<Writable> writables) {
|
||||
PythonVariables pyInputs = getPyInputsFromWritables(writables);
|
||||
Preconditions.checkNotNull(pyInputs,"Inputs must not be null!");
|
||||
|
||||
|
||||
try{
|
||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
||||
return getWritablesFromPyOutputs(pyOutputs);
|
||||
if (returnAllVariables) {
|
||||
if (setupAndRun){
|
||||
return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs));
|
||||
}
|
||||
return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs));
|
||||
}
|
||||
|
||||
if (outputDict != null) {
|
||||
if (setupAndRun) {
|
||||
PythonExecutioner.execWithSetupAndRun(this, pyInputs);
|
||||
}else{
|
||||
PythonExecutioner.exec(this, pyInputs);
|
||||
}
|
||||
PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict);
|
||||
return getWritablesFromPyOutputs(out);
|
||||
}
|
||||
else {
|
||||
if (setupAndRun) {
|
||||
PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs);
|
||||
}else{
|
||||
PythonExecutioner.exec(code, pyInputs, outputs);
|
||||
}
|
||||
|
||||
return getWritablesFromPyOutputs(outputs);
|
||||
}
|
||||
|
||||
}
|
||||
catch (Exception e){
|
||||
throw new RuntimeException(e);
|
||||
|
@ -102,7 +201,7 @@ public class PythonTransform implements Transform{
|
|||
|
||||
@Override
|
||||
public String[] outputColumnNames(){
|
||||
return pyOutputs.getVariables();
|
||||
return outputs.getVariables();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -111,7 +210,7 @@ public class PythonTransform implements Transform{
|
|||
}
|
||||
@Override
|
||||
public String[] columnNames(){
|
||||
return pyOutputs.getVariables();
|
||||
return outputs.getVariables();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -125,13 +224,12 @@ public class PythonTransform implements Transform{
|
|||
|
||||
|
||||
private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
|
||||
|
||||
PythonVariables ret = new PythonVariables();
|
||||
|
||||
for (String name: pyInputs.getVariables()){
|
||||
for (String name: inputs.getVariables()) {
|
||||
int colIdx = inputSchema.getIndexOfColumn(name);
|
||||
Writable w = writables.get(colIdx);
|
||||
PythonVariables.Type pyType = pyInputs.getType(name);
|
||||
PythonVariables.Type pyType = inputs.getType(name);
|
||||
switch (pyType){
|
||||
case INT:
|
||||
if (w instanceof LongWritable){
|
||||
|
@ -151,11 +249,13 @@ public class PythonTransform implements Transform{
|
|||
}
|
||||
break;
|
||||
case STR:
|
||||
ret.addStr(name, ((Text)w).toString());
|
||||
ret.addStr(name, w.toString());
|
||||
break;
|
||||
case NDARRAY:
|
||||
ret.addNDArray(name,((NDArrayWritable)w).get());
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Unsupported input type:" + pyType);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -164,83 +264,84 @@ public class PythonTransform implements Transform{
|
|||
|
||||
private List<Writable> getWritablesFromPyOutputs(PythonVariables pyOuts) {
|
||||
List<Writable> out = new ArrayList<>();
|
||||
for (int i=0; i<outputSchema.numColumns(); i++){
|
||||
String name = outputSchema.getName(i);
|
||||
PythonVariables.Type pyType = pyOutputs.getType(name);
|
||||
String[] varNames;
|
||||
varNames = pyOuts.getVariables();
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
for (int i = 0; i < varNames.length; i++) {
|
||||
String name = varNames[i];
|
||||
PythonVariables.Type pyType = pyOuts.getType(name);
|
||||
switch (pyType){
|
||||
case INT:
|
||||
out.add((Writable) new LongWritable(pyOuts.getIntValue(name)));
|
||||
schemaBuilder.addColumnLong(name);
|
||||
break;
|
||||
case FLOAT:
|
||||
out.add((Writable) new DoubleWritable(pyOuts.getFloatValue(name)));
|
||||
schemaBuilder.addColumnDouble(name);
|
||||
break;
|
||||
case STR:
|
||||
out.add((Writable) new Text(pyOuts.getStrValue(name)));
|
||||
case DICT:
|
||||
case LIST:
|
||||
schemaBuilder.addColumnString(name);
|
||||
break;
|
||||
case NDARRAY:
|
||||
out.add((Writable) new NDArrayWritable(pyOuts.getNDArrayValue(name).getNd4jArray()));
|
||||
NumpyArray arr = pyOuts.getNDArrayValue(name);
|
||||
schemaBuilder.addColumnNDArray(name, arr.getShape());
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException("Unable to support type " + pyType.name());
|
||||
}
|
||||
}
|
||||
this.outputSchema = schemaBuilder.build();
|
||||
|
||||
|
||||
for (int i = 0; i < varNames.length; i++) {
|
||||
String name = varNames[i];
|
||||
PythonVariables.Type pyType = pyOuts.getType(name);
|
||||
|
||||
switch (pyType){
|
||||
case INT:
|
||||
out.add(new LongWritable(pyOuts.getIntValue(name)));
|
||||
break;
|
||||
case FLOAT:
|
||||
out.add(new DoubleWritable(pyOuts.getFloatValue(name)));
|
||||
break;
|
||||
case STR:
|
||||
out.add(new Text(pyOuts.getStrValue(name)));
|
||||
break;
|
||||
case NDARRAY:
|
||||
NumpyArray arr = pyOuts.getNDArrayValue(name);
|
||||
out.add(new NDArrayWritable(arr.getNd4jArray()));
|
||||
break;
|
||||
case DICT:
|
||||
Map<?, ?> dictValue = pyOuts.getDictValue(name);
|
||||
Map noNullValues = new java.util.HashMap<>();
|
||||
for(Map.Entry entry : dictValue.entrySet()) {
|
||||
if(entry.getValue() != org.json.JSONObject.NULL) {
|
||||
noNullValues.put(entry.getKey(), entry.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(noNullValues)));
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalStateException("Unable to serialize dictionary " + name + " to json!");
|
||||
}
|
||||
break;
|
||||
case LIST:
|
||||
Object[] listValue = pyOuts.getListValue(name);
|
||||
try {
|
||||
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue)));
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException("Unable to support type " + pyType.name());
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
public PythonTransform(String code) throws Exception{
|
||||
this.code = code;
|
||||
this.name = UUID.randomUUID().toString();
|
||||
}
|
||||
private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{
|
||||
PythonVariables pyVars = new PythonVariables();
|
||||
int numCols = schema.numColumns();
|
||||
for (int i=0; i<numCols; i++){
|
||||
String colName = schema.getName(i);
|
||||
ColumnType colType = schema.getType(i);
|
||||
switch (colType){
|
||||
case Long:
|
||||
case Integer:
|
||||
pyVars.addInt(colName);
|
||||
break;
|
||||
case Double:
|
||||
case Float:
|
||||
pyVars.addFloat(colName);
|
||||
break;
|
||||
case String:
|
||||
pyVars.addStr(colName);
|
||||
break;
|
||||
case NDArray:
|
||||
pyVars.addNDArray(colName);
|
||||
break;
|
||||
default:
|
||||
throw new Exception("Unsupported python input type: " + colType.toString());
|
||||
}
|
||||
}
|
||||
return pyVars;
|
||||
}
|
||||
|
||||
public PythonTransform(String code, Schema outputSchema) throws Exception{
|
||||
this.code = code;
|
||||
this.name = UUID.randomUUID().toString();
|
||||
this.outputSchema = outputSchema;
|
||||
this.pyOutputs = schemaToPythonVariables(outputSchema);
|
||||
|
||||
|
||||
}
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public String getCode(){
|
||||
return code;
|
||||
}
|
||||
|
||||
public PythonVariables getInputs() {
|
||||
return pyInputs;
|
||||
}
|
||||
|
||||
public PythonVariables getOutputs() {
|
||||
return pyOutputs;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,306 @@
|
|||
package org.datavec.python;
|
||||
|
||||
import org.datavec.api.transform.ColumnType;
|
||||
import org.datavec.api.transform.metadata.BooleanMetaData;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.json.JSONArray;
|
||||
import org.json.JSONObject;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* List of utilities for executing python transforms.
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class PythonUtils {
|
||||
|
||||
/**
|
||||
* Create a {@link Schema}
|
||||
* from {@link PythonVariables}.
|
||||
* Types are mapped to types of the same name.
|
||||
* @param input the input {@link PythonVariables}
|
||||
* @return the output {@link Schema}
|
||||
*/
|
||||
public static Schema fromPythonVariables(PythonVariables input) {
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0,"Input must have variables. Found none.");
|
||||
for(Map.Entry<String,PythonVariables.Type> entry : input.getVars().entrySet()) {
|
||||
switch(entry.getValue()) {
|
||||
case INT:
|
||||
schemaBuilder.addColumnInteger(entry.getKey());
|
||||
break;
|
||||
case STR:
|
||||
schemaBuilder.addColumnString(entry.getKey());
|
||||
break;
|
||||
case FLOAT:
|
||||
schemaBuilder.addColumnFloat(entry.getKey());
|
||||
break;
|
||||
case NDARRAY:
|
||||
schemaBuilder.addColumnNDArray(entry.getKey(),null);
|
||||
break;
|
||||
case BOOL:
|
||||
schemaBuilder.addColumn(new BooleanMetaData(entry.getKey()));
|
||||
}
|
||||
}
|
||||
|
||||
return schemaBuilder.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link Schema} from an input
|
||||
* {@link PythonVariables}
|
||||
* Types are mapped to types of the same name
|
||||
* @param input the input schema
|
||||
* @return the output python variables.
|
||||
*/
|
||||
public static PythonVariables fromSchema(Schema input) {
|
||||
PythonVariables ret = new PythonVariables();
|
||||
for(int i = 0; i < input.numColumns(); i++) {
|
||||
String currColumnName = input.getName(i);
|
||||
ColumnType columnType = input.getType(i);
|
||||
switch(columnType) {
|
||||
case NDArray:
|
||||
ret.add(currColumnName, PythonVariables.Type.NDARRAY);
|
||||
break;
|
||||
case Boolean:
|
||||
ret.add(currColumnName, PythonVariables.Type.BOOL);
|
||||
break;
|
||||
case Categorical:
|
||||
case String:
|
||||
ret.add(currColumnName, PythonVariables.Type.STR);
|
||||
break;
|
||||
case Double:
|
||||
case Float:
|
||||
ret.add(currColumnName, PythonVariables.Type.FLOAT);
|
||||
break;
|
||||
case Integer:
|
||||
case Long:
|
||||
ret.add(currColumnName, PythonVariables.Type.INT);
|
||||
break;
|
||||
case Bytes:
|
||||
break;
|
||||
case Time:
|
||||
throw new UnsupportedOperationException("Unable to process dates with python yet.");
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
/**
|
||||
* Convert a {@link Schema}
|
||||
* to {@link PythonVariables}
|
||||
* @param schema the input schema
|
||||
* @return the output {@link PythonVariables} where each
|
||||
* name in the map is associated with a column name in the schema.
|
||||
* A proper type is also chosen based on the schema
|
||||
* @throws Exception
|
||||
*/
|
||||
public static PythonVariables schemaToPythonVariables(Schema schema) throws Exception {
|
||||
PythonVariables pyVars = new PythonVariables();
|
||||
int numCols = schema.numColumns();
|
||||
for (int i = 0; i < numCols; i++) {
|
||||
String colName = schema.getName(i);
|
||||
ColumnType colType = schema.getType(i);
|
||||
switch (colType){
|
||||
case Long:
|
||||
case Integer:
|
||||
pyVars.addInt(colName);
|
||||
break;
|
||||
case Double:
|
||||
case Float:
|
||||
pyVars.addFloat(colName);
|
||||
break;
|
||||
case String:
|
||||
pyVars.addStr(colName);
|
||||
break;
|
||||
case NDArray:
|
||||
pyVars.addNDArray(colName);
|
||||
break;
|
||||
default:
|
||||
throw new Exception("Unsupported python input type: " + colType.toString());
|
||||
}
|
||||
}
|
||||
|
||||
return pyVars;
|
||||
}
|
||||
|
||||
|
||||
public static NumpyArray mapToNumpyArray(Map map){
|
||||
String dtypeName = (String)map.get("dtype");
|
||||
DataType dtype;
|
||||
if (dtypeName.equals("float64")){
|
||||
dtype = DataType.DOUBLE;
|
||||
}
|
||||
else if (dtypeName.equals("float32")){
|
||||
dtype = DataType.FLOAT;
|
||||
}
|
||||
else if (dtypeName.equals("int16")){
|
||||
dtype = DataType.SHORT;
|
||||
}
|
||||
else if (dtypeName.equals("int32")){
|
||||
dtype = DataType.INT;
|
||||
}
|
||||
else if (dtypeName.equals("int64")){
|
||||
dtype = DataType.LONG;
|
||||
}
|
||||
else{
|
||||
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||
}
|
||||
List shapeList = (List)map.get("shape");
|
||||
long[] shape = new long[shapeList.size()];
|
||||
for (int i = 0; i < shape.length; i++) {
|
||||
shape[i] = (Long)shapeList.get(i);
|
||||
}
|
||||
|
||||
List strideList = (List)map.get("shape");
|
||||
long[] stride = new long[strideList.size()];
|
||||
for (int i = 0; i < stride.length; i++) {
|
||||
stride[i] = (Long)strideList.get(i);
|
||||
}
|
||||
long address = (Long)map.get("address");
|
||||
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
|
||||
return numpyArray;
|
||||
}
|
||||
|
||||
public static PythonVariables expandInnerDict(PythonVariables pyvars, String key){
|
||||
Map dict = pyvars.getDictValue(key);
|
||||
String[] keys = (String[])dict.keySet().toArray(new String[dict.keySet().size()]);
|
||||
PythonVariables pyvars2 = new PythonVariables();
|
||||
for (String subkey: keys){
|
||||
Object value = dict.get(subkey);
|
||||
if (value instanceof Map){
|
||||
Map map = (Map)value;
|
||||
if (map.containsKey("_is_numpy_array")){
|
||||
pyvars2.addNDArray(subkey, mapToNumpyArray(map));
|
||||
|
||||
}
|
||||
else{
|
||||
pyvars2.addDict(subkey, (Map)value);
|
||||
}
|
||||
|
||||
}
|
||||
else if (value instanceof List){
|
||||
pyvars2.addList(subkey, ((List) value).toArray());
|
||||
}
|
||||
else if (value instanceof String){
|
||||
System.out.println((String)value);
|
||||
pyvars2.addStr(subkey, (String) value);
|
||||
}
|
||||
else if (value instanceof Integer || value instanceof Long) {
|
||||
Number number = (Number) value;
|
||||
pyvars2.addInt(subkey, number.intValue());
|
||||
}
|
||||
else if (value instanceof Float || value instanceof Double) {
|
||||
Number number = (Number) value;
|
||||
pyvars2.addFloat(subkey, number.doubleValue());
|
||||
}
|
||||
else if (value instanceof NumpyArray){
|
||||
pyvars2.addNDArray(subkey, (NumpyArray)value);
|
||||
}
|
||||
else if (value == null){
|
||||
pyvars2.addStr(subkey, "None"); // FixMe
|
||||
}
|
||||
else{
|
||||
throw new RuntimeException("Unsupported type!" + value);
|
||||
}
|
||||
}
|
||||
return pyvars2;
|
||||
}
|
||||
|
||||
public static long[] jsonArrayToLongArray(JSONArray jsonArray){
|
||||
long[] longs = new long[jsonArray.length()];
|
||||
for (int i=0; i<longs.length; i++){
|
||||
|
||||
longs[i] = jsonArray.getLong(i);
|
||||
}
|
||||
return longs;
|
||||
}
|
||||
|
||||
public static Map<String, Object> toMap(JSONObject jsonobj) {
|
||||
Map<String, Object> map = new HashMap<>();
|
||||
String[] keys = (String[])jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]);
|
||||
for (String key: keys){
|
||||
Object value = jsonobj.get(key);
|
||||
if (value instanceof JSONArray) {
|
||||
value = toList((JSONArray) value);
|
||||
} else if (value instanceof JSONObject) {
|
||||
JSONObject jsonobj2 = (JSONObject)value;
|
||||
if (jsonobj2.has("_is_numpy_array")){
|
||||
value = jsonToNumpyArray(jsonobj2);
|
||||
}
|
||||
else{
|
||||
value = toMap(jsonobj2);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
map.put(key, value);
|
||||
} return map;
|
||||
}
|
||||
|
||||
|
||||
public static List<Object> toList(JSONArray array) {
|
||||
List<Object> list = new ArrayList<>();
|
||||
for (int i = 0; i < array.length(); i++) {
|
||||
Object value = array.get(i);
|
||||
if (value instanceof JSONArray) {
|
||||
value = toList((JSONArray) value);
|
||||
} else if (value instanceof JSONObject) {
|
||||
JSONObject jsonobj2 = (JSONObject) value;
|
||||
if (jsonobj2.has("_is_numpy_array")) {
|
||||
value = jsonToNumpyArray(jsonobj2);
|
||||
} else {
|
||||
value = toMap(jsonobj2);
|
||||
}
|
||||
}
|
||||
list.add(value);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
|
||||
private static NumpyArray jsonToNumpyArray(JSONObject map){
|
||||
String dtypeName = (String)map.get("dtype");
|
||||
DataType dtype;
|
||||
if (dtypeName.equals("float64")){
|
||||
dtype = DataType.DOUBLE;
|
||||
}
|
||||
else if (dtypeName.equals("float32")){
|
||||
dtype = DataType.FLOAT;
|
||||
}
|
||||
else if (dtypeName.equals("int16")){
|
||||
dtype = DataType.SHORT;
|
||||
}
|
||||
else if (dtypeName.equals("int32")){
|
||||
dtype = DataType.INT;
|
||||
}
|
||||
else if (dtypeName.equals("int64")){
|
||||
dtype = DataType.LONG;
|
||||
}
|
||||
else{
|
||||
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||
}
|
||||
List shapeList = (List)map.get("shape");
|
||||
long[] shape = new long[shapeList.size()];
|
||||
for (int i = 0; i < shape.length; i++) {
|
||||
shape[i] = (Long)shapeList.get(i);
|
||||
}
|
||||
|
||||
List strideList = (List)map.get("shape");
|
||||
long[] stride = new long[strideList.size()];
|
||||
for (int i = 0; i < stride.length; i++) {
|
||||
stride[i] = (Long)strideList.get(i);
|
||||
}
|
||||
long address = (Long)map.get("address");
|
||||
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
|
||||
return numpyArray;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -17,8 +17,8 @@
|
|||
package org.datavec.python;
|
||||
|
||||
import lombok.Data;
|
||||
import org.json.simple.JSONArray;
|
||||
import org.json.simple.JSONObject;
|
||||
import org.json.JSONObject;
|
||||
import org.json.JSONArray;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
@ -31,8 +31,8 @@ import java.util.*;
|
|||
* @author Fariz Rahman
|
||||
*/
|
||||
|
||||
@Data
|
||||
public class PythonVariables implements Serializable{
|
||||
@lombok.Data
|
||||
public class PythonVariables implements java.io.Serializable {
|
||||
|
||||
public enum Type{
|
||||
BOOL,
|
||||
|
@ -41,23 +41,29 @@ public class PythonVariables implements Serializable{
|
|||
FLOAT,
|
||||
NDARRAY,
|
||||
LIST,
|
||||
FILE
|
||||
FILE,
|
||||
DICT
|
||||
|
||||
}
|
||||
|
||||
private Map<String, String> strVars = new HashMap<String, String>();
|
||||
private Map<String, Long> intVars = new HashMap<String, Long>();
|
||||
private Map<String, Double> floatVars = new HashMap<String, Double>();
|
||||
private Map<String, Boolean> boolVars = new HashMap<String, Boolean>();
|
||||
private Map<String, NumpyArray> ndVars = new HashMap<String, NumpyArray>();
|
||||
private Map<String, Object[]> listVars = new HashMap<String, Object[]>();
|
||||
private Map<String, String> fileVars = new HashMap<String, String>();
|
||||
|
||||
private Map<String, Type> vars = new HashMap<String, Type>();
|
||||
|
||||
private Map<Type, Map> maps = new HashMap<Type, Map>();
|
||||
private java.util.Map<String, String> strVariables = new java.util.LinkedHashMap<>();
|
||||
private java.util.Map<String, Long> intVariables = new java.util.LinkedHashMap<>();
|
||||
private java.util.Map<String, Double> floatVariables = new java.util.LinkedHashMap<>();
|
||||
private java.util.Map<String, Boolean> boolVariables = new java.util.LinkedHashMap<>();
|
||||
private java.util.Map<String, NumpyArray> ndVars = new java.util.LinkedHashMap<>();
|
||||
private java.util.Map<String, Object[]> listVariables = new java.util.LinkedHashMap<>();
|
||||
private java.util.Map<String, String> fileVariables = new java.util.LinkedHashMap<>();
|
||||
private java.util.Map<String, java.util.Map<?,?>> dictVariables = new java.util.LinkedHashMap<>();
|
||||
private java.util.Map<String, Type> vars = new java.util.LinkedHashMap<>();
|
||||
private java.util.Map<Type, java.util.Map> maps = new java.util.LinkedHashMap<>();
|
||||
|
||||
|
||||
/**
|
||||
* Returns a copy of the variable
|
||||
* schema in this array without the values
|
||||
* @return an empty variables clone
|
||||
* with no values
|
||||
*/
|
||||
public PythonVariables copySchema(){
|
||||
PythonVariables ret = new PythonVariables();
|
||||
for (String varName: getVariables()){
|
||||
|
@ -66,15 +72,30 @@ public class PythonVariables implements Serializable{
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
public PythonVariables(){
|
||||
maps.put(Type.BOOL, boolVars);
|
||||
maps.put(Type.STR, strVars);
|
||||
maps.put(Type.INT, intVars);
|
||||
maps.put(Type.FLOAT, floatVars);
|
||||
maps.put(Type.NDARRAY, ndVars);
|
||||
maps.put(Type.LIST, listVars);
|
||||
maps.put(Type.FILE, fileVars);
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
public PythonVariables() {
|
||||
maps.put(PythonVariables.Type.BOOL, boolVariables);
|
||||
maps.put(PythonVariables.Type.STR, strVariables);
|
||||
maps.put(PythonVariables.Type.INT, intVariables);
|
||||
maps.put(PythonVariables.Type.FLOAT, floatVariables);
|
||||
maps.put(PythonVariables.Type.NDARRAY, ndVars);
|
||||
maps.put(PythonVariables.Type.LIST, listVariables);
|
||||
maps.put(PythonVariables.Type.FILE, fileVariables);
|
||||
maps.put(PythonVariables.Type.DICT, dictVariables);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @return true if there are no variables.
|
||||
*/
|
||||
public boolean isEmpty() {
|
||||
return getVariables().length < 1;
|
||||
}
|
||||
|
||||
|
||||
|
@ -105,6 +126,9 @@ public class PythonVariables implements Serializable{
|
|||
break;
|
||||
case FILE:
|
||||
addFile(name);
|
||||
break;
|
||||
case DICT:
|
||||
addDict(name);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -113,225 +137,433 @@ public class PythonVariables implements Serializable{
|
|||
* @param name name of the variable
|
||||
* @param type type of the variable
|
||||
* @param value value of the variable (must be instance of expected type)
|
||||
* @throws Exception
|
||||
*/
|
||||
public void add (String name, Type type, Object value) throws Exception{
|
||||
public void add(String name, Type type, Object value) {
|
||||
add(name, type);
|
||||
setValue(name, value);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
*/
|
||||
public void addDict(String name) {
|
||||
vars.put(name, PythonVariables.Type.DICT);
|
||||
dictVariables.put(name,null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
*/
|
||||
public void addBool(String name){
|
||||
vars.put(name, Type.BOOL);
|
||||
boolVars.put(name, null);
|
||||
vars.put(name, PythonVariables.Type.BOOL);
|
||||
boolVariables.put(name, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
*/
|
||||
public void addStr(String name){
|
||||
vars.put(name, Type.STR);
|
||||
strVars.put(name, null);
|
||||
vars.put(name, PythonVariables.Type.STR);
|
||||
strVariables.put(name, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
*/
|
||||
public void addInt(String name){
|
||||
vars.put(name, Type.INT);
|
||||
intVars.put(name, null);
|
||||
vars.put(name, PythonVariables.Type.INT);
|
||||
intVariables.put(name, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
*/
|
||||
public void addFloat(String name){
|
||||
vars.put(name, Type.FLOAT);
|
||||
floatVars.put(name, null);
|
||||
vars.put(name, PythonVariables.Type.FLOAT);
|
||||
floatVariables.put(name, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
*/
|
||||
public void addNDArray(String name){
|
||||
vars.put(name, Type.NDARRAY);
|
||||
vars.put(name, PythonVariables.Type.NDARRAY);
|
||||
ndVars.put(name, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
*/
|
||||
public void addList(String name){
|
||||
vars.put(name, Type.LIST);
|
||||
listVars.put(name, null);
|
||||
vars.put(name, PythonVariables.Type.LIST);
|
||||
listVariables.put(name, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
*/
|
||||
public void addFile(String name){
|
||||
vars.put(name, Type.FILE);
|
||||
fileVars.put(name, null);
|
||||
vars.put(name, PythonVariables.Type.FILE);
|
||||
fileVariables.put(name, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a boolean variable to
|
||||
* the set of variables
|
||||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addBool(String name, boolean value) {
|
||||
vars.put(name, Type.BOOL);
|
||||
boolVars.put(name, value);
|
||||
vars.put(name, PythonVariables.Type.BOOL);
|
||||
boolVariables.put(name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a string variable to
|
||||
* the set of variables
|
||||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addStr(String name, String value) {
|
||||
vars.put(name, Type.STR);
|
||||
strVars.put(name, value);
|
||||
vars.put(name, PythonVariables.Type.STR);
|
||||
strVariables.put(name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add an int variable to
|
||||
* the set of variables
|
||||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addInt(String name, int value) {
|
||||
vars.put(name, Type.INT);
|
||||
intVars.put(name, (long)value);
|
||||
vars.put(name, PythonVariables.Type.INT);
|
||||
intVariables.put(name, (long)value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a long variable to
|
||||
* the set of variables
|
||||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addInt(String name, long value) {
|
||||
vars.put(name, Type.INT);
|
||||
intVars.put(name, value);
|
||||
vars.put(name, PythonVariables.Type.INT);
|
||||
intVariables.put(name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a double variable to
|
||||
* the set of variables
|
||||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addFloat(String name, double value) {
|
||||
vars.put(name, Type.FLOAT);
|
||||
floatVars.put(name, value);
|
||||
vars.put(name, PythonVariables.Type.FLOAT);
|
||||
floatVariables.put(name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a float variable to
|
||||
* the set of variables
|
||||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addFloat(String name, float value) {
|
||||
vars.put(name, Type.FLOAT);
|
||||
floatVars.put(name, (double)value);
|
||||
vars.put(name, PythonVariables.Type.FLOAT);
|
||||
floatVariables.put(name, (double)value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addNDArray(String name, NumpyArray value) {
|
||||
vars.put(name, Type.NDARRAY);
|
||||
vars.put(name, PythonVariables.Type.NDARRAY);
|
||||
ndVars.put(name, value);
|
||||
}
|
||||
|
||||
public void addNDArray(String name, INDArray value){
|
||||
vars.put(name, Type.NDARRAY);
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
|
||||
vars.put(name, PythonVariables.Type.NDARRAY);
|
||||
ndVars.put(name, new NumpyArray(value));
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addList(String name, Object[] value) {
|
||||
vars.put(name, Type.LIST);
|
||||
listVars.put(name, value);
|
||||
vars.put(name, PythonVariables.Type.LIST);
|
||||
listVariables.put(name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addFile(String name, String value) {
|
||||
vars.put(name, Type.FILE);
|
||||
fileVars.put(name, value);
|
||||
vars.put(name, PythonVariables.Type.FILE);
|
||||
fileVariables.put(name, value);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Add a null variable to
|
||||
* the set of variables
|
||||
* to describe the type but no value
|
||||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addDict(String name, java.util.Map value) {
|
||||
vars.put(name, PythonVariables.Type.DICT);
|
||||
dictVariables.put(name, value);
|
||||
}
|
||||
/**
|
||||
*
|
||||
* @param name name of the variable
|
||||
* @param value new value for the variable
|
||||
* @throws Exception
|
||||
*/
|
||||
public void setValue(String name, Object value) {
|
||||
Type type = vars.get(name);
|
||||
if (type == Type.BOOL){
|
||||
boolVars.put(name, (Boolean)value);
|
||||
if (type == PythonVariables.Type.BOOL){
|
||||
boolVariables.put(name, (Boolean)value);
|
||||
}
|
||||
else if (type == Type.INT){
|
||||
if (value instanceof Long){
|
||||
intVars.put(name, ((Long)value));
|
||||
else if (type == PythonVariables.Type.INT){
|
||||
Number number = (Number) value;
|
||||
intVariables.put(name, number.longValue());
|
||||
}
|
||||
else if (value instanceof Integer){
|
||||
intVars.put(name, ((Integer)value).longValue());
|
||||
|
||||
else if (type == PythonVariables.Type.FLOAT){
|
||||
Number number = (Number) value;
|
||||
floatVariables.put(name, number.doubleValue());
|
||||
}
|
||||
}
|
||||
else if (type == Type.FLOAT){
|
||||
floatVars.put(name, (Double)value);
|
||||
}
|
||||
else if (type == Type.NDARRAY){
|
||||
else if (type == PythonVariables.Type.NDARRAY){
|
||||
if (value instanceof NumpyArray){
|
||||
ndVars.put(name, (NumpyArray)value);
|
||||
}
|
||||
else if (value instanceof INDArray){
|
||||
ndVars.put(name, new NumpyArray((INDArray) value));
|
||||
else if (value instanceof org.nd4j.linalg.api.ndarray.INDArray) {
|
||||
ndVars.put(name, new NumpyArray((org.nd4j.linalg.api.ndarray.INDArray) value));
|
||||
}
|
||||
else{
|
||||
throw new RuntimeException("Unsupported type: " + value.getClass().toString());
|
||||
}
|
||||
}
|
||||
else if (type == Type.LIST){
|
||||
listVars.put(name, (Object[]) value);
|
||||
else if (type == PythonVariables.Type.LIST) {
|
||||
if (value instanceof java.util.List) {
|
||||
value = ((java.util.List) value).toArray();
|
||||
listVariables.put(name, (Object[]) value);
|
||||
}
|
||||
else if (type == Type.FILE){
|
||||
fileVars.put(name, (String)value);
|
||||
else if(value instanceof org.json.JSONArray) {
|
||||
org.json.JSONArray jsonArray = (org.json.JSONArray) value;
|
||||
Object[] copyArr = new Object[jsonArray.length()];
|
||||
for(int i = 0; i < copyArr.length; i++) {
|
||||
copyArr[i] = jsonArray.get(i);
|
||||
}
|
||||
listVariables.put(name, copyArr);
|
||||
|
||||
}
|
||||
else {
|
||||
strVars.put(name, (String)value);
|
||||
listVariables.put(name, (Object[]) value);
|
||||
}
|
||||
}
|
||||
else if(type == PythonVariables.Type.DICT) {
|
||||
dictVariables.put(name,(java.util.Map<?,?>) value);
|
||||
}
|
||||
else if (type == PythonVariables.Type.FILE){
|
||||
fileVariables.put(name, (String)value);
|
||||
}
|
||||
else{
|
||||
strVariables.put(name, (String)value);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Do a general object lookup.
|
||||
* The look up will happen relative to the {@link Type}
|
||||
* of variable is described in the
|
||||
* @param name the name of the variable to get
|
||||
* @return teh value for the variable with the given name
|
||||
*/
|
||||
public Object getValue(String name) {
|
||||
Type type = vars.get(name);
|
||||
Map map = maps.get(type);
|
||||
java.util.Map map = maps.get(type);
|
||||
return map.get(name);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns a boolean variable with the given name.
|
||||
* @param name the variable name to get the value for
|
||||
* @return the retrieved boolean value
|
||||
*/
|
||||
public boolean getBooleanValue(String name) {
|
||||
return boolVariables.get(name);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param name the variable name
|
||||
* @return the dictionary value
|
||||
*/
|
||||
public java.util.Map<?,?> getDictValue(String name) {
|
||||
return dictVariables.get(name);
|
||||
}
|
||||
|
||||
/**
|
||||
/**
|
||||
*
|
||||
* @param name the variable name
|
||||
* @return the string value
|
||||
*/
|
||||
public String getStrValue(String name){
|
||||
return strVars.get(name);
|
||||
return strVariables.get(name);
|
||||
}
|
||||
|
||||
public long getIntValue(String name){
|
||||
return intVars.get(name);
|
||||
/**
|
||||
*
|
||||
* @param name the variable name
|
||||
* @return the long value
|
||||
*/
|
||||
public Long getIntValue(String name){
|
||||
return intVariables.get(name);
|
||||
}
|
||||
|
||||
public double getFloatValue(String name){
|
||||
return floatVars.get(name);
|
||||
/**
|
||||
*
|
||||
* @param name the variable name
|
||||
* @return the float value
|
||||
*/
|
||||
public Double getFloatValue(String name){
|
||||
return floatVariables.get(name);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param name the variable name
|
||||
* @return the numpy array value
|
||||
*/
|
||||
public NumpyArray getNDArrayValue(String name){
|
||||
return ndVars.get(name);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param name the variable name
|
||||
* @return the list value as an object array
|
||||
*/
|
||||
public Object[] getListValue(String name){
|
||||
return listVars.get(name);
|
||||
return listVariables.get(name);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param name the variable name
|
||||
* @return the value of the given file name
|
||||
*/
|
||||
public String getFileValue(String name){
|
||||
return fileVars.get(name);
|
||||
return fileVariables.get(name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the type for the given variable name
|
||||
* @param name the name of the variable to get the type for
|
||||
* @return the type for the given variable
|
||||
*/
|
||||
public Type getType(String name){
|
||||
return vars.get(name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all the variables present as a string array
|
||||
* @return the variable names for this variable sset
|
||||
*/
|
||||
public String[] getVariables() {
|
||||
String[] strArr = new String[vars.size()];
|
||||
return vars.keySet().toArray(strArr);
|
||||
}
|
||||
|
||||
|
||||
public Map<String, Boolean> getBoolVariables(){
|
||||
return boolVars;
|
||||
}
|
||||
public Map<String, String> getStrVariables(){
|
||||
return strVars;
|
||||
}
|
||||
|
||||
public Map<String, Long> getIntVariables(){
|
||||
return intVars;
|
||||
}
|
||||
|
||||
public Map<String, Double> getFloatVariables(){
|
||||
return floatVars;
|
||||
}
|
||||
|
||||
public Map<String, NumpyArray> getNDArrayVariables(){
|
||||
return ndVars;
|
||||
}
|
||||
|
||||
public Map<String, Object[]> getListVariables(){
|
||||
return listVars;
|
||||
}
|
||||
|
||||
public Map<String, String> getFileVariables(){
|
||||
return fileVars;
|
||||
}
|
||||
|
||||
public JSONArray toJSON(){
|
||||
JSONArray arr = new JSONArray();
|
||||
/**
|
||||
* This variables set as its json representation (an array of json objects)
|
||||
* @return the json array output
|
||||
*/
|
||||
public org.json.JSONArray toJSON(){
|
||||
org.json.JSONArray arr = new org.json.JSONArray();
|
||||
for (String varName: getVariables()){
|
||||
JSONObject var = new JSONObject();
|
||||
org.json.JSONObject var = new org.json.JSONObject();
|
||||
var.put("name", varName);
|
||||
String varType = getType(varName).toString();
|
||||
var.put("type", varType);
|
||||
arr.add(var);
|
||||
arr.put(var);
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
public static PythonVariables fromJSON(JSONArray jsonArray){
|
||||
/**
|
||||
* Create a schema from a map.
|
||||
* This is an empty PythonVariables
|
||||
* that just contains names and types with no values
|
||||
* @param inputTypes the input types to convert
|
||||
* @return the schema from the given map
|
||||
*/
|
||||
public static PythonVariables schemaFromMap(java.util.Map<String,String> inputTypes) {
|
||||
PythonVariables ret = new PythonVariables();
|
||||
for(java.util.Map.Entry<String,String> entry : inputTypes.entrySet()) {
|
||||
ret.add(entry.getKey(), PythonVariables.Type.valueOf(entry.getValue()));
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the python variable state relative to the
|
||||
* input json array
|
||||
* @param jsonArray the input json array
|
||||
* @return the python variables based on the input json array
|
||||
*/
|
||||
public static PythonVariables fromJSON(org.json.JSONArray jsonArray){
|
||||
PythonVariables pyvars = new PythonVariables();
|
||||
for (int i=0; i<jsonArray.size(); i++){
|
||||
JSONObject input = (JSONObject) jsonArray.get(i);
|
||||
for (int i = 0; i < jsonArray.length(); i++) {
|
||||
org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i);
|
||||
String varName = (String)input.get("name");
|
||||
String varType = (String)input.get("type");
|
||||
if (varType.equals("BOOL")) {
|
||||
|
@ -355,6 +587,9 @@ public class PythonVariables implements Serializable{
|
|||
else if (varType.equals("NDARRAY")) {
|
||||
pyvars.addNDArray(varName);
|
||||
}
|
||||
else if(varType.equals("DICT")) {
|
||||
pyvars.addDict(varName);
|
||||
}
|
||||
}
|
||||
|
||||
return pyvars;
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
#See: https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script
|
||||
import sys
|
||||
this = sys.modules[__name__]
|
||||
for n in dir():
|
||||
if n[0]!='_': delattr(this, n)
|
|
@ -0,0 +1 @@
|
|||
loc = {}
|
|
@ -0,0 +1,20 @@
|
|||
|
||||
def __is_numpy_array(x):
|
||||
return str(type(x))== "<class 'numpy.ndarray'>"
|
||||
|
||||
def maybe_serialize_ndarray_metadata(x):
|
||||
return serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
|
||||
|
||||
|
||||
def serialize_ndarray_metadata(x):
|
||||
return {"address": x.__array_interface__['data'][0],
|
||||
"shape": x.shape,
|
||||
"strides": x.strides,
|
||||
"dtype": str(x.dtype),
|
||||
"_is_numpy_array": True} if __is_numpy_array(x) else x
|
||||
|
||||
|
||||
def is_json_ready(key, value):
|
||||
return key is not 'f2' and not inspect.ismodule(value) \
|
||||
and not hasattr(value, '__call__')
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
#patch
|
||||
|
||||
"""Implementation of __array_function__ overrides from NEP-18."""
|
||||
import collections
|
||||
import functools
|
||||
import os
|
||||
|
||||
from numpy.core._multiarray_umath import (
|
||||
add_docstring, implement_array_function, _get_implementing_args)
|
||||
from numpy.compat._inspect import getargspec
|
||||
|
||||
|
||||
ENABLE_ARRAY_FUNCTION = bool(
|
||||
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
|
||||
|
||||
|
||||
ARRAY_FUNCTION_ENABLED = ENABLE_ARRAY_FUNCTION # backward compat
|
||||
|
||||
|
||||
_add_docstring = add_docstring
|
||||
|
||||
|
||||
def add_docstring(*args):
|
||||
try:
|
||||
_add_docstring(*args)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
add_docstring(
|
||||
implement_array_function,
|
||||
"""
|
||||
Implement a function with checks for __array_function__ overrides.
|
||||
|
||||
All arguments are required, and can only be passed by position.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
implementation : function
|
||||
Function that implements the operation on NumPy array without
|
||||
overrides when called like ``implementation(*args, **kwargs)``.
|
||||
public_api : function
|
||||
Function exposed by NumPy's public API originally called like
|
||||
``public_api(*args, **kwargs)`` on which arguments are now being
|
||||
checked.
|
||||
relevant_args : iterable
|
||||
Iterable of arguments to check for __array_function__ methods.
|
||||
args : tuple
|
||||
Arbitrary positional arguments originally passed into ``public_api``.
|
||||
kwargs : dict
|
||||
Arbitrary keyword arguments originally passed into ``public_api``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Result from calling ``implementation()`` or an ``__array_function__``
|
||||
method, as appropriate.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError : if no implementation is found.
|
||||
""")
|
||||
|
||||
|
||||
# exposed for testing purposes; used internally by implement_array_function
|
||||
add_docstring(
|
||||
_get_implementing_args,
|
||||
"""
|
||||
Collect arguments on which to call __array_function__.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
relevant_args : iterable of array-like
|
||||
Iterable of possibly array-like arguments to check for
|
||||
__array_function__ methods.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Sequence of arguments with __array_function__ methods, in the order in
|
||||
which they should be called.
|
||||
""")
|
||||
|
||||
|
||||
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
|
||||
|
||||
|
||||
def verify_matching_signatures(implementation, dispatcher):
|
||||
"""Verify that a dispatcher function has the right signature."""
|
||||
implementation_spec = ArgSpec(*getargspec(implementation))
|
||||
dispatcher_spec = ArgSpec(*getargspec(dispatcher))
|
||||
|
||||
if (implementation_spec.args != dispatcher_spec.args or
|
||||
implementation_spec.varargs != dispatcher_spec.varargs or
|
||||
implementation_spec.keywords != dispatcher_spec.keywords or
|
||||
(bool(implementation_spec.defaults) !=
|
||||
bool(dispatcher_spec.defaults)) or
|
||||
(implementation_spec.defaults is not None and
|
||||
len(implementation_spec.defaults) !=
|
||||
len(dispatcher_spec.defaults))):
|
||||
raise RuntimeError('implementation and dispatcher for %s have '
|
||||
'different function signatures' % implementation)
|
||||
|
||||
if implementation_spec.defaults is not None:
|
||||
if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
|
||||
raise RuntimeError('dispatcher functions can only use None for '
|
||||
'default argument values')
|
||||
|
||||
|
||||
def set_module(module):
|
||||
"""Decorator for overriding __module__ on a function or class.
|
||||
|
||||
Example usage::
|
||||
|
||||
@set_module('numpy')
|
||||
def example():
|
||||
pass
|
||||
|
||||
assert example.__module__ == 'numpy'
|
||||
"""
|
||||
def decorator(func):
|
||||
if module is not None:
|
||||
func.__module__ = module
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def array_function_dispatch(dispatcher, module=None, verify=True,
|
||||
docs_from_dispatcher=False):
|
||||
"""Decorator for adding dispatch with the __array_function__ protocol.
|
||||
|
||||
See NEP-18 for example usage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dispatcher : callable
|
||||
Function that when called like ``dispatcher(*args, **kwargs)`` with
|
||||
arguments from the NumPy function call returns an iterable of
|
||||
array-like arguments to check for ``__array_function__``.
|
||||
module : str, optional
|
||||
__module__ attribute to set on new function, e.g., ``module='numpy'``.
|
||||
By default, module is copied from the decorated function.
|
||||
verify : bool, optional
|
||||
If True, verify the that the signature of the dispatcher and decorated
|
||||
function signatures match exactly: all required and optional arguments
|
||||
should appear in order with the same names, but the default values for
|
||||
all optional arguments should be ``None``. Only disable verification
|
||||
if the dispatcher's signature needs to deviate for some particular
|
||||
reason, e.g., because the function has a signature like
|
||||
``func(*args, **kwargs)``.
|
||||
docs_from_dispatcher : bool, optional
|
||||
If True, copy docs from the dispatcher function onto the dispatched
|
||||
function, rather than from the implementation. This is useful for
|
||||
functions defined in C, which otherwise don't have docstrings.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Function suitable for decorating the implementation of a NumPy function.
|
||||
"""
|
||||
|
||||
if not ENABLE_ARRAY_FUNCTION:
|
||||
# __array_function__ requires an explicit opt-in for now
|
||||
def decorator(implementation):
|
||||
if module is not None:
|
||||
implementation.__module__ = module
|
||||
if docs_from_dispatcher:
|
||||
add_docstring(implementation, dispatcher.__doc__)
|
||||
return implementation
|
||||
return decorator
|
||||
|
||||
def decorator(implementation):
|
||||
if verify:
|
||||
verify_matching_signatures(implementation, dispatcher)
|
||||
|
||||
if docs_from_dispatcher:
|
||||
add_docstring(implementation, dispatcher.__doc__)
|
||||
|
||||
@functools.wraps(implementation)
|
||||
def public_api(*args, **kwargs):
|
||||
relevant_args = dispatcher(*args, **kwargs)
|
||||
return implement_array_function(
|
||||
implementation, public_api, relevant_args, args, kwargs)
|
||||
|
||||
if module is not None:
|
||||
public_api.__module__ = module
|
||||
|
||||
# TODO: remove this when we drop Python 2 support (functools.wraps)
|
||||
# adds __wrapped__ automatically in later versions)
|
||||
public_api.__wrapped__ = implementation
|
||||
|
||||
return public_api
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def array_function_from_dispatcher(
|
||||
implementation, module=None, verify=True, docs_from_dispatcher=True):
|
||||
"""Like array_function_dispatcher, but with function arguments flipped."""
|
||||
|
||||
def decorator(dispatcher):
|
||||
return array_function_dispatch(
|
||||
dispatcher, module, verify=verify,
|
||||
docs_from_dispatcher=docs_from_dispatcher)(implementation)
|
||||
return decorator
|
|
@ -0,0 +1,172 @@
|
|||
#patch 1
|
||||
|
||||
"""
|
||||
========================
|
||||
Random Number Generation
|
||||
========================
|
||||
|
||||
==================== =========================================================
|
||||
Utility functions
|
||||
==============================================================================
|
||||
random_sample Uniformly distributed floats over ``[0, 1)``.
|
||||
random Alias for `random_sample`.
|
||||
bytes Uniformly distributed random bytes.
|
||||
random_integers Uniformly distributed integers in a given range.
|
||||
permutation Randomly permute a sequence / generate a random sequence.
|
||||
shuffle Randomly permute a sequence in place.
|
||||
seed Seed the random number generator.
|
||||
choice Random sample from 1-D array.
|
||||
|
||||
==================== =========================================================
|
||||
|
||||
==================== =========================================================
|
||||
Compatibility functions
|
||||
==============================================================================
|
||||
rand Uniformly distributed values.
|
||||
randn Normally distributed values.
|
||||
ranf Uniformly distributed floating point numbers.
|
||||
randint Uniformly distributed integers in a given range.
|
||||
==================== =========================================================
|
||||
|
||||
==================== =========================================================
|
||||
Univariate distributions
|
||||
==============================================================================
|
||||
beta Beta distribution over ``[0, 1]``.
|
||||
binomial Binomial distribution.
|
||||
chisquare :math:`\\chi^2` distribution.
|
||||
exponential Exponential distribution.
|
||||
f F (Fisher-Snedecor) distribution.
|
||||
gamma Gamma distribution.
|
||||
geometric Geometric distribution.
|
||||
gumbel Gumbel distribution.
|
||||
hypergeometric Hypergeometric distribution.
|
||||
laplace Laplace distribution.
|
||||
logistic Logistic distribution.
|
||||
lognormal Log-normal distribution.
|
||||
logseries Logarithmic series distribution.
|
||||
negative_binomial Negative binomial distribution.
|
||||
noncentral_chisquare Non-central chi-square distribution.
|
||||
noncentral_f Non-central F distribution.
|
||||
normal Normal / Gaussian distribution.
|
||||
pareto Pareto distribution.
|
||||
poisson Poisson distribution.
|
||||
power Power distribution.
|
||||
rayleigh Rayleigh distribution.
|
||||
triangular Triangular distribution.
|
||||
uniform Uniform distribution.
|
||||
vonmises Von Mises circular distribution.
|
||||
wald Wald (inverse Gaussian) distribution.
|
||||
weibull Weibull distribution.
|
||||
zipf Zipf's distribution over ranked data.
|
||||
==================== =========================================================
|
||||
|
||||
==================== =========================================================
|
||||
Multivariate distributions
|
||||
==============================================================================
|
||||
dirichlet Multivariate generalization of Beta distribution.
|
||||
multinomial Multivariate generalization of the binomial distribution.
|
||||
multivariate_normal Multivariate generalization of the normal distribution.
|
||||
==================== =========================================================
|
||||
|
||||
==================== =========================================================
|
||||
Standard distributions
|
||||
==============================================================================
|
||||
standard_cauchy Standard Cauchy-Lorentz distribution.
|
||||
standard_exponential Standard exponential distribution.
|
||||
standard_gamma Standard Gamma distribution.
|
||||
standard_normal Standard normal distribution.
|
||||
standard_t Standard Student's t-distribution.
|
||||
==================== =========================================================
|
||||
|
||||
==================== =========================================================
|
||||
Internal functions
|
||||
==============================================================================
|
||||
get_state Get tuple representing internal state of generator.
|
||||
set_state Set state of generator.
|
||||
==================== =========================================================
|
||||
|
||||
"""
|
||||
from __future__ import division, absolute_import, print_function
|
||||
|
||||
import warnings
|
||||
|
||||
__all__ = [
|
||||
'beta',
|
||||
'binomial',
|
||||
'bytes',
|
||||
'chisquare',
|
||||
'choice',
|
||||
'dirichlet',
|
||||
'exponential',
|
||||
'f',
|
||||
'gamma',
|
||||
'geometric',
|
||||
'get_state',
|
||||
'gumbel',
|
||||
'hypergeometric',
|
||||
'laplace',
|
||||
'logistic',
|
||||
'lognormal',
|
||||
'logseries',
|
||||
'multinomial',
|
||||
'multivariate_normal',
|
||||
'negative_binomial',
|
||||
'noncentral_chisquare',
|
||||
'noncentral_f',
|
||||
'normal',
|
||||
'pareto',
|
||||
'permutation',
|
||||
'poisson',
|
||||
'power',
|
||||
'rand',
|
||||
'randint',
|
||||
'randn',
|
||||
'random_integers',
|
||||
'random_sample',
|
||||
'rayleigh',
|
||||
'seed',
|
||||
'set_state',
|
||||
'shuffle',
|
||||
'standard_cauchy',
|
||||
'standard_exponential',
|
||||
'standard_gamma',
|
||||
'standard_normal',
|
||||
'standard_t',
|
||||
'triangular',
|
||||
'uniform',
|
||||
'vonmises',
|
||||
'wald',
|
||||
'weibull',
|
||||
'zipf'
|
||||
]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
|
||||
try:
|
||||
from .mtrand import *
|
||||
# Some aliases:
|
||||
ranf = random = sample = random_sample
|
||||
__all__.extend(['ranf', 'random', 'sample'])
|
||||
except:
|
||||
warnings.warn("numpy.random is not available when using multiple interpreters!")
|
||||
|
||||
|
||||
|
||||
def __RandomState_ctor():
|
||||
"""Return a RandomState instance.
|
||||
|
||||
This function exists solely to assist (un)pickling.
|
||||
|
||||
Note that the state of the RandomState returned here is irrelevant, as this function's
|
||||
entire purpose is to return a newly allocated RandomState whose state pickle can set.
|
||||
Consequently the RandomState returned by this function is a freshly allocated copy
|
||||
with a seed=0.
|
||||
|
||||
See https://github.com/numpy/numpy/issues/4763 for a detailed discussion
|
||||
|
||||
"""
|
||||
return RandomState(seed=0)
|
||||
|
||||
from numpy._pytesttester import PytestTester
|
||||
test = PytestTester(__name__)
|
||||
del PytestTester
|
|
@ -0,0 +1,20 @@
|
|||
import sys
|
||||
import traceback
|
||||
import json
|
||||
import inspect
|
||||
|
||||
|
||||
try:
|
||||
|
||||
pass
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
except Exception as ex:
|
||||
try:
|
||||
exc_info = sys.exc_info()
|
||||
finally:
|
||||
print(ex)
|
||||
traceback.print_exception(*exc_info)
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
def __is_numpy_array(x):
|
||||
return str(type(x))== "<class 'numpy.ndarray'>"
|
||||
|
||||
def __maybe_serialize_ndarray_metadata(x):
|
||||
return __serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
|
||||
|
||||
|
||||
def __serialize_ndarray_metadata(x):
|
||||
return {"address": x.__array_interface__['data'][0],
|
||||
"shape": x.shape,
|
||||
"strides": x.strides,
|
||||
"dtype": str(x.dtype),
|
||||
"_is_numpy_array": True} if __is_numpy_array(x) else x
|
||||
|
||||
|
||||
def __serialize_list(x):
|
||||
import json
|
||||
return json.dumps(__recursive_serialize_list(x))
|
||||
|
||||
|
||||
def __serialize_dict(x):
|
||||
import json
|
||||
return json.dumps(__recursive_serialize_dict(x))
|
||||
|
||||
def __recursive_serialize_list(x):
|
||||
out = []
|
||||
for i in x:
|
||||
if __is_numpy_array(i):
|
||||
out.append(__serialize_ndarray_metadata(i))
|
||||
elif isinstance(i, (list, tuple)):
|
||||
out.append(__recursive_serialize_list(i))
|
||||
elif isinstance(i, dict):
|
||||
out.append(__recursive_serialize_dict(i))
|
||||
else:
|
||||
out.append(i)
|
||||
return out
|
||||
|
||||
def __recursive_serialize_dict(x):
|
||||
out = {}
|
||||
for k in x:
|
||||
v = x[k]
|
||||
if __is_numpy_array(v):
|
||||
out[k] = __serialize_ndarray_metadata(v)
|
||||
elif isinstance(v, (list, tuple)):
|
||||
out[k] = __recursive_serialize_list(v)
|
||||
elif isinstance(v, dict):
|
||||
out[k] = __recursive_serialize_dict(v)
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
|
@ -0,0 +1,75 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.python;
|
||||
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
@javax.annotation.concurrent.NotThreadSafe
|
||||
public class TestPythonExecutionSandbox {
|
||||
|
||||
@Test
|
||||
public void testInt(){
|
||||
PythonExecutioner.setInterpreter("interp1");
|
||||
PythonExecutioner.exec("a = 1");
|
||||
PythonExecutioner.setInterpreter("interp2");
|
||||
PythonExecutioner.exec("a = 2");
|
||||
PythonExecutioner.setInterpreter("interp3");
|
||||
PythonExecutioner.exec("a = 3");
|
||||
|
||||
|
||||
PythonExecutioner.setInterpreter("interp1");
|
||||
Assert.assertEquals(1, PythonExecutioner.evalInteger("a"));
|
||||
|
||||
PythonExecutioner.setInterpreter("interp2");
|
||||
Assert.assertEquals(2, PythonExecutioner.evalInteger("a"));
|
||||
|
||||
PythonExecutioner.setInterpreter("interp3");
|
||||
Assert.assertEquals(3, PythonExecutioner.evalInteger("a"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNDArray(){
|
||||
PythonExecutioner.setInterpreter("main");
|
||||
PythonExecutioner.exec("import numpy as np");
|
||||
PythonExecutioner.exec("a = np.zeros(5)");
|
||||
|
||||
PythonExecutioner.setInterpreter("main");
|
||||
//PythonExecutioner.exec("import numpy as np");
|
||||
PythonExecutioner.exec("a = np.zeros(5)");
|
||||
|
||||
PythonExecutioner.setInterpreter("main");
|
||||
PythonExecutioner.exec("a += 2");
|
||||
|
||||
PythonExecutioner.setInterpreter("main");
|
||||
PythonExecutioner.exec("a += 3");
|
||||
|
||||
PythonExecutioner.setInterpreter("main");
|
||||
//PythonExecutioner.exec("import numpy as np");
|
||||
// PythonExecutioner.exec("a = np.zeros(5)");
|
||||
|
||||
PythonExecutioner.setInterpreter("main");
|
||||
Assert.assertEquals(25, PythonExecutioner.evalNdArray("a").getNd4jArray().sum().getDouble(), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNumpyRandom(){
|
||||
PythonExecutioner.setInterpreter("main");
|
||||
PythonExecutioner.exec("import numpy as np; print(np.random.randint(5))");
|
||||
}
|
||||
}
|
|
@ -15,17 +15,25 @@
|
|||
******************************************************************************/
|
||||
|
||||
package org.datavec.python;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
|
||||
|
||||
@javax.annotation.concurrent.NotThreadSafe
|
||||
public class TestPythonExecutioner {
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
|
||||
@org.junit.Test
|
||||
public void testPythonSysVersion() {
|
||||
PythonExecutioner.exec("import sys; print(sys.version)");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStr() throws Exception{
|
||||
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
|
@ -47,7 +55,7 @@ public class TestPythonExecutioner {
|
|||
assertEquals("Hello World", z);
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
@Test
|
||||
public void testInt()throws Exception{
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
PythonVariables pyOutputs = new PythonVariables();
|
||||
|
@ -64,11 +72,11 @@ public class TestPythonExecutioner {
|
|||
|
||||
long z = pyOutputs.getIntValue("z");
|
||||
|
||||
assertEquals(30, z);
|
||||
Assert.assertEquals(30, z);
|
||||
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
@Test
|
||||
public void testList() throws Exception{
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
PythonVariables pyOutputs = new PythonVariables();
|
||||
|
@ -88,18 +96,35 @@ public class TestPythonExecutioner {
|
|||
|
||||
Object[] z = pyOutputs.getListValue("z");
|
||||
|
||||
assertEquals(z.length, x.length + y.length);
|
||||
Assert.assertEquals(z.length, x.length + y.length);
|
||||
|
||||
for (int i = 0; i < x.length; i++) {
|
||||
assertEquals(x[i], z[i]);
|
||||
if(x[i] instanceof Number) {
|
||||
Number xNum = (Number) x[i];
|
||||
Number zNum = (Number) z[i];
|
||||
Assert.assertEquals(xNum.intValue(), zNum.intValue());
|
||||
}
|
||||
else {
|
||||
Assert.assertEquals(x[i], z[i]);
|
||||
}
|
||||
|
||||
}
|
||||
for (int i = 0; i < y.length; i++){
|
||||
assertEquals(y[i], z[x.length + i]);
|
||||
if(y[i] instanceof Number) {
|
||||
Number yNum = (Number) y[i];
|
||||
Number zNum = (Number) z[x.length + i];
|
||||
Assert.assertEquals(yNum.intValue(), zNum.intValue());
|
||||
}
|
||||
else {
|
||||
Assert.assertEquals(y[i], z[x.length + i]);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
@Test
|
||||
public void testNDArrayFloat()throws Exception{
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
PythonVariables pyOutputs = new PythonVariables();
|
||||
|
@ -113,12 +138,17 @@ public class TestPythonExecutioner {
|
|||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
||||
|
||||
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
@Test
|
||||
public void testTensorflowCustomAnaconda() {
|
||||
PythonExecutioner.exec("import tensorflow as tf");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNDArrayDouble()throws Exception {
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
PythonVariables pyOutputs = new PythonVariables();
|
||||
|
@ -132,10 +162,10 @@ public class TestPythonExecutioner {
|
|||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
||||
|
||||
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
@Test
|
||||
public void testNDArrayShort()throws Exception{
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
PythonVariables pyOutputs = new PythonVariables();
|
||||
|
@ -149,11 +179,11 @@ public class TestPythonExecutioner {
|
|||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
||||
|
||||
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||
}
|
||||
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
@Test
|
||||
public void testNDArrayInt()throws Exception{
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
PythonVariables pyOutputs = new PythonVariables();
|
||||
|
@ -167,11 +197,11 @@ public class TestPythonExecutioner {
|
|||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
||||
|
||||
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
@Test
|
||||
public void testNDArrayLong()throws Exception{
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
PythonVariables pyOutputs = new PythonVariables();
|
||||
|
@ -185,7 +215,7 @@ public class TestPythonExecutioner {
|
|||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
||||
|
||||
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
package org.datavec.python;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@javax.annotation.concurrent.NotThreadSafe
|
||||
public class TestPythonSetupAndRun {
|
||||
@Test
|
||||
public void testPythonWithSetupAndRun() throws Exception{
|
||||
String code = "def setup():" +
|
||||
"global counter;counter=0\n" +
|
||||
"def run(step):" +
|
||||
"global counter;" +
|
||||
"counter+=step;" +
|
||||
"return {\"counter\":counter}";
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
pyInputs.addInt("step", 2);
|
||||
PythonVariables pyOutputs = new PythonVariables();
|
||||
pyOutputs.addInt("counter");
|
||||
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
|
||||
assertEquals((long)pyOutputs.getIntValue("counter"), 2L);
|
||||
pyInputs.addInt("step", 3);
|
||||
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
|
||||
assertEquals((long)pyOutputs.getIntValue("counter"), 5L);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
/*
|
||||
*
|
||||
* * ******************************************************************************
|
||||
* * * Copyright (c) 2015-2019 Skymind Inc.
|
||||
* * * Copyright (c) 2019 Konduit AI.
|
||||
* * *
|
||||
* * * This program and the accompanying materials are made available under the
|
||||
* * * terms of the Apache License, Version 2.0 which is available at
|
||||
* * * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* * *
|
||||
* * * Unless required by applicable law or agreed to in writing, software
|
||||
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * * License for the specific language governing permissions and limitations
|
||||
* * * under the License.
|
||||
* * *
|
||||
* * * SPDX-License-Identifier: Apache-2.0
|
||||
* * *****************************************************************************
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
package org.datavec.python;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static junit.framework.TestCase.assertNotNull;
|
||||
import static junit.framework.TestCase.assertNull;
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class TestPythonVariables {
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testImportNumpy(){
|
||||
Nd4j.scalar(1.0);
|
||||
System.out.println(System.getProperty("org.bytedeco.openblas.load"));
|
||||
PythonExecutioner.exec("import numpy as np");
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testDataAssociations() {
|
||||
PythonVariables pythonVariables = new PythonVariables();
|
||||
PythonVariables.Type[] types = {
|
||||
PythonVariables.Type.INT,
|
||||
PythonVariables.Type.FLOAT,
|
||||
PythonVariables.Type.STR,
|
||||
PythonVariables.Type.BOOL,
|
||||
PythonVariables.Type.DICT,
|
||||
PythonVariables.Type.LIST,
|
||||
PythonVariables.Type.LIST,
|
||||
PythonVariables.Type.FILE,
|
||||
PythonVariables.Type.NDARRAY
|
||||
};
|
||||
|
||||
NumpyArray npArr = new NumpyArray(Nd4j.scalar(1.0));
|
||||
Object[] values = {
|
||||
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
||||
new Object[]{1}, Arrays.asList(1),"type", npArr
|
||||
};
|
||||
|
||||
Object[] expectedValues = {
|
||||
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
||||
new Object[]{1}, new Object[]{1},"type", npArr
|
||||
};
|
||||
|
||||
for(int i = 0; i < types.length; i++) {
|
||||
testInsertGet(pythonVariables,types[i].name() + i,values[i],types[i],expectedValues[i]);
|
||||
}
|
||||
|
||||
assertEquals(types.length,pythonVariables.getVariables().length);
|
||||
|
||||
}
|
||||
|
||||
private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonVariables.Type type,Object expectedValue) {
|
||||
pythonVariables.add(key, type);
|
||||
assertNull(pythonVariables.getValue(key));
|
||||
pythonVariables.setValue(key,value);
|
||||
assertNotNull(pythonVariables.getValue(key));
|
||||
Object actualValue = pythonVariables.getValue(key);
|
||||
if (expectedValue instanceof Object[]){
|
||||
assertTrue(actualValue instanceof Object[]);
|
||||
Object[] actualArr = (Object[])actualValue;
|
||||
Object[] expectedArr = (Object[])expectedValue;
|
||||
assertArrayEquals(expectedArr, actualArr);
|
||||
}
|
||||
else{
|
||||
assertEquals(expectedValue,pythonVariables.getValue(key));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -29,7 +29,7 @@ public class TestSerde {
|
|||
public static JsonSerializer j = new JsonSerializer();
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
public void testBasicSerde() throws Exception{
|
||||
public void testBasicSerde(){
|
||||
Schema schema = new Schema.Builder()
|
||||
.addColumnInteger("col1")
|
||||
.addColumnFloat("col2")
|
||||
|
@ -37,10 +37,9 @@ public class TestSerde {
|
|||
.addColumnDouble("col4")
|
||||
.build();
|
||||
|
||||
Transform t = new PythonTransform(
|
||||
"col1+=3\ncol2+=2\ncol3+='a'\ncol4+=2.0",
|
||||
schema
|
||||
);
|
||||
Transform t = PythonTransform.builder().code(
|
||||
"col1+=3\ncol2+=2\ncol3+='a'\ncol4+=2.0"
|
||||
).inputSchema(schema).outputSchema(schema).build();
|
||||
|
||||
String yaml = y.serialize(t);
|
||||
String json = j.serialize(t);
|
||||
|
|
|
@ -247,10 +247,9 @@ public class ExecutionTest extends BaseSparkTest {
|
|||
.addColumnInteger("col1").addColumnDouble("col2").build();
|
||||
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
|
||||
TransformProcess tp = new TransformProcess.Builder(schema).transform(
|
||||
new PythonTransform(
|
||||
pythonCode,
|
||||
finalSchema
|
||||
)
|
||||
PythonTransform.builder().code(
|
||||
"first = np.sin(first)\nsecond = np.cos(second)")
|
||||
.outputSchema(finalSchema).build()
|
||||
).build();
|
||||
List<List<Writable>> inputData = new ArrayList<>();
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
|
@ -288,10 +287,9 @@ public class ExecutionTest extends BaseSparkTest {
|
|||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
TransformProcess tp = new TransformProcess.Builder(schema).transform(
|
||||
new PythonTransform(
|
||||
pythonCode,
|
||||
finalSchema
|
||||
)
|
||||
PythonTransform.builder().code(
|
||||
"first = np.sin(first)\nsecond = np.cos(second)")
|
||||
.outputSchema(schema).build()
|
||||
).build();
|
||||
|
||||
INDArray zeros = Nd4j.zeros(shape);
|
||||
|
|
2
pom.xml
2
pom.xml
|
@ -294,6 +294,8 @@
|
|||
|
||||
<python.version>3.7.5</python.version>
|
||||
<cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version>
|
||||
<numpy.version>1.17.3</numpy.version>
|
||||
<numpy.javacpp.version>${numpy.version}-${javacpp-presets.version}</numpy.javacpp.version>
|
||||
|
||||
<openblas.version>0.3.7</openblas.version>
|
||||
<mkl.version>2019.5</mkl.version>
|
||||
|
|
Loading…
Reference in New Issue