diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java index 2f508f09e..19733f297 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java @@ -256,11 +256,9 @@ public class ExecutionTest { TransformProcess transformProcess = new TransformProcess.Builder(schema) .transform( - new PythonTransform( - "first = np.sin(first)\nsecond = np.cos(second)", - schema - ) - ) + PythonTransform.builder().code( + "first = np.sin(first)\nsecond = np.cos(second)") + .outputSchema(schema).build()) .build(); List> functions = new ArrayList<>(); diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonTransformProcess.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java similarity index 64% rename from datavec/datavec-python/src/test/java/org/datavec/python/TestPythonTransformProcess.java rename to datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java index 77ba53e26..37df8ae52 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonTransformProcess.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java @@ -14,35 +14,40 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.datavec.python; +package org.datavec.local.transforms.transform; import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.filter.ConditionFilter; import org.datavec.api.transform.filter.Filter; -import org.datavec.api.writable.*; import org.datavec.api.transform.schema.Schema; -import org.junit.Ignore; +import org.datavec.local.transforms.LocalTransformExecutor; + +import org.datavec.api.writable.*; +import org.datavec.python.PythonCondition; +import org.datavec.python.PythonTransform; import org.junit.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; - +import javax.annotation.concurrent.NotThreadSafe; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") +import static junit.framework.TestCase.assertTrue; +import static org.datavec.api.transform.schema.Schema.Builder; +import static org.junit.Assert.*; + +@NotThreadSafe public class TestPythonTransformProcess { - @Test(timeout = 60000L) + + @Test() public void testStringConcat() throws Exception{ - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnString("col1") .addColumnString("col2"); @@ -54,10 +59,12 @@ public class TestPythonTransformProcess { String pythonCode = "col3 = col1 + col2"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform(pythonCode, finalSchema) + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .build() ).build(); - List inputs = Arrays.asList((Writable) new Text("Hello "), new Text("World!")); + List inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!")); List outputs = tp.execute(inputs); assertEquals((outputs.get(0)).toString(), "Hello "); @@ -68,7 +75,7 @@ public class TestPythonTransformProcess { @Test(timeout = 60000L) public void testMixedTypes() throws Exception{ - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnInteger("col1") .addColumnFloat("col2") @@ -83,11 +90,12 @@ public class TestPythonTransformProcess { String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform(pythonCode, finalSchema) - ).build(); + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .inputSchema(initialSchema) + .build() ).build(); - List inputs = Arrays.asList((Writable) - new IntWritable(10), + List inputs = Arrays.asList((Writable)new IntWritable(10), new FloatWritable(3.5f), new Text("5"), new DoubleWritable(2.0) @@ -105,7 +113,7 @@ public class TestPythonTransformProcess { INDArray expectedOutput = arr1.add(arr2); - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnNDArray("col1", shape) .addColumnNDArray("col2", shape); @@ -116,12 +124,14 @@ public class TestPythonTransformProcess { String pythonCode = "col3 = col1 + col2"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform(pythonCode, finalSchema) - ).build(); + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .build() ).build(); List inputs = Arrays.asList( - (Writable) new NDArrayWritable(arr1), - new NDArrayWritable(arr2) + (Writable) + new NDArrayWritable(arr1), + new NDArrayWritable(arr2) ); List outputs = tp.execute(inputs); @@ -139,7 +149,7 @@ public class TestPythonTransformProcess { INDArray expectedOutput = arr1.add(arr2); - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnNDArray("col1", shape) .addColumnNDArray("col2", shape); @@ -150,11 +160,13 @@ public class TestPythonTransformProcess { String pythonCode = "col3 = col1 + col2"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform(pythonCode, finalSchema) - ).build(); + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .build() ).build(); List inputs = Arrays.asList( - (Writable) new NDArrayWritable(arr1), + (Writable) + new NDArrayWritable(arr1), new NDArrayWritable(arr2) ); @@ -172,7 +184,7 @@ public class TestPythonTransformProcess { INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape); INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE)); - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnNDArray("col1", shape) .addColumnNDArray("col2", shape); @@ -183,11 +195,14 @@ public class TestPythonTransformProcess { String pythonCode = "col3 = col1 + col2"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform(pythonCode, finalSchema) + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .build() ).build(); List inputs = Arrays.asList( - (Writable) new NDArrayWritable(arr1), + (Writable) + new NDArrayWritable(arr1), new NDArrayWritable(arr2) ); @@ -199,8 +214,8 @@ public class TestPythonTransformProcess { } @Test(timeout = 60000L) - public void testPythonFilter(){ - Schema schema = new Schema.Builder().addColumnInteger("column").build(); + public void testPythonFilter() { + Schema schema = new Builder().addColumnInteger("column").build(); Condition condition = new PythonCondition( "f = lambda: column < 0" @@ -210,17 +225,17 @@ public class TestPythonTransformProcess { Filter filter = new ConditionFilter(condition); - assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(10)))); - assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(1)))); - assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(0)))); - assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-1)))); - assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-10)))); + assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10)))); + assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1)))); + assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0)))); + assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1)))); + assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10)))); } @Test(timeout = 60000L) public void testPythonFilterAndTransform() throws Exception{ - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnInteger("col1") .addColumnFloat("col2") @@ -241,33 +256,85 @@ public class TestPythonTransformProcess { String pythonCode = "col6 = str(col1 + col2)"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform( - pythonCode, - finalSchema - ) + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .build() ).filter( filter ).build(); List> inputs = new ArrayList<>(); inputs.add( - Arrays.asList((Writable) new IntWritable(5), + Arrays.asList( + (Writable) + new IntWritable(5), new FloatWritable(3.0f), new Text("abcd"), new DoubleWritable(2.1)) ); inputs.add( - Arrays.asList((Writable) new IntWritable(-3), + Arrays.asList( + (Writable) + new IntWritable(-3), new FloatWritable(3.0f), new Text("abcd"), new DoubleWritable(2.1)) ); inputs.add( - Arrays.asList((Writable) new IntWritable(5), + Arrays.asList( + (Writable) + new IntWritable(5), new FloatWritable(11.2f), new Text("abcd"), new DoubleWritable(2.1)) ); + LocalTransformExecutor.execute(inputs,tp); } -} + + + @Test + public void testPythonTransformNoOutputSpecified() throws Exception { + PythonTransform pythonTransform = PythonTransform.builder() + .code("a += 2; b = 'hello world'") + .returnAllInputs(true) + .build(); + List> inputs = new ArrayList<>(); + inputs.add(Arrays.asList((Writable)new IntWritable(1))); + Schema inputSchema = new Builder() + .addColumnInteger("a") + .build(); + + TransformProcess tp = new TransformProcess.Builder(inputSchema) + .transform(pythonTransform) + .build(); + List> execute = LocalTransformExecutor.execute(inputs, tp); + assertEquals(3,execute.get(0).get(0).toInt()); + assertEquals("hello world",execute.get(0).get(1).toString()); + + } + + @Test + public void testNumpyTransform() throws Exception { + PythonTransform pythonTransform = PythonTransform.builder() + .code("a += 2; b = 'hello world'") + .returnAllInputs(true) + .build(); + + List> inputs = new ArrayList<>(); + inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)))); + Schema inputSchema = new Builder() + .addColumnNDArray("a",new long[]{1,1}) + .build(); + + TransformProcess tp = new TransformProcess.Builder(inputSchema) + .transform(pythonTransform) + .build(); + List> execute = LocalTransformExecutor.execute(inputs, tp); + assertFalse(execute.isEmpty()); + assertNotNull(execute.get(0)); + assertNotNull(execute.get(0).get(0)); + assertEquals("hello world",execute.get(0).get(0).toString()); + } + +} \ No newline at end of file diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml index 449364207..55cf6c5da 100644 --- a/datavec/datavec-python/pom.xml +++ b/datavec/datavec-python/pom.xml @@ -28,15 +28,21 @@ - com.googlecode.json-simple - json-simple - 1.1 + org.json + json + 20190722 org.bytedeco cpython-platform ${cpython-platform.version} + + org.bytedeco + numpy-platform + ${numpy.javacpp.version} + + com.google.code.findbugs jsr305 diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java index a6ccc3036..ab49cf5ea 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java @@ -16,10 +16,13 @@ package org.datavec.python; +import lombok.Builder; import lombok.Getter; +import lombok.NoArgsConstructor; import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; @@ -33,19 +36,27 @@ import org.nd4j.linalg.api.buffer.DataType; * @author Fariz Rahman */ @Getter +@NoArgsConstructor public class NumpyArray { - private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + private static NativeOps nativeOps; private long address; private long[] shape; private long[] strides; - private DataType dtype = DataType.FLOAT; + private DataType dtype; private INDArray nd4jArray; + static { + //initialize + Nd4j.scalar(1.0); + nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + } - public NumpyArray(long address, long[] shape, long strides[], boolean copy){ + @Builder + public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) { this.address = address; this.shape = shape; this.strides = strides; + this.dtype = dtype; setND4JArray(); if (copy){ nd4jArray = nd4jArray.dup(); @@ -57,8 +68,9 @@ public class NumpyArray { public NumpyArray copy(){ return new NumpyArray(nd4jArray.dup()); } + public NumpyArray(long address, long[] shape, long strides[]){ - this(address, shape, strides, false); + this(address, shape, strides, false,DataType.FLOAT); } public NumpyArray(long address, long[] shape, long strides[], DataType dtype){ @@ -77,9 +89,9 @@ public class NumpyArray { } } - private void setND4JArray(){ + private void setND4JArray() { long size = 1; - for(long d: shape){ + for(long d: shape) { size *= d; } Pointer ptr = nativeOps.pointerForAddress(address); @@ -88,10 +100,11 @@ public class NumpyArray { DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype); int elemSize = buff.getElementSize(); long[] nd4jStrides = new long[strides.length]; - for (int i=0; i= 1,"Python code must not be empty!"); code = pythonCode; } - private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{ - PythonVariables pyVars = new PythonVariables(); - int numCols = schema.numColumns(); - for (int i=0; i writables){ - PythonVariables ret = new PythonVariables(); - for (String name: pyInputs.getVariables()){ - int colIdx = inputSchema.getIndexOfColumn(name); - Writable w = writables.get(colIdx); - PythonVariables.Type pyType = pyInputs.getType(name); - switch (pyType){ - case INT: - if (w instanceof LongWritable){ - ret.addInt(name, ((LongWritable)w).get()); - } - else{ - ret.addInt(name, ((IntWritable)w).get()); - } - break; - case FLOAT: - ret.addFloat(name, ((DoubleWritable)w).get()); - break; - case STR: - ret.addStr(name, ((Text)w).toString()); - break; - case NDARRAY: - ret.addNDArray(name,((NDArrayWritable)w).get()); - break; - } - - } - return ret; - } @Override - public void setInputSchema(Schema inputSchema){ + public void setInputSchema(Schema inputSchema) { this.inputSchema = inputSchema; try{ pyInputs = schemaToPythonVariables(inputSchema); PythonVariables pyOuts = new PythonVariables(); pyOuts.addInt("out"); - pythonTransform = new PythonTransform( - code + "\n\nout=f()\nout=0 if out is None else int(out)", // TODO: remove int conversion after boolean support is covered - pyInputs, - pyOuts - ); + pythonTransform = PythonTransform.builder() + .code(code + "\n\nout=f()\nout=0 if out is None else int(out)") + .inputs(pyInputs) + .outputs(pyOuts) + .build(); + } catch (Exception e){ throw new RuntimeException(e); @@ -127,41 +76,47 @@ public class PythonCondition implements Condition { return inputSchema; } - public String[] outputColumnNames(){ + @Override + public String[] outputColumnNames() { String[] columnNames = new String[inputSchema.numColumns()]; inputSchema.getColumnNames().toArray(columnNames); return columnNames; } + @Override public String outputColumnName(){ return outputColumnNames()[0]; } + @Override public String[] columnNames(){ return outputColumnNames(); } + @Override public String columnName(){ return outputColumnName(); } + @Override public Schema transform(Schema inputSchema){ return inputSchema; } - public boolean condition(List list){ + @Override + public boolean condition(List list) { PythonVariables inputs = getPyInputsFromWritables(list); try{ PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs()); boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0; return ret; } - catch (Exception e){ + catch (Exception e) { throw new RuntimeException(e); } - } + @Override public boolean condition(Object input){ return condition(input); } @@ -177,5 +132,37 @@ public class PythonCondition implements Condition { throw new UnsupportedOperationException("not supported"); } + private PythonVariables getPyInputsFromWritables(List writables) { + PythonVariables ret = new PythonVariables(); -} + for (int i = 0; i < inputSchema.numColumns(); i++){ + String name = inputSchema.getName(i); + Writable w = writables.get(i); + PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i)); + switch (pyType){ + case INT: + if (w instanceof LongWritable) { + ret.addInt(name, ((LongWritable)w).get()); + } + else { + ret.addInt(name, ((IntWritable)w).get()); + } + + break; + case FLOAT: + ret.addFloat(name, ((DoubleWritable)w).get()); + break; + case STR: + ret.addStr(name, w.toString()); + break; + case NDARRAY: + ret.addNDArray(name,((NDArrayWritable)w).get()); + break; + } + } + + return ret; + } + + +} \ No newline at end of file diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index c46d0d710..c6272e7ad 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -17,132 +17,504 @@ package org.datavec.python; -import java.io.File; -import java.io.FileInputStream; -import java.util.HashMap; +import java.io.*; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; import java.util.Map; -import java.util.regex.Pattern; + import lombok.extern.slf4j.Slf4j; -import org.json.simple.JSONArray; -import org.json.simple.JSONObject; -import org.json.simple.parser.JSONParser; +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.IOUtils; +import org.json.JSONObject; +import org.json.JSONArray; import org.bytedeco.javacpp.*; import org.bytedeco.cpython.*; import static org.bytedeco.cpython.global.python.*; +import org.bytedeco.numpy.global.numpy; + +import static org.datavec.python.PythonUtils.*; + +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.io.ClassPathResource; + /** - * Python executioner + * Allows execution of python scripts managed by + * an internal interpreter. + * An end user may specify a python script to run + * via any of the execution methods available in this class. + * + * At static initialization time (when the class is first initialized) + * a number of components are setup: + * 1. The python path. A user may over ride this with the system property {@link #DEFAULT_PYTHON_PATH_PROPERTY} + * + * 2. Since this executioner uses javacpp to manage and run python interpreters underneath the covers, + * a user may also over ride the system property {@link #JAVACPP_PYTHON_APPEND_TYPE} with one of the {@link JavaCppPathType} + * values. This will allow the user to determine whether the javacpp default python path is used at all, and if so + * whether it is appended, prepended, or not used. This behavior is useful when you need to use an external + * python distribution such as anaconda. + * + * 3. A main interpreter: This is the default interpreter to be used with the main thread. + * We may initialize one or more relative to the thread invoking the python code. + * + * 4. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of + * native libraries needed by numpy are allowed to load in the proper order. If we don't do this, + * it causes a variety of issues with running numpy. + * + * 5. Various python scripts pre defined on the classpath included right with the java code. + * These are auxillary python scripts used for loading classes, pre defining certain kinds of behavior + * in order for us to manipulate values within the python memory, as well as pulling them out of memory + * for integration within the internal python executioner. You can see this behavior in {@link #_readOutputs(PythonVariables)} + * as an example. + * + * For more information on how this works, please take a look at the {@link #init()} + * method. + * + * Generally, a user defining a python script for use by the python executioner + * will have a set of defined target input values and output values. + * These values should not be present when actually running the script, but just referenced. + * In order to test your python script for execution outside the engine, + * we recommend commenting out a few default values as dummy input values. + * This will allow an end user to test their script before trying to use the server. + * + * In order to get output values out of a python script, all a user has to do + * is define the output variables they want being used in the final output in the actual pipeline. + * For example, if a user wants to return a dictionary, they just have to create a dictionary with that name + * and based on the configured {@link PythonVariables} passed as outputs + * to one of the execution methods, we can pull the values out automatically. + * + * For input definitions, it is similar. You just define the values you want used in + * {@link PythonVariables} and we will automatically generate code for defining those values + * as desired for running. This allows the user to customize values dynamically + * at runtime but reference them by name in a python script. + * * * @author Fariz Rahman + * @author Adam Gibson + */ + + +/** + * Allows execution of python scripts managed by + * an internal interpreter. + * An end user may specify a python script to run + * via any of the execution methods available in this class. + * + * At static initialization time (when the class is first initialized) + * a number of components are setup: + * 1. The python path. A user may over ride this with the system property {@link #DEFAULT_PYTHON_PATH_PROPERTY} + * + * 2. Since this executioner uses javacpp to manage and run python interpreters underneath the covers, + * a user may also over ride the system property {@link #JAVACPP_PYTHON_APPEND_TYPE} with one of the {@link JavaCppPathType} + * values. This will allow the user to determine whether the javacpp default python path is used at all, and if so + * whether it is appended, prepended, or not used. This behavior is useful when you need to use an external + * python distribution such as anaconda. + * + * 3. A main interpreter: This is the default interpreter to be used with the main thread. + * We may initialize one or more relative to the thread invoking the python code. + * + * 4. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of + * native libraries needed by numpy are allowed to load in the proper order. If we don't do this, + * it causes a variety of issues with running numpy. + * + * 5. Various python scripts pre defined on the classpath included right with the java code. + * These are auxillary python scripts used for loading classes, pre defining certain kinds of behavior + * in order for us to manipulate values within the python memory, as well as pulling them out of memory + * for integration within the internal python executioner. You can see this behavior in {@link #_readOutputs(PythonVariables)} + * as an example. + * + * For more information on how this works, please take a look at the {@link #init()} + * method. + * + * Generally, a user defining a python script for use by the python executioner + * will have a set of defined target input values and output values. + * These values should not be present when actually running the script, but just referenced. + * In order to test your python script for execution outside the engine, + * we recommend commenting out a few default values as dummy input values. + * This will allow an end user to test their script before trying to use the server. + * + * In order to get output values out of a python script, all a user has to do + * is define the output variables they want being used in the final output in the actual pipeline. + * For example, if a user wants to return a dictionary, they just have to create a dictionary with that name + * and based on the configured {@link PythonVariables} passed as outputs + * to one of the execution methods, we can pull the values out automatically. + * + * For input definitions, it is similar. You just define the values you want used in + * {@link PythonVariables} and we will automatically generate code for defining those values + * as desired for running. This allows the user to customize values dynamically + * at runtime but reference them by name in a python script. + * + * + * @author Fariz Rahman + * @author Adam Gibson */ @Slf4j public class PythonExecutioner { - private static PyObject module; - private static PyObject globals; - private static JSONParser parser = new JSONParser(); - private static Map gilStates = new HashMap<>(); + + private final static String fileVarName = "_f" + Nd4j.getRandom().nextInt(); + private static boolean init; + public final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.datavec.python.path"; + public final static String JAVACPP_PYTHON_APPEND_TYPE = "org.datavec.python.javacpp.path.append"; + public final static String DEFAULT_APPEND_TYPE = "before"; + private static Map interpreters = new java.util.concurrent.ConcurrentHashMap<>(); + private static PyThreadState currentThreadState; + private static PyThreadState mainThreadState; + public final static String ALL_VARIABLES_KEY = "allVariables"; + public final static String MAIN_INTERPRETER_NAME = "main"; + private static String clearVarsCode; + + private static String currentInterpreter = MAIN_INTERPRETER_NAME; + + /** + * One of a few desired values + * for how we should handle + * using javacpp's python path. + * BEFORE: Prepend the python path alongside a defined one + * AFTER: Append the javacpp python path alongside the defined one + * NONE: Don't use javacpp's python path at all + */ + public enum JavaCppPathType { + BEFORE,AFTER,NONE + } + + /** + * Set the python path. + * Generally you can just use the PYTHONPATH environment variable, + * but if you need to set it from code, this can work as well. + */ + public static synchronized void setPythonPath() { + if(!init) { + try { + String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY); + if(path == null) { + log.info("Setting python default path"); + File[] packages = numpy.cachePackages(); + Py_SetPath(packages); + } + else { + log.info("Setting python path " + path); + StringBuffer sb = new StringBuffer(); + File[] packages = numpy.cachePackages(); + + JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE,DEFAULT_APPEND_TYPE).toUpperCase()); + switch(pathAppendValue) { + case BEFORE: + for(File cacheDir : packages) { + sb.append(cacheDir); + sb.append(java.io.File.pathSeparator); + } + + sb.append(path); + + log.info("Prepending javacpp python path " + sb.toString()); + break; + case AFTER: + sb.append(path); + + for(File cacheDir : packages) { + sb.append(cacheDir); + sb.append(java.io.File.pathSeparator); + } + + log.info("Appending javacpp python path " + sb.toString()); + break; + case NONE: + log.info("Not appending javacpp path"); + sb.append(path); + break; + } + + //prepend the javacpp packages + log.info("Final python path " + sb.toString()); + + Py_SetPath(sb.toString()); + } + } catch (IOException e) { + log.error("Failed to set python path.", e); + } + } + else { + throw new IllegalStateException("Unable to reset python path. Already initialized."); + } + } + + /** + * Initialize the name space and the python execution + * Calling this method more than once will be a no op + */ + public static synchronized void init() { + if(init) { + return; + } + + try(InputStream is = new org.nd4j.linalg.io.ClassPathResource("pythonexec/clear_vars.py").getInputStream()) { + clearVarsCode = IOUtils.toString(new java.io.InputStreamReader(is)); + } catch (java.io.IOException e) { + throw new IllegalStateException("Unable to read pythonexec/clear_vars.py"); + } + + log.info("CPython: PyEval_InitThreads()"); + PyEval_InitThreads(); + log.info("CPython: Py_InitializeEx()"); + Py_InitializeEx(0); + log.info("CPython: PyGILState_Release()"); + init = true; + interpreters.put(MAIN_INTERPRETER_NAME, PyThreadState_Get()); + numpy._import_array(); + applyPatches(); + } + + + /** + * Run {@link #resetInterpreter(String)} + * on all interpreters. + */ + public static void resetAllInterpreters() { + for(String interpreter : interpreters.keySet()) { + resetInterpreter(interpreter); + } + } + + /** + * Reset the main interpreter. + * For more information see {@link #resetInterpreter(String)} + */ + public static void resetMainInterpreter() { + resetInterpreter(MAIN_INTERPRETER_NAME); + } + + /** + * Reset the interpreter with the given name. + * Runs pythonexec/clear_vars.py + * For more information see: + * https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script + * @param interpreterName the interpreter name to + * reset + */ + public static synchronized void resetInterpreter(String interpreterName) { + Preconditions.checkState(hasInterpreter(interpreterName)); + log.info("Resetting interpreter " + interpreterName); + String oldInterpreter = currentInterpreter; + setInterpreter(interpreterName); + exec("pass"); + //exec(interpreterName); // ?? + setInterpreter(oldInterpreter); + } + + /** + * Clear the non main intrepreters. + */ + public static void clearNonMainInterpreters() { + for(String key : interpreters.keySet()) { + if(!key.equals(MAIN_INTERPRETER_NAME)) { + deleteInterpreter(key); + } + } + } + + public static PythonVariables defaultPythonVariableOutput() { + PythonVariables ret = new PythonVariables(); + ret.add(ALL_VARIABLES_KEY, PythonVariables.Type.DICT); + return ret; + } + + /** + * Return the python path being used. + * @return a string specifying the python path in use + */ + public static String getPythonPath() { + return new BytePointer(Py_GetPath()).getString(); + } + static { + setPythonPath(); init(); } - public static void init(){ - log.info("CPython: Py_InitializeEx()"); - Py_InitializeEx(1); - log.info("CPython: PyEval_InitThreads()"); - PyEval_InitThreads(); - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - log.info("CPython: PyThreadState_Get()"); + + /* ---------sub-interpreter and gil management-----------*/ + public static void setInterpreter(String interpreterName) { + if (!hasInterpreter(interpreterName)){ + PyThreadState main = PyThreadState_Get(); + PyThreadState ts = Py_NewInterpreter(); + + interpreters.put(interpreterName, ts); + PyThreadState_Swap(main); + } + + currentInterpreter = interpreterName; + } + + /** + * Returns the current interpreter. + * @return + */ + public static String getInterpreter() { + return currentInterpreter; + } + + + public static boolean hasInterpreter(String interpreterName){ + return interpreters.containsKey(interpreterName); + } + + public static void deleteInterpreter(String interpreterName) { + if (interpreterName.equals("main")){ + throw new IllegalArgumentException("Can not delete main interpreter"); + } + + Py_EndInterpreter(interpreters.remove(interpreterName)); + } + + private static synchronized void acquireGIL() { + log.info("acquireGIL()"); + log.info("CPython: PyEval_SaveThread()"); + mainThreadState = PyEval_SaveThread(); + log.info("CPython: PyThreadState_New()"); + currentThreadState = PyThreadState_New(interpreters.get(currentInterpreter).interp()); + log.info("CPython: PyEval_RestoreThread()"); + PyEval_RestoreThread(currentThreadState); + log.info("CPython: PyThreadState_Swap()"); + PyThreadState_Swap(currentThreadState); + + } + + private static synchronized void releaseGIL() { + log.info("CPython: PyEval_SaveThread()"); PyEval_SaveThread(); + log.info("CPython: PyEval_RestoreThread()"); + PyEval_RestoreThread(mainThreadState); } - public static void free(){ - Py_Finalize(); + /* -------------------*/ + /** + * Print the python version to standard out. + */ + public static void printPythonVersion() { + exec("import sys; print(sys.version) sys.stdout.flush();"); } - private static String inputCode(PythonVariables pyInputs)throws Exception{ - String inputCode = "loc={};"; + + + private static String inputCode(PythonVariables pyInputs)throws Exception { + String inputCode = ""; if (pyInputs == null){ return inputCode; } + Map strInputs = pyInputs.getStrVariables(); Map intInputs = pyInputs.getIntVariables(); Map floatInputs = pyInputs.getFloatVariables(); - Map ndInputs = pyInputs.getNDArrayVariables(); + Map ndInputs = pyInputs.getNdVars(); Map listInputs = pyInputs.getListVariables(); Map fileInputs = pyInputs.getFileVariables(); + Map> dictInputs = pyInputs.getDictVariables(); - String[] VarNames; + String[] varNames; - VarNames = strInputs.keySet().toArray(new String[strInputs.size()]); - for(Object varName: VarNames){ + varNames = strInputs.keySet().toArray(new String[strInputs.size()]); + for(String varName: varNames) { + Preconditions.checkNotNull(varName,"Var name is null!"); + Preconditions.checkNotNull(varName.isEmpty(),"Var name can not be empty!"); String varValue = strInputs.get(varName); - inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; - inputCode += "loc['" + varName + "']=" + varName + "\n"; + //inputCode += varName + "= {}\n"; + if(varValue != null) + inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; + else { + inputCode += varName + " = ''\n"; + } } - VarNames = intInputs.keySet().toArray(new String[intInputs.size()]); - for(String varName: VarNames){ + varNames = intInputs.keySet().toArray(new String[intInputs.size()]); + for(String varName: varNames) { Long varValue = intInputs.get(varName); - inputCode += varName + " = " + varValue.toString() + "\n"; - inputCode += "loc['" + varName + "']=" + varName + "\n"; + if(varValue != null) + inputCode += varName + " = " + varValue.toString() + "\n"; + else { + inputCode += " = 0\n"; + } } - VarNames = floatInputs.keySet().toArray(new String[floatInputs.size()]); - for(String varName: VarNames){ + varNames = dictInputs.keySet().toArray(new String[dictInputs.size()]); + for(String varName: varNames) { + Map varValue = dictInputs.get(varName); + if(varValue != null) { + throw new IllegalArgumentException("Unable to generate input code for dictionaries."); + } + else { + inputCode += " = {}\n"; + } + } + + varNames = floatInputs.keySet().toArray(new String[floatInputs.size()]); + for(String varName: varNames){ Double varValue = floatInputs.get(varName); - inputCode += varName + " = " + varValue.toString() + "\n"; - inputCode += "loc['" + varName + "']=" + varName + "\n"; + if(varValue != null) + inputCode += varName + " = " + varValue.toString() + "\n"; + else { + inputCode += varName + " = 0.0\n"; + } } - VarNames = listInputs.keySet().toArray(new String[listInputs.size()]); - for (String varName: VarNames){ + varNames = listInputs.keySet().toArray(new String[listInputs.size()]); + for (String varName: varNames) { Object[] varValue = listInputs.get(varName); - String listStr = jArrayToPyString(varValue); - inputCode += varName + " = " + listStr + "\n"; - inputCode += "loc['" + varName + "']=" + varName + "\n"; + if(varValue != null) { + String listStr = jArrayToPyString(varValue); + inputCode += varName + " = " + listStr + "\n"; + } + else { + inputCode += varName + " = []\n"; + } + } - VarNames = fileInputs.keySet().toArray(new String[fileInputs.size()]); - for(Object varName: VarNames){ + varNames = fileInputs.keySet().toArray(new String[fileInputs.size()]); + for(String varName: varNames) { String varValue = fileInputs.get(varName); - inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; - inputCode += "loc['" + varName + "']=" + varName + "\n"; + if(varValue != null) + inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; + else { + inputCode += varName + " = ''\n"; + } } - if (ndInputs.size()> 0){ - inputCode += "import ctypes; import numpy as np;"; - VarNames = ndInputs.keySet().toArray(new String[ndInputs.size()]); + if (!ndInputs.isEmpty()) { + inputCode += "import ctypes\n\nimport sys\nimport numpy as np\n"; + varNames = ndInputs.keySet().toArray(new String[ndInputs.size()]); - String converter = "__arr_converter = lambda addr, shape, type: np.ctypeslib.as_array(ctypes.cast(addr, ctypes.POINTER(type)), shape);"; + String converter = "__arr_converter = lambda addr, shape, type: np.ctypeslib.as_array(ctypes.cast(addr, ctypes.POINTER(type)), shape)\n"; inputCode += converter; - for(String varName: VarNames){ + for(String varName: varNames) { NumpyArray npArr = ndInputs.get(varName); + if(npArr == null) + continue; + npArr = npArr.copy(); String shapeStr = "("; for (long d: npArr.getShape()){ - shapeStr += String.valueOf(d) + ","; + shapeStr += d + ","; } shapeStr += ")"; String code; String ctype; - if (npArr.getDtype() == DataType.FLOAT){ + if (npArr.getDtype() == DataType.FLOAT) { ctype = "ctypes.c_float"; } - else if (npArr.getDtype() == DataType.DOUBLE){ + else if (npArr.getDtype() == DataType.DOUBLE) { ctype = "ctypes.c_double"; } - else if (npArr.getDtype() == DataType.SHORT){ + else if (npArr.getDtype() == DataType.SHORT) { ctype = "ctypes.c_int16"; } - else if (npArr.getDtype() == DataType.INT){ + else if (npArr.getDtype() == DataType.INT) { ctype = "ctypes.c_int32"; } else if (npArr.getDtype() == DataType.LONG){ @@ -152,10 +524,9 @@ public class PythonExecutioner { throw new Exception("Unsupported data type: " + npArr.getDtype().toString() + "."); } - code = "__arr_converter(" + String.valueOf(npArr.getAddress()) + "," + shapeStr + "," + ctype + ")"; - code = varName + "=" + code + "\n"; + code = "__arr_converter(" + npArr.getAddress() + "," + shapeStr + "," + ctype + ")"; + code = varName + "=" + code + "\n"; inputCode += code; - inputCode += "loc['" + varName + "']=" + varName + "\n"; } } @@ -163,49 +534,62 @@ public class PythonExecutioner { } - private static void _readOutputs(PythonVariables pyOutputs){ - String json = read(getTempFile()); + private static synchronized void _readOutputs(PythonVariables pyOutputs) throws IOException { File f = new File(getTempFile()); + Preconditions.checkState(f.exists(),"File " + f.getAbsolutePath() + " failed to get written for reading outputs!"); + String json = FileUtils.readFileToString(f, Charset.defaultCharset()); + log.info("Executioner output: "); + log.info(json); f.delete(); - JSONParser p = new JSONParser(); - try{ - JSONObject jobj = (JSONObject) p.parse(json); - for (String varName: pyOutputs.getVariables()){ + + if(json.isEmpty()) { + log.warn("No json found fore reading outputs. Returning."); + return; + } + + try { + JSONObject jobj = new JSONObject(json); + for (String varName: pyOutputs.getVariables()) { PythonVariables.Type type = pyOutputs.getType(varName); - if (type == PythonVariables.Type.NDARRAY){ + if (type == PythonVariables.Type.NDARRAY) { JSONObject varValue = (JSONObject)jobj.get(varName); - long address = (Long)varValue.get("address"); - JSONArray shapeJson = (JSONArray)varValue.get("shape"); - JSONArray stridesJson = (JSONArray)varValue.get("strides"); + long address = (Long) varValue.getLong("address"); + JSONArray shapeJson = (JSONArray) varValue.get("shape"); + JSONArray stridesJson = (JSONArray) varValue.get("strides"); long[] shape = jsonArrayToLongArray(shapeJson); long[] strides = jsonArrayToLongArray(stridesJson); String dtypeName = (String)varValue.get("dtype"); DataType dtype; - if (dtypeName.equals("float64")){ + if (dtypeName.equals("float64")) { dtype = DataType.DOUBLE; } - else if (dtypeName.equals("float32")){ + else if (dtypeName.equals("float32")) { dtype = DataType.FLOAT; } - else if (dtypeName.equals("int16")){ + else if (dtypeName.equals("int16")) { dtype = DataType.SHORT; } - else if (dtypeName.equals("int32")){ + else if (dtypeName.equals("int32")) { dtype = DataType.INT; } - else if (dtypeName.equals("int64")){ + else if (dtypeName.equals("int64")) { dtype = DataType.LONG; } else{ throw new Exception("Unsupported array type " + dtypeName + "."); } + pyOutputs.setValue(varName, new NumpyArray(address, shape, strides, dtype, true)); - } - else if (type == PythonVariables.Type.LIST){ - JSONArray varValue = (JSONArray)jobj.get(varName); - pyOutputs.setValue(varName, varValue.toArray()); + else if (type == PythonVariables.Type.LIST) { + JSONArray varValue = (JSONArray) jobj.get(varName); + pyOutputs.setValue(varName, varValue); + } + else if (type == PythonVariables.Type.DICT) { + Map map = toMap((JSONObject) jobj.get(varName)); + pyOutputs.setValue(varName, map); + } else{ pyOutputs.setValue(varName, jobj.get(varName)); @@ -217,266 +601,422 @@ public class PythonExecutioner { } } - private static void acquireGIL(){ - log.info("---_enterSubInterpreter()---"); - if (PyGILState_Check() != 1){ - gilStates.put(Thread.currentThread().getId(), PyGILState_Ensure()); - log.info("GIL ensured"); + + + + private static synchronized void _exec(String code) { + log.info(code); + log.info("CPython: PyRun_SimpleStringFlag()"); + + int result = PyRun_SimpleStringFlags(code, null); + if (result != 0) { + log.info("CPython: PyErr_Print"); + PyErr_Print(); + throw new RuntimeException("exec failed"); } } - private static void releaseGIL(){ - if (PyGILState_Check() == 1){ - log.info("Releasing gil..."); - PyGILState_Release(gilStates.get(Thread.currentThread().getId())); - log.info("Gil released."); - } - + private static synchronized void _exec_wrapped(String code) { + _exec(getWrappedCode(code)); } /** * Executes python code. Also manages python thread state. - * @param code + * @param code the code to run */ - public static void exec(String code){ - code = getFunctionalCode("__f_" + Thread.currentThread().getId(), code); + public static void exec(String code) { + code = getWrappedCode(code); + if(code.contains("import numpy") && !getInterpreter().equals("main")) {// FIXME + throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); + } + acquireGIL(); - log.info("CPython: PyRun_SimpleStringFlag()"); - log.info(code); - int result = PyRun_SimpleStringFlags(code, null); - if (result != 0){ - PyErr_Print(); - throw new RuntimeException("exec failed"); + _exec(code); + log.info("Exec done"); + releaseGIL(); + } + + private static boolean _hasGlobalVariable(String varName){ + PyObject mainModule = PyImport_AddModule("__main__"); + PyObject var = PyObject_GetAttrString(mainModule, varName); + boolean hasVar = var != null; + Py_DecRef(var); + return hasVar; + } + + /** + * Executes python code and looks for methods setup() and run() + * If both setup() and run() are found, both are executed for the first + * time and for subsequent calls only run() is executed. + */ + public static void execWithSetupAndRun(String code) { + code = getWrappedCode(code); + if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME + throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); + } + + acquireGIL(); + _exec(code); + if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ + log.debug("setup() and run() methods found."); + if (!_hasGlobalVariable("__setup_done__")){ + log.debug("Calling setup()..."); + _exec("setup()"); + _exec("__setup_done__ = True"); + } + log.debug("Calling run()..."); + _exec("run()"); } log.info("Exec done"); releaseGIL(); } - public static void exec(String code, PythonVariables pyOutputs){ - exec(code + '\n' + outputCode(pyOutputs)); - _readOutputs(pyOutputs); + /** + * Executes python code and looks for methods setup() and run() + * If both setup() and run() are found, both are executed for the first + * time and for subsequent calls only run() is executed. + */ + public static void execWithSetupAndRun(String code, PythonVariables pyOutputs) { + code = getWrappedCode(code); + if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME + throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); + } + + acquireGIL(); + _exec(code); + if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ + log.debug("setup() and run() methods found."); + if (!_hasGlobalVariable("__setup_done__")){ + log.debug("Calling setup()..."); + _exec("setup()"); + _exec("__setup_done__ = True"); + } + log.debug("Calling run()..."); + _exec("__out = run();for (k,v) in __out.items(): globals()[k]=v"); + } + log.info("Exec done"); + try { + + _readOutputs(pyOutputs); + + } catch (IOException e) { + log.error("Failed to read outputs", e); + } + + releaseGIL(); } - public static void exec(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception{ + /** + * Run the given code with the given python outputs + * @param code the code to run + * @param pyOutputs the outputs to run + */ + public static void exec(String code, PythonVariables pyOutputs) { + + exec(code + '\n' + outputCode(pyOutputs)); + try { + + _readOutputs(pyOutputs); + + } catch (IOException e) { + log.error("Failed to read outputs", e); + } + + releaseGIL(); + } + + + /** + * Execute the given python code with the given + * {@link PythonVariables} as inputs and outputs + * @param code the code to run + * @param pyInputs the inputs to the code + * @param pyOutputs the outputs to the code + * @throws Exception + */ + public static void exec(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception { String inputCode = inputCode(pyInputs); exec(inputCode + code, pyOutputs); } - - public static PythonVariables exec(PythonTransform transform) throws Exception{ - if (transform.getInputs() != null && transform.getInputs().getVariables().length > 0){ - throw new Exception("Required inputs not provided."); + /** + * Execute the given python code + * with the {@link PythonVariables} + * inputs and outputs for storing the values + * specified by the user and needed by the user + * as output + * @param code the python code to execute + * @param pyInputs the python variables input in to the python script + * @param pyOutputs the python variables output returned by the python script + * @throws Exception + */ + public static void execWithSetupAndRun(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception { + String inputCode = inputCode(pyInputs); + code = inputCode +code; + code = getWrappedCode(code); + if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME + throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); } - exec(transform.getCode(), null, transform.getOutputs()); - return transform.getOutputs(); + acquireGIL(); + _exec(code); + if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ + log.debug("setup() and run() methods found."); + if (!_hasGlobalVariable("__setup_done__")){ + releaseGIL(); // required + acquireGIL(); + log.debug("Calling setup()..."); + _exec("setup()"); + _exec("__setup_done__ = True"); + }else{ + log.debug("setup() already called once."); + } + log.debug("Calling run()..."); + releaseGIL(); // required + acquireGIL(); + _exec("import inspect\n"+ + "__out = run(**{k:globals()[k]for k in inspect.getfullargspec(run).args})\n"+ + "globals().update(__out)"); + } + releaseGIL(); // required + acquireGIL(); + _exec(outputCode(pyOutputs)); + log.info("Exec done"); + try { + + _readOutputs(pyOutputs); + + } catch (IOException e) { + log.error("Failed to read outputs", e); + } + + releaseGIL(); } - public static PythonVariables exec(PythonTransform transform, PythonVariables inputs)throws Exception{ + + + private static String interpreterNameFromTransform(PythonTransform transform){ + return transform.getName().replace("-", "_"); + } + + + /** + * Run a {@link PythonTransform} with the given inputs + * @param transform the transform to run + * @param inputs the inputs to the transform + * @return the output variables + * @throws Exception + */ + public static PythonVariables exec(PythonTransform transform, PythonVariables inputs)throws Exception { + String name = interpreterNameFromTransform(transform); + setInterpreter(name); + Preconditions.checkNotNull(transform.getOutputs(),"Transform outputs were null!"); exec(transform.getCode(), inputs, transform.getOutputs()); return transform.getOutputs(); } - - - public static String evalSTRING(String varName){ - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - PyObject xObj = PyDict_GetItemString(globals, varName); - PyObject bytes = PyUnicode_AsEncodedString(xObj, "UTF-8", "strict"); - BytePointer bp = PyBytes_AsString(bytes); - String ret = bp.getString(); - Py_DecRef(xObj); - Py_DecRef(bytes); - return ret; + public static PythonVariables execWithSetupAndRun(PythonTransform transform, PythonVariables inputs)throws Exception { + String name = interpreterNameFromTransform(transform); + setInterpreter(name); + Preconditions.checkNotNull(transform.getOutputs(),"Transform outputs were null!"); + execWithSetupAndRun(transform.getCode(), inputs, transform.getOutputs()); + return transform.getOutputs(); } - public static long evalINTEGER(String varName){ - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - PyObject xObj = PyDict_GetItemString(globals, varName); - long ret = PyLong_AsLongLong(xObj); - return ret; + + /** + * Run the code and return the outputs + * @param code the code to run + * @return all python variables + */ + public static PythonVariables execAndReturnAllVariables(String code) { + exec(code + '\n' + outputCodeForAllVariables()); + PythonVariables allVars = new PythonVariables(); + allVars.addDict(ALL_VARIABLES_KEY); + try { + _readOutputs(allVars); + }catch (IOException e) { + log.error("Failed to read outputs", e); + } + + return expandInnerDict(allVars, ALL_VARIABLES_KEY); + } + public static PythonVariables execWithSetupRunAndReturnAllVariables(String code) { + execWithSetupAndRun(code + '\n' + outputCodeForAllVariables()); + PythonVariables allVars = new PythonVariables(); + allVars.addDict(ALL_VARIABLES_KEY); + try { + _readOutputs(allVars); + }catch (IOException e) { + log.error("Failed to read outputs", e); + } + + return expandInnerDict(allVars, ALL_VARIABLES_KEY); } - public static double evalFLOAT(String varName){ - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - PyObject xObj = PyDict_GetItemString(globals, varName); - double ret = PyFloat_AsDouble(xObj); - return ret; + /** + * + * @param code code string to run + * @param pyInputs python input variables + * @return all python variables + * @throws Exception throws when there's an issue while execution of python code + */ + public static PythonVariables execAndReturnAllVariables(String code, PythonVariables pyInputs) throws Exception { + String inputCode = inputCode(pyInputs); + return execAndReturnAllVariables(inputCode + code); + } + public static PythonVariables execWithSetupRunAndReturnAllVariables(String code, PythonVariables pyInputs) throws Exception { + String inputCode = inputCode(pyInputs); + return execWithSetupRunAndReturnAllVariables(inputCode + code); } - public static Object[] evalLIST(String varName) throws Exception{ - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - PyObject xObj = PyDict_GetItemString(globals, varName); - PyObject strObj = PyObject_Str(xObj); - PyObject bytes = PyUnicode_AsEncodedString(strObj, "UTF-8", "strict"); - BytePointer bp = PyBytes_AsString(bytes); - String listStr = bp.getString(); - Py_DecRef(xObj); - Py_DecRef(bytes); - JSONArray jsonArray = (JSONArray)parser.parse(listStr.replace("\'", "\"")); - return jsonArray.toArray(); + + /** + * Evaluate a string based on the + * current variable name. + * This variable named needs to be present + * or defined earlier in python code + * in order to pull out the values. + * + * @param varName the variable name to evaluate + * @return the evaluated value + */ + public static String evalString(String varName) { + PythonVariables vars = new PythonVariables(); + vars.addStr(varName); + exec("print('')", vars); + return vars.getStrValue(varName); } - public static NumpyArray evalNDARRAY(String varName) throws Exception{ - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - PyObject xObj = PyDict_GetItemString(globals, varName); - PyObject arrayInterface = PyObject_GetAttrString(xObj, "__array_interface__"); - PyObject data = PyDict_GetItemString(arrayInterface, "data"); - PyObject zero = PyLong_FromLong(0); - PyObject addressObj = PyObject_GetItem(data, zero); - long address = PyLong_AsLongLong(addressObj); - PyObject shapeObj = PyObject_GetAttrString(xObj, "shape"); - int ndim = (int)PyObject_Size(shapeObj); - PyObject iObj; - long shape[] = new long[ndim]; - for (int i=0; i 0) + outputCode = outputCode.substring(0, outputCode.length() - 1); + outputCode += "})"; + outputCode += "\nwith open('" + getTempFile() + "', 'w') as " + fileVarName + ":" + fileVarName + ".write(" + outputVarName() + ")"; + + return outputCode; } - private static String read(String path){ - try{ - File file = new File(path); - FileInputStream fis = new FileInputStream(file); - byte[] data = new byte[(int) file.length()]; - fis.read(data); - fis.close(); - String str = new String(data, "UTF-8"); - return str; - } - catch (Exception e){ - return ""; - } - } - private static String jArrayToPyString(Object[] array){ + private static String jArrayToPyString(Object[] array) { String str = "["; - for (int i=0; i < array.length; i++){ + for (int i = 0; i < array.length; i++){ Object obj = array[i]; if (obj instanceof Object[]){ str += jArrayToPyString((Object[])obj); @@ -496,32 +1036,109 @@ public class PythonExecutioner { return str; } - private static String escapeStr(String str){ + private static String escapeStr(String str) { + if(str == null) + return null; str = str.replace("\\", "\\\\"); str = str.replace("\"\"\"", "\\\"\\\"\\\""); return str; } - private static String getFunctionalCode(String functionName, String code){ - String out = String.format("def %s():\n", functionName); - for(String line: code.split(Pattern.quote("\n"))){ - out += " " + line + "\n"; + private static String getWrappedCode(String code) { + try(InputStream is = new ClassPathResource("pythonexec/pythonexec.py").getInputStream()) { + String base = IOUtils.toString(is, Charset.defaultCharset()); + StringBuffer indentedCode = new StringBuffer(); + for(String split : code.split("\n")) { + indentedCode.append(" " + split + "\n"); + + } + + String out = base.replace(" pass",indentedCode); + return out; + } catch (IOException e) { + throw new IllegalStateException("Unable to read python code!",e); } - return out + "\n\n" + functionName + "()\n"; + } - private static String getTempFile(){ - String ret = "temp_" + Thread.currentThread().getId() + ".json"; + + private static String getTempFile() { + String ret = "temp_" + Thread.currentThread().getId() + "_" + currentInterpreter + ".json"; log.info(ret); return ret; } - private static long[] jsonArrayToLongArray(JSONArray jsonArray){ - long[] longs = new long[jsonArray.size()]; - for (int i=0; i _getPatches() { + exec("import numpy as np"); + exec( "__overrides_path = np.core.overrides.__file__"); + exec("__random_path = np.random.__file__"); + + List patches = new ArrayList<>(); + + patches.add(new String[]{ + "pythonexec/patch0.py", + evalString("__overrides_path") + }); + patches.add(new String[]{ + "pythonexec/patch1.py", + evalString("__random_path") + }); + + return patches; + } + + private static void _applyPatch(String src, String dest){ + try(InputStream is = new ClassPathResource(src).getInputStream()) { + String patch = IOUtils.toString(is, Charset.defaultCharset()); + FileUtils.write(new File(dest), patch, "utf-8"); + } + catch(IOException e){ + throw new RuntimeException("Error reading resource."); + } + } + + private static boolean _checkPatchApplied(String dest) { + try { + return FileUtils.readFileToString(new File(dest), "utf-8").startsWith("#patch"); + } catch (IOException e) { + throw new RuntimeException("Error patching numpy"); + + } + } + + private static void applyPatches() { + for (String[] patch : _getPatches()){ + if (_checkPatchApplied(patch[1])){ + log.info("Patch already applied for " + patch[1]); + } + else{ + _applyPatch(patch[0], patch[1]); + log.info("Applied patch for " + patch[1]); + } + } + for (String[] patch: _getPatches()){ + if (!_checkPatchApplied(patch[1])){ + throw new RuntimeException("Error patching numpy"); + } + } + } +} \ No newline at end of file diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java index e3b3fb2bf..8f2460035 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java @@ -16,16 +16,29 @@ package org.datavec.python; +import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import org.apache.commons.io.IOUtils; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.Transform; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; +import org.nd4j.base.Preconditions; +import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder; +import org.nd4j.linalg.io.ClassPathResource; +import org.nd4j.shade.jackson.core.JsonProcessingException; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.UUID; +import static org.datavec.python.PythonUtils.schemaToPythonVariables; + /** * Row-wise Transform that applies arbitrary python code on each row * @@ -34,31 +47,87 @@ import java.util.UUID; @NoArgsConstructor @Data -public class PythonTransform implements Transform{ +public class PythonTransform implements Transform { + private String code; - private PythonVariables pyInputs; - private PythonVariables pyOutputs; - private String name; + private PythonVariables inputs; + private PythonVariables outputs; + private String name = UUID.randomUUID().toString(); private Schema inputSchema; private Schema outputSchema; + private String outputDict; + private boolean returnAllVariables; + private boolean setupAndRun = false; - public PythonTransform(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception{ + @Builder + public PythonTransform(String code, + PythonVariables inputs, + PythonVariables outputs, + String name, + Schema inputSchema, + Schema outputSchema, + String outputDict, + boolean returnAllInputs, + boolean setupAndRun) { + Preconditions.checkNotNull(code,"No code found to run!"); this.code = code; - this.pyInputs = pyInputs; - this.pyOutputs = pyOutputs; - this.name = UUID.randomUUID().toString(); + this.returnAllVariables = returnAllInputs; + this.setupAndRun = setupAndRun; + if(inputs != null) + this.inputs = inputs; + if(outputs != null) + this.outputs = outputs; + + if(name != null) + this.name = name; + if (outputDict != null) { + this.outputDict = outputDict; + this.outputs = new PythonVariables(); + this.outputs.addDict(outputDict); + + String helpers; + try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) { + helpers = IOUtils.toString(is, Charset.defaultCharset()); + + }catch (IOException e){ + throw new RuntimeException("Error reading python code"); + } + this.code += "\n\n" + helpers; + this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")"; + } + + try { + if(inputSchema != null) { + this.inputSchema = inputSchema; + if(inputs == null || inputs.isEmpty()) { + this.inputs = schemaToPythonVariables(inputSchema); + } + } + + if(outputSchema != null) { + this.outputSchema = outputSchema; + if(outputs == null || outputs.isEmpty()) { + this.outputs = schemaToPythonVariables(outputSchema); + } + } + }catch(Exception e) { + throw new IllegalStateException(e); + } + } + @Override - public void setInputSchema(Schema inputSchema){ + public void setInputSchema(Schema inputSchema) { + Preconditions.checkNotNull(inputSchema,"No input schema found!"); this.inputSchema = inputSchema; try{ - pyInputs = schemaToPythonVariables(inputSchema); + inputs = schemaToPythonVariables(inputSchema); }catch (Exception e){ throw new RuntimeException(e); } - if (outputSchema == null){ + if (outputSchema == null && outputDict == null){ outputSchema = inputSchema; } @@ -88,12 +157,42 @@ public class PythonTransform implements Transform{ throw new UnsupportedOperationException("Not yet implemented"); } + + + @Override - public List map(List writables){ + public List map(List writables) { PythonVariables pyInputs = getPyInputsFromWritables(writables); + Preconditions.checkNotNull(pyInputs,"Inputs must not be null!"); + + try{ - PythonExecutioner.exec(code, pyInputs, pyOutputs); - return getWritablesFromPyOutputs(pyOutputs); + if (returnAllVariables) { + if (setupAndRun){ + return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs)); + } + return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs)); + } + + if (outputDict != null) { + if (setupAndRun) { + PythonExecutioner.execWithSetupAndRun(this, pyInputs); + }else{ + PythonExecutioner.exec(this, pyInputs); + } + PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict); + return getWritablesFromPyOutputs(out); + } + else { + if (setupAndRun) { + PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs); + }else{ + PythonExecutioner.exec(code, pyInputs, outputs); + } + + return getWritablesFromPyOutputs(outputs); + } + } catch (Exception e){ throw new RuntimeException(e); @@ -102,7 +201,7 @@ public class PythonTransform implements Transform{ @Override public String[] outputColumnNames(){ - return pyOutputs.getVariables(); + return outputs.getVariables(); } @Override @@ -111,7 +210,7 @@ public class PythonTransform implements Transform{ } @Override public String[] columnNames(){ - return pyOutputs.getVariables(); + return outputs.getVariables(); } @Override @@ -124,14 +223,13 @@ public class PythonTransform implements Transform{ } - private PythonVariables getPyInputsFromWritables(List writables){ - + private PythonVariables getPyInputsFromWritables(List writables) { PythonVariables ret = new PythonVariables(); - for (String name: pyInputs.getVariables()){ + for (String name: inputs.getVariables()) { int colIdx = inputSchema.getIndexOfColumn(name); Writable w = writables.get(colIdx); - PythonVariables.Type pyType = pyInputs.getType(name); + PythonVariables.Type pyType = inputs.getType(name); switch (pyType){ case INT: if (w instanceof LongWritable){ @@ -143,7 +241,7 @@ public class PythonTransform implements Transform{ break; case FLOAT: - if (w instanceof DoubleWritable){ + if (w instanceof DoubleWritable) { ret.addFloat(name, ((DoubleWritable)w).get()); } else{ @@ -151,96 +249,99 @@ public class PythonTransform implements Transform{ } break; case STR: - ret.addStr(name, ((Text)w).toString()); + ret.addStr(name, w.toString()); break; case NDARRAY: ret.addNDArray(name,((NDArrayWritable)w).get()); break; + default: + throw new RuntimeException("Unsupported input type:" + pyType); } } return ret; } - private List getWritablesFromPyOutputs(PythonVariables pyOuts){ + private List getWritablesFromPyOutputs(PythonVariables pyOuts) { List out = new ArrayList<>(); - for (int i=0; i dictValue = pyOuts.getDictValue(name); + Map noNullValues = new java.util.HashMap<>(); + for(Map.Entry entry : dictValue.entrySet()) { + if(entry.getValue() != org.json.JSONObject.NULL) { + noNullValues.put(entry.getKey(), entry.getValue()); + } + } + + try { + out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(noNullValues))); + } catch (JsonProcessingException e) { + throw new IllegalStateException("Unable to serialize dictionary " + name + " to json!"); + } + break; + case LIST: + Object[] listValue = pyOuts.getListValue(name); + try { + out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue))); + } catch (JsonProcessingException e) { + throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!"); + } + break; + default: + throw new IllegalStateException("Unable to support type " + pyType.name()); } } return out; } - public PythonTransform(String code) throws Exception{ - this.code = code; - this.name = UUID.randomUUID().toString(); - } - private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{ - PythonVariables pyVars = new PythonVariables(); - int numCols = schema.numColumns(); - for (int i=0; i 0,"Input must have variables. Found none."); + for(Map.Entry entry : input.getVars().entrySet()) { + switch(entry.getValue()) { + case INT: + schemaBuilder.addColumnInteger(entry.getKey()); + break; + case STR: + schemaBuilder.addColumnString(entry.getKey()); + break; + case FLOAT: + schemaBuilder.addColumnFloat(entry.getKey()); + break; + case NDARRAY: + schemaBuilder.addColumnNDArray(entry.getKey(),null); + break; + case BOOL: + schemaBuilder.addColumn(new BooleanMetaData(entry.getKey())); + } + } + + return schemaBuilder.build(); + } + + /** + * Create a {@link Schema} from an input + * {@link PythonVariables} + * Types are mapped to types of the same name + * @param input the input schema + * @return the output python variables. + */ + public static PythonVariables fromSchema(Schema input) { + PythonVariables ret = new PythonVariables(); + for(int i = 0; i < input.numColumns(); i++) { + String currColumnName = input.getName(i); + ColumnType columnType = input.getType(i); + switch(columnType) { + case NDArray: + ret.add(currColumnName, PythonVariables.Type.NDARRAY); + break; + case Boolean: + ret.add(currColumnName, PythonVariables.Type.BOOL); + break; + case Categorical: + case String: + ret.add(currColumnName, PythonVariables.Type.STR); + break; + case Double: + case Float: + ret.add(currColumnName, PythonVariables.Type.FLOAT); + break; + case Integer: + case Long: + ret.add(currColumnName, PythonVariables.Type.INT); + break; + case Bytes: + break; + case Time: + throw new UnsupportedOperationException("Unable to process dates with python yet."); + } + } + + return ret; + } + /** + * Convert a {@link Schema} + * to {@link PythonVariables} + * @param schema the input schema + * @return the output {@link PythonVariables} where each + * name in the map is associated with a column name in the schema. + * A proper type is also chosen based on the schema + * @throws Exception + */ + public static PythonVariables schemaToPythonVariables(Schema schema) throws Exception { + PythonVariables pyVars = new PythonVariables(); + int numCols = schema.numColumns(); + for (int i = 0; i < numCols; i++) { + String colName = schema.getName(i); + ColumnType colType = schema.getType(i); + switch (colType){ + case Long: + case Integer: + pyVars.addInt(colName); + break; + case Double: + case Float: + pyVars.addFloat(colName); + break; + case String: + pyVars.addStr(colName); + break; + case NDArray: + pyVars.addNDArray(colName); + break; + default: + throw new Exception("Unsupported python input type: " + colType.toString()); + } + } + + return pyVars; + } + + + public static NumpyArray mapToNumpyArray(Map map){ + String dtypeName = (String)map.get("dtype"); + DataType dtype; + if (dtypeName.equals("float64")){ + dtype = DataType.DOUBLE; + } + else if (dtypeName.equals("float32")){ + dtype = DataType.FLOAT; + } + else if (dtypeName.equals("int16")){ + dtype = DataType.SHORT; + } + else if (dtypeName.equals("int32")){ + dtype = DataType.INT; + } + else if (dtypeName.equals("int64")){ + dtype = DataType.LONG; + } + else{ + throw new RuntimeException("Unsupported array type " + dtypeName + "."); + } + List shapeList = (List)map.get("shape"); + long[] shape = new long[shapeList.size()]; + for (int i = 0; i < shape.length; i++) { + shape[i] = (Long)shapeList.get(i); + } + + List strideList = (List)map.get("shape"); + long[] stride = new long[strideList.size()]; + for (int i = 0; i < stride.length; i++) { + stride[i] = (Long)strideList.get(i); + } + long address = (Long)map.get("address"); + NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype); + return numpyArray; + } + + public static PythonVariables expandInnerDict(PythonVariables pyvars, String key){ + Map dict = pyvars.getDictValue(key); + String[] keys = (String[])dict.keySet().toArray(new String[dict.keySet().size()]); + PythonVariables pyvars2 = new PythonVariables(); + for (String subkey: keys){ + Object value = dict.get(subkey); + if (value instanceof Map){ + Map map = (Map)value; + if (map.containsKey("_is_numpy_array")){ + pyvars2.addNDArray(subkey, mapToNumpyArray(map)); + + } + else{ + pyvars2.addDict(subkey, (Map)value); + } + + } + else if (value instanceof List){ + pyvars2.addList(subkey, ((List) value).toArray()); + } + else if (value instanceof String){ + System.out.println((String)value); + pyvars2.addStr(subkey, (String) value); + } + else if (value instanceof Integer || value instanceof Long) { + Number number = (Number) value; + pyvars2.addInt(subkey, number.intValue()); + } + else if (value instanceof Float || value instanceof Double) { + Number number = (Number) value; + pyvars2.addFloat(subkey, number.doubleValue()); + } + else if (value instanceof NumpyArray){ + pyvars2.addNDArray(subkey, (NumpyArray)value); + } + else if (value == null){ + pyvars2.addStr(subkey, "None"); // FixMe + } + else{ + throw new RuntimeException("Unsupported type!" + value); + } + } + return pyvars2; + } + + public static long[] jsonArrayToLongArray(JSONArray jsonArray){ + long[] longs = new long[jsonArray.length()]; + for (int i=0; i toMap(JSONObject jsonobj) { + Map map = new HashMap<>(); + String[] keys = (String[])jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]); + for (String key: keys){ + Object value = jsonobj.get(key); + if (value instanceof JSONArray) { + value = toList((JSONArray) value); + } else if (value instanceof JSONObject) { + JSONObject jsonobj2 = (JSONObject)value; + if (jsonobj2.has("_is_numpy_array")){ + value = jsonToNumpyArray(jsonobj2); + } + else{ + value = toMap(jsonobj2); + } + + } + + map.put(key, value); + } return map; + } + + + public static List toList(JSONArray array) { + List list = new ArrayList<>(); + for (int i = 0; i < array.length(); i++) { + Object value = array.get(i); + if (value instanceof JSONArray) { + value = toList((JSONArray) value); + } else if (value instanceof JSONObject) { + JSONObject jsonobj2 = (JSONObject) value; + if (jsonobj2.has("_is_numpy_array")) { + value = jsonToNumpyArray(jsonobj2); + } else { + value = toMap(jsonobj2); + } + } + list.add(value); + } + return list; + } + + + private static NumpyArray jsonToNumpyArray(JSONObject map){ + String dtypeName = (String)map.get("dtype"); + DataType dtype; + if (dtypeName.equals("float64")){ + dtype = DataType.DOUBLE; + } + else if (dtypeName.equals("float32")){ + dtype = DataType.FLOAT; + } + else if (dtypeName.equals("int16")){ + dtype = DataType.SHORT; + } + else if (dtypeName.equals("int32")){ + dtype = DataType.INT; + } + else if (dtypeName.equals("int64")){ + dtype = DataType.LONG; + } + else{ + throw new RuntimeException("Unsupported array type " + dtypeName + "."); + } + List shapeList = (List)map.get("shape"); + long[] shape = new long[shapeList.size()]; + for (int i = 0; i < shape.length; i++) { + shape[i] = (Long)shapeList.get(i); + } + + List strideList = (List)map.get("shape"); + long[] stride = new long[strideList.size()]; + for (int i = 0; i < stride.length; i++) { + stride[i] = (Long)strideList.get(i); + } + long address = (Long)map.get("address"); + NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype); + return numpyArray; + } + + +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java index fb05e7052..4d04f1d87 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java @@ -17,8 +17,8 @@ package org.datavec.python; import lombok.Data; -import org.json.simple.JSONArray; -import org.json.simple.JSONObject; +import org.json.JSONObject; +import org.json.JSONArray; import org.nd4j.linalg.api.ndarray.INDArray; import java.io.Serializable; @@ -31,8 +31,8 @@ import java.util.*; * @author Fariz Rahman */ -@Data -public class PythonVariables implements Serializable{ +@lombok.Data +public class PythonVariables implements java.io.Serializable { public enum Type{ BOOL, @@ -41,23 +41,29 @@ public class PythonVariables implements Serializable{ FLOAT, NDARRAY, LIST, - FILE + FILE, + DICT } - private Map strVars = new HashMap(); - private Map intVars = new HashMap(); - private Map floatVars = new HashMap(); - private Map boolVars = new HashMap(); - private Map ndVars = new HashMap(); - private Map listVars = new HashMap(); - private Map fileVars = new HashMap(); - - private Map vars = new HashMap(); - - private Map maps = new HashMap(); + private java.util.Map strVariables = new java.util.LinkedHashMap<>(); + private java.util.Map intVariables = new java.util.LinkedHashMap<>(); + private java.util.Map floatVariables = new java.util.LinkedHashMap<>(); + private java.util.Map boolVariables = new java.util.LinkedHashMap<>(); + private java.util.Map ndVars = new java.util.LinkedHashMap<>(); + private java.util.Map listVariables = new java.util.LinkedHashMap<>(); + private java.util.Map fileVariables = new java.util.LinkedHashMap<>(); + private java.util.Map> dictVariables = new java.util.LinkedHashMap<>(); + private java.util.Map vars = new java.util.LinkedHashMap<>(); + private java.util.Map maps = new java.util.LinkedHashMap<>(); + /** + * Returns a copy of the variable + * schema in this array without the values + * @return an empty variables clone + * with no values + */ public PythonVariables copySchema(){ PythonVariables ret = new PythonVariables(); for (String varName: getVariables()){ @@ -66,15 +72,30 @@ public class PythonVariables implements Serializable{ } return ret; } - public PythonVariables(){ - maps.put(Type.BOOL, boolVars); - maps.put(Type.STR, strVars); - maps.put(Type.INT, intVars); - maps.put(Type.FLOAT, floatVars); - maps.put(Type.NDARRAY, ndVars); - maps.put(Type.LIST, listVars); - maps.put(Type.FILE, fileVars); + /** + * + */ + public PythonVariables() { + maps.put(PythonVariables.Type.BOOL, boolVariables); + maps.put(PythonVariables.Type.STR, strVariables); + maps.put(PythonVariables.Type.INT, intVariables); + maps.put(PythonVariables.Type.FLOAT, floatVariables); + maps.put(PythonVariables.Type.NDARRAY, ndVars); + maps.put(PythonVariables.Type.LIST, listVariables); + maps.put(PythonVariables.Type.FILE, fileVariables); + maps.put(PythonVariables.Type.DICT, dictVariables); + + } + + + + /** + * + * @return true if there are no variables. + */ + public boolean isEmpty() { + return getVariables().length < 1; } @@ -105,6 +126,9 @@ public class PythonVariables implements Serializable{ break; case FILE: addFile(name); + break; + case DICT: + addDict(name); } } @@ -113,252 +137,463 @@ public class PythonVariables implements Serializable{ * @param name name of the variable * @param type type of the variable * @param value value of the variable (must be instance of expected type) - * @throws Exception */ - public void add (String name, Type type, Object value) throws Exception{ + public void add(String name, Type type, Object value) { add(name, type); setValue(name, value); } + + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ + public void addDict(String name) { + vars.put(name, PythonVariables.Type.DICT); + dictVariables.put(name,null); + } + + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addBool(String name){ - vars.put(name, Type.BOOL); - boolVars.put(name, null); + vars.put(name, PythonVariables.Type.BOOL); + boolVariables.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addStr(String name){ - vars.put(name, Type.STR); - strVars.put(name, null); + vars.put(name, PythonVariables.Type.STR); + strVariables.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addInt(String name){ - vars.put(name, Type.INT); - intVars.put(name, null); + vars.put(name, PythonVariables.Type.INT); + intVariables.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addFloat(String name){ - vars.put(name, Type.FLOAT); - floatVars.put(name, null); + vars.put(name, PythonVariables.Type.FLOAT); + floatVariables.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addNDArray(String name){ - vars.put(name, Type.NDARRAY); + vars.put(name, PythonVariables.Type.NDARRAY); ndVars.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addList(String name){ - vars.put(name, Type.LIST); - listVars.put(name, null); + vars.put(name, PythonVariables.Type.LIST); + listVariables.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addFile(String name){ - vars.put(name, Type.FILE); - fileVars.put(name, null); - } - public void addBool(String name, boolean value){ - vars.put(name, Type.BOOL); - boolVars.put(name, value); + vars.put(name, PythonVariables.Type.FILE); + fileVariables.put(name, null); } - public void addStr(String name, String value){ - vars.put(name, Type.STR); - strVars.put(name, value); + /** + * Add a boolean variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addBool(String name, boolean value) { + vars.put(name, PythonVariables.Type.BOOL); + boolVariables.put(name, value); } - public void addInt(String name, int value){ - vars.put(name, Type.INT); - intVars.put(name, (long)value); + /** + * Add a string variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addStr(String name, String value) { + vars.put(name, PythonVariables.Type.STR); + strVariables.put(name, value); } - public void addInt(String name, long value){ - vars.put(name, Type.INT); - intVars.put(name, value); + /** + * Add an int variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addInt(String name, int value) { + vars.put(name, PythonVariables.Type.INT); + intVariables.put(name, (long)value); } - public void addFloat(String name, double value){ - vars.put(name, Type.FLOAT); - floatVars.put(name, value); + /** + * Add a long variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addInt(String name, long value) { + vars.put(name, PythonVariables.Type.INT); + intVariables.put(name, value); } - public void addFloat(String name, float value){ - vars.put(name, Type.FLOAT); - floatVars.put(name, (double)value); + /** + * Add a double variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addFloat(String name, double value) { + vars.put(name, PythonVariables.Type.FLOAT); + floatVariables.put(name, value); } - public void addNDArray(String name, NumpyArray value){ - vars.put(name, Type.NDARRAY); + /** + * Add a float variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addFloat(String name, float value) { + vars.put(name, PythonVariables.Type.FLOAT); + floatVariables.put(name, (double)value); + } + + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + * @param value the value to add + */ + public void addNDArray(String name, NumpyArray value) { + vars.put(name, PythonVariables.Type.NDARRAY); ndVars.put(name, value); } - public void addNDArray(String name, INDArray value){ - vars.put(name, Type.NDARRAY); + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + * @param value the value to add + */ + public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) { + vars.put(name, PythonVariables.Type.NDARRAY); ndVars.put(name, new NumpyArray(value)); } - public void addList(String name, Object[] value){ - vars.put(name, Type.LIST); - listVars.put(name, value); + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + * @param value the value to add + */ + public void addList(String name, Object[] value) { + vars.put(name, PythonVariables.Type.LIST); + listVariables.put(name, value); } - public void addFile(String name, String value){ - vars.put(name, Type.FILE); - fileVars.put(name, value); + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + * @param value the value to add + */ + public void addFile(String name, String value) { + vars.put(name, PythonVariables.Type.FILE); + fileVariables.put(name, value); } + + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + * @param value the value to add + */ + public void addDict(String name, java.util.Map value) { + vars.put(name, PythonVariables.Type.DICT); + dictVariables.put(name, value); + } /** * * @param name name of the variable * @param value new value for the variable - * @throws Exception */ public void setValue(String name, Object value) { Type type = vars.get(name); - if (type == Type.BOOL){ - boolVars.put(name, (Boolean)value); + if (type == PythonVariables.Type.BOOL){ + boolVariables.put(name, (Boolean)value); } - else if (type == Type.INT){ - if (value instanceof Long){ - intVars.put(name, ((Long)value)); - } - else if (value instanceof Integer){ - intVars.put(name, ((Integer)value).longValue()); - - } + else if (type == PythonVariables.Type.INT){ + Number number = (Number) value; + intVariables.put(name, number.longValue()); } - else if (type == Type.FLOAT){ - floatVars.put(name, (Double)value); + else if (type == PythonVariables.Type.FLOAT){ + Number number = (Number) value; + floatVariables.put(name, number.doubleValue()); } - else if (type == Type.NDARRAY){ + else if (type == PythonVariables.Type.NDARRAY){ if (value instanceof NumpyArray){ ndVars.put(name, (NumpyArray)value); } - else if (value instanceof INDArray){ - ndVars.put(name, new NumpyArray((INDArray) value)); + else if (value instanceof org.nd4j.linalg.api.ndarray.INDArray) { + ndVars.put(name, new NumpyArray((org.nd4j.linalg.api.ndarray.INDArray) value)); } else{ throw new RuntimeException("Unsupported type: " + value.getClass().toString()); } } - else if (type == Type.LIST){ - listVars.put(name, (Object[]) value); + else if (type == PythonVariables.Type.LIST) { + if (value instanceof java.util.List) { + value = ((java.util.List) value).toArray(); + listVariables.put(name, (Object[]) value); + } + else if(value instanceof org.json.JSONArray) { + org.json.JSONArray jsonArray = (org.json.JSONArray) value; + Object[] copyArr = new Object[jsonArray.length()]; + for(int i = 0; i < copyArr.length; i++) { + copyArr[i] = jsonArray.get(i); + } + listVariables.put(name, copyArr); + + } + else { + listVariables.put(name, (Object[]) value); + } } - else if (type == Type.FILE){ - fileVars.put(name, (String)value); + else if(type == PythonVariables.Type.DICT) { + dictVariables.put(name,(java.util.Map) value); + } + else if (type == PythonVariables.Type.FILE){ + fileVariables.put(name, (String)value); } else{ - strVars.put(name, (String)value); + strVariables.put(name, (String)value); } } - public Object getValue(String name){ + /** + * Do a general object lookup. + * The look up will happen relative to the {@link Type} + * of variable is described in the + * @param name the name of the variable to get + * @return teh value for the variable with the given name + */ + public Object getValue(String name) { Type type = vars.get(name); - Map map = maps.get(type); + java.util.Map map = maps.get(type); return map.get(name); } + + /** + * Returns a boolean variable with the given name. + * @param name the variable name to get the value for + * @return the retrieved boolean value + */ + public boolean getBooleanValue(String name) { + return boolVariables.get(name); + } + + /** + * + * @param name the variable name + * @return the dictionary value + */ + public java.util.Map getDictValue(String name) { + return dictVariables.get(name); + } + + /** + /** + * + * @param name the variable name + * @return the string value + */ public String getStrValue(String name){ - return strVars.get(name); + return strVariables.get(name); } - public long getIntValue(String name){ - return intVars.get(name); + /** + * + * @param name the variable name + * @return the long value + */ + public Long getIntValue(String name){ + return intVariables.get(name); } - public double getFloatValue(String name){ - return floatVars.get(name); + /** + * + * @param name the variable name + * @return the float value + */ + public Double getFloatValue(String name){ + return floatVariables.get(name); } + /** + * + * @param name the variable name + * @return the numpy array value + */ public NumpyArray getNDArrayValue(String name){ return ndVars.get(name); } + /** + * + * @param name the variable name + * @return the list value as an object array + */ public Object[] getListValue(String name){ - return listVars.get(name); + return listVariables.get(name); } + /** + * + * @param name the variable name + * @return the value of the given file name + */ public String getFileValue(String name){ - return fileVars.get(name); + return fileVariables.get(name); } + /** + * Returns the type for the given variable name + * @param name the name of the variable to get the type for + * @return the type for the given variable + */ public Type getType(String name){ return vars.get(name); } + /** + * Get all the variables present as a string array + * @return the variable names for this variable sset + */ public String[] getVariables() { String[] strArr = new String[vars.size()]; return vars.keySet().toArray(strArr); } - public Map getBoolVariables(){ - return boolVars; - } - public Map getStrVariables(){ - return strVars; - } - - public Map getIntVariables(){ - return intVars; - } - - public Map getFloatVariables(){ - return floatVars; - } - - public Map getNDArrayVariables(){ - return ndVars; - } - - public Map getListVariables(){ - return listVars; - } - - public Map getFileVariables(){ - return fileVars; - } - - public JSONArray toJSON(){ - JSONArray arr = new JSONArray(); + /** + * This variables set as its json representation (an array of json objects) + * @return the json array output + */ + public org.json.JSONArray toJSON(){ + org.json.JSONArray arr = new org.json.JSONArray(); for (String varName: getVariables()){ - JSONObject var = new JSONObject(); + org.json.JSONObject var = new org.json.JSONObject(); var.put("name", varName); String varType = getType(varName).toString(); var.put("type", varType); - arr.add(var); + arr.put(var); } return arr; } - public static PythonVariables fromJSON(JSONArray jsonArray){ + /** + * Create a schema from a map. + * This is an empty PythonVariables + * that just contains names and types with no values + * @param inputTypes the input types to convert + * @return the schema from the given map + */ + public static PythonVariables schemaFromMap(java.util.Map inputTypes) { + PythonVariables ret = new PythonVariables(); + for(java.util.Map.Entry entry : inputTypes.entrySet()) { + ret.add(entry.getKey(), PythonVariables.Type.valueOf(entry.getValue())); + } + + return ret; + } + + /** + * Get the python variable state relative to the + * input json array + * @param jsonArray the input json array + * @return the python variables based on the input json array + */ + public static PythonVariables fromJSON(org.json.JSONArray jsonArray){ PythonVariables pyvars = new PythonVariables(); - for (int i=0; i" + +def maybe_serialize_ndarray_metadata(x): + return serialize_ndarray_metadata(x) if __is_numpy_array(x) else x + + +def serialize_ndarray_metadata(x): + return {"address": x.__array_interface__['data'][0], + "shape": x.shape, + "strides": x.strides, + "dtype": str(x.dtype), + "_is_numpy_array": True} if __is_numpy_array(x) else x + + +def is_json_ready(key, value): + return key is not 'f2' and not inspect.ismodule(value) \ + and not hasattr(value, '__call__') + diff --git a/datavec/datavec-python/src/main/resources/pythonexec/patch0.py b/datavec/datavec-python/src/main/resources/pythonexec/patch0.py new file mode 100644 index 000000000..d2ed3d5e5 --- /dev/null +++ b/datavec/datavec-python/src/main/resources/pythonexec/patch0.py @@ -0,0 +1,202 @@ +#patch + +"""Implementation of __array_function__ overrides from NEP-18.""" +import collections +import functools +import os + +from numpy.core._multiarray_umath import ( + add_docstring, implement_array_function, _get_implementing_args) +from numpy.compat._inspect import getargspec + + +ENABLE_ARRAY_FUNCTION = bool( + int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) + + +ARRAY_FUNCTION_ENABLED = ENABLE_ARRAY_FUNCTION # backward compat + + +_add_docstring = add_docstring + + +def add_docstring(*args): + try: + _add_docstring(*args) + except: + pass + + +add_docstring( + implement_array_function, + """ + Implement a function with checks for __array_function__ overrides. + + All arguments are required, and can only be passed by position. + + Arguments + --------- + implementation : function + Function that implements the operation on NumPy array without + overrides when called like ``implementation(*args, **kwargs)``. + public_api : function + Function exposed by NumPy's public API originally called like + ``public_api(*args, **kwargs)`` on which arguments are now being + checked. + relevant_args : iterable + Iterable of arguments to check for __array_function__ methods. + args : tuple + Arbitrary positional arguments originally passed into ``public_api``. + kwargs : dict + Arbitrary keyword arguments originally passed into ``public_api``. + + Returns + ------- + Result from calling ``implementation()`` or an ``__array_function__`` + method, as appropriate. + + Raises + ------ + TypeError : if no implementation is found. + """) + + +# exposed for testing purposes; used internally by implement_array_function +add_docstring( + _get_implementing_args, + """ + Collect arguments on which to call __array_function__. + + Parameters + ---------- + relevant_args : iterable of array-like + Iterable of possibly array-like arguments to check for + __array_function__ methods. + + Returns + ------- + Sequence of arguments with __array_function__ methods, in the order in + which they should be called. + """) + + +ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') + + +def verify_matching_signatures(implementation, dispatcher): + """Verify that a dispatcher function has the right signature.""" + implementation_spec = ArgSpec(*getargspec(implementation)) + dispatcher_spec = ArgSpec(*getargspec(dispatcher)) + + if (implementation_spec.args != dispatcher_spec.args or + implementation_spec.varargs != dispatcher_spec.varargs or + implementation_spec.keywords != dispatcher_spec.keywords or + (bool(implementation_spec.defaults) != + bool(dispatcher_spec.defaults)) or + (implementation_spec.defaults is not None and + len(implementation_spec.defaults) != + len(dispatcher_spec.defaults))): + raise RuntimeError('implementation and dispatcher for %s have ' + 'different function signatures' % implementation) + + if implementation_spec.defaults is not None: + if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults): + raise RuntimeError('dispatcher functions can only use None for ' + 'default argument values') + + +def set_module(module): + """Decorator for overriding __module__ on a function or class. + + Example usage:: + + @set_module('numpy') + def example(): + pass + + assert example.__module__ == 'numpy' + """ + def decorator(func): + if module is not None: + func.__module__ = module + return func + return decorator + + +def array_function_dispatch(dispatcher, module=None, verify=True, + docs_from_dispatcher=False): + """Decorator for adding dispatch with the __array_function__ protocol. + + See NEP-18 for example usage. + + Parameters + ---------- + dispatcher : callable + Function that when called like ``dispatcher(*args, **kwargs)`` with + arguments from the NumPy function call returns an iterable of + array-like arguments to check for ``__array_function__``. + module : str, optional + __module__ attribute to set on new function, e.g., ``module='numpy'``. + By default, module is copied from the decorated function. + verify : bool, optional + If True, verify the that the signature of the dispatcher and decorated + function signatures match exactly: all required and optional arguments + should appear in order with the same names, but the default values for + all optional arguments should be ``None``. Only disable verification + if the dispatcher's signature needs to deviate for some particular + reason, e.g., because the function has a signature like + ``func(*args, **kwargs)``. + docs_from_dispatcher : bool, optional + If True, copy docs from the dispatcher function onto the dispatched + function, rather than from the implementation. This is useful for + functions defined in C, which otherwise don't have docstrings. + + Returns + ------- + Function suitable for decorating the implementation of a NumPy function. + """ + + if not ENABLE_ARRAY_FUNCTION: + # __array_function__ requires an explicit opt-in for now + def decorator(implementation): + if module is not None: + implementation.__module__ = module + if docs_from_dispatcher: + add_docstring(implementation, dispatcher.__doc__) + return implementation + return decorator + + def decorator(implementation): + if verify: + verify_matching_signatures(implementation, dispatcher) + + if docs_from_dispatcher: + add_docstring(implementation, dispatcher.__doc__) + + @functools.wraps(implementation) + def public_api(*args, **kwargs): + relevant_args = dispatcher(*args, **kwargs) + return implement_array_function( + implementation, public_api, relevant_args, args, kwargs) + + if module is not None: + public_api.__module__ = module + + # TODO: remove this when we drop Python 2 support (functools.wraps) + # adds __wrapped__ automatically in later versions) + public_api.__wrapped__ = implementation + + return public_api + + return decorator + + +def array_function_from_dispatcher( + implementation, module=None, verify=True, docs_from_dispatcher=True): + """Like array_function_dispatcher, but with function arguments flipped.""" + + def decorator(dispatcher): + return array_function_dispatch( + dispatcher, module, verify=verify, + docs_from_dispatcher=docs_from_dispatcher)(implementation) + return decorator diff --git a/datavec/datavec-python/src/main/resources/pythonexec/patch1.py b/datavec/datavec-python/src/main/resources/pythonexec/patch1.py new file mode 100644 index 000000000..890852bbc --- /dev/null +++ b/datavec/datavec-python/src/main/resources/pythonexec/patch1.py @@ -0,0 +1,172 @@ +#patch 1 + +""" +======================== +Random Number Generation +======================== + +==================== ========================================================= +Utility functions +============================================================================== +random_sample Uniformly distributed floats over ``[0, 1)``. +random Alias for `random_sample`. +bytes Uniformly distributed random bytes. +random_integers Uniformly distributed integers in a given range. +permutation Randomly permute a sequence / generate a random sequence. +shuffle Randomly permute a sequence in place. +seed Seed the random number generator. +choice Random sample from 1-D array. + +==================== ========================================================= + +==================== ========================================================= +Compatibility functions +============================================================================== +rand Uniformly distributed values. +randn Normally distributed values. +ranf Uniformly distributed floating point numbers. +randint Uniformly distributed integers in a given range. +==================== ========================================================= + +==================== ========================================================= +Univariate distributions +============================================================================== +beta Beta distribution over ``[0, 1]``. +binomial Binomial distribution. +chisquare :math:`\\chi^2` distribution. +exponential Exponential distribution. +f F (Fisher-Snedecor) distribution. +gamma Gamma distribution. +geometric Geometric distribution. +gumbel Gumbel distribution. +hypergeometric Hypergeometric distribution. +laplace Laplace distribution. +logistic Logistic distribution. +lognormal Log-normal distribution. +logseries Logarithmic series distribution. +negative_binomial Negative binomial distribution. +noncentral_chisquare Non-central chi-square distribution. +noncentral_f Non-central F distribution. +normal Normal / Gaussian distribution. +pareto Pareto distribution. +poisson Poisson distribution. +power Power distribution. +rayleigh Rayleigh distribution. +triangular Triangular distribution. +uniform Uniform distribution. +vonmises Von Mises circular distribution. +wald Wald (inverse Gaussian) distribution. +weibull Weibull distribution. +zipf Zipf's distribution over ranked data. +==================== ========================================================= + +==================== ========================================================= +Multivariate distributions +============================================================================== +dirichlet Multivariate generalization of Beta distribution. +multinomial Multivariate generalization of the binomial distribution. +multivariate_normal Multivariate generalization of the normal distribution. +==================== ========================================================= + +==================== ========================================================= +Standard distributions +============================================================================== +standard_cauchy Standard Cauchy-Lorentz distribution. +standard_exponential Standard exponential distribution. +standard_gamma Standard Gamma distribution. +standard_normal Standard normal distribution. +standard_t Standard Student's t-distribution. +==================== ========================================================= + +==================== ========================================================= +Internal functions +============================================================================== +get_state Get tuple representing internal state of generator. +set_state Set state of generator. +==================== ========================================================= + +""" +from __future__ import division, absolute_import, print_function + +import warnings + +__all__ = [ + 'beta', + 'binomial', + 'bytes', + 'chisquare', + 'choice', + 'dirichlet', + 'exponential', + 'f', + 'gamma', + 'geometric', + 'get_state', + 'gumbel', + 'hypergeometric', + 'laplace', + 'logistic', + 'lognormal', + 'logseries', + 'multinomial', + 'multivariate_normal', + 'negative_binomial', + 'noncentral_chisquare', + 'noncentral_f', + 'normal', + 'pareto', + 'permutation', + 'poisson', + 'power', + 'rand', + 'randint', + 'randn', + 'random_integers', + 'random_sample', + 'rayleigh', + 'seed', + 'set_state', + 'shuffle', + 'standard_cauchy', + 'standard_exponential', + 'standard_gamma', + 'standard_normal', + 'standard_t', + 'triangular', + 'uniform', + 'vonmises', + 'wald', + 'weibull', + 'zipf' +] + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="numpy.ndarray size changed") + try: + from .mtrand import * + # Some aliases: + ranf = random = sample = random_sample + __all__.extend(['ranf', 'random', 'sample']) + except: + warnings.warn("numpy.random is not available when using multiple interpreters!") + + + +def __RandomState_ctor(): + """Return a RandomState instance. + + This function exists solely to assist (un)pickling. + + Note that the state of the RandomState returned here is irrelevant, as this function's + entire purpose is to return a newly allocated RandomState whose state pickle can set. + Consequently the RandomState returned by this function is a freshly allocated copy + with a seed=0. + + See https://github.com/numpy/numpy/issues/4763 for a detailed discussion + + """ + return RandomState(seed=0) + +from numpy._pytesttester import PytestTester +test = PytestTester(__name__) +del PytestTester diff --git a/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py b/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py new file mode 100644 index 000000000..dbdceff0e --- /dev/null +++ b/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py @@ -0,0 +1,20 @@ +import sys +import traceback +import json +import inspect + + +try: + + pass + sys.stdout.flush() + sys.stderr.flush() +except Exception as ex: + try: + exc_info = sys.exc_info() + finally: + print(ex) + traceback.print_exception(*exc_info) + sys.stdout.flush() + sys.stderr.flush() + diff --git a/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py b/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py new file mode 100644 index 000000000..ac6f5b1c1 --- /dev/null +++ b/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py @@ -0,0 +1,50 @@ +def __is_numpy_array(x): + return str(type(x))== "" + +def __maybe_serialize_ndarray_metadata(x): + return __serialize_ndarray_metadata(x) if __is_numpy_array(x) else x + + +def __serialize_ndarray_metadata(x): + return {"address": x.__array_interface__['data'][0], + "shape": x.shape, + "strides": x.strides, + "dtype": str(x.dtype), + "_is_numpy_array": True} if __is_numpy_array(x) else x + + +def __serialize_list(x): + import json + return json.dumps(__recursive_serialize_list(x)) + + +def __serialize_dict(x): + import json + return json.dumps(__recursive_serialize_dict(x)) + +def __recursive_serialize_list(x): + out = [] + for i in x: + if __is_numpy_array(i): + out.append(__serialize_ndarray_metadata(i)) + elif isinstance(i, (list, tuple)): + out.append(__recursive_serialize_list(i)) + elif isinstance(i, dict): + out.append(__recursive_serialize_dict(i)) + else: + out.append(i) + return out + +def __recursive_serialize_dict(x): + out = {} + for k in x: + v = x[k] + if __is_numpy_array(v): + out[k] = __serialize_ndarray_metadata(v) + elif isinstance(v, (list, tuple)): + out[k] = __recursive_serialize_list(v) + elif isinstance(v, dict): + out[k] = __recursive_serialize_dict(v) + else: + out[k] = v + return out \ No newline at end of file diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java new file mode 100644 index 000000000..435babf7c --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java @@ -0,0 +1,75 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + + +import org.junit.Assert; +import org.junit.Test; + +@javax.annotation.concurrent.NotThreadSafe +public class TestPythonExecutionSandbox { + + @Test + public void testInt(){ + PythonExecutioner.setInterpreter("interp1"); + PythonExecutioner.exec("a = 1"); + PythonExecutioner.setInterpreter("interp2"); + PythonExecutioner.exec("a = 2"); + PythonExecutioner.setInterpreter("interp3"); + PythonExecutioner.exec("a = 3"); + + + PythonExecutioner.setInterpreter("interp1"); + Assert.assertEquals(1, PythonExecutioner.evalInteger("a")); + + PythonExecutioner.setInterpreter("interp2"); + Assert.assertEquals(2, PythonExecutioner.evalInteger("a")); + + PythonExecutioner.setInterpreter("interp3"); + Assert.assertEquals(3, PythonExecutioner.evalInteger("a")); + } + + @Test + public void testNDArray(){ + PythonExecutioner.setInterpreter("main"); + PythonExecutioner.exec("import numpy as np"); + PythonExecutioner.exec("a = np.zeros(5)"); + + PythonExecutioner.setInterpreter("main"); + //PythonExecutioner.exec("import numpy as np"); + PythonExecutioner.exec("a = np.zeros(5)"); + + PythonExecutioner.setInterpreter("main"); + PythonExecutioner.exec("a += 2"); + + PythonExecutioner.setInterpreter("main"); + PythonExecutioner.exec("a += 3"); + + PythonExecutioner.setInterpreter("main"); + //PythonExecutioner.exec("import numpy as np"); + // PythonExecutioner.exec("a = np.zeros(5)"); + + PythonExecutioner.setInterpreter("main"); + Assert.assertEquals(25, PythonExecutioner.evalNdArray("a").getNd4jArray().sum().getDouble(), 1e-5); + } + + @Test + public void testNumpyRandom(){ + PythonExecutioner.setInterpreter("main"); + PythonExecutioner.exec("import numpy as np; print(np.random.randint(5))"); + } +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java index 791950043..c8e67febb 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java @@ -15,17 +15,25 @@ ******************************************************************************/ package org.datavec.python; -import org.junit.Ignore; +import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; + import static org.junit.Assert.assertEquals; -@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") + +@javax.annotation.concurrent.NotThreadSafe public class TestPythonExecutioner { - @Test(timeout = 60000L) + + @org.junit.Test + public void testPythonSysVersion() { + PythonExecutioner.exec("import sys; print(sys.version)"); + } + + @Test public void testStr() throws Exception{ PythonVariables pyInputs = new PythonVariables(); @@ -47,7 +55,7 @@ public class TestPythonExecutioner { assertEquals("Hello World", z); } - @Test(timeout = 60000L) + @Test public void testInt()throws Exception{ PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -55,7 +63,7 @@ public class TestPythonExecutioner { pyInputs.addInt("x", 10); pyInputs.addInt("y", 20); - String code = "z = x + y"; + String code = "z = x + y"; pyOutputs.addInt("z"); @@ -64,11 +72,11 @@ public class TestPythonExecutioner { long z = pyOutputs.getIntValue("z"); - assertEquals(30, z); + Assert.assertEquals(30, z); } - @Test(timeout = 60000L) + @Test public void testList() throws Exception{ PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -88,18 +96,35 @@ public class TestPythonExecutioner { Object[] z = pyOutputs.getListValue("z"); - assertEquals(z.length, x.length + y.length); + Assert.assertEquals(z.length, x.length + y.length); + + for (int i = 0; i < x.length; i++) { + if(x[i] instanceof Number) { + Number xNum = (Number) x[i]; + Number zNum = (Number) z[i]; + Assert.assertEquals(xNum.intValue(), zNum.intValue()); + } + else { + Assert.assertEquals(x[i], z[i]); + } - for (int i=0; i < x.length; i++){ - assertEquals(x[i], z[i]); } - for (int i=0; i> inputData = new ArrayList<>(); inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); @@ -288,10 +287,9 @@ public class ExecutionTest extends BaseSparkTest { String pythonCode = "col3 = col1 + col2"; TransformProcess tp = new TransformProcess.Builder(schema).transform( - new PythonTransform( - pythonCode, - finalSchema - ) + PythonTransform.builder().code( + "first = np.sin(first)\nsecond = np.cos(second)") + .outputSchema(schema).build() ).build(); INDArray zeros = Nd4j.zeros(shape); diff --git a/pom.xml b/pom.xml index 35ef4bcab..ada833f12 100644 --- a/pom.xml +++ b/pom.xml @@ -294,6 +294,8 @@ 3.7.5 ${python.version}-${javacpp-presets.version} + 1.17.3 + ${numpy.version}-${javacpp-presets.version} 0.3.7 2019.5