Python updates (#86)

* python updates

* fix cyclic deps

* konduit updates

* konduit updates

* fix list

* fixes

* sync pyvars test

* setuprun comments

* Version fix, other module test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* bug fix using advanced hacking skillzz
master
Fariz Rahman 2019-12-02 13:50:23 +05:30 committed by Alex Black
parent 8123d9fa9b
commit 1adc25919c
24 changed files with 2724 additions and 691 deletions

View File

@ -256,11 +256,9 @@ public class ExecutionTest {
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.transform(
new PythonTransform(
"first = np.sin(first)\nsecond = np.cos(second)",
schema
)
)
PythonTransform.builder().code(
"first = np.sin(first)\nsecond = np.cos(second)")
.outputSchema(schema).build())
.build();
List<List<Writable>> functions = new ArrayList<>();

View File

@ -14,35 +14,40 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
package org.datavec.local.transforms.transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.condition.Condition;
import org.datavec.api.transform.filter.ConditionFilter;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.writable.*;
import org.datavec.api.transform.schema.Schema;
import org.junit.Ignore;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.datavec.api.writable.*;
import org.datavec.python.PythonCondition;
import org.datavec.python.PythonTransform;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
import static junit.framework.TestCase.assertTrue;
import static org.datavec.api.transform.schema.Schema.Builder;
import static org.junit.Assert.*;
@NotThreadSafe
public class TestPythonTransformProcess {
@Test(timeout = 60000L)
@Test()
public void testStringConcat() throws Exception{
Schema.Builder schemaBuilder = new Schema.Builder();
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnString("col1")
.addColumnString("col2");
@ -54,10 +59,12 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform(pythonCode, finalSchema)
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build();
List<Writable> inputs = Arrays.asList((Writable) new Text("Hello "), new Text("World!"));
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
List<Writable> outputs = tp.execute(inputs);
assertEquals((outputs.get(0)).toString(), "Hello ");
@ -68,7 +75,7 @@ public class TestPythonTransformProcess {
@Test(timeout = 60000L)
public void testMixedTypes() throws Exception{
Schema.Builder schemaBuilder = new Schema.Builder();
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
@ -83,11 +90,12 @@ public class TestPythonTransformProcess {
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform(pythonCode, finalSchema)
).build();
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.inputSchema(initialSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList((Writable)
new IntWritable(10),
List<Writable> inputs = Arrays.asList((Writable)new IntWritable(10),
new FloatWritable(3.5f),
new Text("5"),
new DoubleWritable(2.0)
@ -105,7 +113,7 @@ public class TestPythonTransformProcess {
INDArray expectedOutput = arr1.add(arr2);
Schema.Builder schemaBuilder = new Schema.Builder();
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
@ -116,12 +124,14 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform(pythonCode, finalSchema)
).build();
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(
(Writable) new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
@ -139,7 +149,7 @@ public class TestPythonTransformProcess {
INDArray expectedOutput = arr1.add(arr2);
Schema.Builder schemaBuilder = new Schema.Builder();
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
@ -150,11 +160,13 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform(pythonCode, finalSchema)
).build();
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(
(Writable) new NDArrayWritable(arr1),
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
@ -172,7 +184,7 @@ public class TestPythonTransformProcess {
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
Schema.Builder schemaBuilder = new Schema.Builder();
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
@ -183,11 +195,14 @@ public class TestPythonTransformProcess {
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform(pythonCode, finalSchema)
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build();
List<Writable> inputs = Arrays.asList(
(Writable) new NDArrayWritable(arr1),
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
@ -199,8 +214,8 @@ public class TestPythonTransformProcess {
}
@Test(timeout = 60000L)
public void testPythonFilter(){
Schema schema = new Schema.Builder().addColumnInteger("column").build();
public void testPythonFilter() {
Schema schema = new Builder().addColumnInteger("column").build();
Condition condition = new PythonCondition(
"f = lambda: column < 0"
@ -210,17 +225,17 @@ public class TestPythonTransformProcess {
Filter filter = new ConditionFilter(condition);
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(10))));
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(1))));
assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(0))));
assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-1))));
assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-10))));
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
}
@Test(timeout = 60000L)
public void testPythonFilterAndTransform() throws Exception{
Schema.Builder schemaBuilder = new Schema.Builder();
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
@ -241,33 +256,85 @@ public class TestPythonTransformProcess {
String pythonCode = "col6 = str(col1 + col2)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
new PythonTransform(
pythonCode,
finalSchema
)
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).filter(
filter
).build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(
Arrays.asList((Writable) new IntWritable(5),
Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
Arrays.asList((Writable) new IntWritable(-3),
Arrays.asList(
(Writable)
new IntWritable(-3),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
Arrays.asList((Writable) new IntWritable(5),
Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(11.2f),
new Text("abcd"),
new DoubleWritable(2.1))
);
LocalTransformExecutor.execute(inputs,tp);
}
@Test
public void testPythonTransformNoOutputSpecified() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
Schema inputSchema = new Builder()
.addColumnInteger("a")
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertEquals(3,execute.get(0).get(0).toInt());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testNumpyTransform() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertEquals("hello world",execute.get(0).get(0).toString());
}
}

View File

@ -28,15 +28,21 @@
<dependencies>
<dependency>
<groupId>com.googlecode.json-simple</groupId>
<artifactId>json-simple</artifactId>
<version>1.1</version>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20190722</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>cpython-platform</artifactId>
<version>${cpython-platform.version}</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>numpy-platform</artifactId>
<version>${numpy.javacpp.version}</version>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>

View File

@ -16,10 +16,13 @@
package org.datavec.python;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
@ -33,19 +36,27 @@ import org.nd4j.linalg.api.buffer.DataType;
* @author Fariz Rahman
*/
@Getter
@NoArgsConstructor
public class NumpyArray {
private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
private static NativeOps nativeOps;
private long address;
private long[] shape;
private long[] strides;
private DataType dtype = DataType.FLOAT;
private DataType dtype;
private INDArray nd4jArray;
static {
//initialize
Nd4j.scalar(1.0);
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
}
public NumpyArray(long address, long[] shape, long strides[], boolean copy){
@Builder
public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) {
this.address = address;
this.shape = shape;
this.strides = strides;
this.dtype = dtype;
setND4JArray();
if (copy){
nd4jArray = nd4jArray.dup();
@ -57,8 +68,9 @@ public class NumpyArray {
public NumpyArray copy(){
return new NumpyArray(nd4jArray.dup());
}
public NumpyArray(long address, long[] shape, long strides[]){
this(address, shape, strides, false);
this(address, shape, strides, false,DataType.FLOAT);
}
public NumpyArray(long address, long[] shape, long strides[], DataType dtype){
@ -77,9 +89,9 @@ public class NumpyArray {
}
}
private void setND4JArray(){
private void setND4JArray() {
long size = 1;
for(long d: shape){
for(long d: shape) {
size *= d;
}
Pointer ptr = nativeOps.pointerForAddress(address);
@ -88,10 +100,11 @@ public class NumpyArray {
DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype);
int elemSize = buff.getElementSize();
long[] nd4jStrides = new long[strides.length];
for (int i=0; i<strides.length; i++){
for (int i = 0; i < strides.length; i++) {
nd4jStrides[i] = strides[i] / elemSize;
}
this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, 'c', dtype);
this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype);
}

View File

@ -23,6 +23,8 @@ import org.datavec.api.writable.*;
import java.util.List;
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
/**
* Lets a condition be defined as a python method f that takes no arguments
* and returns a boolean indicating whether or not to filter a row.
@ -38,81 +40,28 @@ public class PythonCondition implements Condition {
private String code;
public PythonCondition(String pythonCode){
public PythonCondition(String pythonCode) {
org.nd4j.base.Preconditions.checkNotNull("Python code must not be null!",pythonCode);
org.nd4j.base.Preconditions.checkState(pythonCode.length() >= 1,"Python code must not be empty!");
code = pythonCode;
}
private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{
PythonVariables pyVars = new PythonVariables();
int numCols = schema.numColumns();
for (int i=0; i<numCols; i++){
String colName = schema.getName(i);
ColumnType colType = schema.getType(i);
switch (colType){
case Long:
case Integer:
pyVars.addInt(colName);
break;
case Double:
case Float:
pyVars.addFloat(colName);
break;
case String:
pyVars.addStr(colName);
break;
case NDArray:
pyVars.addNDArray(colName);
break;
default:
throw new Exception("Unsupported python input type: " + colType.toString());
}
}
return pyVars;
}
private PythonVariables getPyInputsFromWritables(List<Writable> writables){
PythonVariables ret = new PythonVariables();
for (String name: pyInputs.getVariables()){
int colIdx = inputSchema.getIndexOfColumn(name);
Writable w = writables.get(colIdx);
PythonVariables.Type pyType = pyInputs.getType(name);
switch (pyType){
case INT:
if (w instanceof LongWritable){
ret.addInt(name, ((LongWritable)w).get());
}
else{
ret.addInt(name, ((IntWritable)w).get());
}
break;
case FLOAT:
ret.addFloat(name, ((DoubleWritable)w).get());
break;
case STR:
ret.addStr(name, ((Text)w).toString());
break;
case NDARRAY:
ret.addNDArray(name,((NDArrayWritable)w).get());
break;
}
}
return ret;
}
@Override
public void setInputSchema(Schema inputSchema){
public void setInputSchema(Schema inputSchema) {
this.inputSchema = inputSchema;
try{
pyInputs = schemaToPythonVariables(inputSchema);
PythonVariables pyOuts = new PythonVariables();
pyOuts.addInt("out");
pythonTransform = new PythonTransform(
code + "\n\nout=f()\nout=0 if out is None else int(out)", // TODO: remove int conversion after boolean support is covered
pyInputs,
pyOuts
);
pythonTransform = PythonTransform.builder()
.code(code + "\n\nout=f()\nout=0 if out is None else int(out)")
.inputs(pyInputs)
.outputs(pyOuts)
.build();
}
catch (Exception e){
throw new RuntimeException(e);
@ -127,41 +76,47 @@ public class PythonCondition implements Condition {
return inputSchema;
}
public String[] outputColumnNames(){
@Override
public String[] outputColumnNames() {
String[] columnNames = new String[inputSchema.numColumns()];
inputSchema.getColumnNames().toArray(columnNames);
return columnNames;
}
@Override
public String outputColumnName(){
return outputColumnNames()[0];
}
@Override
public String[] columnNames(){
return outputColumnNames();
}
@Override
public String columnName(){
return outputColumnName();
}
@Override
public Schema transform(Schema inputSchema){
return inputSchema;
}
public boolean condition(List<Writable> list){
@Override
public boolean condition(List<Writable> list) {
PythonVariables inputs = getPyInputsFromWritables(list);
try{
PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs());
boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0;
return ret;
}
catch (Exception e){
catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public boolean condition(Object input){
return condition(input);
}
@ -177,5 +132,37 @@ public class PythonCondition implements Condition {
throw new UnsupportedOperationException("not supported");
}
private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
PythonVariables ret = new PythonVariables();
for (int i = 0; i < inputSchema.numColumns(); i++){
String name = inputSchema.getName(i);
Writable w = writables.get(i);
PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i));
switch (pyType){
case INT:
if (w instanceof LongWritable) {
ret.addInt(name, ((LongWritable)w).get());
}
else {
ret.addInt(name, ((IntWritable)w).get());
}
break;
case FLOAT:
ret.addFloat(name, ((DoubleWritable)w).get());
break;
case STR:
ret.addStr(name, w.toString());
break;
case NDARRAY:
ret.addNDArray(name,((NDArrayWritable)w).get());
break;
}
}
return ret;
}
}

View File

@ -16,16 +16,29 @@
package org.datavec.python;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.apache.commons.io.IOUtils;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.nd4j.base.Preconditions;
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
/**
* Row-wise Transform that applies arbitrary python code on each row
*
@ -34,31 +47,87 @@ import java.util.UUID;
@NoArgsConstructor
@Data
public class PythonTransform implements Transform{
public class PythonTransform implements Transform {
private String code;
private PythonVariables pyInputs;
private PythonVariables pyOutputs;
private String name;
private PythonVariables inputs;
private PythonVariables outputs;
private String name = UUID.randomUUID().toString();
private Schema inputSchema;
private Schema outputSchema;
private String outputDict;
private boolean returnAllVariables;
private boolean setupAndRun = false;
public PythonTransform(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception{
@Builder
public PythonTransform(String code,
PythonVariables inputs,
PythonVariables outputs,
String name,
Schema inputSchema,
Schema outputSchema,
String outputDict,
boolean returnAllInputs,
boolean setupAndRun) {
Preconditions.checkNotNull(code,"No code found to run!");
this.code = code;
this.pyInputs = pyInputs;
this.pyOutputs = pyOutputs;
this.name = UUID.randomUUID().toString();
this.returnAllVariables = returnAllInputs;
this.setupAndRun = setupAndRun;
if(inputs != null)
this.inputs = inputs;
if(outputs != null)
this.outputs = outputs;
if(name != null)
this.name = name;
if (outputDict != null) {
this.outputDict = outputDict;
this.outputs = new PythonVariables();
this.outputs.addDict(outputDict);
String helpers;
try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) {
helpers = IOUtils.toString(is, Charset.defaultCharset());
}catch (IOException e){
throw new RuntimeException("Error reading python code");
}
this.code += "\n\n" + helpers;
this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")";
}
try {
if(inputSchema != null) {
this.inputSchema = inputSchema;
if(inputs == null || inputs.isEmpty()) {
this.inputs = schemaToPythonVariables(inputSchema);
}
}
if(outputSchema != null) {
this.outputSchema = outputSchema;
if(outputs == null || outputs.isEmpty()) {
this.outputs = schemaToPythonVariables(outputSchema);
}
}
}catch(Exception e) {
throw new IllegalStateException(e);
}
}
@Override
public void setInputSchema(Schema inputSchema){
public void setInputSchema(Schema inputSchema) {
Preconditions.checkNotNull(inputSchema,"No input schema found!");
this.inputSchema = inputSchema;
try{
pyInputs = schemaToPythonVariables(inputSchema);
inputs = schemaToPythonVariables(inputSchema);
}catch (Exception e){
throw new RuntimeException(e);
}
if (outputSchema == null){
if (outputSchema == null && outputDict == null){
outputSchema = inputSchema;
}
@ -88,12 +157,42 @@ public class PythonTransform implements Transform{
throw new UnsupportedOperationException("Not yet implemented");
}
@Override
public List<Writable> map(List<Writable> writables){
public List<Writable> map(List<Writable> writables) {
PythonVariables pyInputs = getPyInputsFromWritables(writables);
Preconditions.checkNotNull(pyInputs,"Inputs must not be null!");
try{
PythonExecutioner.exec(code, pyInputs, pyOutputs);
return getWritablesFromPyOutputs(pyOutputs);
if (returnAllVariables) {
if (setupAndRun){
return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs));
}
return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs));
}
if (outputDict != null) {
if (setupAndRun) {
PythonExecutioner.execWithSetupAndRun(this, pyInputs);
}else{
PythonExecutioner.exec(this, pyInputs);
}
PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict);
return getWritablesFromPyOutputs(out);
}
else {
if (setupAndRun) {
PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs);
}else{
PythonExecutioner.exec(code, pyInputs, outputs);
}
return getWritablesFromPyOutputs(outputs);
}
}
catch (Exception e){
throw new RuntimeException(e);
@ -102,7 +201,7 @@ public class PythonTransform implements Transform{
@Override
public String[] outputColumnNames(){
return pyOutputs.getVariables();
return outputs.getVariables();
}
@Override
@ -111,7 +210,7 @@ public class PythonTransform implements Transform{
}
@Override
public String[] columnNames(){
return pyOutputs.getVariables();
return outputs.getVariables();
}
@Override
@ -124,14 +223,13 @@ public class PythonTransform implements Transform{
}
private PythonVariables getPyInputsFromWritables(List<Writable> writables){
private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
PythonVariables ret = new PythonVariables();
for (String name: pyInputs.getVariables()){
for (String name: inputs.getVariables()) {
int colIdx = inputSchema.getIndexOfColumn(name);
Writable w = writables.get(colIdx);
PythonVariables.Type pyType = pyInputs.getType(name);
PythonVariables.Type pyType = inputs.getType(name);
switch (pyType){
case INT:
if (w instanceof LongWritable){
@ -143,7 +241,7 @@ public class PythonTransform implements Transform{
break;
case FLOAT:
if (w instanceof DoubleWritable){
if (w instanceof DoubleWritable) {
ret.addFloat(name, ((DoubleWritable)w).get());
}
else{
@ -151,96 +249,99 @@ public class PythonTransform implements Transform{
}
break;
case STR:
ret.addStr(name, ((Text)w).toString());
ret.addStr(name, w.toString());
break;
case NDARRAY:
ret.addNDArray(name,((NDArrayWritable)w).get());
break;
default:
throw new RuntimeException("Unsupported input type:" + pyType);
}
}
return ret;
}
private List<Writable> getWritablesFromPyOutputs(PythonVariables pyOuts){
private List<Writable> getWritablesFromPyOutputs(PythonVariables pyOuts) {
List<Writable> out = new ArrayList<>();
for (int i=0; i<outputSchema.numColumns(); i++){
String name = outputSchema.getName(i);
PythonVariables.Type pyType = pyOutputs.getType(name);
String[] varNames;
varNames = pyOuts.getVariables();
Schema.Builder schemaBuilder = new Schema.Builder();
for (int i = 0; i < varNames.length; i++) {
String name = varNames[i];
PythonVariables.Type pyType = pyOuts.getType(name);
switch (pyType){
case INT:
out.add((Writable) new LongWritable(pyOuts.getIntValue(name)));
schemaBuilder.addColumnLong(name);
break;
case FLOAT:
out.add((Writable) new DoubleWritable(pyOuts.getFloatValue(name)));
schemaBuilder.addColumnDouble(name);
break;
case STR:
out.add((Writable) new Text(pyOuts.getStrValue(name)));
case DICT:
case LIST:
schemaBuilder.addColumnString(name);
break;
case NDARRAY:
out.add((Writable) new NDArrayWritable(pyOuts.getNDArrayValue(name).getNd4jArray()));
NumpyArray arr = pyOuts.getNDArrayValue(name);
schemaBuilder.addColumnNDArray(name, arr.getShape());
break;
default:
throw new IllegalStateException("Unable to support type " + pyType.name());
}
}
this.outputSchema = schemaBuilder.build();
for (int i = 0; i < varNames.length; i++) {
String name = varNames[i];
PythonVariables.Type pyType = pyOuts.getType(name);
switch (pyType){
case INT:
out.add(new LongWritable(pyOuts.getIntValue(name)));
break;
case FLOAT:
out.add(new DoubleWritable(pyOuts.getFloatValue(name)));
break;
case STR:
out.add(new Text(pyOuts.getStrValue(name)));
break;
case NDARRAY:
NumpyArray arr = pyOuts.getNDArrayValue(name);
out.add(new NDArrayWritable(arr.getNd4jArray()));
break;
case DICT:
Map<?, ?> dictValue = pyOuts.getDictValue(name);
Map noNullValues = new java.util.HashMap<>();
for(Map.Entry entry : dictValue.entrySet()) {
if(entry.getValue() != org.json.JSONObject.NULL) {
noNullValues.put(entry.getKey(), entry.getValue());
}
}
try {
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(noNullValues)));
} catch (JsonProcessingException e) {
throw new IllegalStateException("Unable to serialize dictionary " + name + " to json!");
}
break;
case LIST:
Object[] listValue = pyOuts.getListValue(name);
try {
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue)));
} catch (JsonProcessingException e) {
throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
}
break;
default:
throw new IllegalStateException("Unable to support type " + pyType.name());
}
}
return out;
}
public PythonTransform(String code) throws Exception{
this.code = code;
this.name = UUID.randomUUID().toString();
}
private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{
PythonVariables pyVars = new PythonVariables();
int numCols = schema.numColumns();
for (int i=0; i<numCols; i++){
String colName = schema.getName(i);
ColumnType colType = schema.getType(i);
switch (colType){
case Long:
case Integer:
pyVars.addInt(colName);
break;
case Double:
case Float:
pyVars.addFloat(colName);
break;
case String:
pyVars.addStr(colName);
break;
case NDArray:
pyVars.addNDArray(colName);
break;
default:
throw new Exception("Unsupported python input type: " + colType.toString());
}
}
return pyVars;
}
public PythonTransform(String code, Schema outputSchema) throws Exception{
this.code = code;
this.name = UUID.randomUUID().toString();
this.outputSchema = outputSchema;
this.pyOutputs = schemaToPythonVariables(outputSchema);
}
public String getName() {
return name;
}
public String getCode(){
return code;
}
public PythonVariables getInputs() {
return pyInputs;
}
public PythonVariables getOutputs() {
return pyOutputs;
}
}

View File

@ -0,0 +1,306 @@
package org.datavec.python;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.BooleanMetaData;
import org.datavec.api.transform.schema.Schema;
import org.json.JSONArray;
import org.json.JSONObject;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* List of utilities for executing python transforms.
*
* @author Adam Gibson
*/
public class PythonUtils {
/**
* Create a {@link Schema}
* from {@link PythonVariables}.
* Types are mapped to types of the same name.
* @param input the input {@link PythonVariables}
* @return the output {@link Schema}
*/
public static Schema fromPythonVariables(PythonVariables input) {
Schema.Builder schemaBuilder = new Schema.Builder();
Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0,"Input must have variables. Found none.");
for(Map.Entry<String,PythonVariables.Type> entry : input.getVars().entrySet()) {
switch(entry.getValue()) {
case INT:
schemaBuilder.addColumnInteger(entry.getKey());
break;
case STR:
schemaBuilder.addColumnString(entry.getKey());
break;
case FLOAT:
schemaBuilder.addColumnFloat(entry.getKey());
break;
case NDARRAY:
schemaBuilder.addColumnNDArray(entry.getKey(),null);
break;
case BOOL:
schemaBuilder.addColumn(new BooleanMetaData(entry.getKey()));
}
}
return schemaBuilder.build();
}
/**
* Create a {@link Schema} from an input
* {@link PythonVariables}
* Types are mapped to types of the same name
* @param input the input schema
* @return the output python variables.
*/
public static PythonVariables fromSchema(Schema input) {
PythonVariables ret = new PythonVariables();
for(int i = 0; i < input.numColumns(); i++) {
String currColumnName = input.getName(i);
ColumnType columnType = input.getType(i);
switch(columnType) {
case NDArray:
ret.add(currColumnName, PythonVariables.Type.NDARRAY);
break;
case Boolean:
ret.add(currColumnName, PythonVariables.Type.BOOL);
break;
case Categorical:
case String:
ret.add(currColumnName, PythonVariables.Type.STR);
break;
case Double:
case Float:
ret.add(currColumnName, PythonVariables.Type.FLOAT);
break;
case Integer:
case Long:
ret.add(currColumnName, PythonVariables.Type.INT);
break;
case Bytes:
break;
case Time:
throw new UnsupportedOperationException("Unable to process dates with python yet.");
}
}
return ret;
}
/**
* Convert a {@link Schema}
* to {@link PythonVariables}
* @param schema the input schema
* @return the output {@link PythonVariables} where each
* name in the map is associated with a column name in the schema.
* A proper type is also chosen based on the schema
* @throws Exception
*/
public static PythonVariables schemaToPythonVariables(Schema schema) throws Exception {
PythonVariables pyVars = new PythonVariables();
int numCols = schema.numColumns();
for (int i = 0; i < numCols; i++) {
String colName = schema.getName(i);
ColumnType colType = schema.getType(i);
switch (colType){
case Long:
case Integer:
pyVars.addInt(colName);
break;
case Double:
case Float:
pyVars.addFloat(colName);
break;
case String:
pyVars.addStr(colName);
break;
case NDArray:
pyVars.addNDArray(colName);
break;
default:
throw new Exception("Unsupported python input type: " + colType.toString());
}
}
return pyVars;
}
public static NumpyArray mapToNumpyArray(Map map){
String dtypeName = (String)map.get("dtype");
DataType dtype;
if (dtypeName.equals("float64")){
dtype = DataType.DOUBLE;
}
else if (dtypeName.equals("float32")){
dtype = DataType.FLOAT;
}
else if (dtypeName.equals("int16")){
dtype = DataType.SHORT;
}
else if (dtypeName.equals("int32")){
dtype = DataType.INT;
}
else if (dtypeName.equals("int64")){
dtype = DataType.LONG;
}
else{
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
}
List shapeList = (List)map.get("shape");
long[] shape = new long[shapeList.size()];
for (int i = 0; i < shape.length; i++) {
shape[i] = (Long)shapeList.get(i);
}
List strideList = (List)map.get("shape");
long[] stride = new long[strideList.size()];
for (int i = 0; i < stride.length; i++) {
stride[i] = (Long)strideList.get(i);
}
long address = (Long)map.get("address");
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
return numpyArray;
}
public static PythonVariables expandInnerDict(PythonVariables pyvars, String key){
Map dict = pyvars.getDictValue(key);
String[] keys = (String[])dict.keySet().toArray(new String[dict.keySet().size()]);
PythonVariables pyvars2 = new PythonVariables();
for (String subkey: keys){
Object value = dict.get(subkey);
if (value instanceof Map){
Map map = (Map)value;
if (map.containsKey("_is_numpy_array")){
pyvars2.addNDArray(subkey, mapToNumpyArray(map));
}
else{
pyvars2.addDict(subkey, (Map)value);
}
}
else if (value instanceof List){
pyvars2.addList(subkey, ((List) value).toArray());
}
else if (value instanceof String){
System.out.println((String)value);
pyvars2.addStr(subkey, (String) value);
}
else if (value instanceof Integer || value instanceof Long) {
Number number = (Number) value;
pyvars2.addInt(subkey, number.intValue());
}
else if (value instanceof Float || value instanceof Double) {
Number number = (Number) value;
pyvars2.addFloat(subkey, number.doubleValue());
}
else if (value instanceof NumpyArray){
pyvars2.addNDArray(subkey, (NumpyArray)value);
}
else if (value == null){
pyvars2.addStr(subkey, "None"); // FixMe
}
else{
throw new RuntimeException("Unsupported type!" + value);
}
}
return pyvars2;
}
public static long[] jsonArrayToLongArray(JSONArray jsonArray){
long[] longs = new long[jsonArray.length()];
for (int i=0; i<longs.length; i++){
longs[i] = jsonArray.getLong(i);
}
return longs;
}
public static Map<String, Object> toMap(JSONObject jsonobj) {
Map<String, Object> map = new HashMap<>();
String[] keys = (String[])jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]);
for (String key: keys){
Object value = jsonobj.get(key);
if (value instanceof JSONArray) {
value = toList((JSONArray) value);
} else if (value instanceof JSONObject) {
JSONObject jsonobj2 = (JSONObject)value;
if (jsonobj2.has("_is_numpy_array")){
value = jsonToNumpyArray(jsonobj2);
}
else{
value = toMap(jsonobj2);
}
}
map.put(key, value);
} return map;
}
public static List<Object> toList(JSONArray array) {
List<Object> list = new ArrayList<>();
for (int i = 0; i < array.length(); i++) {
Object value = array.get(i);
if (value instanceof JSONArray) {
value = toList((JSONArray) value);
} else if (value instanceof JSONObject) {
JSONObject jsonobj2 = (JSONObject) value;
if (jsonobj2.has("_is_numpy_array")) {
value = jsonToNumpyArray(jsonobj2);
} else {
value = toMap(jsonobj2);
}
}
list.add(value);
}
return list;
}
private static NumpyArray jsonToNumpyArray(JSONObject map){
String dtypeName = (String)map.get("dtype");
DataType dtype;
if (dtypeName.equals("float64")){
dtype = DataType.DOUBLE;
}
else if (dtypeName.equals("float32")){
dtype = DataType.FLOAT;
}
else if (dtypeName.equals("int16")){
dtype = DataType.SHORT;
}
else if (dtypeName.equals("int32")){
dtype = DataType.INT;
}
else if (dtypeName.equals("int64")){
dtype = DataType.LONG;
}
else{
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
}
List shapeList = (List)map.get("shape");
long[] shape = new long[shapeList.size()];
for (int i = 0; i < shape.length; i++) {
shape[i] = (Long)shapeList.get(i);
}
List strideList = (List)map.get("shape");
long[] stride = new long[strideList.size()];
for (int i = 0; i < stride.length; i++) {
stride[i] = (Long)strideList.get(i);
}
long address = (Long)map.get("address");
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
return numpyArray;
}
}

View File

@ -17,8 +17,8 @@
package org.datavec.python;
import lombok.Data;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.JSONObject;
import org.json.JSONArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.Serializable;
@ -31,8 +31,8 @@ import java.util.*;
* @author Fariz Rahman
*/
@Data
public class PythonVariables implements Serializable{
@lombok.Data
public class PythonVariables implements java.io.Serializable {
public enum Type{
BOOL,
@ -41,23 +41,29 @@ public class PythonVariables implements Serializable{
FLOAT,
NDARRAY,
LIST,
FILE
FILE,
DICT
}
private Map<String, String> strVars = new HashMap<String, String>();
private Map<String, Long> intVars = new HashMap<String, Long>();
private Map<String, Double> floatVars = new HashMap<String, Double>();
private Map<String, Boolean> boolVars = new HashMap<String, Boolean>();
private Map<String, NumpyArray> ndVars = new HashMap<String, NumpyArray>();
private Map<String, Object[]> listVars = new HashMap<String, Object[]>();
private Map<String, String> fileVars = new HashMap<String, String>();
private Map<String, Type> vars = new HashMap<String, Type>();
private Map<Type, Map> maps = new HashMap<Type, Map>();
private java.util.Map<String, String> strVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, Long> intVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, Double> floatVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, Boolean> boolVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, NumpyArray> ndVars = new java.util.LinkedHashMap<>();
private java.util.Map<String, Object[]> listVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, String> fileVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, java.util.Map<?,?>> dictVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, Type> vars = new java.util.LinkedHashMap<>();
private java.util.Map<Type, java.util.Map> maps = new java.util.LinkedHashMap<>();
/**
* Returns a copy of the variable
* schema in this array without the values
* @return an empty variables clone
* with no values
*/
public PythonVariables copySchema(){
PythonVariables ret = new PythonVariables();
for (String varName: getVariables()){
@ -66,15 +72,30 @@ public class PythonVariables implements Serializable{
}
return ret;
}
public PythonVariables(){
maps.put(Type.BOOL, boolVars);
maps.put(Type.STR, strVars);
maps.put(Type.INT, intVars);
maps.put(Type.FLOAT, floatVars);
maps.put(Type.NDARRAY, ndVars);
maps.put(Type.LIST, listVars);
maps.put(Type.FILE, fileVars);
/**
*
*/
public PythonVariables() {
maps.put(PythonVariables.Type.BOOL, boolVariables);
maps.put(PythonVariables.Type.STR, strVariables);
maps.put(PythonVariables.Type.INT, intVariables);
maps.put(PythonVariables.Type.FLOAT, floatVariables);
maps.put(PythonVariables.Type.NDARRAY, ndVars);
maps.put(PythonVariables.Type.LIST, listVariables);
maps.put(PythonVariables.Type.FILE, fileVariables);
maps.put(PythonVariables.Type.DICT, dictVariables);
}
/**
*
* @return true if there are no variables.
*/
public boolean isEmpty() {
return getVariables().length < 1;
}
@ -105,6 +126,9 @@ public class PythonVariables implements Serializable{
break;
case FILE:
addFile(name);
break;
case DICT:
addDict(name);
}
}
@ -113,248 +137,459 @@ public class PythonVariables implements Serializable{
* @param name name of the variable
* @param type type of the variable
* @param value value of the variable (must be instance of expected type)
* @throws Exception
*/
public void add (String name, Type type, Object value) throws Exception{
public void add(String name, Type type, Object value) {
add(name, type);
setValue(name, value);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addDict(String name) {
vars.put(name, PythonVariables.Type.DICT);
dictVariables.put(name,null);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addBool(String name){
vars.put(name, Type.BOOL);
boolVars.put(name, null);
vars.put(name, PythonVariables.Type.BOOL);
boolVariables.put(name, null);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addStr(String name){
vars.put(name, Type.STR);
strVars.put(name, null);
vars.put(name, PythonVariables.Type.STR);
strVariables.put(name, null);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addInt(String name){
vars.put(name, Type.INT);
intVars.put(name, null);
vars.put(name, PythonVariables.Type.INT);
intVariables.put(name, null);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addFloat(String name){
vars.put(name, Type.FLOAT);
floatVars.put(name, null);
vars.put(name, PythonVariables.Type.FLOAT);
floatVariables.put(name, null);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addNDArray(String name){
vars.put(name, Type.NDARRAY);
vars.put(name, PythonVariables.Type.NDARRAY);
ndVars.put(name, null);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addList(String name){
vars.put(name, Type.LIST);
listVars.put(name, null);
vars.put(name, PythonVariables.Type.LIST);
listVariables.put(name, null);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addFile(String name){
vars.put(name, Type.FILE);
fileVars.put(name, null);
}
public void addBool(String name, boolean value){
vars.put(name, Type.BOOL);
boolVars.put(name, value);
vars.put(name, PythonVariables.Type.FILE);
fileVariables.put(name, null);
}
public void addStr(String name, String value){
vars.put(name, Type.STR);
strVars.put(name, value);
/**
* Add a boolean variable to
* the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addBool(String name, boolean value) {
vars.put(name, PythonVariables.Type.BOOL);
boolVariables.put(name, value);
}
public void addInt(String name, int value){
vars.put(name, Type.INT);
intVars.put(name, (long)value);
/**
* Add a string variable to
* the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addStr(String name, String value) {
vars.put(name, PythonVariables.Type.STR);
strVariables.put(name, value);
}
public void addInt(String name, long value){
vars.put(name, Type.INT);
intVars.put(name, value);
/**
* Add an int variable to
* the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addInt(String name, int value) {
vars.put(name, PythonVariables.Type.INT);
intVariables.put(name, (long)value);
}
public void addFloat(String name, double value){
vars.put(name, Type.FLOAT);
floatVars.put(name, value);
/**
* Add a long variable to
* the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addInt(String name, long value) {
vars.put(name, PythonVariables.Type.INT);
intVariables.put(name, value);
}
public void addFloat(String name, float value){
vars.put(name, Type.FLOAT);
floatVars.put(name, (double)value);
/**
* Add a double variable to
* the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addFloat(String name, double value) {
vars.put(name, PythonVariables.Type.FLOAT);
floatVariables.put(name, value);
}
public void addNDArray(String name, NumpyArray value){
vars.put(name, Type.NDARRAY);
/**
* Add a float variable to
* the set of variables
* @param name the field to add
* @param value the value to add
*/
public void addFloat(String name, float value) {
vars.put(name, PythonVariables.Type.FLOAT);
floatVariables.put(name, (double)value);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add
*/
public void addNDArray(String name, NumpyArray value) {
vars.put(name, PythonVariables.Type.NDARRAY);
ndVars.put(name, value);
}
public void addNDArray(String name, INDArray value){
vars.put(name, Type.NDARRAY);
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add
*/
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
vars.put(name, PythonVariables.Type.NDARRAY);
ndVars.put(name, new NumpyArray(value));
}
public void addList(String name, Object[] value){
vars.put(name, Type.LIST);
listVars.put(name, value);
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add
*/
public void addList(String name, Object[] value) {
vars.put(name, PythonVariables.Type.LIST);
listVariables.put(name, value);
}
public void addFile(String name, String value){
vars.put(name, Type.FILE);
fileVars.put(name, value);
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add
*/
public void addFile(String name, String value) {
vars.put(name, PythonVariables.Type.FILE);
fileVariables.put(name, value);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add
*/
public void addDict(String name, java.util.Map value) {
vars.put(name, PythonVariables.Type.DICT);
dictVariables.put(name, value);
}
/**
*
* @param name name of the variable
* @param value new value for the variable
* @throws Exception
*/
public void setValue(String name, Object value) {
Type type = vars.get(name);
if (type == Type.BOOL){
boolVars.put(name, (Boolean)value);
if (type == PythonVariables.Type.BOOL){
boolVariables.put(name, (Boolean)value);
}
else if (type == Type.INT){
if (value instanceof Long){
intVars.put(name, ((Long)value));
}
else if (value instanceof Integer){
intVars.put(name, ((Integer)value).longValue());
}
else if (type == PythonVariables.Type.INT){
Number number = (Number) value;
intVariables.put(name, number.longValue());
}
else if (type == Type.FLOAT){
floatVars.put(name, (Double)value);
else if (type == PythonVariables.Type.FLOAT){
Number number = (Number) value;
floatVariables.put(name, number.doubleValue());
}
else if (type == Type.NDARRAY){
else if (type == PythonVariables.Type.NDARRAY){
if (value instanceof NumpyArray){
ndVars.put(name, (NumpyArray)value);
}
else if (value instanceof INDArray){
ndVars.put(name, new NumpyArray((INDArray) value));
else if (value instanceof org.nd4j.linalg.api.ndarray.INDArray) {
ndVars.put(name, new NumpyArray((org.nd4j.linalg.api.ndarray.INDArray) value));
}
else{
throw new RuntimeException("Unsupported type: " + value.getClass().toString());
}
}
else if (type == Type.LIST){
listVars.put(name, (Object[]) value);
else if (type == PythonVariables.Type.LIST) {
if (value instanceof java.util.List) {
value = ((java.util.List) value).toArray();
listVariables.put(name, (Object[]) value);
}
else if(value instanceof org.json.JSONArray) {
org.json.JSONArray jsonArray = (org.json.JSONArray) value;
Object[] copyArr = new Object[jsonArray.length()];
for(int i = 0; i < copyArr.length; i++) {
copyArr[i] = jsonArray.get(i);
}
listVariables.put(name, copyArr);
}
else {
listVariables.put(name, (Object[]) value);
}
}
else if (type == Type.FILE){
fileVars.put(name, (String)value);
else if(type == PythonVariables.Type.DICT) {
dictVariables.put(name,(java.util.Map<?,?>) value);
}
else if (type == PythonVariables.Type.FILE){
fileVariables.put(name, (String)value);
}
else{
strVars.put(name, (String)value);
strVariables.put(name, (String)value);
}
}
public Object getValue(String name){
/**
* Do a general object lookup.
* The look up will happen relative to the {@link Type}
* of variable is described in the
* @param name the name of the variable to get
* @return teh value for the variable with the given name
*/
public Object getValue(String name) {
Type type = vars.get(name);
Map map = maps.get(type);
java.util.Map map = maps.get(type);
return map.get(name);
}
/**
* Returns a boolean variable with the given name.
* @param name the variable name to get the value for
* @return the retrieved boolean value
*/
public boolean getBooleanValue(String name) {
return boolVariables.get(name);
}
/**
*
* @param name the variable name
* @return the dictionary value
*/
public java.util.Map<?,?> getDictValue(String name) {
return dictVariables.get(name);
}
/**
/**
*
* @param name the variable name
* @return the string value
*/
public String getStrValue(String name){
return strVars.get(name);
return strVariables.get(name);
}
public long getIntValue(String name){
return intVars.get(name);
/**
*
* @param name the variable name
* @return the long value
*/
public Long getIntValue(String name){
return intVariables.get(name);
}
public double getFloatValue(String name){
return floatVars.get(name);
/**
*
* @param name the variable name
* @return the float value
*/
public Double getFloatValue(String name){
return floatVariables.get(name);
}
/**
*
* @param name the variable name
* @return the numpy array value
*/
public NumpyArray getNDArrayValue(String name){
return ndVars.get(name);
}
/**
*
* @param name the variable name
* @return the list value as an object array
*/
public Object[] getListValue(String name){
return listVars.get(name);
return listVariables.get(name);
}
/**
*
* @param name the variable name
* @return the value of the given file name
*/
public String getFileValue(String name){
return fileVars.get(name);
return fileVariables.get(name);
}
/**
* Returns the type for the given variable name
* @param name the name of the variable to get the type for
* @return the type for the given variable
*/
public Type getType(String name){
return vars.get(name);
}
/**
* Get all the variables present as a string array
* @return the variable names for this variable sset
*/
public String[] getVariables() {
String[] strArr = new String[vars.size()];
return vars.keySet().toArray(strArr);
}
public Map<String, Boolean> getBoolVariables(){
return boolVars;
}
public Map<String, String> getStrVariables(){
return strVars;
}
public Map<String, Long> getIntVariables(){
return intVars;
}
public Map<String, Double> getFloatVariables(){
return floatVars;
}
public Map<String, NumpyArray> getNDArrayVariables(){
return ndVars;
}
public Map<String, Object[]> getListVariables(){
return listVars;
}
public Map<String, String> getFileVariables(){
return fileVars;
}
public JSONArray toJSON(){
JSONArray arr = new JSONArray();
/**
* This variables set as its json representation (an array of json objects)
* @return the json array output
*/
public org.json.JSONArray toJSON(){
org.json.JSONArray arr = new org.json.JSONArray();
for (String varName: getVariables()){
JSONObject var = new JSONObject();
org.json.JSONObject var = new org.json.JSONObject();
var.put("name", varName);
String varType = getType(varName).toString();
var.put("type", varType);
arr.add(var);
arr.put(var);
}
return arr;
}
public static PythonVariables fromJSON(JSONArray jsonArray){
/**
* Create a schema from a map.
* This is an empty PythonVariables
* that just contains names and types with no values
* @param inputTypes the input types to convert
* @return the schema from the given map
*/
public static PythonVariables schemaFromMap(java.util.Map<String,String> inputTypes) {
PythonVariables ret = new PythonVariables();
for(java.util.Map.Entry<String,String> entry : inputTypes.entrySet()) {
ret.add(entry.getKey(), PythonVariables.Type.valueOf(entry.getValue()));
}
return ret;
}
/**
* Get the python variable state relative to the
* input json array
* @param jsonArray the input json array
* @return the python variables based on the input json array
*/
public static PythonVariables fromJSON(org.json.JSONArray jsonArray){
PythonVariables pyvars = new PythonVariables();
for (int i=0; i<jsonArray.size(); i++){
JSONObject input = (JSONObject) jsonArray.get(i);
for (int i = 0; i < jsonArray.length(); i++) {
org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i);
String varName = (String)input.get("name");
String varType = (String)input.get("type");
if (varType.equals("BOOL")){
if (varType.equals("BOOL")) {
pyvars.addBool(varName);
}
else if (varType.equals("INT")){
else if (varType.equals("INT")) {
pyvars.addInt(varName);
}
else if (varType.equals("FlOAT")){
pyvars.addFloat(varName);
}
else if (varType.equals("STR")){
else if (varType.equals("STR")) {
pyvars.addStr(varName);
}
else if (varType.equals("LIST")){
else if (varType.equals("LIST")) {
pyvars.addList(varName);
}
else if (varType.equals("FILE")){
pyvars.addFile(varName);
}
else if (varType.equals("NDARRAY")){
else if (varType.equals("NDARRAY")) {
pyvars.addNDArray(varName);
}
else if(varType.equals("DICT")) {
pyvars.addDict(varName);
}
}
return pyvars;

View File

@ -0,0 +1,5 @@
#See: https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script
import sys
this = sys.modules[__name__]
for n in dir():
if n[0]!='_': delattr(this, n)

View File

@ -0,0 +1 @@
loc = {}

View File

@ -0,0 +1,20 @@
def __is_numpy_array(x):
return str(type(x))== "<class 'numpy.ndarray'>"
def maybe_serialize_ndarray_metadata(x):
return serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
def serialize_ndarray_metadata(x):
return {"address": x.__array_interface__['data'][0],
"shape": x.shape,
"strides": x.strides,
"dtype": str(x.dtype),
"_is_numpy_array": True} if __is_numpy_array(x) else x
def is_json_ready(key, value):
return key is not 'f2' and not inspect.ismodule(value) \
and not hasattr(value, '__call__')

View File

@ -0,0 +1,202 @@
#patch
"""Implementation of __array_function__ overrides from NEP-18."""
import collections
import functools
import os
from numpy.core._multiarray_umath import (
add_docstring, implement_array_function, _get_implementing_args)
from numpy.compat._inspect import getargspec
ENABLE_ARRAY_FUNCTION = bool(
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
ARRAY_FUNCTION_ENABLED = ENABLE_ARRAY_FUNCTION # backward compat
_add_docstring = add_docstring
def add_docstring(*args):
try:
_add_docstring(*args)
except:
pass
add_docstring(
implement_array_function,
"""
Implement a function with checks for __array_function__ overrides.
All arguments are required, and can only be passed by position.
Arguments
---------
implementation : function
Function that implements the operation on NumPy array without
overrides when called like ``implementation(*args, **kwargs)``.
public_api : function
Function exposed by NumPy's public API originally called like
``public_api(*args, **kwargs)`` on which arguments are now being
checked.
relevant_args : iterable
Iterable of arguments to check for __array_function__ methods.
args : tuple
Arbitrary positional arguments originally passed into ``public_api``.
kwargs : dict
Arbitrary keyword arguments originally passed into ``public_api``.
Returns
-------
Result from calling ``implementation()`` or an ``__array_function__``
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
""")
# exposed for testing purposes; used internally by implement_array_function
add_docstring(
_get_implementing_args,
"""
Collect arguments on which to call __array_function__.
Parameters
----------
relevant_args : iterable of array-like
Iterable of possibly array-like arguments to check for
__array_function__ methods.
Returns
-------
Sequence of arguments with __array_function__ methods, in the order in
which they should be called.
""")
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
def verify_matching_signatures(implementation, dispatcher):
"""Verify that a dispatcher function has the right signature."""
implementation_spec = ArgSpec(*getargspec(implementation))
dispatcher_spec = ArgSpec(*getargspec(dispatcher))
if (implementation_spec.args != dispatcher_spec.args or
implementation_spec.varargs != dispatcher_spec.varargs or
implementation_spec.keywords != dispatcher_spec.keywords or
(bool(implementation_spec.defaults) !=
bool(dispatcher_spec.defaults)) or
(implementation_spec.defaults is not None and
len(implementation_spec.defaults) !=
len(dispatcher_spec.defaults))):
raise RuntimeError('implementation and dispatcher for %s have '
'different function signatures' % implementation)
if implementation_spec.defaults is not None:
if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
raise RuntimeError('dispatcher functions can only use None for '
'default argument values')
def set_module(module):
"""Decorator for overriding __module__ on a function or class.
Example usage::
@set_module('numpy')
def example():
pass
assert example.__module__ == 'numpy'
"""
def decorator(func):
if module is not None:
func.__module__ = module
return func
return decorator
def array_function_dispatch(dispatcher, module=None, verify=True,
docs_from_dispatcher=False):
"""Decorator for adding dispatch with the __array_function__ protocol.
See NEP-18 for example usage.
Parameters
----------
dispatcher : callable
Function that when called like ``dispatcher(*args, **kwargs)`` with
arguments from the NumPy function call returns an iterable of
array-like arguments to check for ``__array_function__``.
module : str, optional
__module__ attribute to set on new function, e.g., ``module='numpy'``.
By default, module is copied from the decorated function.
verify : bool, optional
If True, verify the that the signature of the dispatcher and decorated
function signatures match exactly: all required and optional arguments
should appear in order with the same names, but the default values for
all optional arguments should be ``None``. Only disable verification
if the dispatcher's signature needs to deviate for some particular
reason, e.g., because the function has a signature like
``func(*args, **kwargs)``.
docs_from_dispatcher : bool, optional
If True, copy docs from the dispatcher function onto the dispatched
function, rather than from the implementation. This is useful for
functions defined in C, which otherwise don't have docstrings.
Returns
-------
Function suitable for decorating the implementation of a NumPy function.
"""
if not ENABLE_ARRAY_FUNCTION:
# __array_function__ requires an explicit opt-in for now
def decorator(implementation):
if module is not None:
implementation.__module__ = module
if docs_from_dispatcher:
add_docstring(implementation, dispatcher.__doc__)
return implementation
return decorator
def decorator(implementation):
if verify:
verify_matching_signatures(implementation, dispatcher)
if docs_from_dispatcher:
add_docstring(implementation, dispatcher.__doc__)
@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
return implement_array_function(
implementation, public_api, relevant_args, args, kwargs)
if module is not None:
public_api.__module__ = module
# TODO: remove this when we drop Python 2 support (functools.wraps)
# adds __wrapped__ automatically in later versions)
public_api.__wrapped__ = implementation
return public_api
return decorator
def array_function_from_dispatcher(
implementation, module=None, verify=True, docs_from_dispatcher=True):
"""Like array_function_dispatcher, but with function arguments flipped."""
def decorator(dispatcher):
return array_function_dispatch(
dispatcher, module, verify=verify,
docs_from_dispatcher=docs_from_dispatcher)(implementation)
return decorator

View File

@ -0,0 +1,172 @@
#patch 1
"""
========================
Random Number Generation
========================
==================== =========================================================
Utility functions
==============================================================================
random_sample Uniformly distributed floats over ``[0, 1)``.
random Alias for `random_sample`.
bytes Uniformly distributed random bytes.
random_integers Uniformly distributed integers in a given range.
permutation Randomly permute a sequence / generate a random sequence.
shuffle Randomly permute a sequence in place.
seed Seed the random number generator.
choice Random sample from 1-D array.
==================== =========================================================
==================== =========================================================
Compatibility functions
==============================================================================
rand Uniformly distributed values.
randn Normally distributed values.
ranf Uniformly distributed floating point numbers.
randint Uniformly distributed integers in a given range.
==================== =========================================================
==================== =========================================================
Univariate distributions
==============================================================================
beta Beta distribution over ``[0, 1]``.
binomial Binomial distribution.
chisquare :math:`\\chi^2` distribution.
exponential Exponential distribution.
f F (Fisher-Snedecor) distribution.
gamma Gamma distribution.
geometric Geometric distribution.
gumbel Gumbel distribution.
hypergeometric Hypergeometric distribution.
laplace Laplace distribution.
logistic Logistic distribution.
lognormal Log-normal distribution.
logseries Logarithmic series distribution.
negative_binomial Negative binomial distribution.
noncentral_chisquare Non-central chi-square distribution.
noncentral_f Non-central F distribution.
normal Normal / Gaussian distribution.
pareto Pareto distribution.
poisson Poisson distribution.
power Power distribution.
rayleigh Rayleigh distribution.
triangular Triangular distribution.
uniform Uniform distribution.
vonmises Von Mises circular distribution.
wald Wald (inverse Gaussian) distribution.
weibull Weibull distribution.
zipf Zipf's distribution over ranked data.
==================== =========================================================
==================== =========================================================
Multivariate distributions
==============================================================================
dirichlet Multivariate generalization of Beta distribution.
multinomial Multivariate generalization of the binomial distribution.
multivariate_normal Multivariate generalization of the normal distribution.
==================== =========================================================
==================== =========================================================
Standard distributions
==============================================================================
standard_cauchy Standard Cauchy-Lorentz distribution.
standard_exponential Standard exponential distribution.
standard_gamma Standard Gamma distribution.
standard_normal Standard normal distribution.
standard_t Standard Student's t-distribution.
==================== =========================================================
==================== =========================================================
Internal functions
==============================================================================
get_state Get tuple representing internal state of generator.
set_state Set state of generator.
==================== =========================================================
"""
from __future__ import division, absolute_import, print_function
import warnings
__all__ = [
'beta',
'binomial',
'bytes',
'chisquare',
'choice',
'dirichlet',
'exponential',
'f',
'gamma',
'geometric',
'get_state',
'gumbel',
'hypergeometric',
'laplace',
'logistic',
'lognormal',
'logseries',
'multinomial',
'multivariate_normal',
'negative_binomial',
'noncentral_chisquare',
'noncentral_f',
'normal',
'pareto',
'permutation',
'poisson',
'power',
'rand',
'randint',
'randn',
'random_integers',
'random_sample',
'rayleigh',
'seed',
'set_state',
'shuffle',
'standard_cauchy',
'standard_exponential',
'standard_gamma',
'standard_normal',
'standard_t',
'triangular',
'uniform',
'vonmises',
'wald',
'weibull',
'zipf'
]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
try:
from .mtrand import *
# Some aliases:
ranf = random = sample = random_sample
__all__.extend(['ranf', 'random', 'sample'])
except:
warnings.warn("numpy.random is not available when using multiple interpreters!")
def __RandomState_ctor():
"""Return a RandomState instance.
This function exists solely to assist (un)pickling.
Note that the state of the RandomState returned here is irrelevant, as this function's
entire purpose is to return a newly allocated RandomState whose state pickle can set.
Consequently the RandomState returned by this function is a freshly allocated copy
with a seed=0.
See https://github.com/numpy/numpy/issues/4763 for a detailed discussion
"""
return RandomState(seed=0)
from numpy._pytesttester import PytestTester
test = PytestTester(__name__)
del PytestTester

View File

@ -0,0 +1,20 @@
import sys
import traceback
import json
import inspect
try:
pass
sys.stdout.flush()
sys.stderr.flush()
except Exception as ex:
try:
exc_info = sys.exc_info()
finally:
print(ex)
traceback.print_exception(*exc_info)
sys.stdout.flush()
sys.stderr.flush()

View File

@ -0,0 +1,50 @@
def __is_numpy_array(x):
return str(type(x))== "<class 'numpy.ndarray'>"
def __maybe_serialize_ndarray_metadata(x):
return __serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
def __serialize_ndarray_metadata(x):
return {"address": x.__array_interface__['data'][0],
"shape": x.shape,
"strides": x.strides,
"dtype": str(x.dtype),
"_is_numpy_array": True} if __is_numpy_array(x) else x
def __serialize_list(x):
import json
return json.dumps(__recursive_serialize_list(x))
def __serialize_dict(x):
import json
return json.dumps(__recursive_serialize_dict(x))
def __recursive_serialize_list(x):
out = []
for i in x:
if __is_numpy_array(i):
out.append(__serialize_ndarray_metadata(i))
elif isinstance(i, (list, tuple)):
out.append(__recursive_serialize_list(i))
elif isinstance(i, dict):
out.append(__recursive_serialize_dict(i))
else:
out.append(i)
return out
def __recursive_serialize_dict(x):
out = {}
for k in x:
v = x[k]
if __is_numpy_array(v):
out[k] = __serialize_ndarray_metadata(v)
elif isinstance(v, (list, tuple)):
out[k] = __recursive_serialize_list(v)
elif isinstance(v, dict):
out[k] = __recursive_serialize_dict(v)
else:
out[k] = v
return out

View File

@ -0,0 +1,75 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import org.junit.Assert;
import org.junit.Test;
@javax.annotation.concurrent.NotThreadSafe
public class TestPythonExecutionSandbox {
@Test
public void testInt(){
PythonExecutioner.setInterpreter("interp1");
PythonExecutioner.exec("a = 1");
PythonExecutioner.setInterpreter("interp2");
PythonExecutioner.exec("a = 2");
PythonExecutioner.setInterpreter("interp3");
PythonExecutioner.exec("a = 3");
PythonExecutioner.setInterpreter("interp1");
Assert.assertEquals(1, PythonExecutioner.evalInteger("a"));
PythonExecutioner.setInterpreter("interp2");
Assert.assertEquals(2, PythonExecutioner.evalInteger("a"));
PythonExecutioner.setInterpreter("interp3");
Assert.assertEquals(3, PythonExecutioner.evalInteger("a"));
}
@Test
public void testNDArray(){
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("import numpy as np");
PythonExecutioner.exec("a = np.zeros(5)");
PythonExecutioner.setInterpreter("main");
//PythonExecutioner.exec("import numpy as np");
PythonExecutioner.exec("a = np.zeros(5)");
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("a += 2");
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("a += 3");
PythonExecutioner.setInterpreter("main");
//PythonExecutioner.exec("import numpy as np");
// PythonExecutioner.exec("a = np.zeros(5)");
PythonExecutioner.setInterpreter("main");
Assert.assertEquals(25, PythonExecutioner.evalNdArray("a").getNd4jArray().sum().getDouble(), 1e-5);
}
@Test
public void testNumpyRandom(){
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("import numpy as np; print(np.random.randint(5))");
}
}

View File

@ -15,17 +15,25 @@
******************************************************************************/
package org.datavec.python;
import org.junit.Ignore;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
@javax.annotation.concurrent.NotThreadSafe
public class TestPythonExecutioner {
@Test(timeout = 60000L)
@org.junit.Test
public void testPythonSysVersion() {
PythonExecutioner.exec("import sys; print(sys.version)");
}
@Test
public void testStr() throws Exception{
PythonVariables pyInputs = new PythonVariables();
@ -47,7 +55,7 @@ public class TestPythonExecutioner {
assertEquals("Hello World", z);
}
@Test(timeout = 60000L)
@Test
public void testInt()throws Exception{
PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables();
@ -55,7 +63,7 @@ public class TestPythonExecutioner {
pyInputs.addInt("x", 10);
pyInputs.addInt("y", 20);
String code = "z = x + y";
String code = "z = x + y";
pyOutputs.addInt("z");
@ -64,11 +72,11 @@ public class TestPythonExecutioner {
long z = pyOutputs.getIntValue("z");
assertEquals(30, z);
Assert.assertEquals(30, z);
}
@Test(timeout = 60000L)
@Test
public void testList() throws Exception{
PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables();
@ -88,18 +96,35 @@ public class TestPythonExecutioner {
Object[] z = pyOutputs.getListValue("z");
assertEquals(z.length, x.length + y.length);
Assert.assertEquals(z.length, x.length + y.length);
for (int i = 0; i < x.length; i++) {
if(x[i] instanceof Number) {
Number xNum = (Number) x[i];
Number zNum = (Number) z[i];
Assert.assertEquals(xNum.intValue(), zNum.intValue());
}
else {
Assert.assertEquals(x[i], z[i]);
}
for (int i=0; i < x.length; i++){
assertEquals(x[i], z[i]);
}
for (int i=0; i<y.length; i++){
assertEquals(y[i], z[x.length + i]);
for (int i = 0; i < y.length; i++){
if(y[i] instanceof Number) {
Number yNum = (Number) y[i];
Number zNum = (Number) z[x.length + i];
Assert.assertEquals(yNum.intValue(), zNum.intValue());
}
else {
Assert.assertEquals(y[i], z[x.length + i]);
}
}
}
@Test(timeout = 60000L)
@Test
public void testNDArrayFloat()throws Exception{
PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables();
@ -113,12 +138,17 @@ public class TestPythonExecutioner {
PythonExecutioner.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
}
@Test(timeout = 60000L)
@Test
public void testTensorflowCustomAnaconda() {
PythonExecutioner.exec("import tensorflow as tf");
}
@Test
public void testNDArrayDouble()throws Exception {
PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables();
@ -132,10 +162,10 @@ public class TestPythonExecutioner {
PythonExecutioner.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
}
@Test(timeout = 60000L)
@Test
public void testNDArrayShort()throws Exception{
PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables();
@ -149,11 +179,11 @@ public class TestPythonExecutioner {
PythonExecutioner.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
}
@Test(timeout = 60000L)
@Test
public void testNDArrayInt()throws Exception{
PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables();
@ -167,11 +197,11 @@ public class TestPythonExecutioner {
PythonExecutioner.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
}
@Test(timeout = 60000L)
@Test
public void testNDArrayLong()throws Exception{
PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables();
@ -185,7 +215,7 @@ public class TestPythonExecutioner {
PythonExecutioner.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
assertEquals(6.0, z.sum().getDouble(0), 1e-5);
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
}

View File

@ -0,0 +1,27 @@
package org.datavec.python;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@javax.annotation.concurrent.NotThreadSafe
public class TestPythonSetupAndRun {
@Test
public void testPythonWithSetupAndRun() throws Exception{
String code = "def setup():" +
"global counter;counter=0\n" +
"def run(step):" +
"global counter;" +
"counter+=step;" +
"return {\"counter\":counter}";
PythonVariables pyInputs = new PythonVariables();
pyInputs.addInt("step", 2);
PythonVariables pyOutputs = new PythonVariables();
pyOutputs.addInt("counter");
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
assertEquals((long)pyOutputs.getIntValue("counter"), 2L);
pyInputs.addInt("step", 3);
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
assertEquals((long)pyOutputs.getIntValue("counter"), 5L);
}
}

View File

@ -0,0 +1,102 @@
/*
*
* * ******************************************************************************
* * * Copyright (c) 2015-2019 Skymind Inc.
* * * Copyright (c) 2019 Konduit AI.
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.datavec.python;
import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
import java.util.Collections;
import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.assertNull;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class TestPythonVariables {
@Test
public void testImportNumpy(){
Nd4j.scalar(1.0);
System.out.println(System.getProperty("org.bytedeco.openblas.load"));
PythonExecutioner.exec("import numpy as np");
}
@Test
public void testDataAssociations() {
PythonVariables pythonVariables = new PythonVariables();
PythonVariables.Type[] types = {
PythonVariables.Type.INT,
PythonVariables.Type.FLOAT,
PythonVariables.Type.STR,
PythonVariables.Type.BOOL,
PythonVariables.Type.DICT,
PythonVariables.Type.LIST,
PythonVariables.Type.LIST,
PythonVariables.Type.FILE,
PythonVariables.Type.NDARRAY
};
NumpyArray npArr = new NumpyArray(Nd4j.scalar(1.0));
Object[] values = {
1L,1.0,"1",true, Collections.singletonMap("1",1),
new Object[]{1}, Arrays.asList(1),"type", npArr
};
Object[] expectedValues = {
1L,1.0,"1",true, Collections.singletonMap("1",1),
new Object[]{1}, new Object[]{1},"type", npArr
};
for(int i = 0; i < types.length; i++) {
testInsertGet(pythonVariables,types[i].name() + i,values[i],types[i],expectedValues[i]);
}
assertEquals(types.length,pythonVariables.getVariables().length);
}
private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonVariables.Type type,Object expectedValue) {
pythonVariables.add(key, type);
assertNull(pythonVariables.getValue(key));
pythonVariables.setValue(key,value);
assertNotNull(pythonVariables.getValue(key));
Object actualValue = pythonVariables.getValue(key);
if (expectedValue instanceof Object[]){
assertTrue(actualValue instanceof Object[]);
Object[] actualArr = (Object[])actualValue;
Object[] expectedArr = (Object[])expectedValue;
assertArrayEquals(expectedArr, actualArr);
}
else{
assertEquals(expectedValue,pythonVariables.getValue(key));
}
}
}

View File

@ -29,7 +29,7 @@ public class TestSerde {
public static JsonSerializer j = new JsonSerializer();
@Test(timeout = 60000L)
public void testBasicSerde() throws Exception{
public void testBasicSerde(){
Schema schema = new Schema.Builder()
.addColumnInteger("col1")
.addColumnFloat("col2")
@ -37,10 +37,9 @@ public class TestSerde {
.addColumnDouble("col4")
.build();
Transform t = new PythonTransform(
"col1+=3\ncol2+=2\ncol3+='a'\ncol4+=2.0",
schema
);
Transform t = PythonTransform.builder().code(
"col1+=3\ncol2+=2\ncol3+='a'\ncol4+=2.0"
).inputSchema(schema).outputSchema(schema).build();
String yaml = y.serialize(t);
String json = j.serialize(t);

View File

@ -247,10 +247,9 @@ public class ExecutionTest extends BaseSparkTest {
.addColumnInteger("col1").addColumnDouble("col2").build();
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
TransformProcess tp = new TransformProcess.Builder(schema).transform(
new PythonTransform(
pythonCode,
finalSchema
)
PythonTransform.builder().code(
"first = np.sin(first)\nsecond = np.cos(second)")
.outputSchema(finalSchema).build()
).build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
@ -288,10 +287,9 @@ public class ExecutionTest extends BaseSparkTest {
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(schema).transform(
new PythonTransform(
pythonCode,
finalSchema
)
PythonTransform.builder().code(
"first = np.sin(first)\nsecond = np.cos(second)")
.outputSchema(schema).build()
).build();
INDArray zeros = Nd4j.zeros(shape);

View File

@ -294,6 +294,8 @@
<python.version>3.7.5</python.version>
<cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version>
<numpy.version>1.17.3</numpy.version>
<numpy.javacpp.version>${numpy.version}-${javacpp-presets.version}</numpy.javacpp.version>
<openblas.version>0.3.7</openblas.version>
<mkl.version>2019.5</mkl.version>