diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java index afd128669..1f7e938b6 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java @@ -45,7 +45,7 @@ public class TypeConversion { } public int convertInt(String o) { - return Integer.parseInt(o); + return (int) Double.parseDouble(o); } public double convertDouble(Writable writable) { diff --git a/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java b/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java index 48f9474d5..f66475357 100644 --- a/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java +++ b/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java @@ -725,7 +725,6 @@ public class ArrowConverter { case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break; case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break; default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i)); - } } @@ -802,8 +801,13 @@ public class ArrowConverter { //for proper offsets ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get()); nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity()); + case Boolean: + BitVector bitVector = (BitVector) fieldVector; + if(value instanceof Boolean) + bitVector.set(row, (boolean) value ? 1 : 0); + else + bitVector.set(row, ((BooleanWritable) value).get() ? 1 : 0); break; - } }catch(Exception e) { log.warn("Unable to set value at row " + row); diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java index 37df8ae52..7a6ab3f03 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java @@ -315,7 +315,7 @@ public class TestPythonTransformProcess { } @Test - public void testNumpyTransform() throws Exception { + public void testNumpyTransform() { PythonTransform pythonTransform = PythonTransform.builder() .code("a += 2; b = 'hello world'") .returnAllInputs(true) @@ -334,7 +334,42 @@ public class TestPythonTransformProcess { assertFalse(execute.isEmpty()); assertNotNull(execute.get(0)); assertNotNull(execute.get(0).get(0)); - assertEquals("hello world",execute.get(0).get(0).toString()); + assertNotNull(execute.get(0).get(1)); + assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get()); + assertEquals("hello world",execute.get(0).get(1).toString()); + } + + @Test + public void testWithSetupRun() throws Exception { + + PythonTransform pythonTransform = PythonTransform.builder() + .code("five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n") + .returnAllInputs(true) + .setupAndRun(true) + .build(); + + List> inputs = new ArrayList<>(); + inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)), + new NDArrayWritable(Nd4j.scalar(2).reshape(1,1)))); + Schema inputSchema = new Builder() + .addColumnNDArray("a",new long[]{1,1}) + .addColumnNDArray("b", new long[]{1, 1}) + .build(); + + TransformProcess tp = new TransformProcess.Builder(inputSchema) + .transform(pythonTransform) + .build(); + List> execute = LocalTransformExecutor.execute(inputs, tp); + assertFalse(execute.isEmpty()); + assertNotNull(execute.get(0)); + assertNotNull(execute.get(0).get(0)); + assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get()); } } \ No newline at end of file diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java index 24a2c2e09..dd1613d0a 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java @@ -29,6 +29,7 @@ import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.linalg.api.buffer.DataType; +import static org.nd4j.linalg.api.buffer.DataType.FLOAT; /** @@ -46,55 +47,45 @@ public class NumpyArray { private long[] strides; private DataType dtype; private INDArray nd4jArray; + static { //initialize Nd4j.scalar(1.0); - nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); } @Builder - public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) { + public NumpyArray(long address, long[] shape, long strides[], DataType dtype, boolean copy) { this.address = address; this.shape = shape; this.strides = strides; this.dtype = dtype; setND4JArray(); - if (copy){ + if (copy) { nd4jArray = nd4jArray.dup(); Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); this.address = nd4jArray.data().address(); - } } - public NumpyArray copy(){ + + + public NumpyArray copy() { return new NumpyArray(nd4jArray.dup()); } - public NumpyArray(long address, long[] shape, long strides[]){ - this(address, shape, strides, false,DataType.FLOAT); + public NumpyArray(long address, long[] shape, long strides[]) { + this(address, shape, strides, FLOAT, false); } - public NumpyArray(long address, long[] shape, long strides[], DataType dtype){ + public NumpyArray(long address, long[] shape, long strides[], DataType dtype) { this(address, shape, strides, dtype, false); } - public NumpyArray(long address, long[] shape, long strides[], DataType dtype, boolean copy){ - this.address = address; - this.shape = shape; - this.strides = strides; - this.dtype = dtype; - setND4JArray(); - if (copy){ - nd4jArray = nd4jArray.dup(); - Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); - this.address = nd4jArray.data().address(); - } - } private void setND4JArray() { long size = 1; - for(long d: shape) { + for (long d : shape) { size *= d; } Pointer ptr = nativeOps.pointerForAddress(address); @@ -107,11 +98,11 @@ public class NumpyArray { nd4jStrides[i] = strides[i] / elemSize; } - nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype); + nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype); Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); } - public NumpyArray(INDArray nd4jArray){ + public NumpyArray(INDArray nd4jArray) { Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); DataBuffer buff = nd4jArray.data(); address = buff.pointer().address(); @@ -119,7 +110,7 @@ public class NumpyArray { long[] nd4jStrides = nd4jArray.stride(); strides = new long[nd4jStrides.length]; int elemSize = buff.getElementSize(); - for(int i=0; i= 1,"Python code must not be empty!"); + checkNotNull("Python code must not be null!", pythonCode); + checkState(!pythonCode.isEmpty(), "Python code must not be empty!"); code = pythonCode; } - - @Override public void setInputSchema(Schema inputSchema) { this.inputSchema = inputSchema; - try{ + try { pyInputs = schemaToPythonVariables(inputSchema); PythonVariables pyOuts = new PythonVariables(); pyOuts.addInt("out"); @@ -62,17 +60,15 @@ public class PythonCondition implements Condition { .outputs(pyOuts) .build(); - } - catch (Exception e){ + } catch (Exception e) { throw new RuntimeException(e); } - } @Override - public Schema getInputSchema(){ + public Schema getInputSchema() { return inputSchema; } @@ -84,40 +80,39 @@ public class PythonCondition implements Condition { } @Override - public String outputColumnName(){ + public String outputColumnName() { return outputColumnNames()[0]; } @Override - public String[] columnNames(){ + public String[] columnNames() { return outputColumnNames(); } @Override - public String columnName(){ + public String columnName() { return outputColumnName(); } @Override - public Schema transform(Schema inputSchema){ + public Schema transform(Schema inputSchema) { return inputSchema; } @Override public boolean condition(List list) { PythonVariables inputs = getPyInputsFromWritables(list); - try{ - PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs()); + try { + pythonTransform.getPythonJob().exec(inputs, pythonTransform.getOutputs()); boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0; return ret; - } - catch (Exception e) { + } catch (Exception e) { throw new RuntimeException(e); } } @Override - public boolean condition(Object input){ + public boolean condition(Object input) { return condition(input); } @@ -135,28 +130,27 @@ public class PythonCondition implements Condition { private PythonVariables getPyInputsFromWritables(List writables) { PythonVariables ret = new PythonVariables(); - for (int i = 0; i < inputSchema.numColumns(); i++){ + for (int i = 0; i < inputSchema.numColumns(); i++) { String name = inputSchema.getName(i); Writable w = writables.get(i); - PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i)); - switch (pyType){ + PythonType pyType = pyInputs.getType(inputSchema.getName(i)); + switch (pyType.getName()) { case INT: if (w instanceof LongWritable) { - ret.addInt(name, ((LongWritable)w).get()); - } - else { - ret.addInt(name, ((IntWritable)w).get()); + ret.addInt(name, ((LongWritable) w).get()); + } else { + ret.addInt(name, ((IntWritable) w).get()); } break; case FLOAT: - ret.addFloat(name, ((DoubleWritable)w).get()); + ret.addFloat(name, ((DoubleWritable) w).get()); break; case STR: ret.addStr(name, w.toString()); break; case NDARRAY: - ret.addNDArray(name,((NDArrayWritable)w).get()); + ret.addNDArray(name, ((NDArrayWritable) w).get()); break; } } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonContextManager.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonContextManager.java new file mode 100644 index 000000000..c3563bfc2 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonContextManager.java @@ -0,0 +1,188 @@ + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + + +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Emulates multiples interpreters in a single interpreter. + * This works by simply obfuscating/de-obfuscating variable names + * such that only the required subset of the global namespace is "visible" + * at any given time. + * By default, there exists a "main" context emulating the default interpreter + * and cannot be deleted. + * @author Fariz Rahman + */ + + +public class PythonContextManager { + + private static Set contexts = new HashSet<>(); + private static AtomicBoolean init = new AtomicBoolean(false); + private static String currentContext; + private static final String MAIN_CONTEXT = "main"; + static { + init(); + } + + private static void init() { + if (init.get()) return; + new PythonExecutioner(); + init.set(true); + currentContext = MAIN_CONTEXT; + contexts.add(currentContext); + } + + + public static void addContext(String contextName) throws PythonException { + if (!validateContextName(contextName)) { + throw new PythonException("Invalid context name: " + contextName); + } + contexts.add(contextName); + } + + public static boolean hasContext(String contextName) { + return contexts.contains(contextName); + } + + + public static boolean validateContextName(String s) { + if (s.length() == 0) return false; + if (!Character.isJavaIdentifierStart(s.charAt(0))) return false; + for (int i = 1; i < s.length(); i++) + if (!Character.isJavaIdentifierPart(s.charAt(i))) + return false; + return true; + } + + private static String getContextPrefix(String contextName) { + return "__collapsed__" + contextName + "__"; + } + + private static String getCollapsedVarNameForContext(String varName, String contextName) { + return getContextPrefix(contextName) + varName; + } + + private static String expandCollapsedVarName(String varName, String contextName) { + String prefix = "__collapsed__" + contextName + "__"; + return varName.substring(prefix.length()); + + } + + private static void collapseContext(String contextName) { + PythonObject globals = Python.globals(); + PythonObject keysList = Python.list(globals.attr("keys").call()); + int numKeys = Python.len(keysList).toInt(); + for (int i = 0; i < numKeys; i++) { + PythonObject key = keysList.get(i); + String keyStr = key.toString(); + if (!((keyStr.startsWith("__") && keyStr.endsWith("__")) || keyStr.startsWith("__collapsed_"))) { + String collapsedKey = getCollapsedVarNameForContext(keyStr, contextName); + PythonObject val = globals.attr("pop").call(key); + globals.set(new PythonObject(collapsedKey), val); + } + } + } + + private static void expandContext(String contextName) { + String prefix = getContextPrefix(contextName); + PythonObject globals = Python.globals(); + PythonObject keysList = Python.list(globals.attr("keys").call()); + int numKeys = Python.len(keysList).toInt(); + for (int i = 0; i < numKeys; i++) { + PythonObject key = keysList.get(i); + String keyStr = key.toString(); + if (keyStr.startsWith(prefix)) { + String expandedKey = expandCollapsedVarName(keyStr, contextName); + PythonObject val = globals.attr("pop").call(key); + globals.set(new PythonObject(expandedKey), val); + } + } + + } + + public static void setContext(String contextName) throws PythonException{ + if (contextName.equals(currentContext)) { + return; + } + if (!hasContext(contextName)) { + addContext(contextName); + } + collapseContext(currentContext); + expandContext(contextName); + currentContext = contextName; + + } + + public static void setMainContext() { + try{ + setContext(MAIN_CONTEXT); + } + catch (PythonException pe){ + throw new RuntimeException(pe); + } + + } + + public static String getCurrentContext() { + return currentContext; + } + + public static void deleteContext(String contextName) throws PythonException { + if (contextName.equals(MAIN_CONTEXT)) { + throw new PythonException("Can not delete main context!"); + } + if (contextName.equals(currentContext)) { + throw new PythonException("Can not delete current context!"); + } + String prefix = getContextPrefix(contextName); + PythonObject globals = Python.globals(); + PythonObject keysList = Python.list(globals.attr("keys").call()); + int numKeys = Python.len(keysList).toInt(); + for (int i = 0; i < numKeys; i++) { + PythonObject key = keysList.get(i); + String keyStr = key.toString(); + if (keyStr.startsWith(prefix)) { + globals.attr("__delitem__").call(key); + } + } + contexts.remove(contextName); + } + + public static void deleteNonMainContexts() { + try{ + setContext(MAIN_CONTEXT); // will never fail + for (String c : contexts.toArray(new String[0])) { + if (!c.equals(MAIN_CONTEXT)) { + deleteContext(c); // will never fail + } + } + }catch(Exception e){ + throw new RuntimeException(e); + } + } + + public String[] getContexts() { + return contexts.toArray(new String[0]); + } + +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonException.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonException.java new file mode 100644 index 000000000..d66c67a32 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonException.java @@ -0,0 +1,44 @@ + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + +/** + * Thrown when an exception occurs in python land + */ +public class PythonException extends Exception { + public PythonException(String message){ + super(message); + } + private static String getExceptionString(PythonObject exception){ + if (Python.isinstance(exception, Python.ExceptionType())){ + String exceptionClass = Python.type(exception).attr("__name__").toString(); + String message = exception.toString(); + return exceptionClass + ": " + message; + } + return exception.toString(); + } + public PythonException(PythonObject exception){ + this(getExceptionString(exception)); + } + public PythonException(String message, Throwable cause){ + super(message, cause); + } + public PythonException(Throwable cause){ + super(cause); + } +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index 8afec3b5a..a06e60e98 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -1,5 +1,6 @@ + /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,33 +15,29 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ + package org.datavec.python; - -import java.io.*; -import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - - import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; -import org.json.JSONObject; -import org.json.JSONArray; -import org.bytedeco.javacpp.*; -import org.bytedeco.cpython.*; -import static org.bytedeco.cpython.global.python.*; -import org.bytedeco.numpy.global.numpy; - -import static org.datavec.python.PythonUtils.*; - -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.factory.Nd4j; +import org.bytedeco.cpython.PyThreadState; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.numpy.global.numpy; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.io.ClassPathResource; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.cpython.global.python.PyThreadState_Get; +import static org.datavec.python.Python.*; /** * Allows execution of python scripts managed by @@ -58,18 +55,14 @@ import org.nd4j.linalg.io.ClassPathResource; * whether it is appended, prepended, or not used. This behavior is useful when you need to use an external * python distribution such as anaconda. * - * 3. A main interpreter: This is the default interpreter to be used with the main thread. - * We may initialize one or more relative to the thread invoking the python code. - * - * 4. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of + * 3. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of * native libraries needed by numpy are allowed to load in the proper order. If we don't do this, - * it causes a variety of issues with running numpy. + * it causes a variety of issues with running numpy. (User must still include "import numpy as np" in their scripts). * - * 5. Various python scripts pre defined on the classpath included right with the java code. + * 4. Various python scripts pre defined on the classpath included right with the java code. * These are auxillary python scripts used for loading classes, pre defining certain kinds of behavior * in order for us to manipulate values within the python memory, as well as pulling them out of memory - * for integration within the internal python executioner. You can see this behavior in {@link #_readOutputs(PythonVariables)} - * as an example. + * for integration within the internal python executioner. * * For more information on how this works, please take a look at the {@link #init()} * method. @@ -94,80 +87,242 @@ import org.nd4j.linalg.io.ClassPathResource; * * * @author Fariz Rahman - * @author Adam Gibson + * @author Adam Gibson */ -/** - * Allows execution of python scripts managed by - * an internal interpreter. - * An end user may specify a python script to run - * via any of the execution methods available in this class. - * - * At static initialization time (when the class is first initialized) - * a number of components are setup: - * 1. The python path. A user may over ride this with the system property {@link #DEFAULT_PYTHON_PATH_PROPERTY} - * - * 2. Since this executioner uses javacpp to manage and run python interpreters underneath the covers, - * a user may also over ride the system property {@link #JAVACPP_PYTHON_APPEND_TYPE} with one of the {@link JavaCppPathType} - * values. This will allow the user to determine whether the javacpp default python path is used at all, and if so - * whether it is appended, prepended, or not used. This behavior is useful when you need to use an external - * python distribution such as anaconda. - * - * 3. A main interpreter: This is the default interpreter to be used with the main thread. - * We may initialize one or more relative to the thread invoking the python code. - * - * 4. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of - * native libraries needed by numpy are allowed to load in the proper order. If we don't do this, - * it causes a variety of issues with running numpy. - * - * 5. Various python scripts pre defined on the classpath included right with the java code. - * These are auxillary python scripts used for loading classes, pre defining certain kinds of behavior - * in order for us to manipulate values within the python memory, as well as pulling them out of memory - * for integration within the internal python executioner. You can see this behavior in {@link #_readOutputs(PythonVariables)} - * as an example. - * - * For more information on how this works, please take a look at the {@link #init()} - * method. - * - * Generally, a user defining a python script for use by the python executioner - * will have a set of defined target input values and output values. - * These values should not be present when actually running the script, but just referenced. - * In order to test your python script for execution outside the engine, - * we recommend commenting out a few default values as dummy input values. - * This will allow an end user to test their script before trying to use the server. - * - * In order to get output values out of a python script, all a user has to do - * is define the output variables they want being used in the final output in the actual pipeline. - * For example, if a user wants to return a dictionary, they just have to create a dictionary with that name - * and based on the configured {@link PythonVariables} passed as outputs - * to one of the execution methods, we can pull the values out automatically. - * - * For input definitions, it is similar. You just define the values you want used in - * {@link PythonVariables} and we will automatically generate code for defining those values - * as desired for running. This allows the user to customize values dynamically - * at runtime but reference them by name in a python script. - * - * - * @author Fariz Rahman - * @author Adam Gibson - */ @Slf4j public class PythonExecutioner { - private final static String fileVarName = "_f" + Nd4j.getRandom().nextInt(); - private static boolean init; + + private static AtomicBoolean init = new AtomicBoolean(false); public final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.datavec.python.path"; public final static String JAVACPP_PYTHON_APPEND_TYPE = "org.datavec.python.javacpp.path.append"; public final static String DEFAULT_APPEND_TYPE = "before"; - private static Map interpreters = new java.util.concurrent.ConcurrentHashMap<>(); - private static PyThreadState currentThreadState; - private static PyThreadState mainThreadState; - public final static String ALL_VARIABLES_KEY = "allVariables"; - public final static String MAIN_INTERPRETER_NAME = "main"; - private static String clearVarsCode; + private final static String PYTHON_EXCEPTION_KEY = "__python_exception__"; - private static String currentInterpreter = MAIN_INTERPRETER_NAME; + static { + init(); + } + + private static synchronized void init() { + if (init.get()) { + return; + } + initPythonPath(); + init.set(true); + log.info("CPython: PyEval_InitThreads()"); + PyEval_InitThreads(); + log.info("CPython: Py_InitializeEx()"); + Py_InitializeEx(0); + numpy._import_array(); + } + + private static synchronized void simpleExec(String code) throws PythonException{ + log.debug(code); + log.info("CPython: PyRun_SimpleStringFlag()"); + + int result = PyRun_SimpleStringFlags(code, null); + if (result != 0) { + throw new PythonException("Execution failed, unable to retrieve python exception."); + } + } + + public static boolean validateVariableName(String s) { + if (s.isEmpty()) return false; + if (!Character.isJavaIdentifierStart(s.charAt(0))) return false; + for (int i = 1; i < s.length(); i++) + if (!Character.isJavaIdentifierPart(s.charAt(i))) + return false; + return true; + } + + + /** + * Sets a variable in the global scope of the current context (See @PythonContextManager). + * This is equivalent to `exec("a = b");` where a is the variable name + * and b is the variable value. + * @param varName Name of the python variable being set. Should be a valid python identifier string + * @param pythonObject Value for the python variable + * @throws Exception + */ + public static void setVariable(String varName, PythonObject pythonObject) throws PythonException{ + if (!validateVariableName(varName)){ + throw new PythonException("Invalid variable name: " + varName); + } + Python.globals().set(new PythonObject(varName), pythonObject); + } + + public static void setVariable(String varName, PythonType varType, Object value) throws PythonException { + PythonObject pythonObject; + switch (varType.getName()) { + case STR: + pythonObject = new PythonObject(PythonType.STR.convert(value)); + break; + case INT: + pythonObject = new PythonObject(PythonType.INT.convert(value)); + break; + case FLOAT: + pythonObject = new PythonObject(PythonType.FLOAT.convert(value)); + break; + case BOOL: + pythonObject = new PythonObject(PythonType.BOOL.convert(value)); + break; + case NDARRAY: + pythonObject = new PythonObject(PythonType.NDARRAY.convert(value)); + break; + case LIST: + pythonObject = new PythonObject(PythonType.LIST.convert(value)); + break; + case DICT: + pythonObject = new PythonObject(PythonType.DICT.convert(value)); + break; + case BYTES: + pythonObject = new PythonObject(PythonType.BYTES.convert(value)); + break; + default: + throw new PythonException("Unsupported type: " + varType); + + } + setVariable(varName, pythonObject); + } + + public static void setVariables(PythonVariables pyVars) throws PythonException{ + if (pyVars == null) return; + for (String varName : pyVars.getVariables()) { + setVariable(varName, pyVars.getType(varName), pyVars.getValue(varName)); + } + } + + public static PythonObject getVariable(String varName) { + return Python.globals().attr("get").call(varName); + } + + public static T getVariable(String varName, PythonType varType) throws PythonException{ + PythonObject pythonObject = getVariable(varName); + return varType.toJava(pythonObject); + } + + public static void getVariables(PythonVariables pyVars) throws PythonException { + for (String varName : pyVars.getVariables()) { + pyVars.setValue(varName, getVariable(varName, pyVars.getType(varName))); + } + } + + + private static String getWrappedCode(String code) { + try (InputStream is = new ClassPathResource("pythonexec/pythonexec.py").getInputStream()) { + String base = IOUtils.toString(is, Charset.defaultCharset()); + StringBuffer indentedCode = new StringBuffer(); + for (String split : code.split("\n")) { + indentedCode.append(" " + split + "\n"); + + } + + String out = base.replace(" pass", indentedCode); + return out; + } catch (IOException e) { + throw new IllegalStateException("Unable to read python code!", e); + } + + } + + private static void throwIfExecutionFailed() throws PythonException{ + PythonObject ex = getVariable(PYTHON_EXCEPTION_KEY); + if (ex != null && !ex.toString().isEmpty()){ + setVariable(PYTHON_EXCEPTION_KEY, new PythonObject("")); + throw new PythonException(ex); + } + } + + public static void exec(String code) throws PythonException { + simpleExec(getWrappedCode(code)); + throwIfExecutionFailed(); + } + + public static void exec(String code, PythonVariables outputVariables)throws PythonException { + simpleExec(getWrappedCode(code)); + throwIfExecutionFailed(); + getVariables(outputVariables); + } + + public static void exec(String code, PythonVariables inputVariables, PythonVariables outputVariables) throws PythonException { + setVariables(inputVariables); + simpleExec(getWrappedCode(code)); + throwIfExecutionFailed(); + getVariables(outputVariables); + } + + public static PythonVariables execAndReturnAllVariables(String code) throws PythonException { + simpleExec(getWrappedCode(code)); + throwIfExecutionFailed(); + PythonVariables out = new PythonVariables(); + PythonObject globals = Python.globals(); + PythonObject keysList = Python.list(globals.attr("keys")); + int numKeys = Python.len(keysList).toInt(); + for (int i = 0; i < numKeys; i++) { + PythonObject key = keysList.get(i); + String keyStr = key.toString(); + if (!keyStr.startsWith("_")) { + PythonObject val = globals.get(key); + if (Python.isinstance(val, intType())) { + out.addInt(keyStr, val.toInt()); + } else if (Python.isinstance(val, floatType())) { + out.addFloat(keyStr, val.toDouble()); + } else if (Python.isinstance(val, strType())) { + out.addStr(keyStr, val.toString()); + } else if (Python.isinstance(val, boolType())) { + out.addBool(keyStr, val.toBoolean()); + } else if (Python.isinstance(val, listType())) { + out.addList(keyStr, val.toList().toArray(new Object[0])); + } else if (Python.isinstance(val, dictType())) { + out.addDict(keyStr, val.toMap()); + } + } + } + return out; + + } + + public static PythonVariables getAllVariables() throws PythonException{ + PythonVariables out = new PythonVariables(); + PythonObject globals = Python.globals(); + PythonObject keysList = Python.list(globals.attr("keys").call()); + int numKeys = Python.len(keysList).toInt(); + for (int i = 0; i < numKeys; i++) { + PythonObject key = keysList.get(i); + String keyStr = key.toString(); + if (!keyStr.startsWith("_")) { + PythonObject val = globals.get(key); + if (Python.isinstance(val, intType())) { + out.addInt(keyStr, val.toInt()); + } else if (Python.isinstance(val, floatType())) { + out.addFloat(keyStr, val.toDouble()); + } else if (Python.isinstance(val, strType())) { + out.addStr(keyStr, val.toString()); + } else if (Python.isinstance(val, boolType())) { + out.addBool(keyStr, val.toBoolean()); + } else if (Python.isinstance(val, listType())) { + out.addList(keyStr, val.toList().toArray(new Object[0])); + } else if (Python.isinstance(val, dictType())) { + out.addDict(keyStr, val.toMap()); + } else { + PythonObject np = importModule("numpy"); + if (Python.isinstance(val, np.attr("ndarray"), np.attr("generic"))) { + out.addNDArray(keyStr, val.toNumpy()); + } + } + + } + } + return out; + } + + public static PythonVariables execAndReturnAllVariables(String code, PythonVariables inputs) throws Exception{ + setVariables(inputs); + simpleExec(getWrappedCode(code)); + return getAllVariables(); + } /** * One of a few desired values @@ -178,7 +333,7 @@ public class PythonExecutioner { * NONE: Don't use javacpp's python path at all */ public enum JavaCppPathType { - BEFORE,AFTER,NONE + BEFORE, AFTER, NONE } /** @@ -186,36 +341,36 @@ public class PythonExecutioner { * Generally you can just use the PYTHONPATH environment variable, * but if you need to set it from code, this can work as well. */ - public static synchronized void setPythonPath() { - if(!init) { + + public static synchronized void initPythonPath() { + if (!init.get()) { try { String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY); - if(path == null) { + if (path == null) { log.info("Setting python default path"); File[] packages = numpy.cachePackages(); Py_SetPath(packages); - } - else { + } else { log.info("Setting python path " + path); StringBuffer sb = new StringBuffer(); File[] packages = numpy.cachePackages(); - JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE,DEFAULT_APPEND_TYPE).toUpperCase()); - switch(pathAppendValue) { + JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase()); + switch (pathAppendValue) { case BEFORE: - for(File cacheDir : packages) { + for (File cacheDir : packages) { sb.append(cacheDir); sb.append(java.io.File.pathSeparator); } sb.append(path); - log.info("Prepending javacpp python path " + sb.toString()); + log.info("Prepending javacpp python path: {}", sb.toString()); break; case AFTER: sb.append(path); - for(File cacheDir : packages) { + for (File cacheDir : packages) { sb.append(cacheDir); sb.append(java.io.File.pathSeparator); } @@ -229,916 +384,15 @@ public class PythonExecutioner { } //prepend the javacpp packages - log.info("Final python path " + sb.toString()); + log.info("Final python path: {}", sb.toString()); Py_SetPath(sb.toString()); } } catch (IOException e) { log.error("Failed to set python path.", e); } - } - else { + } else { throw new IllegalStateException("Unable to reset python path. Already initialized."); } } - - /** - * Initialize the name space and the python execution - * Calling this method more than once will be a no op - */ - public static synchronized void init() { - if(init) { - return; - } - - try(InputStream is = new org.nd4j.linalg.io.ClassPathResource("pythonexec/clear_vars.py").getInputStream()) { - clearVarsCode = IOUtils.toString(new java.io.InputStreamReader(is)); - } catch (java.io.IOException e) { - throw new IllegalStateException("Unable to read pythonexec/clear_vars.py"); - } - - log.info("CPython: PyEval_InitThreads()"); - PyEval_InitThreads(); - log.info("CPython: Py_InitializeEx()"); - Py_InitializeEx(0); - log.info("CPython: PyGILState_Release()"); - init = true; - interpreters.put(MAIN_INTERPRETER_NAME, PyThreadState_Get()); - numpy._import_array(); - applyPatches(); - } - - - /** - * Run {@link #resetInterpreter(String)} - * on all interpreters. - */ - public static void resetAllInterpreters() { - for(String interpreter : interpreters.keySet()) { - resetInterpreter(interpreter); - } - } - - /** - * Reset the main interpreter. - * For more information see {@link #resetInterpreter(String)} - */ - public static void resetMainInterpreter() { - resetInterpreter(MAIN_INTERPRETER_NAME); - } - - /** - * Reset the interpreter with the given name. - * Runs pythonexec/clear_vars.py - * For more information see: - * https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script - * @param interpreterName the interpreter name to - * reset - */ - public static synchronized void resetInterpreter(String interpreterName) { - Preconditions.checkState(hasInterpreter(interpreterName)); - log.info("Resetting interpreter " + interpreterName); - String oldInterpreter = currentInterpreter; - setInterpreter(interpreterName); - exec("pass"); - //exec(interpreterName); // ?? - setInterpreter(oldInterpreter); - } - - /** - * Clear the non main intrepreters. - */ - public static void clearNonMainInterpreters() { - for(String key : interpreters.keySet()) { - if(!key.equals(MAIN_INTERPRETER_NAME)) { - deleteInterpreter(key); - } - } - } - - public static PythonVariables defaultPythonVariableOutput() { - PythonVariables ret = new PythonVariables(); - ret.add(ALL_VARIABLES_KEY, PythonVariables.Type.DICT); - return ret; - } - - /** - * Return the python path being used. - * @return a string specifying the python path in use - */ - public static String getPythonPath() { - return new BytePointer(Py_GetPath()).getString(); - } - - - static { - setPythonPath(); - init(); - } - - - /* ---------sub-interpreter and gil management-----------*/ - public static void setInterpreter(String interpreterName) { - if (!hasInterpreter(interpreterName)){ - PyThreadState main = PyThreadState_Get(); - PyThreadState ts = Py_NewInterpreter(); - - interpreters.put(interpreterName, ts); - PyThreadState_Swap(main); - } - - currentInterpreter = interpreterName; - } - - /** - * Returns the current interpreter. - * @return - */ - public static String getInterpreter() { - return currentInterpreter; - } - - - public static boolean hasInterpreter(String interpreterName){ - return interpreters.containsKey(interpreterName); - } - - public static void deleteInterpreter(String interpreterName) { - if (interpreterName.equals("main")){ - throw new IllegalArgumentException("Can not delete main interpreter"); - } - - Py_EndInterpreter(interpreters.remove(interpreterName)); - } - - private static synchronized void acquireGIL() { - log.info("acquireGIL()"); - log.info("CPython: PyEval_SaveThread()"); - mainThreadState = PyEval_SaveThread(); - log.info("CPython: PyThreadState_New()"); - currentThreadState = PyThreadState_New(interpreters.get(currentInterpreter).interp()); - log.info("CPython: PyEval_RestoreThread()"); - PyEval_RestoreThread(currentThreadState); - log.info("CPython: PyThreadState_Swap()"); - PyThreadState_Swap(currentThreadState); - - } - - private static synchronized void releaseGIL() { - log.info("CPython: PyEval_SaveThread()"); - PyEval_SaveThread(); - log.info("CPython: PyEval_RestoreThread()"); - PyEval_RestoreThread(mainThreadState); - } - - /* -------------------*/ - /** - * Print the python version to standard out. - */ - public static void printPythonVersion() { - exec("import sys; print(sys.version) sys.stdout.flush();"); - } - - - - private static String inputCode(PythonVariables pyInputs)throws Exception { - String inputCode = ""; - if (pyInputs == null){ - return inputCode; - } - - Map strInputs = pyInputs.getStrVariables(); - Map intInputs = pyInputs.getIntVariables(); - Map floatInputs = pyInputs.getFloatVariables(); - Map ndInputs = pyInputs.getNdVars(); - Map listInputs = pyInputs.getListVariables(); - Map fileInputs = pyInputs.getFileVariables(); - Map> dictInputs = pyInputs.getDictVariables(); - - String[] varNames; - - - varNames = strInputs.keySet().toArray(new String[strInputs.size()]); - for(String varName: varNames) { - Preconditions.checkNotNull(varName,"Var name is null!"); - Preconditions.checkNotNull(varName.isEmpty(),"Var name can not be empty!"); - String varValue = strInputs.get(varName); - //inputCode += varName + "= {}\n"; - if(varValue != null) - inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; - else { - inputCode += varName + " = ''\n"; - } - } - - varNames = intInputs.keySet().toArray(new String[intInputs.size()]); - for(String varName: varNames) { - Long varValue = intInputs.get(varName); - if(varValue != null) - inputCode += varName + " = " + varValue.toString() + "\n"; - else { - inputCode += " = 0\n"; - } - } - - varNames = dictInputs.keySet().toArray(new String[dictInputs.size()]); - for(String varName: varNames) { - Map varValue = dictInputs.get(varName); - if(varValue != null) { - throw new IllegalArgumentException("Unable to generate input code for dictionaries."); - } - else { - inputCode += " = {}\n"; - } - } - - varNames = floatInputs.keySet().toArray(new String[floatInputs.size()]); - for(String varName: varNames){ - Double varValue = floatInputs.get(varName); - if(varValue != null) - inputCode += varName + " = " + varValue.toString() + "\n"; - else { - inputCode += varName + " = 0.0\n"; - } - } - - varNames = listInputs.keySet().toArray(new String[listInputs.size()]); - for (String varName: varNames) { - Object[] varValue = listInputs.get(varName); - if(varValue != null) { - String listStr = jArrayToPyString(varValue); - inputCode += varName + " = " + listStr + "\n"; - } - else { - inputCode += varName + " = []\n"; - } - - } - - varNames = fileInputs.keySet().toArray(new String[fileInputs.size()]); - for(String varName: varNames) { - String varValue = fileInputs.get(varName); - if(varValue != null) - inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; - else { - inputCode += varName + " = ''\n"; - } - } - - if (!ndInputs.isEmpty()) { - inputCode += "import ctypes\n\nimport sys\nimport numpy as np\n"; - varNames = ndInputs.keySet().toArray(new String[ndInputs.size()]); - - String converter = "__arr_converter = lambda addr, shape, type: np.ctypeslib.as_array(ctypes.cast(addr, ctypes.POINTER(type)), shape)\n"; - inputCode += converter; - for(String varName: varNames) { - NumpyArray npArr = ndInputs.get(varName); - if(npArr == null) - continue; - - npArr = npArr.copy(); - String shapeStr = "("; - for (long d: npArr.getShape()){ - shapeStr += d + ","; - } - shapeStr += ")"; - String code; - String ctype; - if (npArr.getDtype() == DataType.FLOAT) { - - ctype = "ctypes.c_float"; - } - else if (npArr.getDtype() == DataType.DOUBLE) { - ctype = "ctypes.c_double"; - } - else if (npArr.getDtype() == DataType.SHORT) { - ctype = "ctypes.c_int16"; - } - else if (npArr.getDtype() == DataType.INT) { - ctype = "ctypes.c_int32"; - } - else if (npArr.getDtype() == DataType.LONG){ - ctype = "ctypes.c_int64"; - } - else{ - throw new Exception("Unsupported data type: " + npArr.getDtype().toString() + "."); - } - - code = "__arr_converter(" + npArr.getAddress() + "," + shapeStr + "," + ctype + ")"; - code = varName + "=" + code + "\n"; - inputCode += code; - } - - } - return inputCode; - } - - - private static synchronized void _readOutputs(PythonVariables pyOutputs) throws IOException { - File f = new File(getTempFile()); - Preconditions.checkState(f.exists(),"File " + f.getAbsolutePath() + " failed to get written for reading outputs!"); - String json = FileUtils.readFileToString(f, Charset.defaultCharset()); - log.info("Executioner output: "); - log.info(json); - f.delete(); - - if(json.isEmpty()) { - log.warn("No json found fore reading outputs. Returning."); - return; - } - - try { - JSONObject jobj = new JSONObject(json); - for (String varName: pyOutputs.getVariables()) { - PythonVariables.Type type = pyOutputs.getType(varName); - if (type == PythonVariables.Type.NDARRAY) { - JSONObject varValue = (JSONObject)jobj.get(varName); - long address = (Long) varValue.getLong("address"); - JSONArray shapeJson = (JSONArray) varValue.get("shape"); - JSONArray stridesJson = (JSONArray) varValue.get("strides"); - long[] shape = jsonArrayToLongArray(shapeJson); - long[] strides = jsonArrayToLongArray(stridesJson); - String dtypeName = (String)varValue.get("dtype"); - DataType dtype; - if (dtypeName.equals("float64")) { - dtype = DataType.DOUBLE; - } - else if (dtypeName.equals("float32")) { - dtype = DataType.FLOAT; - } - else if (dtypeName.equals("int16")) { - dtype = DataType.SHORT; - } - else if (dtypeName.equals("int32")) { - dtype = DataType.INT; - } - else if (dtypeName.equals("int64")) { - dtype = DataType.LONG; - } - else{ - throw new Exception("Unsupported array type " + dtypeName + "."); - } - - pyOutputs.setValue(varName, new NumpyArray(address, shape, strides, dtype, true)); - - } - else if (type == PythonVariables.Type.LIST) { - JSONArray varValue = (JSONArray) jobj.get(varName); - pyOutputs.setValue(varName, varValue); - } - else if (type == PythonVariables.Type.DICT) { - Map map = toMap((JSONObject) jobj.get(varName)); - pyOutputs.setValue(varName, map); - - } - else{ - pyOutputs.setValue(varName, jobj.get(varName)); - } - } - } - catch (Exception e){ - throw new RuntimeException(e); - } - } - - - - - private static synchronized void _exec(String code) { - log.debug(code); - log.info("CPython: PyRun_SimpleStringFlag()"); - - int result = PyRun_SimpleStringFlags(code, null); - if (result != 0) { - log.info("CPython: PyErr_Print"); - PyErr_Print(); - throw new RuntimeException("exec failed"); - } - } - - private static synchronized void _exec_wrapped(String code) { - _exec(getWrappedCode(code)); - } - - /** - * Executes python code. Also manages python thread state. - * @param code the code to run - */ - - public static void exec(String code) { - code = getWrappedCode(code); - if(code.contains("import numpy") && !getInterpreter().equals("main")) {// FIXME - throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); - } - - acquireGIL(); - _exec(code); - log.info("Exec done"); - releaseGIL(); - } - - private static boolean _hasGlobalVariable(String varName){ - PyObject mainModule = PyImport_AddModule("__main__"); - PyObject var = PyObject_GetAttrString(mainModule, varName); - boolean hasVar = var != null; - Py_DecRef(var); - return hasVar; - } - - /** - * Executes python code and looks for methods setup() and run() - * If both setup() and run() are found, both are executed for the first - * time and for subsequent calls only run() is executed. - */ - public static void execWithSetupAndRun(String code) { - code = getWrappedCode(code); - if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME - throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); - } - - acquireGIL(); - _exec(code); - if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ - log.debug("setup() and run() methods found."); - if (!_hasGlobalVariable("__setup_done__")){ - log.debug("Calling setup()..."); - _exec("setup()"); - _exec("__setup_done__ = True"); - } - log.debug("Calling run()..."); - _exec("run()"); - } - log.info("Exec done"); - releaseGIL(); - } - - /** - * Executes python code and looks for methods setup() and run() - * If both setup() and run() are found, both are executed for the first - * time and for subsequent calls only run() is executed. - */ - public static void execWithSetupAndRun(String code, PythonVariables pyOutputs) { - code = getWrappedCode(code); - if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME - throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); - } - - acquireGIL(); - _exec(code); - if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ - log.debug("setup() and run() methods found."); - if (!_hasGlobalVariable("__setup_done__")){ - log.debug("Calling setup()..."); - _exec("setup()"); - _exec("__setup_done__ = True"); - } - log.debug("Calling run()..."); - _exec("__out = run();for (k,v) in __out.items(): globals()[k]=v"); - } - log.info("Exec done"); - try { - - _readOutputs(pyOutputs); - - } catch (IOException e) { - log.error("Failed to read outputs", e); - } - - releaseGIL(); - } - - /** - * Run the given code with the given python outputs - * @param code the code to run - * @param pyOutputs the outputs to run - */ - public static void exec(String code, PythonVariables pyOutputs) { - - exec(code + '\n' + outputCode(pyOutputs)); - try { - - _readOutputs(pyOutputs); - - } catch (IOException e) { - log.error("Failed to read outputs", e); - } - - releaseGIL(); - } - - - /** - * Execute the given python code with the given - * {@link PythonVariables} as inputs and outputs - * @param code the code to run - * @param pyInputs the inputs to the code - * @param pyOutputs the outputs to the code - * @throws Exception - */ - public static void exec(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception { - String inputCode = inputCode(pyInputs); - exec(inputCode + code, pyOutputs); - } - - /** - * Execute the given python code - * with the {@link PythonVariables} - * inputs and outputs for storing the values - * specified by the user and needed by the user - * as output - * @param code the python code to execute - * @param pyInputs the python variables input in to the python script - * @param pyOutputs the python variables output returned by the python script - * @throws Exception - */ - public static void execWithSetupAndRun(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception { - String inputCode = inputCode(pyInputs); - code = inputCode +code; - code = getWrappedCode(code); - if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME - throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); - } - acquireGIL(); - _exec(code); - if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ - log.debug("setup() and run() methods found."); - if (!_hasGlobalVariable("__setup_done__")){ - releaseGIL(); // required - acquireGIL(); - log.debug("Calling setup()..."); - _exec("setup()"); - _exec("__setup_done__ = True"); - }else{ - log.debug("setup() already called once."); - } - log.debug("Calling run()..."); - releaseGIL(); // required - acquireGIL(); - _exec("import inspect\n"+ - "__out = run(**{k:globals()[k]for k in inspect.getfullargspec(run).args})\n"+ - "globals().update(__out)"); - } - releaseGIL(); // required - acquireGIL(); - _exec(outputCode(pyOutputs)); - log.info("Exec done"); - try { - - _readOutputs(pyOutputs); - - } catch (IOException e) { - log.error("Failed to read outputs", e); - } - - releaseGIL(); - } - - - - private static String interpreterNameFromTransform(PythonTransform transform){ - return transform.getName().replace("-", "_"); - } - - - /** - * Run a {@link PythonTransform} with the given inputs - * @param transform the transform to run - * @param inputs the inputs to the transform - * @return the output variables - * @throws Exception - */ - public static PythonVariables exec(PythonTransform transform, PythonVariables inputs)throws Exception { - String name = interpreterNameFromTransform(transform); - setInterpreter(name); - Preconditions.checkNotNull(transform.getOutputs(),"Transform outputs were null!"); - exec(transform.getCode(), inputs, transform.getOutputs()); - return transform.getOutputs(); - } - public static PythonVariables execWithSetupAndRun(PythonTransform transform, PythonVariables inputs)throws Exception { - String name = interpreterNameFromTransform(transform); - setInterpreter(name); - Preconditions.checkNotNull(transform.getOutputs(),"Transform outputs were null!"); - execWithSetupAndRun(transform.getCode(), inputs, transform.getOutputs()); - return transform.getOutputs(); - } - - - /** - * Run the code and return the outputs - * @param code the code to run - * @return all python variables - */ - public static PythonVariables execAndReturnAllVariables(String code) { - exec(code + '\n' + outputCodeForAllVariables()); - PythonVariables allVars = new PythonVariables(); - allVars.addDict(ALL_VARIABLES_KEY); - try { - _readOutputs(allVars); - }catch (IOException e) { - log.error("Failed to read outputs", e); - } - - return expandInnerDict(allVars, ALL_VARIABLES_KEY); - } - public static PythonVariables execWithSetupRunAndReturnAllVariables(String code) { - execWithSetupAndRun(code + '\n' + outputCodeForAllVariables()); - PythonVariables allVars = new PythonVariables(); - allVars.addDict(ALL_VARIABLES_KEY); - try { - _readOutputs(allVars); - }catch (IOException e) { - log.error("Failed to read outputs", e); - } - - return expandInnerDict(allVars, ALL_VARIABLES_KEY); - } - - /** - * - * @param code code string to run - * @param pyInputs python input variables - * @return all python variables - * @throws Exception throws when there's an issue while execution of python code - */ - public static PythonVariables execAndReturnAllVariables(String code, PythonVariables pyInputs) throws Exception { - String inputCode = inputCode(pyInputs); - return execAndReturnAllVariables(inputCode + code); - } - public static PythonVariables execWithSetupRunAndReturnAllVariables(String code, PythonVariables pyInputs) throws Exception { - String inputCode = inputCode(pyInputs); - return execWithSetupRunAndReturnAllVariables(inputCode + code); - } - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static String evalString(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addStr(varName); - exec("print('')", vars); - return vars.getStrValue(varName); - } - - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static long evalInteger(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addInt(varName); - exec("print('')", vars); - return vars.getIntValue(varName); - } - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static Double evalFloat(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addFloat(varName); - exec("print('')", vars); - return vars.getFloatValue(varName); - } - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static Object[] evalList(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addList(varName); - exec("pass", vars); - return vars.getListValue(varName); - } - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static Map evalDict(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addDict(varName); - exec("pass", vars); - return vars.getDictValue(varName); - } - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static NumpyArray evalNdArray(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addNDArray(varName); - exec("pass", vars); - return vars.getNDArrayValue(varName); - } - - private static String outputVarName() { - return "_" + Thread.currentThread().getId() + "_" + currentInterpreter + "_out"; - } - - private static String outputCode(PythonVariables pyOutputs) { - if (pyOutputs == null){ - return ""; - } - - String outputCode = "import json\n"; - String outputFunctions; - try(BufferedInputStream bufferedInputStream = new BufferedInputStream(new ClassPathResource("pythonexec/serialize_array.py").getInputStream())) { - outputFunctions= IOUtils.toString(bufferedInputStream,Charset.defaultCharset()); - outputCode += outputFunctions; - outputCode += "\n"; - } catch (IOException e) { - throw new IllegalStateException("Unable to read python file pythonexec/serialize_arrays.py from classpath"); - } - - outputCode += outputVarName() + " = __serialize_dict({"; - String[] varNames = pyOutputs.getVariables(); - for (String varName: varNames) { - outputCode += "\"" + varName + "\": " + varName + ","; - } - - - if (varNames.length > 0) - outputCode = outputCode.substring(0, outputCode.length() - 1); - outputCode += "})"; - outputCode += "\nwith open('" + getTempFile() + "', 'w') as " + fileVarName + ":" + fileVarName + ".write(" + outputVarName() + ")"; - - - return outputCode; - - } - - private static String jArrayToPyString(Object[] array) { - String str = "["; - for (int i = 0; i < array.length; i++){ - Object obj = array[i]; - if (obj instanceof Object[]){ - str += jArrayToPyString((Object[])obj); - } - else if (obj instanceof String){ - str += "\"" + obj + "\""; - } - else{ - str += obj.toString().replace("\"", "\\\""); - } - if (i < array.length - 1){ - str += ","; - } - - } - str += "]"; - return str; - } - - private static String escapeStr(String str) { - if(str == null) - return null; - str = str.replace("\\", "\\\\"); - str = str.replace("\"\"\"", "\\\"\\\"\\\""); - return str; - } - - private static String getWrappedCode(String code) { - try(InputStream is = new ClassPathResource("pythonexec/pythonexec.py").getInputStream()) { - String base = IOUtils.toString(is, Charset.defaultCharset()); - StringBuffer indentedCode = new StringBuffer(); - for(String split : code.split("\n")) { - indentedCode.append(" " + split + "\n"); - - } - - String out = base.replace(" pass",indentedCode); - return out; - } catch (IOException e) { - throw new IllegalStateException("Unable to read python code!",e); - } - - } - - - - private static String getTempFile() { - String ret = "temp_" + Thread.currentThread().getId() + "_" + currentInterpreter + ".json"; - log.info(ret); - return ret; - } - - - private static String outputCodeForAllVariables() { - String outputCode = ""; - try(BufferedInputStream bufferedInputStream = new BufferedInputStream(new ClassPathResource("pythonexec/outputcode.py").getInputStream())) { - outputCode += IOUtils.toString(bufferedInputStream,Charset.defaultCharset()).replace("f2",fileVarName); - outputCode += "\n"; - } catch (IOException e) { - throw new IllegalStateException("Unable to read python file pythonexec/outputcode.py from classpath"); - } - - outputCode += String.format("vars = {key:value for (key,value) in locals().items() if not key.startswith('_') and key is not '%s' and key is not 'loc' and type(value) in (list, dict, str, int, float, bool, type(None))}\n",fileVarName); - outputCode += String.format("with open('" + getTempFile() + "', 'w') as %s:json.dump({",fileVarName); - outputCode +=String.format( "\"" + ALL_VARIABLES_KEY + "\"" + ": vars}, %s)\n",fileVarName); - return outputCode; - } - - - /*-----monkey patch for numpy-----*/ - private static List _getPatches() { - exec("import numpy as np"); - exec( "__overrides_path = np.core.overrides.__file__"); - exec("__random_path = np.random.__file__"); - - List patches = new ArrayList<>(); - - patches.add(new String[]{ - "pythonexec/patch0.py", - evalString("__overrides_path") - }); - patches.add(new String[]{ - "pythonexec/patch1.py", - evalString("__random_path") - }); - - return patches; - } - - private static void _applyPatch(String src, String dest){ - try(InputStream is = new ClassPathResource(src).getInputStream()) { - String patch = IOUtils.toString(is, Charset.defaultCharset()); - FileUtils.write(new File(dest), patch, "utf-8"); - } - catch(IOException e){ - log.warn("Error patching numpy: " + e); - } - } - - private static boolean _checkPatchApplied(String dest) { - try { - return FileUtils.readFileToString(new File(dest), "utf-8").startsWith("#patch"); - } catch (IOException e) { - return false; - } - } - - private static void applyPatches() { - // We patch numpy for partial support of multiple interpreters - for (String[] patch : _getPatches()){ - if (_checkPatchApplied(patch[1])){ - log.info("Patch already applied for " + patch[1]); - } - else{ - _applyPatch(patch[0], patch[1]); - log.info("Applied patch for " + patch[1]); - } - } - for (String[] patch: _getPatches()){ - if (!_checkPatchApplied(patch[1])){ - log.warn("Error patching numpy"); - } - } - } -} \ No newline at end of file +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonGIL.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonGIL.java new file mode 100644 index 000000000..d8afa2836 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonGIL.java @@ -0,0 +1,68 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + +import lombok.extern.slf4j.Slf4j; +import org.bytedeco.cpython.PyThreadState; + +import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.cpython.global.python.PyEval_RestoreThread; +import static org.bytedeco.cpython.global.python.PyEval_SaveThread; + + +@Slf4j +public class PythonGIL implements AutoCloseable { + private static PyThreadState mainThreadState; + + static { + log.debug("CPython: PyThreadState_Get()"); + mainThreadState = PyThreadState_Get(); + } + + private PythonGIL() { + acquire(); + } + + @Override + public void close() { + release(); + } + + public static PythonGIL lock() { + return new PythonGIL(); + } + + private static synchronized void acquire() { + log.debug("acquireGIL()"); + log.debug("CPython: PyEval_SaveThread()"); + mainThreadState = PyEval_SaveThread(); + log.debug("CPython: PyThreadState_New()"); + PyThreadState ts = PyThreadState_New(mainThreadState.interp()); + log.debug("CPython: PyEval_RestoreThread()"); + PyEval_RestoreThread(ts); + log.debug("CPython: PyThreadState_Swap()"); + PyThreadState_Swap(ts); + } + + private static synchronized void release() { + log.debug("CPython: PyEval_SaveThread()"); + PyEval_SaveThread(); + log.debug("CPython: PyEval_RestoreThread()"); + PyEval_RestoreThread(mainThreadState); + } +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonJob.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonJob.java new file mode 100644 index 000000000..c50c9bb9e --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonJob.java @@ -0,0 +1,171 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import javax.annotation.Nonnull; +import java.util.HashMap; +import java.util.Map; + + +@Data +@NoArgsConstructor +/** + * PythonJob is the right abstraction for executing multiple python scripts + * in a multi thread stateful environment. The setup-and-run mode allows your + * "setup" code (imports, model loading etc) to be executed only once. + */ +public class PythonJob { + + private String code; + private String name; + private String context; + private boolean setupRunMode; + private PythonObject runF; + + static { + new PythonExecutioner(); + } + + @Builder + /** + * @param name Name for the python job. + * @param code Python code. + * @param setupRunMode If true, the python code is expected to have two methods: setup(), which takes no arguments, + * and run() which takes some or no arguments. setup() method is executed once, + * and the run() method is called with the inputs(if any) per transaction, and is expected to return a dictionary + * mapping from output variable names (str) to output values. + * If false, the full script is run on each transaction and the output variables are obtained from the global namespace + * after execution. + */ + public PythonJob(@Nonnull String name, @Nonnull String code, boolean setupRunMode) throws Exception { + this.name = name; + this.code = code; + this.setupRunMode = setupRunMode; + context = "__job_" + name; + if (PythonContextManager.hasContext(context)) { + throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!"); + } + if (setupRunMode) setup(); + } + + + /** + * Clears all variables in current context and calls setup() + */ + public void clearState() throws Exception { + String context = this.context; + PythonContextManager.setContext("main"); + PythonContextManager.deleteContext(context); + this.context = context; + setup(); + } + + public void setup() throws Exception { + try (PythonGIL gil = PythonGIL.lock()) { + PythonContextManager.setContext(context); + PythonObject runF = PythonExecutioner.getVariable("run"); + if (runF.isNone() || !Python.callable(runF)) { + PythonExecutioner.exec(code); + runF = PythonExecutioner.getVariable("run"); + } + if (runF.isNone() || !Python.callable(runF)) { + throw new PythonException("run() method not found! " + + "If a PythonJob is created with 'setup and run' " + + "mode enabled, the associated python code is " + + "expected to contain a run() method " + + "(with or without arguments)."); + } + this.runF = runF; + PythonObject setupF = PythonExecutioner.getVariable("setup"); + if (!setupF.isNone()) { + setupF.call(); + } + } + } + + public void exec(PythonVariables inputs, PythonVariables outputs) throws Exception { + try (PythonGIL gil = PythonGIL.lock()) { + PythonContextManager.setContext(context); + if (!setupRunMode) { + PythonExecutioner.exec(code, inputs, outputs); + return; + } + PythonExecutioner.setVariables(inputs); + + PythonObject inspect = Python.importModule("inspect"); + PythonObject getfullargspec = inspect.attr("getfullargspec"); + PythonObject argspec = getfullargspec.call(runF); + PythonObject argsList = argspec.attr("args"); + PythonObject runargs = Python.dict(); + int argsCount = Python.len(argsList).toInt(); + for (int i = 0; i < argsCount; i++) { + PythonObject arg = argsList.get(i); + PythonObject val = Python.globals().get(arg); + if (val.isNone()) { + throw new PythonException("Input value not received for run() argument: " + arg.toString()); + } + runargs.set(arg, val); + } + PythonObject outDict = runF.callWithKwargs(runargs); + Python.globals().attr("update").call(outDict); + + PythonExecutioner.getVariables(outputs); + inspect.del(); + getfullargspec.del(); + argspec.del(); + runargs.del(); + } + } + + public PythonVariables execAndReturnAllVariables(PythonVariables inputs) throws Exception { + try (PythonGIL gil = PythonGIL.lock()) { + PythonContextManager.setContext(context); + if (!setupRunMode) { + return PythonExecutioner.execAndReturnAllVariables(code, inputs); + } + PythonExecutioner.setVariables(inputs); + PythonObject inspect = Python.importModule("inspect"); + PythonObject getfullargspec = inspect.attr("getfullargspec"); + PythonObject argspec = getfullargspec.call(runF); + PythonObject argsList = argspec.attr("args"); + PythonObject runargs = Python.dict(); + int argsCount = Python.len(argsList).toInt(); + for (int i = 0; i < argsCount; i++) { + PythonObject arg = argsList.get(i); + PythonObject val = Python.globals().get(arg); + if (val.isNone()) { + throw new PythonException("Input value not received for run() argument: " + arg.toString()); + } + runargs.set(arg, val); + } + PythonObject outDict = runF.callWithKwargs(runargs); + Python.globals().attr("update").call(outDict); + inspect.del(); + getfullargspec.del(); + argspec.del(); + runargs.del(); + return PythonExecutioner.getAllVariables(); + } + } + + +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java new file mode 100644 index 000000000..f1d54168b --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java @@ -0,0 +1,554 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + + +import org.bytedeco.cpython.PyObject; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Pointer; +import org.json.JSONArray; +import org.json.JSONObject; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.nativeblas.NativeOpsHolder; + +import java.util.*; + +import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.cpython.global.python.PyObject_SetItem; + +/** + * Swift like python wrapper for J + * + * @author Fariz Rahman + */ + +public class PythonObject { + private PyObject nativePythonObject; + + static { + new PythonExecutioner(); + } + + private static Map _getNDArraySerializer() { + Map ndarraySerializer = new HashMap<>(); + PythonObject lambda = Python.eval( + "lambda x: " + + "{'address':" + + "x.__array_interface__['data'][0]," + + "'shape':x.shape,'strides':x.strides," + + "'dtype': str(x.dtype),'_is_numpy_array': True}" + + " if str(type(x))== \"\" else x"); + ndarraySerializer.put("default", + lambda); + return ndarraySerializer; + + } + + public PythonObject(PyObject pyObject) { + nativePythonObject = pyObject; + } + + public PythonObject(INDArray npArray) { + this(new NumpyArray(npArray)); + } + + public PythonObject(BytePointer bp){ + nativePythonObject = PyByteArray_FromStringAndSize(bp, bp.capacity()); + } + + public PythonObject(NumpyArray npArray) { + PyObject ctypes = PyImport_ImportModule("ctypes"); + PyObject np = PyImport_ImportModule("numpy"); + PyObject ctype; + switch (npArray.getDtype()) { + case DOUBLE: + ctype = PyObject_GetAttrString(ctypes, "c_double"); + break; + case FLOAT: + ctype = PyObject_GetAttrString(ctypes, "c_float"); + break; + case LONG: + ctype = PyObject_GetAttrString(ctypes, "c_int64"); + break; + case INT: + ctype = PyObject_GetAttrString(ctypes, "c_int32"); + break; + case SHORT: + ctype = PyObject_GetAttrString(ctypes, "c_int16"); + break; + case UINT16: + ctype = PyObject_GetAttrString(ctypes, "c_uint16"); + break; + case UINT32: + ctype = PyObject_GetAttrString(ctypes, "c_uint32"); + break; + case UINT64: + ctype = PyObject_GetAttrString(ctypes, "c_uint64"); + break; + case BOOL: + ctype = PyObject_GetAttrString(ctypes, "c_bool"); + break; + case BYTE: + ctype = PyObject_GetAttrString(ctypes, "c_byte"); + break; + case UBYTE: + ctype = PyObject_GetAttrString(ctypes, "c_ubyte"); + break; + default: + throw new RuntimeException("Unsupported dtype: " + npArray.getDtype()); + } + + PyObject ctypesPointer = PyObject_GetAttrString(ctypes, "POINTER"); + PyObject argsTuple = PyTuple_New(1); + PyTuple_SetItem(argsTuple, 0, ctype); + PyObject ptrType = PyObject_Call(ctypesPointer, argsTuple, null); + + PyObject cast = PyObject_GetAttrString(ctypes, "cast"); + PyObject address = PyLong_FromLong(npArray.getAddress()); + PyObject argsTuple2 = PyTuple_New(2); + PyTuple_SetItem(argsTuple2, 0, address); + PyTuple_SetItem(argsTuple2, 1, ptrType); + PyObject ptr = PyObject_Call(cast, argsTuple2, null); + PyObject shapeTuple = PyTuple_New(npArray.getShape().length); + for (int i = 0; i < npArray.getShape().length; i++) { + PyObject dim = PyLong_FromLong(npArray.getShape()[i]); + PyTuple_SetItem(shapeTuple, i, dim); + Py_DecRef(dim); + } + PyObject ctypesLib = PyObject_GetAttrString(np, "ctypeslib"); + PyObject asArray = PyObject_GetAttrString(ctypesLib, "as_array"); + PyObject argsTuple3 = PyTuple_New(2); + PyTuple_SetItem(argsTuple3, 0, ptr); + PyTuple_SetItem(argsTuple3, 1, shapeTuple); + nativePythonObject = PyObject_Call(asArray, argsTuple3, null); + + Py_DecRef(ctypesPointer); + Py_DecRef(ctypesLib); + Py_DecRef(argsTuple); + Py_DecRef(argsTuple2); + Py_DecRef(argsTuple3); + Py_DecRef(cast); + Py_DecRef(asArray); + + } + + /*---primitve constructors---*/ + public PyObject getNativePythonObject() { + return nativePythonObject; + } + + public PythonObject(String data) { + nativePythonObject = PyUnicode_FromString(data); + } + + public PythonObject(int data) { + nativePythonObject = PyLong_FromLong((long) data); + } + + public PythonObject(long data) { + nativePythonObject = PyLong_FromLong(data); + } + + public PythonObject(double data) { + nativePythonObject = PyFloat_FromDouble(data); + } + + public PythonObject(boolean data) { + nativePythonObject = PyBool_FromLong(data ? 1 : 0); + } + + private static PythonObject j2pyObject(Object item) { + if (item instanceof PythonObject) { + return (PythonObject) item; + } else if (item instanceof PyObject) { + return new PythonObject((PyObject) item); + } else if (item instanceof INDArray) { + return new PythonObject((INDArray) item); + } else if (item instanceof NumpyArray) { + return new PythonObject((NumpyArray) item); + } else if (item instanceof List) { + return new PythonObject((List) item); + } else if (item instanceof Object[]) { + return new PythonObject((Object[]) item); + } else if (item instanceof Map) { + return new PythonObject((Map) item); + } else if (item instanceof String) { + return new PythonObject((String) item); + } else if (item instanceof Double) { + return new PythonObject((Double) item); + } else if (item instanceof Float) { + return new PythonObject((Float) item); + } else if (item instanceof Long) { + return new PythonObject((Long) item); + } else if (item instanceof Integer) { + return new PythonObject((Integer) item); + } else if (item instanceof Boolean) { + return new PythonObject((Boolean) item); + } else if (item instanceof Pointer){ + return new PythonObject(new BytePointer((Pointer)item)); + } else { + throw new RuntimeException("Unsupported item in list: " + item); + } + } + + public PythonObject(Object[] data) { + PyObject pyList = PyList_New((long) data.length); + for (int i = 0; i < data.length; i++) { + PyList_SetItem(pyList, i, j2pyObject(data[i]).nativePythonObject); + } + nativePythonObject = pyList; + } + + public PythonObject(List data) { + PyObject pyList = PyList_New((long) data.size()); + for (int i = 0; i < data.size(); i++) { + PyList_SetItem(pyList, i, j2pyObject(data.get(i)).nativePythonObject); + } + nativePythonObject = pyList; + } + + public PythonObject(Map data) { + PyObject pyDict = PyDict_New(); + for (Object k : data.keySet()) { + PythonObject pyKey; + if (k instanceof PythonObject) { + pyKey = (PythonObject) k; + } else if (k instanceof String) { + pyKey = new PythonObject((String) k); + } else if (k instanceof Double) { + pyKey = new PythonObject((Double) k); + } else if (k instanceof Float) { + pyKey = new PythonObject((Float) k); + } else if (k instanceof Long) { + pyKey = new PythonObject((Long) k); + } else if (k instanceof Integer) { + pyKey = new PythonObject((Integer) k); + } else if (k instanceof Boolean) { + pyKey = new PythonObject((Boolean) k); + } else { + throw new RuntimeException("Unsupported key in map: " + k.getClass()); + } + Object v = data.get(k); + PythonObject pyVal; + if (v instanceof PythonObject) { + pyVal = (PythonObject) v; + } else if (v instanceof PyObject) { + pyVal = new PythonObject((PyObject) v); + } else if (v instanceof INDArray) { + pyVal = new PythonObject((INDArray) v); + } else if (v instanceof NumpyArray) { + pyVal = new PythonObject((NumpyArray) v); + } else if (v instanceof Map) { + pyVal = new PythonObject((Map) v); + } else if (v instanceof List) { + pyVal = new PythonObject((List) v); + } else if (v instanceof String) { + pyVal = new PythonObject((String) v); + } else if (v instanceof Double) { + pyVal = new PythonObject((Double) v); + } else if (v instanceof Float) { + pyVal = new PythonObject((Float) v); + } else if (v instanceof Long) { + pyVal = new PythonObject((Long) v); + } else if (v instanceof Integer) { + pyVal = new PythonObject((Integer) v); + } else if (v instanceof Boolean) { + pyVal = new PythonObject((Boolean) v); + } else { + throw new RuntimeException("Unsupported value in map: " + k.getClass()); + } + + PyDict_SetItem(pyDict, pyKey.nativePythonObject, pyVal.nativePythonObject); + + } + nativePythonObject = pyDict; + } + + + /*------*/ + + private static String pyObjectToString(PyObject pyObject) { + PyObject repr = PyObject_Str(pyObject); + PyObject str = PyUnicode_AsEncodedString(repr, "utf-8", "~E~"); + String jstr = PyBytes_AsString(str).getString(); + Py_DecRef(repr); + Py_DecRef(str); + return jstr; + } + + public String toString() { + return pyObjectToString(nativePythonObject); + } + + public double toDouble() { + return PyFloat_AsDouble(nativePythonObject); + } + + public float toFloat() { + return (float) PyFloat_AsDouble(nativePythonObject); + } + + public int toInt() { + return (int) PyLong_AsLong(nativePythonObject); + } + + public long toLong() { + return PyLong_AsLong(nativePythonObject); + } + + public boolean toBoolean() { + if (isNone()) return false; + return toInt() != 0; + } + + public NumpyArray toNumpy() { + PyObject arrInterface = PyObject_GetAttrString(nativePythonObject, "__array_interface__"); // borrowed reference; DO NOT Py_DecRef() ! + PyObject data = PyDict_GetItemString(arrInterface, "data"); + PyObject pyAddress = PyTuple_GetItem(data, 0); + long address = PyLong_AsLong(pyAddress); + PyObject pyDtype = PyObject_GetAttrString(nativePythonObject, "dtype"); + PyObject pyDtypeName = PyObject_GetAttrString(pyDtype, "name"); + String dtypeName = pyObjectToString(pyDtypeName); + Py_DecRef(pyDtype); + Py_DecRef(pyDtypeName); + PyObject shape = PyObject_GetAttrString(nativePythonObject, "shape"); + PyObject strides = PyObject_GetAttrString(nativePythonObject, "strides"); + int ndim = (int) PyObject_Size(shape); + long[] jshape = new long[ndim]; + long[] jstrides = new long[ndim]; + for (int i = 0; i < ndim; i++) { + jshape[i] = PyLong_AsLong(PyTuple_GetItem(shape, i)); + jstrides[i] = PyLong_AsLong(PyTuple_GetItem(strides, i)); + } + Py_DecRef(shape); + Py_DecRef(strides); + DataType dtype; + if (dtypeName.equals("float64")) { + dtype = DataType.DOUBLE; + } else if (dtypeName.equals("float32")) { + dtype = DataType.FLOAT; + } else if (dtypeName.equals("int16")) { + dtype = DataType.SHORT; + } else if (dtypeName.equals("int32")) { + dtype = DataType.INT; + } else if (dtypeName.equals("int64")) { + dtype = DataType.LONG; + } else { + throw new RuntimeException("Unsupported array type " + dtypeName + "."); + } + return new NumpyArray(address, jshape, jstrides, dtype); + + } + + public PythonObject attr(String attr) { + + return new PythonObject(PyObject_GetAttrString(nativePythonObject, attr)); + } + + public PythonObject call(Object... args) { + if (args.length > 0 && args[args.length - 1] instanceof Map) { + List args2 = new ArrayList<>(); + for (int i = 0; i < args.length - 1; i++) { + args2.add(args[i]); + } + return call(args2, (Map) args[args.length - 1]); + } + if (args.length == 0) { + return new PythonObject(PyObject_CallObject(nativePythonObject, null)); + } + PyObject tuple = PyTuple_New(args.length); // leaky; tuple may contain borrowed references, so can not be de-allocated. + for (int i = 0; i < args.length; i++) { + PyTuple_SetItem(tuple, i, j2pyObject(args[i]).nativePythonObject); + } + PythonObject ret = new PythonObject(PyObject_Call(nativePythonObject, tuple, null)); + return ret; + } + + public PythonObject callWithArgs(PythonObject args) { + PyObject tuple = PyList_AsTuple(args.nativePythonObject); + return new PythonObject(PyObject_Call(nativePythonObject, tuple, null)); + } + + public PythonObject callWithKwargs(PythonObject kwargs) { + PyObject tuple = PyTuple_New(0); + return new PythonObject(PyObject_Call(nativePythonObject, tuple, kwargs.nativePythonObject)); + } + + public PythonObject callWithArgsAndKwargs(PythonObject args, PythonObject kwargs) { + PyObject tuple = PyList_AsTuple(args.nativePythonObject); + PyObject dict = kwargs.nativePythonObject; + return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict)); + } + + public PythonObject call(Map kwargs) { + PyObject dict = new PythonObject(kwargs).nativePythonObject; + PyObject tuple = PyTuple_New(0); + return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict)); + } + + public PythonObject call(List args) { + PyObject tuple = PyList_AsTuple(new PythonObject(args).nativePythonObject); + return new PythonObject(PyObject_Call(nativePythonObject, tuple, null)); + } + + public PythonObject call(List args, Map kwargs) { + PyObject tuple = PyList_AsTuple(new PythonObject(args).nativePythonObject); + PyObject dict = new PythonObject(kwargs).nativePythonObject; + return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict)); + } + + private PythonObject get(PyObject key) { + return new PythonObject( + PyObject_GetItem(nativePythonObject, key) + ); + } + + public PythonObject get(PythonObject key) { + return get(key.nativePythonObject); + } + + + public PythonObject get(int key) { + return get(PyLong_FromLong((long) key)); + } + + public PythonObject get(long key) { + return new PythonObject( + PyObject_GetItem(nativePythonObject, PyLong_FromLong(key)) + + ); + } + + public PythonObject get(double key) { + return new PythonObject( + PyObject_GetItem(nativePythonObject, PyFloat_FromDouble(key)) + + ); + } + + public PythonObject get(String key) { + return get(new PythonObject(key)); + } + + public void set(PythonObject key, PythonObject value) { + PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject); + } + + public void del() { + Py_DecRef(nativePythonObject); + nativePythonObject = null; + } + + public JSONArray toJSONArray() throws PythonException { + PythonObject json = Python.importModule("json"); + PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer()); + String jsonString = serialized.toString(); + return new JSONArray(jsonString); + + } + + public JSONObject toJSONObject() throws PythonException { + PythonObject json = Python.importModule("json"); + PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer()); + String jsonString = serialized.toString(); + return new JSONObject(jsonString); + } + + public List toList() throws PythonException{ + List list = new ArrayList(); + int n = Python.len(this).toInt(); + for (int i = 0; i < n; i++) { + PythonObject o = get(i); + if (Python.isinstance(o, Python.strType())) { + list.add(o.toString()); + } else if (Python.isinstance(o, Python.intType())) { + list.add(o.toLong()); + } else if (Python.isinstance(o, Python.floatType())) { + list.add(o.toDouble()); + } else if (Python.isinstance(o, Python.boolType())) { + list.add(o); + } else if (Python.isinstance(o, Python.listType(), Python.tupleType())) { + list.add(o.toList()); + } else if (Python.isinstance(o, Python.importModule("numpy").attr("ndarray"))) { + list.add(o.toNumpy().getNd4jArray()); + } else if (Python.isinstance(o, Python.dictType())) { + list.add(o.toMap()); + } else { + throw new RuntimeException("Error while converting python" + + " list to java List: Unable to serialize python " + + "object of type " + Python.type(this).toString()); + } + } + + return list; + } + + public Map toMap() throws PythonException{ + Map map = new HashMap(); + List keys = Python.list(attr("keys").call()).toList(); + List values = Python.list(attr("values").call()).toList(); + for (int i = 0; i < keys.size(); i++) { + map.put(keys.get(i), values.get(i)); + } + return map; + } + + public BytePointer toBytePointer() throws PythonException{ + if (Python.isinstance(this, Python.bytesType())){ + PyObject byteArray = PyByteArray_FromObject(nativePythonObject); + return PyByteArray_AsString(byteArray); + + } + else if (Python.isinstance(this, Python.bytearrayType())){ + return PyByteArray_AsString(nativePythonObject); + } + else{ + PyObject ctypes = PyImport_ImportModule("ctypes"); + PyObject cArrType = PyObject_GetAttrString(ctypes, "Array"); + if (PyObject_IsInstance(nativePythonObject, cArrType) != 0){ + PyObject cVoidP = PyObject_GetAttrString(ctypes, "c_void_p"); + PyObject cast = PyObject_GetAttrString(ctypes, "cast"); + PyObject argsTuple = PyTuple_New(2); + PyTuple_SetItem(argsTuple, 0, nativePythonObject); + PyTuple_SetItem(argsTuple, 1, cVoidP); + PyObject voidPtr = PyObject_Call(cast, argsTuple, null); + PyObject pyAddress = PyObject_GetAttrString(voidPtr, "value"); + long address = PyLong_AsLong(pyAddress); + long size = PyObject_Size(nativePythonObject); + Py_DecRef(ctypes); + Py_DecRef(cArrType); + Py_DecRef(argsTuple); + Py_DecRef(voidPtr); + Py_DecRef(pyAddress); + Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address); + ptr = ptr.limit(size); + ptr = ptr.capacity(size); + return new BytePointer(ptr); + } + else{ + throw new PythonException("Expected bytes, bytearray or ctypesArray. Received " + Python.type(this).toString()); + } + + } + } + public boolean isNone() { + return nativePythonObject == null; + } + +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java index 8f2460035..ab67adc46 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java @@ -24,10 +24,13 @@ import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.Transform; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; +import org.json.JSONPropertyIgnore; import org.nd4j.base.Preconditions; import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import java.io.IOException; import java.io.InputStream; @@ -52,12 +55,13 @@ public class PythonTransform implements Transform { private String code; private PythonVariables inputs; private PythonVariables outputs; - private String name = UUID.randomUUID().toString(); + private String name = UUID.randomUUID().toString(); private Schema inputSchema; private Schema outputSchema; private String outputDict; private boolean returnAllVariables; private boolean setupAndRun = false; + private PythonJob pythonJob; @Builder @@ -70,71 +74,70 @@ public class PythonTransform implements Transform { String outputDict, boolean returnAllInputs, boolean setupAndRun) { - Preconditions.checkNotNull(code,"No code found to run!"); + Preconditions.checkNotNull(code, "No code found to run!"); this.code = code; this.returnAllVariables = returnAllInputs; this.setupAndRun = setupAndRun; - if(inputs != null) + if (inputs != null) this.inputs = inputs; - if(outputs != null) + if (outputs != null) this.outputs = outputs; - - if(name != null) + if (name != null) this.name = name; if (outputDict != null) { this.outputDict = outputDict; this.outputs = new PythonVariables(); this.outputs.addDict(outputDict); - - String helpers; - try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) { - helpers = IOUtils.toString(is, Charset.defaultCharset()); - - }catch (IOException e){ - throw new RuntimeException("Error reading python code"); - } - this.code += "\n\n" + helpers; - this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")"; } try { - if(inputSchema != null) { + if (inputSchema != null) { this.inputSchema = inputSchema; - if(inputs == null || inputs.isEmpty()) { + if (inputs == null || inputs.isEmpty()) { this.inputs = schemaToPythonVariables(inputSchema); } } - if(outputSchema != null) { + if (outputSchema != null) { this.outputSchema = outputSchema; - if(outputs == null || outputs.isEmpty()) { + if (outputs == null || outputs.isEmpty()) { this.outputs = schemaToPythonVariables(outputSchema); } } - }catch(Exception e) { + } catch (Exception e) { throw new IllegalStateException(e); } + try{ + pythonJob = PythonJob.builder() + .name("a" + UUID.randomUUID().toString().replace("-", "_")) + .code(code) + .setupRunMode(setupAndRun) + .build(); + } + catch(Exception e){ + throw new IllegalStateException("Error creating python job: " + e); + } } @Override public void setInputSchema(Schema inputSchema) { - Preconditions.checkNotNull(inputSchema,"No input schema found!"); + Preconditions.checkNotNull(inputSchema, "No input schema found!"); this.inputSchema = inputSchema; - try{ + try { inputs = schemaToPythonVariables(inputSchema); - }catch (Exception e){ + } catch (Exception e) { throw new RuntimeException(e); } - if (outputSchema == null && outputDict == null){ + if (outputSchema == null && outputDict == null) { outputSchema = inputSchema; } } @Override - public Schema getInputSchema(){ + public Schema getInputSchema() { return inputSchema; } @@ -158,67 +161,51 @@ public class PythonTransform implements Transform { } - - @Override public List map(List writables) { PythonVariables pyInputs = getPyInputsFromWritables(writables); - Preconditions.checkNotNull(pyInputs,"Inputs must not be null!"); - - - try{ + Preconditions.checkNotNull(pyInputs, "Inputs must not be null!"); + try { if (returnAllVariables) { - if (setupAndRun){ - return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs)); - } - return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs)); + return getWritablesFromPyOutputs(pythonJob.execAndReturnAllVariables(pyInputs)); } if (outputDict != null) { - if (setupAndRun) { - PythonExecutioner.execWithSetupAndRun(this, pyInputs); - }else{ - PythonExecutioner.exec(this, pyInputs); - } + pythonJob.exec(pyInputs, outputs); PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict); return getWritablesFromPyOutputs(out); - } - else { - if (setupAndRun) { - PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs); - }else{ - PythonExecutioner.exec(code, pyInputs, outputs); - } + } else { + pythonJob.exec(pyInputs, outputs); return getWritablesFromPyOutputs(outputs); } - } - catch (Exception e){ + } catch (Exception e) { throw new RuntimeException(e); } } @Override - public String[] outputColumnNames(){ + public String[] outputColumnNames() { return outputs.getVariables(); } @Override - public String outputColumnName(){ + public String outputColumnName() { return outputColumnNames()[0]; } + @Override - public String[] columnNames(){ + public String[] columnNames() { return outputs.getVariables(); } @Override - public String columnName(){ + public String columnName() { return columnNames()[0]; } - public Schema transform(Schema inputSchema){ + public Schema transform(Schema inputSchema) { return outputSchema; } @@ -226,33 +213,33 @@ public class PythonTransform implements Transform { private PythonVariables getPyInputsFromWritables(List writables) { PythonVariables ret = new PythonVariables(); - for (String name: inputs.getVariables()) { + for (String name : inputs.getVariables()) { int colIdx = inputSchema.getIndexOfColumn(name); Writable w = writables.get(colIdx); - PythonVariables.Type pyType = inputs.getType(name); - switch (pyType){ + PythonType pyType = inputs.getType(name); + switch (pyType.getName()) { case INT: - if (w instanceof LongWritable){ - ret.addInt(name, ((LongWritable)w).get()); + if (w instanceof LongWritable) { + ret.addInt(name, ((LongWritable) w).get()); + } else { + ret.addInt(name, ((IntWritable) w).get()); } - else{ - ret.addInt(name, ((IntWritable)w).get()); - } - break; case FLOAT: if (w instanceof DoubleWritable) { - ret.addFloat(name, ((DoubleWritable)w).get()); - } - else{ - ret.addFloat(name, ((FloatWritable)w).get()); + ret.addFloat(name, ((DoubleWritable) w).get()); + } else { + ret.addFloat(name, ((FloatWritable) w).get()); } break; case STR: ret.addStr(name, w.toString()); break; case NDARRAY: - ret.addNDArray(name,((NDArrayWritable)w).get()); + ret.addNDArray(name, ((NDArrayWritable) w).get()); + break; + case BOOL: + ret.addBool(name, ((BooleanWritable) w).get()); break; default: throw new RuntimeException("Unsupported input type:" + pyType); @@ -269,8 +256,8 @@ public class PythonTransform implements Transform { Schema.Builder schemaBuilder = new Schema.Builder(); for (int i = 0; i < varNames.length; i++) { String name = varNames[i]; - PythonVariables.Type pyType = pyOuts.getType(name); - switch (pyType){ + PythonType pyType = pyOuts.getType(name); + switch (pyType.getName()) { case INT: schemaBuilder.addColumnLong(name); break; @@ -283,11 +270,14 @@ public class PythonTransform implements Transform { schemaBuilder.addColumnString(name); break; case NDARRAY: - NumpyArray arr = pyOuts.getNDArrayValue(name); - schemaBuilder.addColumnNDArray(name, arr.getShape()); + INDArray arr = pyOuts.getNDArrayValue(name); + schemaBuilder.addColumnNDArray(name, arr.shape()); + break; + case BOOL: + schemaBuilder.addColumnBoolean(name); break; default: - throw new IllegalStateException("Unable to support type " + pyType.name()); + throw new IllegalStateException("Unable to support type " + pyType.getName()); } } this.outputSchema = schemaBuilder.build(); @@ -295,9 +285,9 @@ public class PythonTransform implements Transform { for (int i = 0; i < varNames.length; i++) { String name = varNames[i]; - PythonVariables.Type pyType = pyOuts.getType(name); + PythonType pyType = pyOuts.getType(name); - switch (pyType){ + switch (pyType.getName()) { case INT: out.add(new LongWritable(pyOuts.getIntValue(name))); break; @@ -308,14 +298,14 @@ public class PythonTransform implements Transform { out.add(new Text(pyOuts.getStrValue(name))); break; case NDARRAY: - NumpyArray arr = pyOuts.getNDArrayValue(name); - out.add(new NDArrayWritable(arr.getNd4jArray())); + INDArray arr = pyOuts.getNDArrayValue(name); + out.add(new NDArrayWritable(arr)); break; case DICT: Map dictValue = pyOuts.getDictValue(name); Map noNullValues = new java.util.HashMap<>(); - for(Map.Entry entry : dictValue.entrySet()) { - if(entry.getValue() != org.json.JSONObject.NULL) { + for (Map.Entry entry : dictValue.entrySet()) { + if (entry.getValue() != org.json.JSONObject.NULL) { noNullValues.put(entry.getKey(), entry.getValue()); } } @@ -327,21 +317,22 @@ public class PythonTransform implements Transform { } break; case LIST: - Object[] listValue = pyOuts.getListValue(name); + Object[] listValue = pyOuts.getListValue(name).toArray(); try { out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue))); } catch (JsonProcessingException e) { throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!"); } break; + case BOOL: + out.add(new BooleanWritable(pyOuts.getBooleanValue(name))); + break; default: - throw new IllegalStateException("Unable to support type " + pyType.name()); + throw new IllegalStateException("Unable to support type " + pyType.getName()); } } return out; } - - } \ No newline at end of file diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonType.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonType.java new file mode 100644 index 000000000..60603a8e3 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonType.java @@ -0,0 +1,238 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Pointer; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.datavec.python.Python.importModule; + + +/** + * + * @param Corresponding Java type for the Python type + */ +public abstract class PythonType { + + public abstract T toJava(PythonObject pythonObject) throws PythonException; + private final TypeName typeName; + + enum TypeName{ + STR, + INT, + FLOAT, + BOOL, + LIST, + DICT, + NDARRAY, + BYTES + } + private PythonType(TypeName typeName){ + this.typeName = typeName; + } + public TypeName getName(){return typeName;} + public String toString(){ + return getName().name(); + } + public static PythonType valueOf(String typeName) throws PythonException{ + try{ + typeName.valueOf(typeName); + } catch (IllegalArgumentException iae){ + throw new PythonException("Invalid python type: " + typeName, iae); + } + try{ + return (PythonType)PythonType.class.getField(typeName).get(null); // shouldn't fail + } catch (Exception e){ + throw new RuntimeException(e); + } + + } + public static PythonType valueOf(TypeName typeName){ + try{ + return valueOf(typeName.name()); // shouldn't fail + }catch (PythonException pe){ + throw new RuntimeException(pe); + } + } + + /** + * Since multiple java types can map to the same python type, + * this method "normalizes" all supported incoming objects to T + * + * @param object object to be converted to type T + * @return + */ + public T convert(Object object) throws PythonException { + return (T) object; + } + + public static final PythonType STR = new PythonType(TypeName.STR) { + @Override + public String toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.strType())) { + throw new PythonException("Expected variable to be str, but was " + Python.type(pythonObject)); + } + return pythonObject.toString(); + } + + @Override + public String convert(Object object) { + return object.toString(); + } + }; + + public static final PythonType INT = new PythonType(TypeName.INT) { + @Override + public Long toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.intType())) { + throw new PythonException("Expected variable to be int, but was " + Python.type(pythonObject)); + } + return pythonObject.toLong(); + } + + @Override + public Long convert(Object object) throws PythonException { + if (object instanceof Number) { + return ((Number) object).longValue(); + } + throw new PythonException("Unable to cast " + object + " to Long."); + } + }; + + public static final PythonType FLOAT = new PythonType(TypeName.FLOAT) { + @Override + public Double toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.floatType())) { + throw new PythonException("Expected variable to be float, but was " + Python.type(pythonObject)); + } + return pythonObject.toDouble(); + } + + @Override + public Double convert(Object object) throws PythonException { + if (object instanceof Number) { + return ((Number) object).doubleValue(); + } + throw new PythonException("Unable to cast " + object + " to Double."); + } + }; + + public static final PythonType BOOL = new PythonType(TypeName.BOOL) { + @Override + public Boolean toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.boolType())) { + throw new PythonException("Expected variable to be bool, but was " + Python.type(pythonObject)); + } + return pythonObject.toBoolean(); + } + + @Override + public Boolean convert(Object object) throws PythonException { + if (object instanceof Number) { + return ((Number) object).intValue() != 0; + } else if (object instanceof Boolean) { + return (Boolean) object; + } + throw new PythonException("Unable to cast " + object + " to Boolean."); + } + }; + + public static final PythonType LIST = new PythonType(TypeName.LIST) { + @Override + public List toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.listType())) { + throw new PythonException("Expected variable to be list, but was " + Python.type(pythonObject)); + } + return pythonObject.toList(); + } + + @Override + public List convert(Object object) throws PythonException { + if (object instanceof java.util.List) { + return (List) object; + } else if (object instanceof org.json.JSONArray) { + org.json.JSONArray jsonArray = (org.json.JSONArray) object; + return jsonArray.toList(); + + } else if (object instanceof Object[]) { + return Arrays.asList((Object[]) object); + } + throw new PythonException("Unable to cast " + object + " to List."); + } + }; + + public static final PythonType DICT = new PythonType(TypeName.DICT) { + @Override + public Map toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.dictType())) { + throw new PythonException("Expected variable to be dict, but was " + Python.type(pythonObject)); + } + return pythonObject.toMap(); + } + + @Override + public Map convert(Object object) throws PythonException { + if (object instanceof Map) { + return (Map) object; + } + throw new PythonException("Unable to cast " + object + " to Map."); + } + }; + + public static final PythonType NDARRAY = new PythonType(TypeName.NDARRAY) { + @Override + public INDArray toJava(PythonObject pythonObject) throws PythonException { + PythonObject np = importModule("numpy"); + if (!Python.isinstance(pythonObject, np.attr("ndarray"), np.attr("generic"))) { + throw new PythonException("Expected variable to be numpy.ndarray, but was " + Python.type(pythonObject)); + } + return pythonObject.toNumpy().getNd4jArray(); + } + + @Override + public INDArray convert(Object object) throws PythonException { + if (object instanceof INDArray) { + return (INDArray) object; + } else if (object instanceof NumpyArray) { + return ((NumpyArray) object).getNd4jArray(); + } + throw new PythonException("Unable to cast " + object + " to INDArray."); + } + }; + + public static final PythonType BYTES = new PythonType(TypeName.BYTES) { + @Override + public BytePointer toJava(PythonObject pythonObject) throws PythonException { + return pythonObject.toBytePointer(); + } + + @Override + public BytePointer convert(Object object) throws PythonException { + if (object instanceof BytePointer) { + return (BytePointer) object; + } else if (object instanceof Pointer) { + return new BytePointer((Pointer) object); + } + throw new PythonException("Unable to cast " + object + " to BytePointer."); + } + }; +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java index a8334cbc5..c510d8130 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java @@ -24,28 +24,30 @@ public class PythonUtils { * Create a {@link Schema} * from {@link PythonVariables}. * Types are mapped to types of the same name. + * * @param input the input {@link PythonVariables} * @return the output {@link Schema} */ public static Schema fromPythonVariables(PythonVariables input) { Schema.Builder schemaBuilder = new Schema.Builder(); - Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0,"Input must have variables. Found none."); - for(Map.Entry entry : input.getVars().entrySet()) { - switch(entry.getValue()) { + Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0, "Input must have variables. Found none."); + for (String varName: input.getVariables()) { + + switch (input.getType(varName).getName()) { case INT: - schemaBuilder.addColumnInteger(entry.getKey()); + schemaBuilder.addColumnInteger(varName); break; case STR: - schemaBuilder.addColumnString(entry.getKey()); + schemaBuilder.addColumnString(varName); break; case FLOAT: - schemaBuilder.addColumnFloat(entry.getKey()); + schemaBuilder.addColumnFloat(varName); break; case NDARRAY: - schemaBuilder.addColumnNDArray(entry.getKey(),null); + schemaBuilder.addColumnNDArray(varName, null); break; case BOOL: - schemaBuilder.addColumn(new BooleanMetaData(entry.getKey())); + schemaBuilder.addColumn(new BooleanMetaData(varName)); } } @@ -56,34 +58,36 @@ public class PythonUtils { * Create a {@link Schema} from an input * {@link PythonVariables} * Types are mapped to types of the same name + * * @param input the input schema * @return the output python variables. */ public static PythonVariables fromSchema(Schema input) { PythonVariables ret = new PythonVariables(); - for(int i = 0; i < input.numColumns(); i++) { + for (int i = 0; i < input.numColumns(); i++) { String currColumnName = input.getName(i); ColumnType columnType = input.getType(i); - switch(columnType) { + switch (columnType) { case NDArray: - ret.add(currColumnName, PythonVariables.Type.NDARRAY); + ret.add(currColumnName, PythonType.NDARRAY); break; case Boolean: - ret.add(currColumnName, PythonVariables.Type.BOOL); + ret.add(currColumnName, PythonType.BOOL); break; case Categorical: case String: - ret.add(currColumnName, PythonVariables.Type.STR); + ret.add(currColumnName, PythonType.STR); break; case Double: case Float: - ret.add(currColumnName, PythonVariables.Type.FLOAT); + ret.add(currColumnName, PythonType.FLOAT); break; case Integer: case Long: - ret.add(currColumnName, PythonVariables.Type.INT); + ret.add(currColumnName, PythonType.INT); break; case Bytes: + ret.add(currColumnName, PythonType.BYTES); break; case Time: throw new UnsupportedOperationException("Unable to process dates with python yet."); @@ -92,9 +96,11 @@ public class PythonUtils { return ret; } + /** * Convert a {@link Schema} * to {@link PythonVariables} + * * @param schema the input schema * @return the output {@link PythonVariables} where each * name in the map is associated with a column name in the schema. @@ -107,7 +113,7 @@ public class PythonUtils { for (int i = 0; i < numCols; i++) { String colName = schema.getName(i); ColumnType colType = schema.getType(i); - switch (colType){ + switch (colType) { case Long: case Integer: pyVars.addInt(colName); @@ -122,6 +128,9 @@ public class PythonUtils { case NDArray: pyVars.addNDArray(colName); break; + case Boolean: + pyVars.addBool(colName); + break; default: throw new Exception("Unsupported python input type: " + colType.toString()); } @@ -131,117 +140,104 @@ public class PythonUtils { } - public static NumpyArray mapToNumpyArray(Map map){ - String dtypeName = (String)map.get("dtype"); + public static NumpyArray mapToNumpyArray(Map map) { + String dtypeName = (String) map.get("dtype"); DataType dtype; - if (dtypeName.equals("float64")){ + if (dtypeName.equals("float64")) { dtype = DataType.DOUBLE; - } - else if (dtypeName.equals("float32")){ + } else if (dtypeName.equals("float32")) { dtype = DataType.FLOAT; - } - else if (dtypeName.equals("int16")){ + } else if (dtypeName.equals("int16")) { dtype = DataType.SHORT; - } - else if (dtypeName.equals("int32")){ + } else if (dtypeName.equals("int32")) { dtype = DataType.INT; - } - else if (dtypeName.equals("int64")){ + } else if (dtypeName.equals("int64")) { dtype = DataType.LONG; - } - else{ + } else { throw new RuntimeException("Unsupported array type " + dtypeName + "."); } - List shapeList = (List)map.get("shape"); + List shapeList = (List) map.get("shape"); long[] shape = new long[shapeList.size()]; for (int i = 0; i < shape.length; i++) { - shape[i] = (Long)shapeList.get(i); + shape[i] = (Long) shapeList.get(i); } - List strideList = (List)map.get("shape"); + List strideList = (List) map.get("shape"); long[] stride = new long[strideList.size()]; for (int i = 0; i < stride.length; i++) { - stride[i] = (Long)strideList.get(i); + stride[i] = (Long) strideList.get(i); } - long address = (Long)map.get("address"); - NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype); + long address = (Long) map.get("address"); + NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true); return numpyArray; } - public static PythonVariables expandInnerDict(PythonVariables pyvars, String key){ + public static PythonVariables expandInnerDict(PythonVariables pyvars, String key) { Map dict = pyvars.getDictValue(key); - String[] keys = (String[])dict.keySet().toArray(new String[dict.keySet().size()]); + String[] keys = (String[]) dict.keySet().toArray(new String[dict.keySet().size()]); PythonVariables pyvars2 = new PythonVariables(); - for (String subkey: keys){ + for (String subkey : keys) { Object value = dict.get(subkey); - if (value instanceof Map){ - Map map = (Map)value; - if (map.containsKey("_is_numpy_array")){ + if (value instanceof Map) { + Map map = (Map) value; + if (map.containsKey("_is_numpy_array")) { pyvars2.addNDArray(subkey, mapToNumpyArray(map)); - } - else{ - pyvars2.addDict(subkey, (Map)value); + } else { + pyvars2.addDict(subkey, (Map) value); } - } - else if (value instanceof List){ + } else if (value instanceof List) { pyvars2.addList(subkey, ((List) value).toArray()); - } - else if (value instanceof String){ - System.out.println((String)value); + } else if (value instanceof String) { + System.out.println((String) value); pyvars2.addStr(subkey, (String) value); - } - else if (value instanceof Integer || value instanceof Long) { + } else if (value instanceof Integer || value instanceof Long) { Number number = (Number) value; pyvars2.addInt(subkey, number.intValue()); - } - else if (value instanceof Float || value instanceof Double) { + } else if (value instanceof Float || value instanceof Double) { Number number = (Number) value; pyvars2.addFloat(subkey, number.doubleValue()); - } - else if (value instanceof NumpyArray){ - pyvars2.addNDArray(subkey, (NumpyArray)value); - } - else if (value == null){ + } else if (value instanceof NumpyArray) { + pyvars2.addNDArray(subkey, (NumpyArray) value); + } else if (value == null) { pyvars2.addStr(subkey, "None"); // FixMe - } - else{ + } else { throw new RuntimeException("Unsupported type!" + value); } } return pyvars2; } - public static long[] jsonArrayToLongArray(JSONArray jsonArray){ + public static long[] jsonArrayToLongArray(JSONArray jsonArray) { long[] longs = new long[jsonArray.length()]; - for (int i=0; i toMap(JSONObject jsonobj) { + public static Map toMap(JSONObject jsonobj) { Map map = new HashMap<>(); - String[] keys = (String[])jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]); - for (String key: keys){ + String[] keys = (String[]) jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]); + for (String key : keys) { Object value = jsonobj.get(key); if (value instanceof JSONArray) { value = toList((JSONArray) value); } else if (value instanceof JSONObject) { - JSONObject jsonobj2 = (JSONObject)value; - if (jsonobj2.has("_is_numpy_array")){ + JSONObject jsonobj2 = (JSONObject) value; + if (jsonobj2.has("_is_numpy_array")) { value = jsonToNumpyArray(jsonobj2); - } - else{ + } else { value = toMap(jsonobj2); } } map.put(key, value); - } return map; + } + return map; } @@ -265,40 +261,35 @@ public class PythonUtils { } - private static NumpyArray jsonToNumpyArray(JSONObject map){ - String dtypeName = (String)map.get("dtype"); + private static NumpyArray jsonToNumpyArray(JSONObject map) { + String dtypeName = (String) map.get("dtype"); DataType dtype; - if (dtypeName.equals("float64")){ + if (dtypeName.equals("float64")) { dtype = DataType.DOUBLE; - } - else if (dtypeName.equals("float32")){ + } else if (dtypeName.equals("float32")) { dtype = DataType.FLOAT; - } - else if (dtypeName.equals("int16")){ + } else if (dtypeName.equals("int16")) { dtype = DataType.SHORT; - } - else if (dtypeName.equals("int32")){ + } else if (dtypeName.equals("int32")) { dtype = DataType.INT; - } - else if (dtypeName.equals("int64")){ + } else if (dtypeName.equals("int64")) { dtype = DataType.LONG; - } - else{ + } else { throw new RuntimeException("Unsupported array type " + dtypeName + "."); } - List shapeList = (List)map.get("shape"); + List shapeList = map.getJSONArray("shape").toList(); long[] shape = new long[shapeList.size()]; for (int i = 0; i < shape.length; i++) { - shape[i] = (Long)shapeList.get(i); + shape[i] = ((Number) shapeList.get(i)).longValue(); } - List strideList = (List)map.get("shape"); + List strideList = map.getJSONArray("shape").toList(); long[] stride = new long[strideList.size()]; for (int i = 0; i < stride.length; i++) { - stride[i] = (Long)strideList.get(i); + stride[i] = ((Number) strideList.get(i)).longValue(); } - long address = (Long)map.get("address"); - NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype); + long address = ((Number) map.get("address")).longValue(); + NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true); return numpyArray; } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java index 4d04f1d87..9d8b5c2a1 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java @@ -17,13 +17,19 @@ package org.datavec.python; import lombok.Data; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Pointer; import org.json.JSONObject; import org.json.JSONArray; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.nativeblas.NativeOpsHolder; import java.io.Serializable; +import java.nio.ByteBuffer; import java.util.*; + + /** * Holds python variable names, types and values. * Also handles mapping from java types to python types. @@ -33,41 +39,31 @@ import java.util.*; @lombok.Data public class PythonVariables implements java.io.Serializable { - - public enum Type{ - BOOL, - STR, - INT, - FLOAT, - NDARRAY, - LIST, - FILE, - DICT - - } + private java.util.Map strVariables = new java.util.LinkedHashMap<>(); private java.util.Map intVariables = new java.util.LinkedHashMap<>(); private java.util.Map floatVariables = new java.util.LinkedHashMap<>(); private java.util.Map boolVariables = new java.util.LinkedHashMap<>(); - private java.util.Map ndVars = new java.util.LinkedHashMap<>(); - private java.util.Map listVariables = new java.util.LinkedHashMap<>(); - private java.util.Map fileVariables = new java.util.LinkedHashMap<>(); - private java.util.Map> dictVariables = new java.util.LinkedHashMap<>(); - private java.util.Map vars = new java.util.LinkedHashMap<>(); - private java.util.Map maps = new java.util.LinkedHashMap<>(); + private java.util.Map ndVars = new java.util.LinkedHashMap<>(); + private java.util.Map listVariables = new java.util.LinkedHashMap<>(); + private java.util.Map bytesVariables = new java.util.LinkedHashMap<>(); + private java.util.Map> dictVariables = new java.util.LinkedHashMap<>(); + private java.util.Map vars = new java.util.LinkedHashMap<>(); + private java.util.Map maps = new java.util.LinkedHashMap<>(); /** * Returns a copy of the variable * schema in this array without the values + * * @return an empty variables clone * with no values */ - public PythonVariables copySchema(){ + public PythonVariables copySchema() { PythonVariables ret = new PythonVariables(); - for (String varName: getVariables()){ - Type type = getType(varName); + for (String varName : getVariables()) { + PythonType type = getType(varName); ret.add(varName, type); } return ret; @@ -77,21 +73,19 @@ public class PythonVariables implements java.io.Serializable { * */ public PythonVariables() { - maps.put(PythonVariables.Type.BOOL, boolVariables); - maps.put(PythonVariables.Type.STR, strVariables); - maps.put(PythonVariables.Type.INT, intVariables); - maps.put(PythonVariables.Type.FLOAT, floatVariables); - maps.put(PythonVariables.Type.NDARRAY, ndVars); - maps.put(PythonVariables.Type.LIST, listVariables); - maps.put(PythonVariables.Type.FILE, fileVariables); - maps.put(PythonVariables.Type.DICT, dictVariables); + maps.put(PythonType.TypeName.BOOL, boolVariables); + maps.put(PythonType.TypeName.STR, strVariables); + maps.put(PythonType.TypeName.INT, intVariables); + maps.put(PythonType.TypeName.FLOAT, floatVariables); + maps.put(PythonType.TypeName.NDARRAY, ndVars); + maps.put(PythonType.TypeName.LIST, listVariables); + maps.put(PythonType.TypeName.DICT, dictVariables); + maps.put(PythonType.TypeName.BYTES, bytesVariables); } - /** - * * @return true if there are no variables. */ public boolean isEmpty() { @@ -100,12 +94,11 @@ public class PythonVariables implements java.io.Serializable { /** - * * @param name Name of the variable * @param type Type of the variable */ - public void add(String name, Type type){ - switch (type){ + public void add(String name, PythonType type) { + switch (type.getName()) { case BOOL: addBool(name); break; @@ -124,21 +117,21 @@ public class PythonVariables implements java.io.Serializable { case LIST: addList(name); break; - case FILE: - addFile(name); - break; case DICT: addDict(name); + break; + case BYTES: + addBytes(name); + break; } } /** - * - * @param name name of the variable - * @param type type of the variable + * @param name name of the variable + * @param type type of the variable * @param value value of the variable (must be instance of expected type) */ - public void add(String name, Type type, Object value) { + public void add(String name, PythonType type, Object value) throws PythonException { add(name, type); setValue(name, value); } @@ -148,21 +141,23 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ public void addDict(String name) { - vars.put(name, PythonVariables.Type.DICT); - dictVariables.put(name,null); + vars.put(name, PythonType.TypeName.DICT); + dictVariables.put(name, null); } /** * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addBool(String name){ - vars.put(name, PythonVariables.Type.BOOL); + public void addBool(String name) { + vars.put(name, PythonType.TypeName.BOOL); boolVariables.put(name, null); } @@ -170,10 +165,11 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addStr(String name){ - vars.put(name, PythonVariables.Type.STR); + public void addStr(String name) { + vars.put(name, PythonType.TypeName.STR); strVariables.put(name, null); } @@ -181,10 +177,11 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addInt(String name){ - vars.put(name, PythonVariables.Type.INT); + public void addInt(String name) { + vars.put(name, PythonType.TypeName.INT); intVariables.put(name, null); } @@ -192,10 +189,11 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addFloat(String name){ - vars.put(name, PythonVariables.Type.FLOAT); + public void addFloat(String name) { + vars.put(name, PythonType.TypeName.FLOAT); floatVariables.put(name, null); } @@ -203,10 +201,11 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addNDArray(String name){ - vars.put(name, PythonVariables.Type.NDARRAY); + public void addNDArray(String name) { + vars.put(name, PythonType.TypeName.NDARRAY); ndVars.put(name, null); } @@ -214,99 +213,109 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addList(String name){ - vars.put(name, PythonVariables.Type.LIST); + public void addList(String name) { + vars.put(name, PythonType.TypeName.LIST); listVariables.put(name, null); } - /** - * Add a null variable to - * the set of variables - * to describe the type but no value - * @param name the field to add - */ - public void addFile(String name){ - vars.put(name, PythonVariables.Type.FILE); - fileVariables.put(name, null); - } - /** * Add a boolean variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addBool(String name, boolean value) { - vars.put(name, PythonVariables.Type.BOOL); + vars.put(name, PythonType.TypeName.BOOL); boolVariables.put(name, value); } /** * Add a string variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addStr(String name, String value) { - vars.put(name, PythonVariables.Type.STR); + vars.put(name, PythonType.TypeName.STR); strVariables.put(name, value); } /** * Add an int variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addInt(String name, int value) { - vars.put(name, PythonVariables.Type.INT); - intVariables.put(name, (long)value); + vars.put(name, PythonType.TypeName.INT); + intVariables.put(name, (long) value); } /** * Add a long variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addInt(String name, long value) { - vars.put(name, PythonVariables.Type.INT); + vars.put(name, PythonType.TypeName.INT); intVariables.put(name, value); } /** * Add a double variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addFloat(String name, double value) { - vars.put(name, PythonVariables.Type.FLOAT); + vars.put(name, PythonType.TypeName.FLOAT); floatVariables.put(name, value); } /** * Add a float variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addFloat(String name, float value) { - vars.put(name, PythonVariables.Type.FLOAT); - floatVariables.put(name, (double)value); + vars.put(name, PythonType.TypeName.FLOAT); + floatVariables.put(name, (double) value); } /** * Add a null variable to * the set of variables * to describe the type but no value - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addNDArray(String name, NumpyArray value) { - vars.put(name, PythonVariables.Type.NDARRAY); + vars.put(name, PythonType.TypeName.NDARRAY); + ndVars.put(name, value.getNd4jArray()); + } + + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * + * @param name the field to add + * @param value the value to add + */ + public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) { + vars.put(name, PythonType.TypeName.NDARRAY); ndVars.put(name, value); } @@ -314,117 +323,63 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value - * @param name the field to add - * @param value the value to add - */ - public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) { - vars.put(name, PythonVariables.Type.NDARRAY); - ndVars.put(name, new NumpyArray(value)); - } - - /** - * Add a null variable to - * the set of variables - * to describe the type but no value - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addList(String name, Object[] value) { - vars.put(name, PythonVariables.Type.LIST); - listVariables.put(name, value); + vars.put(name, PythonType.TypeName.LIST); + listVariables.put(name, Arrays.asList(value)); } - + /** * Add a null variable to * the set of variables * to describe the type but no value - * @param name the field to add - * @param value the value to add - */ - public void addFile(String name, String value) { - vars.put(name, PythonVariables.Type.FILE); - fileVariables.put(name, value); - } - - - /** - * Add a null variable to - * the set of variables - * to describe the type but no value - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addDict(String name, java.util.Map value) { - vars.put(name, PythonVariables.Type.DICT); + vars.put(name, PythonType.TypeName.DICT); dictVariables.put(name, value); } + + + public void addBytes(String name){ + vars.put(name, PythonType.TypeName.BYTES); + bytesVariables.put(name, null); + } + + public void addBytes(String name, BytePointer value){ + vars.put(name, PythonType.TypeName.BYTES); + bytesVariables.put(name, value); + } + +// public void addBytes(String name, ByteBuffer value){ +// Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress((value.address()); +// BytePointer bp = new BytePointer(ptr); +// addBytes(name, bp); +// } /** - * - * @param name name of the variable + * @param name name of the variable * @param value new value for the variable */ - public void setValue(String name, Object value) { - Type type = vars.get(name); - if (type == PythonVariables.Type.BOOL){ - boolVariables.put(name, (Boolean)value); - } - else if (type == PythonVariables.Type.INT){ - Number number = (Number) value; - intVariables.put(name, number.longValue()); - } - else if (type == PythonVariables.Type.FLOAT){ - Number number = (Number) value; - floatVariables.put(name, number.doubleValue()); - } - else if (type == PythonVariables.Type.NDARRAY){ - if (value instanceof NumpyArray){ - ndVars.put(name, (NumpyArray)value); - } - else if (value instanceof org.nd4j.linalg.api.ndarray.INDArray) { - ndVars.put(name, new NumpyArray((org.nd4j.linalg.api.ndarray.INDArray) value)); - } - else{ - throw new RuntimeException("Unsupported type: " + value.getClass().toString()); - } - } - else if (type == PythonVariables.Type.LIST) { - if (value instanceof java.util.List) { - value = ((java.util.List) value).toArray(); - listVariables.put(name, (Object[]) value); - } - else if(value instanceof org.json.JSONArray) { - org.json.JSONArray jsonArray = (org.json.JSONArray) value; - Object[] copyArr = new Object[jsonArray.length()]; - for(int i = 0; i < copyArr.length; i++) { - copyArr[i] = jsonArray.get(i); - } - listVariables.put(name, copyArr); - - } - else { - listVariables.put(name, (Object[]) value); - } - } - else if(type == PythonVariables.Type.DICT) { - dictVariables.put(name,(java.util.Map) value); - } - else if (type == PythonVariables.Type.FILE){ - fileVariables.put(name, (String)value); - } - else{ - strVariables.put(name, (String)value); - } + public void setValue(String name, Object value) throws PythonException { + PythonType.TypeName type = vars.get(name); + maps.get(type).put(name, PythonType.valueOf(type).convert(value)); } /** * Do a general object lookup. - * The look up will happen relative to the {@link Type} + * The look up will happen relative to the {@link PythonType} * of variable is described in the + * * @param name the name of the variable to get * @return teh value for the variable with the given name */ public Object getValue(String name) { - Type type = vars.get(name); + PythonType.TypeName type = vars.get(name); java.util.Map map = maps.get(type); return map.get(name); } @@ -432,6 +387,7 @@ public class PythonVariables implements java.io.Serializable { /** * Returns a boolean variable with the given name. + * * @param name the variable name to get the value for * @return the retrieved boolean value */ @@ -440,80 +396,78 @@ public class PythonVariables implements java.io.Serializable { } /** - * * @param name the variable name * @return the dictionary value */ - public java.util.Map getDictValue(String name) { + public java.util.Map getDictValue(String name) { return dictVariables.get(name); } /** - /** + * /** * * @param name the variable name * @return the string value */ - public String getStrValue(String name){ + public String getStrValue(String name) { return strVariables.get(name); } /** - * * @param name the variable name * @return the long value */ - public Long getIntValue(String name){ + public Long getIntValue(String name) { return intVariables.get(name); } /** - * * @param name the variable name * @return the float value */ - public Double getFloatValue(String name){ + public Double getFloatValue(String name) { return floatVariables.get(name); } /** - * * @param name the variable name * @return the numpy array value */ - public NumpyArray getNDArrayValue(String name){ + public INDArray getNDArrayValue(String name) { return ndVars.get(name); } /** - * * @param name the variable name * @return the list value as an object array */ - public Object[] getListValue(String name){ + public List getListValue(String name) { return listVariables.get(name); } /** - * * @param name the variable name - * @return the value of the given file name + * @return the bytes value as a BytePointer */ - public String getFileValue(String name){ - return fileVariables.get(name); - } - + public BytePointer getBytesValue(String name){return bytesVariables.get(name);} /** * Returns the type for the given variable name + * * @param name the name of the variable to get the type for * @return the type for the given variable */ - public Type getType(String name){ - return vars.get(name); + public PythonType getType(String name){ + try{ + return PythonType.valueOf(vars.get(name)); // will never fail + }catch (Exception e) + { + throw new RuntimeException(e); + } } /** * Get all the variables present as a string array + * * @return the variable names for this variable sset */ public String[] getVariables() { @@ -524,11 +478,12 @@ public class PythonVariables implements java.io.Serializable { /** * This variables set as its json representation (an array of json objects) + * * @return the json array output */ - public org.json.JSONArray toJSON(){ + public org.json.JSONArray toJSON() { org.json.JSONArray arr = new org.json.JSONArray(); - for (String varName: getVariables()){ + for (String varName : getVariables()) { org.json.JSONObject var = new org.json.JSONObject(); var.put("name", varName); String varType = getType(varName).toString(); @@ -542,13 +497,14 @@ public class PythonVariables implements java.io.Serializable { * Create a schema from a map. * This is an empty PythonVariables * that just contains names and types with no values + * * @param inputTypes the input types to convert * @return the schema from the given map */ - public static PythonVariables schemaFromMap(java.util.Map inputTypes) { + public static PythonVariables schemaFromMap(java.util.Map inputTypes) throws Exception{ PythonVariables ret = new PythonVariables(); - for(java.util.Map.Entry entry : inputTypes.entrySet()) { - ret.add(entry.getKey(), PythonVariables.Type.valueOf(entry.getValue())); + for (java.util.Map.Entry entry : inputTypes.entrySet()) { + ret.add(entry.getKey(), PythonType.valueOf(entry.getValue())); } return ret; @@ -557,39 +513,17 @@ public class PythonVariables implements java.io.Serializable { /** * Get the python variable state relative to the * input json array + * * @param jsonArray the input json array * @return the python variables based on the input json array */ - public static PythonVariables fromJSON(org.json.JSONArray jsonArray){ + public static PythonVariables fromJSON(org.json.JSONArray jsonArray) { PythonVariables pyvars = new PythonVariables(); for (int i = 0; i < jsonArray.length(); i++) { org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i); - String varName = (String)input.get("name"); - String varType = (String)input.get("type"); - if (varType.equals("BOOL")) { - pyvars.addBool(varName); - } - else if (varType.equals("INT")) { - pyvars.addInt(varName); - } - else if (varType.equals("FlOAT")){ - pyvars.addFloat(varName); - } - else if (varType.equals("STR")) { - pyvars.addStr(varName); - } - else if (varType.equals("LIST")) { - pyvars.addList(varName); - } - else if (varType.equals("FILE")){ - pyvars.addFile(varName); - } - else if (varType.equals("NDARRAY")) { - pyvars.addNDArray(varName); - } - else if(varType.equals("DICT")) { - pyvars.addDict(varName); - } + String varName = (String) input.get("name"); + String varType = (String) input.get("type"); + pyvars.maps.get(PythonType.TypeName.valueOf(varType)).put(varName, null); } return pyvars; diff --git a/datavec/datavec-python/src/main/resources/pythonexec/__init__.py b/datavec/datavec-python/src/main/resources/pythonexec/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/datavec/datavec-python/src/main/resources/pythonexec/clear_vars.py b/datavec/datavec-python/src/main/resources/pythonexec/clear_vars.py deleted file mode 100644 index 239ad694f..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/clear_vars.py +++ /dev/null @@ -1,5 +0,0 @@ -#See: https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script -import sys -this = sys.modules[__name__] -for n in dir(): - if n[0]!='_': delattr(this, n) \ No newline at end of file diff --git a/datavec/datavec-python/src/main/resources/pythonexec/input_code.py b/datavec/datavec-python/src/main/resources/pythonexec/input_code.py deleted file mode 100644 index 92ea40ac5..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/input_code.py +++ /dev/null @@ -1 +0,0 @@ -loc = {} diff --git a/datavec/datavec-python/src/main/resources/pythonexec/outputcode.py b/datavec/datavec-python/src/main/resources/pythonexec/outputcode.py deleted file mode 100644 index 7f70a0e93..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/outputcode.py +++ /dev/null @@ -1,20 +0,0 @@ - -def __is_numpy_array(x): - return str(type(x))== "" - -def maybe_serialize_ndarray_metadata(x): - return serialize_ndarray_metadata(x) if __is_numpy_array(x) else x - - -def serialize_ndarray_metadata(x): - return {"address": x.__array_interface__['data'][0], - "shape": x.shape, - "strides": x.strides, - "dtype": str(x.dtype), - "_is_numpy_array": True} if __is_numpy_array(x) else x - - -def is_json_ready(key, value): - return key is not 'f2' and not inspect.ismodule(value) \ - and not hasattr(value, '__call__') - diff --git a/datavec/datavec-python/src/main/resources/pythonexec/patch0.py b/datavec/datavec-python/src/main/resources/pythonexec/patch0.py deleted file mode 100644 index d2ed3d5e5..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/patch0.py +++ /dev/null @@ -1,202 +0,0 @@ -#patch - -"""Implementation of __array_function__ overrides from NEP-18.""" -import collections -import functools -import os - -from numpy.core._multiarray_umath import ( - add_docstring, implement_array_function, _get_implementing_args) -from numpy.compat._inspect import getargspec - - -ENABLE_ARRAY_FUNCTION = bool( - int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) - - -ARRAY_FUNCTION_ENABLED = ENABLE_ARRAY_FUNCTION # backward compat - - -_add_docstring = add_docstring - - -def add_docstring(*args): - try: - _add_docstring(*args) - except: - pass - - -add_docstring( - implement_array_function, - """ - Implement a function with checks for __array_function__ overrides. - - All arguments are required, and can only be passed by position. - - Arguments - --------- - implementation : function - Function that implements the operation on NumPy array without - overrides when called like ``implementation(*args, **kwargs)``. - public_api : function - Function exposed by NumPy's public API originally called like - ``public_api(*args, **kwargs)`` on which arguments are now being - checked. - relevant_args : iterable - Iterable of arguments to check for __array_function__ methods. - args : tuple - Arbitrary positional arguments originally passed into ``public_api``. - kwargs : dict - Arbitrary keyword arguments originally passed into ``public_api``. - - Returns - ------- - Result from calling ``implementation()`` or an ``__array_function__`` - method, as appropriate. - - Raises - ------ - TypeError : if no implementation is found. - """) - - -# exposed for testing purposes; used internally by implement_array_function -add_docstring( - _get_implementing_args, - """ - Collect arguments on which to call __array_function__. - - Parameters - ---------- - relevant_args : iterable of array-like - Iterable of possibly array-like arguments to check for - __array_function__ methods. - - Returns - ------- - Sequence of arguments with __array_function__ methods, in the order in - which they should be called. - """) - - -ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') - - -def verify_matching_signatures(implementation, dispatcher): - """Verify that a dispatcher function has the right signature.""" - implementation_spec = ArgSpec(*getargspec(implementation)) - dispatcher_spec = ArgSpec(*getargspec(dispatcher)) - - if (implementation_spec.args != dispatcher_spec.args or - implementation_spec.varargs != dispatcher_spec.varargs or - implementation_spec.keywords != dispatcher_spec.keywords or - (bool(implementation_spec.defaults) != - bool(dispatcher_spec.defaults)) or - (implementation_spec.defaults is not None and - len(implementation_spec.defaults) != - len(dispatcher_spec.defaults))): - raise RuntimeError('implementation and dispatcher for %s have ' - 'different function signatures' % implementation) - - if implementation_spec.defaults is not None: - if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults): - raise RuntimeError('dispatcher functions can only use None for ' - 'default argument values') - - -def set_module(module): - """Decorator for overriding __module__ on a function or class. - - Example usage:: - - @set_module('numpy') - def example(): - pass - - assert example.__module__ == 'numpy' - """ - def decorator(func): - if module is not None: - func.__module__ = module - return func - return decorator - - -def array_function_dispatch(dispatcher, module=None, verify=True, - docs_from_dispatcher=False): - """Decorator for adding dispatch with the __array_function__ protocol. - - See NEP-18 for example usage. - - Parameters - ---------- - dispatcher : callable - Function that when called like ``dispatcher(*args, **kwargs)`` with - arguments from the NumPy function call returns an iterable of - array-like arguments to check for ``__array_function__``. - module : str, optional - __module__ attribute to set on new function, e.g., ``module='numpy'``. - By default, module is copied from the decorated function. - verify : bool, optional - If True, verify the that the signature of the dispatcher and decorated - function signatures match exactly: all required and optional arguments - should appear in order with the same names, but the default values for - all optional arguments should be ``None``. Only disable verification - if the dispatcher's signature needs to deviate for some particular - reason, e.g., because the function has a signature like - ``func(*args, **kwargs)``. - docs_from_dispatcher : bool, optional - If True, copy docs from the dispatcher function onto the dispatched - function, rather than from the implementation. This is useful for - functions defined in C, which otherwise don't have docstrings. - - Returns - ------- - Function suitable for decorating the implementation of a NumPy function. - """ - - if not ENABLE_ARRAY_FUNCTION: - # __array_function__ requires an explicit opt-in for now - def decorator(implementation): - if module is not None: - implementation.__module__ = module - if docs_from_dispatcher: - add_docstring(implementation, dispatcher.__doc__) - return implementation - return decorator - - def decorator(implementation): - if verify: - verify_matching_signatures(implementation, dispatcher) - - if docs_from_dispatcher: - add_docstring(implementation, dispatcher.__doc__) - - @functools.wraps(implementation) - def public_api(*args, **kwargs): - relevant_args = dispatcher(*args, **kwargs) - return implement_array_function( - implementation, public_api, relevant_args, args, kwargs) - - if module is not None: - public_api.__module__ = module - - # TODO: remove this when we drop Python 2 support (functools.wraps) - # adds __wrapped__ automatically in later versions) - public_api.__wrapped__ = implementation - - return public_api - - return decorator - - -def array_function_from_dispatcher( - implementation, module=None, verify=True, docs_from_dispatcher=True): - """Like array_function_dispatcher, but with function arguments flipped.""" - - def decorator(dispatcher): - return array_function_dispatch( - dispatcher, module, verify=verify, - docs_from_dispatcher=docs_from_dispatcher)(implementation) - return decorator diff --git a/datavec/datavec-python/src/main/resources/pythonexec/patch1.py b/datavec/datavec-python/src/main/resources/pythonexec/patch1.py deleted file mode 100644 index 890852bbc..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/patch1.py +++ /dev/null @@ -1,172 +0,0 @@ -#patch 1 - -""" -======================== -Random Number Generation -======================== - -==================== ========================================================= -Utility functions -============================================================================== -random_sample Uniformly distributed floats over ``[0, 1)``. -random Alias for `random_sample`. -bytes Uniformly distributed random bytes. -random_integers Uniformly distributed integers in a given range. -permutation Randomly permute a sequence / generate a random sequence. -shuffle Randomly permute a sequence in place. -seed Seed the random number generator. -choice Random sample from 1-D array. - -==================== ========================================================= - -==================== ========================================================= -Compatibility functions -============================================================================== -rand Uniformly distributed values. -randn Normally distributed values. -ranf Uniformly distributed floating point numbers. -randint Uniformly distributed integers in a given range. -==================== ========================================================= - -==================== ========================================================= -Univariate distributions -============================================================================== -beta Beta distribution over ``[0, 1]``. -binomial Binomial distribution. -chisquare :math:`\\chi^2` distribution. -exponential Exponential distribution. -f F (Fisher-Snedecor) distribution. -gamma Gamma distribution. -geometric Geometric distribution. -gumbel Gumbel distribution. -hypergeometric Hypergeometric distribution. -laplace Laplace distribution. -logistic Logistic distribution. -lognormal Log-normal distribution. -logseries Logarithmic series distribution. -negative_binomial Negative binomial distribution. -noncentral_chisquare Non-central chi-square distribution. -noncentral_f Non-central F distribution. -normal Normal / Gaussian distribution. -pareto Pareto distribution. -poisson Poisson distribution. -power Power distribution. -rayleigh Rayleigh distribution. -triangular Triangular distribution. -uniform Uniform distribution. -vonmises Von Mises circular distribution. -wald Wald (inverse Gaussian) distribution. -weibull Weibull distribution. -zipf Zipf's distribution over ranked data. -==================== ========================================================= - -==================== ========================================================= -Multivariate distributions -============================================================================== -dirichlet Multivariate generalization of Beta distribution. -multinomial Multivariate generalization of the binomial distribution. -multivariate_normal Multivariate generalization of the normal distribution. -==================== ========================================================= - -==================== ========================================================= -Standard distributions -============================================================================== -standard_cauchy Standard Cauchy-Lorentz distribution. -standard_exponential Standard exponential distribution. -standard_gamma Standard Gamma distribution. -standard_normal Standard normal distribution. -standard_t Standard Student's t-distribution. -==================== ========================================================= - -==================== ========================================================= -Internal functions -============================================================================== -get_state Get tuple representing internal state of generator. -set_state Set state of generator. -==================== ========================================================= - -""" -from __future__ import division, absolute_import, print_function - -import warnings - -__all__ = [ - 'beta', - 'binomial', - 'bytes', - 'chisquare', - 'choice', - 'dirichlet', - 'exponential', - 'f', - 'gamma', - 'geometric', - 'get_state', - 'gumbel', - 'hypergeometric', - 'laplace', - 'logistic', - 'lognormal', - 'logseries', - 'multinomial', - 'multivariate_normal', - 'negative_binomial', - 'noncentral_chisquare', - 'noncentral_f', - 'normal', - 'pareto', - 'permutation', - 'poisson', - 'power', - 'rand', - 'randint', - 'randn', - 'random_integers', - 'random_sample', - 'rayleigh', - 'seed', - 'set_state', - 'shuffle', - 'standard_cauchy', - 'standard_exponential', - 'standard_gamma', - 'standard_normal', - 'standard_t', - 'triangular', - 'uniform', - 'vonmises', - 'wald', - 'weibull', - 'zipf' -] - -with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="numpy.ndarray size changed") - try: - from .mtrand import * - # Some aliases: - ranf = random = sample = random_sample - __all__.extend(['ranf', 'random', 'sample']) - except: - warnings.warn("numpy.random is not available when using multiple interpreters!") - - - -def __RandomState_ctor(): - """Return a RandomState instance. - - This function exists solely to assist (un)pickling. - - Note that the state of the RandomState returned here is irrelevant, as this function's - entire purpose is to return a newly allocated RandomState whose state pickle can set. - Consequently the RandomState returned by this function is a freshly allocated copy - with a seed=0. - - See https://github.com/numpy/numpy/issues/4763 for a detailed discussion - - """ - return RandomState(seed=0) - -from numpy._pytesttester import PytestTester -test = PytestTester(__name__) -del PytestTester diff --git a/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py b/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py index dbdceff0e..1509610c7 100644 --- a/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py +++ b/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py @@ -3,13 +3,13 @@ import traceback import json import inspect - +__python_exception__ = "" try: - pass sys.stdout.flush() sys.stderr.flush() except Exception as ex: + __python_exception__ = ex try: exc_info = sys.exc_info() finally: diff --git a/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py b/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py deleted file mode 100644 index ac6f5b1c1..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py +++ /dev/null @@ -1,50 +0,0 @@ -def __is_numpy_array(x): - return str(type(x))== "" - -def __maybe_serialize_ndarray_metadata(x): - return __serialize_ndarray_metadata(x) if __is_numpy_array(x) else x - - -def __serialize_ndarray_metadata(x): - return {"address": x.__array_interface__['data'][0], - "shape": x.shape, - "strides": x.strides, - "dtype": str(x.dtype), - "_is_numpy_array": True} if __is_numpy_array(x) else x - - -def __serialize_list(x): - import json - return json.dumps(__recursive_serialize_list(x)) - - -def __serialize_dict(x): - import json - return json.dumps(__recursive_serialize_dict(x)) - -def __recursive_serialize_list(x): - out = [] - for i in x: - if __is_numpy_array(i): - out.append(__serialize_ndarray_metadata(i)) - elif isinstance(i, (list, tuple)): - out.append(__recursive_serialize_list(i)) - elif isinstance(i, dict): - out.append(__recursive_serialize_dict(i)) - else: - out.append(i) - return out - -def __recursive_serialize_dict(x): - out = {} - for k in x: - v = x[k] - if __is_numpy_array(v): - out[k] = __serialize_ndarray_metadata(v) - elif isinstance(v, (list, tuple)): - out[k] = __recursive_serialize_list(v) - elif isinstance(v, dict): - out[k] = __recursive_serialize_dict(v) - else: - out[k] = v - return out \ No newline at end of file diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonContextManager.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonContextManager.java new file mode 100644 index 000000000..4abad139f --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonContextManager.java @@ -0,0 +1,87 @@ + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + + +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import javax.annotation.concurrent.NotThreadSafe; + +@NotThreadSafe +public class TestPythonContextManager { + + @Test + public void testInt() throws Exception{ + Python.setContext("context1"); + Python.exec("a = 1"); + Python.setContext("context2"); + Python.exec("a = 2"); + Python.setContext("context3"); + Python.exec("a = 3"); + + + Python.setContext("context1"); + Assert.assertEquals(1, PythonExecutioner.getVariable("a").toInt()); + + Python.setContext("context2"); + Assert.assertEquals(2, PythonExecutioner.getVariable("a").toInt()); + + Python.setContext("context3"); + Assert.assertEquals(3, PythonExecutioner.getVariable("a").toInt()); + + PythonContextManager.deleteNonMainContexts(); + } + + @Test + public void testNDArray() throws Exception{ + Python.setContext("context1"); + Python.exec("import numpy as np"); + Python.exec("a = np.zeros((3,2)) + 1"); + + Python.setContext("context2"); + Python.exec("import numpy as np"); + Python.exec("a = np.zeros((3,2)) + 2"); + + Python.setContext("context3"); + Python.exec("import numpy as np"); + Python.exec("a = np.zeros((3,2)) + 3"); + + Python.setContext("context1"); + Python.exec("a += 1"); + + Python.setContext("context2"); + Python.exec("a += 2"); + + Python.setContext("context3"); + Python.exec("a += 3"); + + INDArray arr = Nd4j.create(DataType.DOUBLE, 3, 2); + Python.setContext("context1"); + Assert.assertEquals(arr.add(2), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray()); + + Python.setContext("context2"); + Assert.assertEquals(arr.add(4), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray()); + + Python.setContext("context3"); + Assert.assertEquals(arr.add(6), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray()); + } + +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonDict.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonDict.java new file mode 100644 index 000000000..bc06269f1 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonDict.java @@ -0,0 +1,64 @@ + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + + +import lombok.var; +import org.json.JSONArray; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +@javax.annotation.concurrent.NotThreadSafe +public class TestPythonDict { + + @Test + public void testPythonDictFromMap() throws Exception{ + Map map = new HashMap<>(); + map.put("a", 1); + map.put("b", "a"); + map.put("1", Arrays.asList(1, 2, 3, "4", Arrays.asList("x", 2.3))); + Map innerMap = new HashMap<>(); + innerMap.put("k", 32); + map.put("inner", innerMap); + map.put("ndarray", Nd4j.linspace(1, 4, 4)); + innerMap.put("ndarray", Nd4j.linspace(5, 8, 4)); + PythonObject dict = new PythonObject(map); + assertEquals(map.size(), Python.len(dict).toInt()); + assertEquals("{'a': 1, '1': [1, 2, 3, '4', ['" + + "x', 2.3]], 'b': 'a', 'inner': {'k': 32," + + " 'ndarray': array([5., 6., 7., 8.], dty" + + "pe=float32)}, 'ndarray': array([1., 2., " + + "3., 4.], dtype=float32)}", + dict.toString()); + Map map2 = dict.toMap(); + PythonObject dict2 = new PythonObject(map2); + assertEquals(dict.toString(), dict2.toString()); + + + } + +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java deleted file mode 100644 index 435babf7c..000000000 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java +++ /dev/null @@ -1,75 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.python; - - -import org.junit.Assert; -import org.junit.Test; - -@javax.annotation.concurrent.NotThreadSafe -public class TestPythonExecutionSandbox { - - @Test - public void testInt(){ - PythonExecutioner.setInterpreter("interp1"); - PythonExecutioner.exec("a = 1"); - PythonExecutioner.setInterpreter("interp2"); - PythonExecutioner.exec("a = 2"); - PythonExecutioner.setInterpreter("interp3"); - PythonExecutioner.exec("a = 3"); - - - PythonExecutioner.setInterpreter("interp1"); - Assert.assertEquals(1, PythonExecutioner.evalInteger("a")); - - PythonExecutioner.setInterpreter("interp2"); - Assert.assertEquals(2, PythonExecutioner.evalInteger("a")); - - PythonExecutioner.setInterpreter("interp3"); - Assert.assertEquals(3, PythonExecutioner.evalInteger("a")); - } - - @Test - public void testNDArray(){ - PythonExecutioner.setInterpreter("main"); - PythonExecutioner.exec("import numpy as np"); - PythonExecutioner.exec("a = np.zeros(5)"); - - PythonExecutioner.setInterpreter("main"); - //PythonExecutioner.exec("import numpy as np"); - PythonExecutioner.exec("a = np.zeros(5)"); - - PythonExecutioner.setInterpreter("main"); - PythonExecutioner.exec("a += 2"); - - PythonExecutioner.setInterpreter("main"); - PythonExecutioner.exec("a += 3"); - - PythonExecutioner.setInterpreter("main"); - //PythonExecutioner.exec("import numpy as np"); - // PythonExecutioner.exec("a = np.zeros(5)"); - - PythonExecutioner.setInterpreter("main"); - Assert.assertEquals(25, PythonExecutioner.evalNdArray("a").getNd4jArray().sum().getDouble(), 1e-5); - } - - @Test - public void testNumpyRandom(){ - PythonExecutioner.setInterpreter("main"); - PythonExecutioner.exec("import numpy as np; print(np.random.randint(5))"); - } -} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java index c8e67febb..bb436e808 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java @@ -15,13 +15,17 @@ ******************************************************************************/ package org.datavec.python; + +import org.bytedeco.javacpp.BytePointer; import org.junit.Assert; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; @javax.annotation.concurrent.NotThreadSafe @@ -29,12 +33,12 @@ public class TestPythonExecutioner { @org.junit.Test - public void testPythonSysVersion() { - PythonExecutioner.exec("import sys; print(sys.version)"); + public void testPythonSysVersion() throws PythonException { + Python.exec("import sys; print(sys.version)"); } @Test - public void testStr() throws Exception{ + public void testStr() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -46,7 +50,7 @@ public class TestPythonExecutioner { String code = "z = x + ' ' + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); + Python.exec(code, pyInputs, pyOutputs); String z = pyOutputs.getStrValue("z"); @@ -56,7 +60,7 @@ public class TestPythonExecutioner { } @Test - public void testInt()throws Exception{ + public void testInt() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -68,7 +72,7 @@ public class TestPythonExecutioner { pyOutputs.addInt("z"); - PythonExecutioner.exec(code, pyInputs, pyOutputs); + Python.exec(code, pyInputs, pyOutputs); long z = pyOutputs.getIntValue("z"); @@ -77,7 +81,7 @@ public class TestPythonExecutioner { } @Test - public void testList() throws Exception{ + public void testList() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -92,30 +96,28 @@ public class TestPythonExecutioner { pyOutputs.addList("z"); - PythonExecutioner.exec(code, pyInputs, pyOutputs); + Python.exec(code, pyInputs, pyOutputs); - Object[] z = pyOutputs.getListValue("z"); + Object[] z = pyOutputs.getListValue("z").toArray(); Assert.assertEquals(z.length, x.length + y.length); for (int i = 0; i < x.length; i++) { - if(x[i] instanceof Number) { + if (x[i] instanceof Number) { Number xNum = (Number) x[i]; Number zNum = (Number) z[i]; Assert.assertEquals(xNum.intValue(), zNum.intValue()); - } - else { + } else { Assert.assertEquals(x[i], z[i]); } } - for (int i = 0; i < y.length; i++){ - if(y[i] instanceof Number) { + for (int i = 0; i < y.length; i++) { + if (y[i] instanceof Number) { Number yNum = (Number) y[i]; Number zNum = (Number) z[x.length + i]; Assert.assertEquals(yNum.intValue(), zNum.intValue()); - } - else { + } else { Assert.assertEquals(y[i], z[x.length + i]); } @@ -125,7 +127,7 @@ public class TestPythonExecutioner { } @Test - public void testNDArrayFloat()throws Exception{ + public void testNDArrayFloat() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -135,8 +137,8 @@ public class TestPythonExecutioner { String code = "z = x + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); - INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); + Python.exec(code, pyInputs, pyOutputs); + INDArray z = pyOutputs.getNDArrayValue("z"); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); @@ -144,12 +146,13 @@ public class TestPythonExecutioner { } @Test - public void testTensorflowCustomAnaconda() { - PythonExecutioner.exec("import tensorflow as tf"); + @Ignore + public void testTensorflowCustomAnaconda() throws PythonException { + Python.exec("import tensorflow as tf"); } @Test - public void testNDArrayDouble()throws Exception { + public void testNDArrayDouble() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -159,14 +162,14 @@ public class TestPythonExecutioner { String code = "z = x + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); - INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); + Python.exec(code, pyInputs, pyOutputs); + INDArray z = pyOutputs.getNDArrayValue("z"); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); } @Test - public void testNDArrayShort()throws Exception{ + public void testNDArrayShort() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -176,15 +179,15 @@ public class TestPythonExecutioner { String code = "z = x + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); - INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); + Python.exec(code, pyInputs, pyOutputs); + INDArray z = pyOutputs.getNDArrayValue("z"); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); } @Test - public void testNDArrayInt()throws Exception{ + public void testNDArrayInt() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -194,15 +197,15 @@ public class TestPythonExecutioner { String code = "z = x + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); - INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); + Python.exec(code, pyInputs, pyOutputs); + INDArray z = pyOutputs.getNDArrayValue("z"); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); } @Test - public void testNDArrayLong()throws Exception{ + public void testNDArrayLong() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -212,12 +215,91 @@ public class TestPythonExecutioner { String code = "z = x + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); - INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); + Python.exec(code, pyInputs, pyOutputs); + INDArray z = pyOutputs.getNDArrayValue("z"); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); + } + @Test + public void testByteBufferInput() throws Exception{ + //ByteBuffer buff = ByteBuffer.allocateDirect(3); + INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE); + buff.putScalar(0, 97); // a + buff.putScalar(1, 98); // b + buff.putScalar(2, 99); // c + + + PythonVariables pyInputs = new PythonVariables(); + pyInputs.addBytes("buff", new BytePointer(buff.data().pointer())); + + PythonVariables pyOutputs= new PythonVariables(); + pyOutputs.addStr("out"); + + String code = "out = buff.decode()"; + Python.exec(code, pyInputs, pyOutputs); + Assert.assertEquals("abc", pyOutputs.getStrValue("out")); + + } + + @Test + public void testByteBufferOutputNoCopy() throws Exception{ + INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE); + buff.putScalar(0, 97); // a + buff.putScalar(1, 98); // b + buff.putScalar(2, 99); // c + + + PythonVariables pyInputs = new PythonVariables(); + pyInputs.addBytes("buff", new BytePointer(buff.data().pointer())); + + PythonVariables pyOutputs = new PythonVariables(); + pyOutputs.addBytes("buff"); // same name as input, because inplace update + + String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97"; + Python.exec(code, pyInputs, pyOutputs); + Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString()); + } + + @Test + public void testByteBufferOutputWithCopy() throws Exception{ + INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE); + buff.putScalar(0, 97); // a + buff.putScalar(1, 98); // b + buff.putScalar(2, 99); // c + + + PythonVariables pyInputs = new PythonVariables(); + pyInputs.addBytes("buff", new BytePointer(buff.data().pointer())); + + PythonVariables pyOutputs = new PythonVariables(); + pyOutputs.addBytes("out"); + + String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97\nout=bytes(buff)"; + Python.exec(code, pyInputs, pyOutputs); + Assert.assertEquals("cba", pyOutputs.getBytesValue("out").getString()); + } + @Test + public void testBadCode() throws Exception{ + Python.setContext("badcode"); + PythonVariables pyInputs = new PythonVariables(); + PythonVariables pyOutputs = new PythonVariables(); + + pyInputs.addNDArray("x", Nd4j.zeros(DataType.LONG, 2, 3)); + pyInputs.addNDArray("y", Nd4j.ones(DataType.LONG, 2, 3)); + pyOutputs.addNDArray("z"); + + String code = "z = x + a"; + + try{ + Python.exec(code, pyInputs, pyOutputs); + fail("No exception thrown"); + } catch (PythonException pe ){ + Assert.assertEquals("NameError: name 'a' is not defined", pe.getMessage()); + } + + Python.setMainContext(); } } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonJob.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonJob.java new file mode 100644 index 000000000..bb79d2837 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonJob.java @@ -0,0 +1,326 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertEquals; + + +@javax.annotation.concurrent.NotThreadSafe +public class TestPythonJob { + + @Test + public void testPythonJobBasic() throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code = "c = a + b"; + PythonJob job = new PythonJob("job1", code, false); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + PythonVariables outputs = new PythonVariables(); + outputs.addInt("c"); + + job.exec(inputs, outputs); + + assertEquals(5L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = new PythonVariables(); + outputs.addFloat("c"); + + + job.exec(inputs, outputs); + + assertEquals(7.0, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = new PythonVariables(); + outputs.addNDArray("c"); + + + job.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c")); + } + + @Test + public void testPythonJobReturnAllVariables()throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code = "c = a + b"; + PythonJob job = new PythonJob("job1", code, false); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + + PythonVariables outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(5L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(7.0, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c")); + } + + @Test + public void testMultiplePythonJobsParallel()throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code1 = "c = a + b"; + PythonJob job1 = new PythonJob("job1", code1, false); + + String code2 = "c = a - b"; + PythonJob job2 = new PythonJob("job2", code2, false); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + PythonVariables outputs = new PythonVariables(); + outputs.addInt("c"); + + job1.exec(inputs, outputs); + + assertEquals(5L, (long)outputs.getIntValue("c")); + + job2.exec(inputs, outputs); + + assertEquals(-1L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = new PythonVariables(); + outputs.addFloat("c"); + + + job1.exec(inputs, outputs); + + assertEquals(7.0, outputs.getFloatValue("c"), 1e-5); + + job2.exec(inputs, outputs); + + assertEquals(-1L, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = new PythonVariables(); + outputs.addNDArray("c"); + + + job1.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c")); + + job2.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).sub(1), outputs.getNDArrayValue("c")); + } + @Test + public void testPythonJobSetupRun()throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + PythonJob job = new PythonJob("job1", code, true); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + PythonVariables outputs = new PythonVariables(); + outputs.addInt("c"); + + job.exec(inputs, outputs); + + assertEquals(10L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = new PythonVariables(); + outputs.addFloat("c"); + + + job.exec(inputs, outputs); + + assertEquals(12.0, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = new PythonVariables(); + outputs.addNDArray("c"); + + + job.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c")); + } + @Test + public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + PythonJob job = new PythonJob("job1", code, true); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + + PythonVariables outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(10L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(12.0, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c")); + } + + @Test + public void testMultiplePythonJobsSetupRunParallel()throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code1 = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + PythonJob job1 = new PythonJob("job1", code1, true); + + String code2 = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b - five\n"+ + " return {'c':c}\n\n"; + PythonJob job2 = new PythonJob("job2", code2, true); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + PythonVariables outputs = new PythonVariables(); + outputs.addInt("c"); + + job1.exec(inputs, outputs); + + assertEquals(10L, (long)outputs.getIntValue("c")); + + job2.exec(inputs, outputs); + + assertEquals(0L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = new PythonVariables(); + outputs.addFloat("c"); + + + job1.exec(inputs, outputs); + + assertEquals(12.0, outputs.getFloatValue("c"), 1e-5); + + job2.exec(inputs, outputs); + + assertEquals(2L, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = new PythonVariables(); + outputs.addNDArray("c"); + + + job1.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c")); + + job2.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).add(4), outputs.getNDArrayValue("c")); + } + +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonList.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonList.java new file mode 100644 index 000000000..7362a5e49 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonList.java @@ -0,0 +1,108 @@ + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + + +import lombok.var; +import org.json.JSONArray; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.*; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +@javax.annotation.concurrent.NotThreadSafe +public class TestPythonList { + + @Test + public void testPythonListFromIntArray() { + PythonObject pyList = new PythonObject(new Integer[]{1, 2, 3, 4, 5}); + pyList.attr("append").call(6); + pyList.attr("append").call(7); + pyList.attr("append").call(8); + assertEquals(8, Python.len(pyList).toInt()); + for (int i = 0; i < 8; i++) { + assertEquals(i + 1, pyList.get(i).toInt()); + } + + } + + @Test + public void testPythonListFromLongArray() { + PythonObject pyList = new PythonObject(new Long[]{1L, 2L, 3L, 4L, 5L}); + pyList.attr("append").call(6); + pyList.attr("append").call(7); + pyList.attr("append").call(8); + assertEquals(8, Python.len(pyList).toInt()); + for (int i = 0; i < 8; i++) { + assertEquals(i + 1, pyList.get(i).toInt()); + } + + } + + @Test + public void testPythonListFromDoubleArray() { + PythonObject pyList = new PythonObject(new Double[]{1., 2., 3., 4., 5.}); + pyList.attr("append").call(6); + pyList.attr("append").call(7); + pyList.attr("append").call(8); + assertEquals(8, Python.len(pyList).toInt()); + for (int i = 0; i < 8; i++) { + assertEquals(i + 1, pyList.get(i).toInt()); + assertEquals((double) i + 1, pyList.get(i).toDouble(), 1e-5); + } + + } + + @Test + public void testPythonListFromStringArray() { + PythonObject pyList = new PythonObject(new String[]{"abcd", "efg"}); + pyList.attr("append").call("hijk"); + pyList.attr("append").call("lmnop"); + assertEquals("abcdefghijklmnop", new PythonObject("").attr("join").call(pyList).toString()); + } + + @Test + public void testPythonListFromMixedArray()throws Exception { + Map map = new HashMap<>(); + map.put(1, "a"); + map.put("a", Arrays.asList("a", "b", "c")); + map.put("arr", Nd4j.linspace(1, 4, 4)); + Object[] objs = new Object[]{ + 1, 2, "a", 3f, 4L, 5.0, Arrays.asList(10, + 20, "b", 30f, 40L, 50.0, map + + ), map + }; + PythonObject pyList = new PythonObject(objs); + System.out.println(pyList.toString()); + String expectedStr = "[1, 2, 'a', 3.0, 4, 5.0, [10" + + ", 20, 'b', 30.0, 40, 50.0, {'arr': array([1.," + + " 2., 3., 4.], dtype=float32), 1: 'a', 'a': [" + + "'a', 'b', 'c']}], {'arr': array([1., 2., 3.," + + " 4.], dtype=float32), 1: 'a', 'a': ['a', 'b', 'c']}]"; + assertEquals(expectedStr, pyList.toString()); + List objs2 = pyList.toList(); + PythonObject pyList2 = new PythonObject(objs2); + assertEquals(pyList.toString(), pyList2.toString()); + } + +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonSetupAndRun.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonSetupAndRun.java deleted file mode 100644 index 42a22c07f..000000000 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonSetupAndRun.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.datavec.python; - -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -@javax.annotation.concurrent.NotThreadSafe -public class TestPythonSetupAndRun { - @Test - public void testPythonWithSetupAndRun() throws Exception{ - String code = "def setup():" + - "global counter;counter=0\n" + - "def run(step):" + - "global counter;" + - "counter+=step;" + - "return {\"counter\":counter}"; - PythonVariables pyInputs = new PythonVariables(); - pyInputs.addInt("step", 2); - PythonVariables pyOutputs = new PythonVariables(); - pyOutputs.addInt("counter"); - PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs); - assertEquals((long)pyOutputs.getIntValue("counter"), 2L); - pyInputs.addInt("step", 3); - PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs); - assertEquals((long)pyOutputs.getIntValue("counter"), 5L); - } -} \ No newline at end of file diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java index 0f14cb756..22f8ba230 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java @@ -22,11 +22,14 @@ package org.datavec.python; +import org.bytedeco.javacpp.BytePointer; import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.Collections; +import java.util.List; import static junit.framework.TestCase.assertNotNull; import static junit.framework.TestCase.assertNull; @@ -36,59 +39,50 @@ import static org.junit.Assert.assertTrue; public class TestPythonVariables { - - @Test - public void testImportNumpy(){ - Nd4j.scalar(1.0); - System.out.println(System.getProperty("org.bytedeco.openblas.load")); - PythonExecutioner.exec("import numpy as np"); - } - - - @Test - public void testDataAssociations() { + public void testDataAssociations() throws PythonException{ PythonVariables pythonVariables = new PythonVariables(); - PythonVariables.Type[] types = { - PythonVariables.Type.INT, - PythonVariables.Type.FLOAT, - PythonVariables.Type.STR, - PythonVariables.Type.BOOL, - PythonVariables.Type.DICT, - PythonVariables.Type.LIST, - PythonVariables.Type.LIST, - PythonVariables.Type.FILE, - PythonVariables.Type.NDARRAY + PythonType[] types = { + PythonType.INT, + PythonType.FLOAT, + PythonType.STR, + PythonType.BOOL, + PythonType.DICT, + PythonType.LIST, + PythonType.LIST, + PythonType.NDARRAY, + PythonType.BYTES }; - NumpyArray npArr = new NumpyArray(Nd4j.scalar(1.0)); + INDArray arr = Nd4j.scalar(1.0); + BytePointer bp = new BytePointer(arr.data().pointer()); Object[] values = { 1L,1.0,"1",true, Collections.singletonMap("1",1), - new Object[]{1}, Arrays.asList(1),"type", npArr + new Object[]{1}, Arrays.asList(1), arr, bp }; Object[] expectedValues = { 1L,1.0,"1",true, Collections.singletonMap("1",1), - new Object[]{1}, new Object[]{1},"type", npArr + Arrays.asList(1), Arrays.asList(1), arr, bp }; for(int i = 0; i < types.length; i++) { - testInsertGet(pythonVariables,types[i].name() + i,values[i],types[i],expectedValues[i]); + testInsertGet(pythonVariables,types[i].getName().name() + i,values[i],types[i],expectedValues[i]); } assertEquals(types.length,pythonVariables.getVariables().length); } - private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonVariables.Type type,Object expectedValue) { + private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonType type,Object expectedValue) throws PythonException{ pythonVariables.add(key, type); assertNull(pythonVariables.getValue(key)); pythonVariables.setValue(key,value); assertNotNull(pythonVariables.getValue(key)); Object actualValue = pythonVariables.getValue(key); if (expectedValue instanceof Object[]){ - assertTrue(actualValue instanceof Object[]); - Object[] actualArr = (Object[])actualValue; + assertTrue(actualValue instanceof List); + Object[] actualArr = ((List)actualValue).toArray(); Object[] expectedArr = (Object[])expectedValue; assertArrayEquals(expectedArr, actualArr); } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestSerde.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestSerde.java index cc12174f8..71d37ca91 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestSerde.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestSerde.java @@ -44,7 +44,7 @@ public class TestSerde { String yaml = y.serialize(t); String json = j.serialize(t); - Transform t2 = y.deserializeTransform(json); + Transform t2 = y.deserializeTransform(yaml); Transform t3 = j.deserializeTransform(json); assertEquals(t, t2); assertEquals(t, t3);