Python executioner 2.0 (#134)

* wrapper

* builtins

* context mgr

* direct place

* ret all var

* call fix

* fix ndarray serde

* jobs

* try-with gil management

* cleanup

* exec tests passing

* list tests

* transforms test passing

* all pass

* headers

* dict fixes+test

* python path

* bool isinstance

* job tests

* nits

* transform fix+test

* transform tests

* leak fixes

* more mem leak fixes

* more fixes

* nits for adam

* PythonJob lombok builder

* checked exceptions

* more nits

* small leak fix

* more nits

* pythonexceptions

* fix jvm crash when bad python code

* Exception->PythonException

* Add support for boolean types in arrow records and ability to cast from float, double to int for TypeConversion (#178)

* nits for alex

* update tests

* fix test

* all pass

* refacc

* rem old code

* dtypes

* bytes working+exception pass through+cleanup (#209)

* more bytes tests

* header

* rem dummy test

* rem bad import

* alex nits + refacc

* Small error fixes (wrong type in msg) + minor formatting

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* use actual python type names (dictionary->dict, boolean->bool)

Co-authored-by: Shams Ul Azeem <shamsazeem20@gmail.com>
Co-authored-by: Alex Black <blacka101@gmail.com>
master
Fariz Rahman 2020-02-04 13:23:59 +04:00 committed by GitHub
parent 7ea07de76b
commit f6b3032def
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 2907 additions and 2076 deletions

View File

@ -45,7 +45,7 @@ public class TypeConversion {
} }
public int convertInt(String o) { public int convertInt(String o) {
return Integer.parseInt(o); return (int) Double.parseDouble(o);
} }
public double convertDouble(Writable writable) { public double convertDouble(Writable writable) {

View File

@ -725,7 +725,6 @@ public class ArrowConverter {
case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break; case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break; case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i)); default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i));
} }
} }
@ -802,8 +801,13 @@ public class ArrowConverter {
//for proper offsets //for proper offsets
ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get()); ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get());
nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity()); nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity());
case Boolean:
BitVector bitVector = (BitVector) fieldVector;
if(value instanceof Boolean)
bitVector.set(row, (boolean) value ? 1 : 0);
else
bitVector.set(row, ((BooleanWritable) value).get() ? 1 : 0);
break; break;
} }
}catch(Exception e) { }catch(Exception e) {
log.warn("Unable to set value at row " + row); log.warn("Unable to set value at row " + row);

View File

@ -315,7 +315,7 @@ public class TestPythonTransformProcess {
} }
@Test @Test
public void testNumpyTransform() throws Exception { public void testNumpyTransform() {
PythonTransform pythonTransform = PythonTransform.builder() PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'") .code("a += 2; b = 'hello world'")
.returnAllInputs(true) .returnAllInputs(true)
@ -334,7 +334,42 @@ public class TestPythonTransformProcess {
assertFalse(execute.isEmpty()); assertFalse(execute.isEmpty());
assertNotNull(execute.get(0)); assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0)); assertNotNull(execute.get(0).get(0));
assertEquals("hello world",execute.get(0).get(0).toString()); assertNotNull(execute.get(0).get(1));
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testWithSetupRun() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n")
.returnAllInputs(true)
.setupAndRun(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.addColumnNDArray("b", new long[]{1, 1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
} }
} }

View File

@ -29,6 +29,7 @@ import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import static org.nd4j.linalg.api.buffer.DataType.FLOAT;
/** /**
@ -46,55 +47,45 @@ public class NumpyArray {
private long[] strides; private long[] strides;
private DataType dtype; private DataType dtype;
private INDArray nd4jArray; private INDArray nd4jArray;
static { static {
//initialize //initialize
Nd4j.scalar(1.0); Nd4j.scalar(1.0);
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
} }
@Builder @Builder
public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) { public NumpyArray(long address, long[] shape, long strides[], DataType dtype, boolean copy) {
this.address = address; this.address = address;
this.shape = shape; this.shape = shape;
this.strides = strides; this.strides = strides;
this.dtype = dtype; this.dtype = dtype;
setND4JArray(); setND4JArray();
if (copy){ if (copy) {
nd4jArray = nd4jArray.dup(); nd4jArray = nd4jArray.dup();
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
this.address = nd4jArray.data().address(); this.address = nd4jArray.data().address();
} }
} }
public NumpyArray copy(){
public NumpyArray copy() {
return new NumpyArray(nd4jArray.dup()); return new NumpyArray(nd4jArray.dup());
} }
public NumpyArray(long address, long[] shape, long strides[]){ public NumpyArray(long address, long[] shape, long strides[]) {
this(address, shape, strides, false,DataType.FLOAT); this(address, shape, strides, FLOAT, false);
} }
public NumpyArray(long address, long[] shape, long strides[], DataType dtype){ public NumpyArray(long address, long[] shape, long strides[], DataType dtype) {
this(address, shape, strides, dtype, false); this(address, shape, strides, dtype, false);
} }
public NumpyArray(long address, long[] shape, long strides[], DataType dtype, boolean copy){
this.address = address;
this.shape = shape;
this.strides = strides;
this.dtype = dtype;
setND4JArray();
if (copy){
nd4jArray = nd4jArray.dup();
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
this.address = nd4jArray.data().address();
}
}
private void setND4JArray() { private void setND4JArray() {
long size = 1; long size = 1;
for(long d: shape) { for (long d : shape) {
size *= d; size *= d;
} }
Pointer ptr = nativeOps.pointerForAddress(address); Pointer ptr = nativeOps.pointerForAddress(address);
@ -107,11 +98,11 @@ public class NumpyArray {
nd4jStrides[i] = strides[i] / elemSize; nd4jStrides[i] = strides[i] / elemSize;
} }
nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype); nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
} }
public NumpyArray(INDArray nd4jArray){ public NumpyArray(INDArray nd4jArray) {
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
DataBuffer buff = nd4jArray.data(); DataBuffer buff = nd4jArray.data();
address = buff.pointer().address(); address = buff.pointer().address();
@ -119,7 +110,7 @@ public class NumpyArray {
long[] nd4jStrides = nd4jArray.stride(); long[] nd4jStrides = nd4jArray.stride();
strides = new long[nd4jStrides.length]; strides = new long[nd4jStrides.length];
int elemSize = buff.getElementSize(); int elemSize = buff.getElementSize();
for(int i=0; i<strides.length; i++){ for (int i = 0; i < strides.length; i++) {
strides[i] = nd4jStrides[i] * elemSize; strides[i] = nd4jStrides[i] * elemSize;
} }
dtype = nd4jArray.dataType(); dtype = nd4jArray.dataType();

View File

@ -0,0 +1,265 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import org.bytedeco.cpython.PyObject;
import static org.bytedeco.cpython.global.python.*;
/**
* Swift like python wrapper for Java
*
* @author Fariz Rahman
*/
public class Python {
/**
* Imports a python module, similar to python import statement.
* @param moduleName name of the module to be imported
* @return reference to the module object
* @throws PythonException
*/
public static PythonObject importModule(String moduleName) throws PythonException{
PythonObject module = new PythonObject(PyImport_ImportModule(moduleName));
if (module.isNone()) {
throw new PythonException("Error importing module: " + moduleName);
}
return module;
}
public static PythonObject attr(String attrName) {
return builtins().attr(attrName);
}
public static PythonObject len(PythonObject pythonObject) {
return attr("len").call(pythonObject);
}
public static PythonObject str(PythonObject pythonObject) {
return attr("str").call(pythonObject);
}
public static PythonObject str() {
return attr("str").call();
}
public static PythonObject strType() {
return attr("str");
}
public static PythonObject float_(PythonObject pythonObject) {
return attr("float").call(pythonObject);
}
public static PythonObject float_() {
return attr("float").call();
}
public static PythonObject floatType() {
return attr("float");
}
public static PythonObject bool(PythonObject pythonObject) {
return attr("bool").call(pythonObject);
}
public static PythonObject bool() {
return attr("bool").call();
}
public static PythonObject boolType() {
return attr("bool");
}
public static PythonObject int_(PythonObject pythonObject) {
return attr("int").call(pythonObject);
}
public static PythonObject int_() {
return attr("int").call();
}
public static PythonObject intType() {
return attr("int");
}
public static PythonObject list(PythonObject pythonObject) {
return attr("list").call(pythonObject);
}
public static PythonObject list() {
return attr("list").call();
}
public static PythonObject listType() {
return attr("list");
}
public static PythonObject dict(PythonObject pythonObject) {
return attr("dict").call(pythonObject);
}
public static PythonObject dict() {
return attr("dict").call();
}
public static PythonObject dictType() {
return attr("dict");
}
public static PythonObject set(PythonObject pythonObject) {
return attr("set").call(pythonObject);
}
public static PythonObject set() {
return attr("set").call();
}
public static PythonObject bytearray(PythonObject pythonObject) {
return attr("bytearray").call(pythonObject);
}
public static PythonObject bytearray() {
return attr("bytearray").call();
}
public static PythonObject bytearrayType() {
return attr("bytearray");
}
public static PythonObject bytes(PythonObject pythonObject) {
return attr("bytes").call(pythonObject);
}
public static PythonObject bytes() {
return attr("bytes").call();
}
public static PythonObject bytesType() {
return attr("bytes");
}
public static PythonObject tuple(PythonObject pythonObject) {
return attr("tuple").call(pythonObject);
}
public static PythonObject tuple() {
return attr("tuple").call();
}
public static PythonObject Exception(PythonObject pythonObject) {
return attr("Exception").call(pythonObject);
}
public static PythonObject Exception() {
return attr("Exception").call();
}
public static PythonObject ExceptionType() {
return attr("Exception");
}
public static PythonObject tupleType() {
return attr("tuple");
}
public static PythonObject globals() {
return new PythonObject(PyModule_GetDict(PyImport_ImportModule("__main__")));
}
public static PythonObject type(PythonObject obj) {
return attr("type").call(obj);
}
public static boolean isinstance(PythonObject obj, PythonObject... type) {
return PyObject_IsInstance(obj.getNativePythonObject(),
PyList_AsTuple(new PythonObject(type).getNativePythonObject())) != 0;
}
public static PythonObject eval(String code) {
PyObject compiledCode = Py_CompileString(code, "", Py_eval_input);
PyObject globals = globals().getNativePythonObject();
PyObject locals = Python.dict().getNativePythonObject();
return new PythonObject(PyEval_EvalCode(compiledCode, globals, locals));
}
public static PythonObject builtins(){
try{
return importModule("builtins");
}catch (PythonException pe){
throw new IllegalStateException("Unable to import builtins: " + pe); // this should never happen
}
}
public static PythonObject None() {
return dict().attr("get").call(0);
}
public static PythonObject True() {
return boolType().call(1);
}
public static PythonObject False() {
return boolType().call(0);
}
public static boolean callable(PythonObject pythonObject) {
return PyCallable_Check(pythonObject.getNativePythonObject()) == 1;
}
public static void setContext(String context) throws PythonException{
PythonContextManager.setContext(context);
}
public static String getCurrentContext(){
return PythonContextManager.getCurrentContext();
}
public static void deleteContext(String context) throws PythonException{
PythonContextManager.deleteContext(context);
}
public static void deleteNonMainContexts(){
PythonContextManager.deleteNonMainContexts();
}
public static void setMainContext(){PythonContextManager.setMainContext();}
public static void exec(String code)throws PythonException{
PythonExecutioner.exec(code);
}
public static void exec(String code, PythonVariables inputs) throws PythonException{
PythonExecutioner.exec(code, inputs);
}
public static void exec(String code, PythonVariables inputs, PythonVariables outputs) throws PythonException{
PythonExecutioner.exec(code, inputs, outputs);
}
public static PythonGIL lock(){
return PythonGIL.lock();
}
}

View File

@ -16,14 +16,14 @@
package org.datavec.python; package org.datavec.python;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.condition.Condition;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import java.util.List; import java.util.List;
import static org.datavec.python.PythonUtils.schemaToPythonVariables; import static org.datavec.python.PythonUtils.schemaToPythonVariables;
import static org.nd4j.base.Preconditions.checkNotNull;
import static org.nd4j.base.Preconditions.checkState;
/** /**
* Lets a condition be defined as a python method f that takes no arguments * Lets a condition be defined as a python method f that takes no arguments
@ -41,18 +41,16 @@ public class PythonCondition implements Condition {
public PythonCondition(String pythonCode) { public PythonCondition(String pythonCode) {
org.nd4j.base.Preconditions.checkNotNull("Python code must not be null!",pythonCode); checkNotNull("Python code must not be null!", pythonCode);
org.nd4j.base.Preconditions.checkState(pythonCode.length() >= 1,"Python code must not be empty!"); checkState(!pythonCode.isEmpty(), "Python code must not be empty!");
code = pythonCode; code = pythonCode;
} }
@Override @Override
public void setInputSchema(Schema inputSchema) { public void setInputSchema(Schema inputSchema) {
this.inputSchema = inputSchema; this.inputSchema = inputSchema;
try{ try {
pyInputs = schemaToPythonVariables(inputSchema); pyInputs = schemaToPythonVariables(inputSchema);
PythonVariables pyOuts = new PythonVariables(); PythonVariables pyOuts = new PythonVariables();
pyOuts.addInt("out"); pyOuts.addInt("out");
@ -62,17 +60,15 @@ public class PythonCondition implements Condition {
.outputs(pyOuts) .outputs(pyOuts)
.build(); .build();
} } catch (Exception e) {
catch (Exception e){
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
@Override @Override
public Schema getInputSchema(){ public Schema getInputSchema() {
return inputSchema; return inputSchema;
} }
@ -84,40 +80,39 @@ public class PythonCondition implements Condition {
} }
@Override @Override
public String outputColumnName(){ public String outputColumnName() {
return outputColumnNames()[0]; return outputColumnNames()[0];
} }
@Override @Override
public String[] columnNames(){ public String[] columnNames() {
return outputColumnNames(); return outputColumnNames();
} }
@Override @Override
public String columnName(){ public String columnName() {
return outputColumnName(); return outputColumnName();
} }
@Override @Override
public Schema transform(Schema inputSchema){ public Schema transform(Schema inputSchema) {
return inputSchema; return inputSchema;
} }
@Override @Override
public boolean condition(List<Writable> list) { public boolean condition(List<Writable> list) {
PythonVariables inputs = getPyInputsFromWritables(list); PythonVariables inputs = getPyInputsFromWritables(list);
try{ try {
PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs()); pythonTransform.getPythonJob().exec(inputs, pythonTransform.getOutputs());
boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0; boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0;
return ret; return ret;
} } catch (Exception e) {
catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
@Override @Override
public boolean condition(Object input){ public boolean condition(Object input) {
return condition(input); return condition(input);
} }
@ -135,28 +130,27 @@ public class PythonCondition implements Condition {
private PythonVariables getPyInputsFromWritables(List<Writable> writables) { private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
PythonVariables ret = new PythonVariables(); PythonVariables ret = new PythonVariables();
for (int i = 0; i < inputSchema.numColumns(); i++){ for (int i = 0; i < inputSchema.numColumns(); i++) {
String name = inputSchema.getName(i); String name = inputSchema.getName(i);
Writable w = writables.get(i); Writable w = writables.get(i);
PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i)); PythonType pyType = pyInputs.getType(inputSchema.getName(i));
switch (pyType){ switch (pyType.getName()) {
case INT: case INT:
if (w instanceof LongWritable) { if (w instanceof LongWritable) {
ret.addInt(name, ((LongWritable)w).get()); ret.addInt(name, ((LongWritable) w).get());
} } else {
else { ret.addInt(name, ((IntWritable) w).get());
ret.addInt(name, ((IntWritable)w).get());
} }
break; break;
case FLOAT: case FLOAT:
ret.addFloat(name, ((DoubleWritable)w).get()); ret.addFloat(name, ((DoubleWritable) w).get());
break; break;
case STR: case STR:
ret.addStr(name, w.toString()); ret.addStr(name, w.toString());
break; break;
case NDARRAY: case NDARRAY:
ret.addNDArray(name,((NDArrayWritable)w).get()); ret.addNDArray(name, ((NDArrayWritable) w).get());
break; break;
} }
} }

View File

@ -0,0 +1,188 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* Emulates multiples interpreters in a single interpreter.
* This works by simply obfuscating/de-obfuscating variable names
* such that only the required subset of the global namespace is "visible"
* at any given time.
* By default, there exists a "main" context emulating the default interpreter
* and cannot be deleted.
* @author Fariz Rahman
*/
public class PythonContextManager {
private static Set<String> contexts = new HashSet<>();
private static AtomicBoolean init = new AtomicBoolean(false);
private static String currentContext;
private static final String MAIN_CONTEXT = "main";
static {
init();
}
private static void init() {
if (init.get()) return;
new PythonExecutioner();
init.set(true);
currentContext = MAIN_CONTEXT;
contexts.add(currentContext);
}
public static void addContext(String contextName) throws PythonException {
if (!validateContextName(contextName)) {
throw new PythonException("Invalid context name: " + contextName);
}
contexts.add(contextName);
}
public static boolean hasContext(String contextName) {
return contexts.contains(contextName);
}
public static boolean validateContextName(String s) {
if (s.length() == 0) return false;
if (!Character.isJavaIdentifierStart(s.charAt(0))) return false;
for (int i = 1; i < s.length(); i++)
if (!Character.isJavaIdentifierPart(s.charAt(i)))
return false;
return true;
}
private static String getContextPrefix(String contextName) {
return "__collapsed__" + contextName + "__";
}
private static String getCollapsedVarNameForContext(String varName, String contextName) {
return getContextPrefix(contextName) + varName;
}
private static String expandCollapsedVarName(String varName, String contextName) {
String prefix = "__collapsed__" + contextName + "__";
return varName.substring(prefix.length());
}
private static void collapseContext(String contextName) {
PythonObject globals = Python.globals();
PythonObject keysList = Python.list(globals.attr("keys").call());
int numKeys = Python.len(keysList).toInt();
for (int i = 0; i < numKeys; i++) {
PythonObject key = keysList.get(i);
String keyStr = key.toString();
if (!((keyStr.startsWith("__") && keyStr.endsWith("__")) || keyStr.startsWith("__collapsed_"))) {
String collapsedKey = getCollapsedVarNameForContext(keyStr, contextName);
PythonObject val = globals.attr("pop").call(key);
globals.set(new PythonObject(collapsedKey), val);
}
}
}
private static void expandContext(String contextName) {
String prefix = getContextPrefix(contextName);
PythonObject globals = Python.globals();
PythonObject keysList = Python.list(globals.attr("keys").call());
int numKeys = Python.len(keysList).toInt();
for (int i = 0; i < numKeys; i++) {
PythonObject key = keysList.get(i);
String keyStr = key.toString();
if (keyStr.startsWith(prefix)) {
String expandedKey = expandCollapsedVarName(keyStr, contextName);
PythonObject val = globals.attr("pop").call(key);
globals.set(new PythonObject(expandedKey), val);
}
}
}
public static void setContext(String contextName) throws PythonException{
if (contextName.equals(currentContext)) {
return;
}
if (!hasContext(contextName)) {
addContext(contextName);
}
collapseContext(currentContext);
expandContext(contextName);
currentContext = contextName;
}
public static void setMainContext() {
try{
setContext(MAIN_CONTEXT);
}
catch (PythonException pe){
throw new RuntimeException(pe);
}
}
public static String getCurrentContext() {
return currentContext;
}
public static void deleteContext(String contextName) throws PythonException {
if (contextName.equals(MAIN_CONTEXT)) {
throw new PythonException("Can not delete main context!");
}
if (contextName.equals(currentContext)) {
throw new PythonException("Can not delete current context!");
}
String prefix = getContextPrefix(contextName);
PythonObject globals = Python.globals();
PythonObject keysList = Python.list(globals.attr("keys").call());
int numKeys = Python.len(keysList).toInt();
for (int i = 0; i < numKeys; i++) {
PythonObject key = keysList.get(i);
String keyStr = key.toString();
if (keyStr.startsWith(prefix)) {
globals.attr("__delitem__").call(key);
}
}
contexts.remove(contextName);
}
public static void deleteNonMainContexts() {
try{
setContext(MAIN_CONTEXT); // will never fail
for (String c : contexts.toArray(new String[0])) {
if (!c.equals(MAIN_CONTEXT)) {
deleteContext(c); // will never fail
}
}
}catch(Exception e){
throw new RuntimeException(e);
}
}
public String[] getContexts() {
return contexts.toArray(new String[0]);
}
}

View File

@ -0,0 +1,44 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
/**
* Thrown when an exception occurs in python land
*/
public class PythonException extends Exception {
public PythonException(String message){
super(message);
}
private static String getExceptionString(PythonObject exception){
if (Python.isinstance(exception, Python.ExceptionType())){
String exceptionClass = Python.type(exception).attr("__name__").toString();
String message = exception.toString();
return exceptionClass + ": " + message;
}
return exception.toString();
}
public PythonException(PythonObject exception){
this(getExceptionString(exception));
}
public PythonException(String message, Throwable cause){
super(message, cause);
}
public PythonException(Throwable cause){
super(cause);
}
}

View File

@ -0,0 +1,68 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.cpython.PyThreadState;
import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.cpython.global.python.PyEval_RestoreThread;
import static org.bytedeco.cpython.global.python.PyEval_SaveThread;
@Slf4j
public class PythonGIL implements AutoCloseable {
private static PyThreadState mainThreadState;
static {
log.debug("CPython: PyThreadState_Get()");
mainThreadState = PyThreadState_Get();
}
private PythonGIL() {
acquire();
}
@Override
public void close() {
release();
}
public static PythonGIL lock() {
return new PythonGIL();
}
private static synchronized void acquire() {
log.debug("acquireGIL()");
log.debug("CPython: PyEval_SaveThread()");
mainThreadState = PyEval_SaveThread();
log.debug("CPython: PyThreadState_New()");
PyThreadState ts = PyThreadState_New(mainThreadState.interp());
log.debug("CPython: PyEval_RestoreThread()");
PyEval_RestoreThread(ts);
log.debug("CPython: PyThreadState_Swap()");
PyThreadState_Swap(ts);
}
private static synchronized void release() {
log.debug("CPython: PyEval_SaveThread()");
PyEval_SaveThread();
log.debug("CPython: PyEval_RestoreThread()");
PyEval_RestoreThread(mainThreadState);
}
}

View File

@ -0,0 +1,171 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import javax.annotation.Nonnull;
import java.util.HashMap;
import java.util.Map;
@Data
@NoArgsConstructor
/**
* PythonJob is the right abstraction for executing multiple python scripts
* in a multi thread stateful environment. The setup-and-run mode allows your
* "setup" code (imports, model loading etc) to be executed only once.
*/
public class PythonJob {
private String code;
private String name;
private String context;
private boolean setupRunMode;
private PythonObject runF;
static {
new PythonExecutioner();
}
@Builder
/**
* @param name Name for the python job.
* @param code Python code.
* @param setupRunMode If true, the python code is expected to have two methods: setup(), which takes no arguments,
* and run() which takes some or no arguments. setup() method is executed once,
* and the run() method is called with the inputs(if any) per transaction, and is expected to return a dictionary
* mapping from output variable names (str) to output values.
* If false, the full script is run on each transaction and the output variables are obtained from the global namespace
* after execution.
*/
public PythonJob(@Nonnull String name, @Nonnull String code, boolean setupRunMode) throws Exception {
this.name = name;
this.code = code;
this.setupRunMode = setupRunMode;
context = "__job_" + name;
if (PythonContextManager.hasContext(context)) {
throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!");
}
if (setupRunMode) setup();
}
/**
* Clears all variables in current context and calls setup()
*/
public void clearState() throws Exception {
String context = this.context;
PythonContextManager.setContext("main");
PythonContextManager.deleteContext(context);
this.context = context;
setup();
}
public void setup() throws Exception {
try (PythonGIL gil = PythonGIL.lock()) {
PythonContextManager.setContext(context);
PythonObject runF = PythonExecutioner.getVariable("run");
if (runF.isNone() || !Python.callable(runF)) {
PythonExecutioner.exec(code);
runF = PythonExecutioner.getVariable("run");
}
if (runF.isNone() || !Python.callable(runF)) {
throw new PythonException("run() method not found! " +
"If a PythonJob is created with 'setup and run' " +
"mode enabled, the associated python code is " +
"expected to contain a run() method " +
"(with or without arguments).");
}
this.runF = runF;
PythonObject setupF = PythonExecutioner.getVariable("setup");
if (!setupF.isNone()) {
setupF.call();
}
}
}
public void exec(PythonVariables inputs, PythonVariables outputs) throws Exception {
try (PythonGIL gil = PythonGIL.lock()) {
PythonContextManager.setContext(context);
if (!setupRunMode) {
PythonExecutioner.exec(code, inputs, outputs);
return;
}
PythonExecutioner.setVariables(inputs);
PythonObject inspect = Python.importModule("inspect");
PythonObject getfullargspec = inspect.attr("getfullargspec");
PythonObject argspec = getfullargspec.call(runF);
PythonObject argsList = argspec.attr("args");
PythonObject runargs = Python.dict();
int argsCount = Python.len(argsList).toInt();
for (int i = 0; i < argsCount; i++) {
PythonObject arg = argsList.get(i);
PythonObject val = Python.globals().get(arg);
if (val.isNone()) {
throw new PythonException("Input value not received for run() argument: " + arg.toString());
}
runargs.set(arg, val);
}
PythonObject outDict = runF.callWithKwargs(runargs);
Python.globals().attr("update").call(outDict);
PythonExecutioner.getVariables(outputs);
inspect.del();
getfullargspec.del();
argspec.del();
runargs.del();
}
}
public PythonVariables execAndReturnAllVariables(PythonVariables inputs) throws Exception {
try (PythonGIL gil = PythonGIL.lock()) {
PythonContextManager.setContext(context);
if (!setupRunMode) {
return PythonExecutioner.execAndReturnAllVariables(code, inputs);
}
PythonExecutioner.setVariables(inputs);
PythonObject inspect = Python.importModule("inspect");
PythonObject getfullargspec = inspect.attr("getfullargspec");
PythonObject argspec = getfullargspec.call(runF);
PythonObject argsList = argspec.attr("args");
PythonObject runargs = Python.dict();
int argsCount = Python.len(argsList).toInt();
for (int i = 0; i < argsCount; i++) {
PythonObject arg = argsList.get(i);
PythonObject val = Python.globals().get(arg);
if (val.isNone()) {
throw new PythonException("Input value not received for run() argument: " + arg.toString());
}
runargs.set(arg, val);
}
PythonObject outDict = runF.callWithKwargs(runargs);
Python.globals().attr("update").call(outDict);
inspect.del();
getfullargspec.del();
argspec.del();
runargs.del();
return PythonExecutioner.getAllVariables();
}
}
}

View File

@ -0,0 +1,554 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import org.bytedeco.cpython.PyObject;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.json.JSONArray;
import org.json.JSONObject;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.*;
import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.cpython.global.python.PyObject_SetItem;
/**
* Swift like python wrapper for J
*
* @author Fariz Rahman
*/
public class PythonObject {
private PyObject nativePythonObject;
static {
new PythonExecutioner();
}
private static Map<String, PythonObject> _getNDArraySerializer() {
Map<String, PythonObject> ndarraySerializer = new HashMap<>();
PythonObject lambda = Python.eval(
"lambda x: " +
"{'address':" +
"x.__array_interface__['data'][0]," +
"'shape':x.shape,'strides':x.strides," +
"'dtype': str(x.dtype),'_is_numpy_array': True}" +
" if str(type(x))== \"<class 'numpy.ndarray'>\" else x");
ndarraySerializer.put("default",
lambda);
return ndarraySerializer;
}
public PythonObject(PyObject pyObject) {
nativePythonObject = pyObject;
}
public PythonObject(INDArray npArray) {
this(new NumpyArray(npArray));
}
public PythonObject(BytePointer bp){
nativePythonObject = PyByteArray_FromStringAndSize(bp, bp.capacity());
}
public PythonObject(NumpyArray npArray) {
PyObject ctypes = PyImport_ImportModule("ctypes");
PyObject np = PyImport_ImportModule("numpy");
PyObject ctype;
switch (npArray.getDtype()) {
case DOUBLE:
ctype = PyObject_GetAttrString(ctypes, "c_double");
break;
case FLOAT:
ctype = PyObject_GetAttrString(ctypes, "c_float");
break;
case LONG:
ctype = PyObject_GetAttrString(ctypes, "c_int64");
break;
case INT:
ctype = PyObject_GetAttrString(ctypes, "c_int32");
break;
case SHORT:
ctype = PyObject_GetAttrString(ctypes, "c_int16");
break;
case UINT16:
ctype = PyObject_GetAttrString(ctypes, "c_uint16");
break;
case UINT32:
ctype = PyObject_GetAttrString(ctypes, "c_uint32");
break;
case UINT64:
ctype = PyObject_GetAttrString(ctypes, "c_uint64");
break;
case BOOL:
ctype = PyObject_GetAttrString(ctypes, "c_bool");
break;
case BYTE:
ctype = PyObject_GetAttrString(ctypes, "c_byte");
break;
case UBYTE:
ctype = PyObject_GetAttrString(ctypes, "c_ubyte");
break;
default:
throw new RuntimeException("Unsupported dtype: " + npArray.getDtype());
}
PyObject ctypesPointer = PyObject_GetAttrString(ctypes, "POINTER");
PyObject argsTuple = PyTuple_New(1);
PyTuple_SetItem(argsTuple, 0, ctype);
PyObject ptrType = PyObject_Call(ctypesPointer, argsTuple, null);
PyObject cast = PyObject_GetAttrString(ctypes, "cast");
PyObject address = PyLong_FromLong(npArray.getAddress());
PyObject argsTuple2 = PyTuple_New(2);
PyTuple_SetItem(argsTuple2, 0, address);
PyTuple_SetItem(argsTuple2, 1, ptrType);
PyObject ptr = PyObject_Call(cast, argsTuple2, null);
PyObject shapeTuple = PyTuple_New(npArray.getShape().length);
for (int i = 0; i < npArray.getShape().length; i++) {
PyObject dim = PyLong_FromLong(npArray.getShape()[i]);
PyTuple_SetItem(shapeTuple, i, dim);
Py_DecRef(dim);
}
PyObject ctypesLib = PyObject_GetAttrString(np, "ctypeslib");
PyObject asArray = PyObject_GetAttrString(ctypesLib, "as_array");
PyObject argsTuple3 = PyTuple_New(2);
PyTuple_SetItem(argsTuple3, 0, ptr);
PyTuple_SetItem(argsTuple3, 1, shapeTuple);
nativePythonObject = PyObject_Call(asArray, argsTuple3, null);
Py_DecRef(ctypesPointer);
Py_DecRef(ctypesLib);
Py_DecRef(argsTuple);
Py_DecRef(argsTuple2);
Py_DecRef(argsTuple3);
Py_DecRef(cast);
Py_DecRef(asArray);
}
/*---primitve constructors---*/
public PyObject getNativePythonObject() {
return nativePythonObject;
}
public PythonObject(String data) {
nativePythonObject = PyUnicode_FromString(data);
}
public PythonObject(int data) {
nativePythonObject = PyLong_FromLong((long) data);
}
public PythonObject(long data) {
nativePythonObject = PyLong_FromLong(data);
}
public PythonObject(double data) {
nativePythonObject = PyFloat_FromDouble(data);
}
public PythonObject(boolean data) {
nativePythonObject = PyBool_FromLong(data ? 1 : 0);
}
private static PythonObject j2pyObject(Object item) {
if (item instanceof PythonObject) {
return (PythonObject) item;
} else if (item instanceof PyObject) {
return new PythonObject((PyObject) item);
} else if (item instanceof INDArray) {
return new PythonObject((INDArray) item);
} else if (item instanceof NumpyArray) {
return new PythonObject((NumpyArray) item);
} else if (item instanceof List) {
return new PythonObject((List) item);
} else if (item instanceof Object[]) {
return new PythonObject((Object[]) item);
} else if (item instanceof Map) {
return new PythonObject((Map) item);
} else if (item instanceof String) {
return new PythonObject((String) item);
} else if (item instanceof Double) {
return new PythonObject((Double) item);
} else if (item instanceof Float) {
return new PythonObject((Float) item);
} else if (item instanceof Long) {
return new PythonObject((Long) item);
} else if (item instanceof Integer) {
return new PythonObject((Integer) item);
} else if (item instanceof Boolean) {
return new PythonObject((Boolean) item);
} else if (item instanceof Pointer){
return new PythonObject(new BytePointer((Pointer)item));
} else {
throw new RuntimeException("Unsupported item in list: " + item);
}
}
public PythonObject(Object[] data) {
PyObject pyList = PyList_New((long) data.length);
for (int i = 0; i < data.length; i++) {
PyList_SetItem(pyList, i, j2pyObject(data[i]).nativePythonObject);
}
nativePythonObject = pyList;
}
public PythonObject(List data) {
PyObject pyList = PyList_New((long) data.size());
for (int i = 0; i < data.size(); i++) {
PyList_SetItem(pyList, i, j2pyObject(data.get(i)).nativePythonObject);
}
nativePythonObject = pyList;
}
public PythonObject(Map data) {
PyObject pyDict = PyDict_New();
for (Object k : data.keySet()) {
PythonObject pyKey;
if (k instanceof PythonObject) {
pyKey = (PythonObject) k;
} else if (k instanceof String) {
pyKey = new PythonObject((String) k);
} else if (k instanceof Double) {
pyKey = new PythonObject((Double) k);
} else if (k instanceof Float) {
pyKey = new PythonObject((Float) k);
} else if (k instanceof Long) {
pyKey = new PythonObject((Long) k);
} else if (k instanceof Integer) {
pyKey = new PythonObject((Integer) k);
} else if (k instanceof Boolean) {
pyKey = new PythonObject((Boolean) k);
} else {
throw new RuntimeException("Unsupported key in map: " + k.getClass());
}
Object v = data.get(k);
PythonObject pyVal;
if (v instanceof PythonObject) {
pyVal = (PythonObject) v;
} else if (v instanceof PyObject) {
pyVal = new PythonObject((PyObject) v);
} else if (v instanceof INDArray) {
pyVal = new PythonObject((INDArray) v);
} else if (v instanceof NumpyArray) {
pyVal = new PythonObject((NumpyArray) v);
} else if (v instanceof Map) {
pyVal = new PythonObject((Map) v);
} else if (v instanceof List) {
pyVal = new PythonObject((List) v);
} else if (v instanceof String) {
pyVal = new PythonObject((String) v);
} else if (v instanceof Double) {
pyVal = new PythonObject((Double) v);
} else if (v instanceof Float) {
pyVal = new PythonObject((Float) v);
} else if (v instanceof Long) {
pyVal = new PythonObject((Long) v);
} else if (v instanceof Integer) {
pyVal = new PythonObject((Integer) v);
} else if (v instanceof Boolean) {
pyVal = new PythonObject((Boolean) v);
} else {
throw new RuntimeException("Unsupported value in map: " + k.getClass());
}
PyDict_SetItem(pyDict, pyKey.nativePythonObject, pyVal.nativePythonObject);
}
nativePythonObject = pyDict;
}
/*------*/
private static String pyObjectToString(PyObject pyObject) {
PyObject repr = PyObject_Str(pyObject);
PyObject str = PyUnicode_AsEncodedString(repr, "utf-8", "~E~");
String jstr = PyBytes_AsString(str).getString();
Py_DecRef(repr);
Py_DecRef(str);
return jstr;
}
public String toString() {
return pyObjectToString(nativePythonObject);
}
public double toDouble() {
return PyFloat_AsDouble(nativePythonObject);
}
public float toFloat() {
return (float) PyFloat_AsDouble(nativePythonObject);
}
public int toInt() {
return (int) PyLong_AsLong(nativePythonObject);
}
public long toLong() {
return PyLong_AsLong(nativePythonObject);
}
public boolean toBoolean() {
if (isNone()) return false;
return toInt() != 0;
}
public NumpyArray toNumpy() {
PyObject arrInterface = PyObject_GetAttrString(nativePythonObject, "__array_interface__"); // borrowed reference; DO NOT Py_DecRef() !
PyObject data = PyDict_GetItemString(arrInterface, "data");
PyObject pyAddress = PyTuple_GetItem(data, 0);
long address = PyLong_AsLong(pyAddress);
PyObject pyDtype = PyObject_GetAttrString(nativePythonObject, "dtype");
PyObject pyDtypeName = PyObject_GetAttrString(pyDtype, "name");
String dtypeName = pyObjectToString(pyDtypeName);
Py_DecRef(pyDtype);
Py_DecRef(pyDtypeName);
PyObject shape = PyObject_GetAttrString(nativePythonObject, "shape");
PyObject strides = PyObject_GetAttrString(nativePythonObject, "strides");
int ndim = (int) PyObject_Size(shape);
long[] jshape = new long[ndim];
long[] jstrides = new long[ndim];
for (int i = 0; i < ndim; i++) {
jshape[i] = PyLong_AsLong(PyTuple_GetItem(shape, i));
jstrides[i] = PyLong_AsLong(PyTuple_GetItem(strides, i));
}
Py_DecRef(shape);
Py_DecRef(strides);
DataType dtype;
if (dtypeName.equals("float64")) {
dtype = DataType.DOUBLE;
} else if (dtypeName.equals("float32")) {
dtype = DataType.FLOAT;
} else if (dtypeName.equals("int16")) {
dtype = DataType.SHORT;
} else if (dtypeName.equals("int32")) {
dtype = DataType.INT;
} else if (dtypeName.equals("int64")) {
dtype = DataType.LONG;
} else {
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
}
return new NumpyArray(address, jshape, jstrides, dtype);
}
public PythonObject attr(String attr) {
return new PythonObject(PyObject_GetAttrString(nativePythonObject, attr));
}
public PythonObject call(Object... args) {
if (args.length > 0 && args[args.length - 1] instanceof Map) {
List<Object> args2 = new ArrayList<>();
for (int i = 0; i < args.length - 1; i++) {
args2.add(args[i]);
}
return call(args2, (Map) args[args.length - 1]);
}
if (args.length == 0) {
return new PythonObject(PyObject_CallObject(nativePythonObject, null));
}
PyObject tuple = PyTuple_New(args.length); // leaky; tuple may contain borrowed references, so can not be de-allocated.
for (int i = 0; i < args.length; i++) {
PyTuple_SetItem(tuple, i, j2pyObject(args[i]).nativePythonObject);
}
PythonObject ret = new PythonObject(PyObject_Call(nativePythonObject, tuple, null));
return ret;
}
public PythonObject callWithArgs(PythonObject args) {
PyObject tuple = PyList_AsTuple(args.nativePythonObject);
return new PythonObject(PyObject_Call(nativePythonObject, tuple, null));
}
public PythonObject callWithKwargs(PythonObject kwargs) {
PyObject tuple = PyTuple_New(0);
return new PythonObject(PyObject_Call(nativePythonObject, tuple, kwargs.nativePythonObject));
}
public PythonObject callWithArgsAndKwargs(PythonObject args, PythonObject kwargs) {
PyObject tuple = PyList_AsTuple(args.nativePythonObject);
PyObject dict = kwargs.nativePythonObject;
return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict));
}
public PythonObject call(Map kwargs) {
PyObject dict = new PythonObject(kwargs).nativePythonObject;
PyObject tuple = PyTuple_New(0);
return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict));
}
public PythonObject call(List args) {
PyObject tuple = PyList_AsTuple(new PythonObject(args).nativePythonObject);
return new PythonObject(PyObject_Call(nativePythonObject, tuple, null));
}
public PythonObject call(List args, Map kwargs) {
PyObject tuple = PyList_AsTuple(new PythonObject(args).nativePythonObject);
PyObject dict = new PythonObject(kwargs).nativePythonObject;
return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict));
}
private PythonObject get(PyObject key) {
return new PythonObject(
PyObject_GetItem(nativePythonObject, key)
);
}
public PythonObject get(PythonObject key) {
return get(key.nativePythonObject);
}
public PythonObject get(int key) {
return get(PyLong_FromLong((long) key));
}
public PythonObject get(long key) {
return new PythonObject(
PyObject_GetItem(nativePythonObject, PyLong_FromLong(key))
);
}
public PythonObject get(double key) {
return new PythonObject(
PyObject_GetItem(nativePythonObject, PyFloat_FromDouble(key))
);
}
public PythonObject get(String key) {
return get(new PythonObject(key));
}
public void set(PythonObject key, PythonObject value) {
PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject);
}
public void del() {
Py_DecRef(nativePythonObject);
nativePythonObject = null;
}
public JSONArray toJSONArray() throws PythonException {
PythonObject json = Python.importModule("json");
PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer());
String jsonString = serialized.toString();
return new JSONArray(jsonString);
}
public JSONObject toJSONObject() throws PythonException {
PythonObject json = Python.importModule("json");
PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer());
String jsonString = serialized.toString();
return new JSONObject(jsonString);
}
public List toList() throws PythonException{
List list = new ArrayList();
int n = Python.len(this).toInt();
for (int i = 0; i < n; i++) {
PythonObject o = get(i);
if (Python.isinstance(o, Python.strType())) {
list.add(o.toString());
} else if (Python.isinstance(o, Python.intType())) {
list.add(o.toLong());
} else if (Python.isinstance(o, Python.floatType())) {
list.add(o.toDouble());
} else if (Python.isinstance(o, Python.boolType())) {
list.add(o);
} else if (Python.isinstance(o, Python.listType(), Python.tupleType())) {
list.add(o.toList());
} else if (Python.isinstance(o, Python.importModule("numpy").attr("ndarray"))) {
list.add(o.toNumpy().getNd4jArray());
} else if (Python.isinstance(o, Python.dictType())) {
list.add(o.toMap());
} else {
throw new RuntimeException("Error while converting python" +
" list to java List: Unable to serialize python " +
"object of type " + Python.type(this).toString());
}
}
return list;
}
public Map toMap() throws PythonException{
Map map = new HashMap();
List keys = Python.list(attr("keys").call()).toList();
List values = Python.list(attr("values").call()).toList();
for (int i = 0; i < keys.size(); i++) {
map.put(keys.get(i), values.get(i));
}
return map;
}
public BytePointer toBytePointer() throws PythonException{
if (Python.isinstance(this, Python.bytesType())){
PyObject byteArray = PyByteArray_FromObject(nativePythonObject);
return PyByteArray_AsString(byteArray);
}
else if (Python.isinstance(this, Python.bytearrayType())){
return PyByteArray_AsString(nativePythonObject);
}
else{
PyObject ctypes = PyImport_ImportModule("ctypes");
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
if (PyObject_IsInstance(nativePythonObject, cArrType) != 0){
PyObject cVoidP = PyObject_GetAttrString(ctypes, "c_void_p");
PyObject cast = PyObject_GetAttrString(ctypes, "cast");
PyObject argsTuple = PyTuple_New(2);
PyTuple_SetItem(argsTuple, 0, nativePythonObject);
PyTuple_SetItem(argsTuple, 1, cVoidP);
PyObject voidPtr = PyObject_Call(cast, argsTuple, null);
PyObject pyAddress = PyObject_GetAttrString(voidPtr, "value");
long address = PyLong_AsLong(pyAddress);
long size = PyObject_Size(nativePythonObject);
Py_DecRef(ctypes);
Py_DecRef(cArrType);
Py_DecRef(argsTuple);
Py_DecRef(voidPtr);
Py_DecRef(pyAddress);
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address);
ptr = ptr.limit(size);
ptr = ptr.capacity(size);
return new BytePointer(ptr);
}
else{
throw new PythonException("Expected bytes, bytearray or ctypesArray. Received " + Python.type(this).toString());
}
}
}
public boolean isNone() {
return nativePythonObject == null;
}
}

View File

@ -24,10 +24,13 @@ import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.Transform; import org.datavec.api.transform.Transform;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.json.JSONPropertyIgnore;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder; import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -52,12 +55,13 @@ public class PythonTransform implements Transform {
private String code; private String code;
private PythonVariables inputs; private PythonVariables inputs;
private PythonVariables outputs; private PythonVariables outputs;
private String name = UUID.randomUUID().toString(); private String name = UUID.randomUUID().toString();
private Schema inputSchema; private Schema inputSchema;
private Schema outputSchema; private Schema outputSchema;
private String outputDict; private String outputDict;
private boolean returnAllVariables; private boolean returnAllVariables;
private boolean setupAndRun = false; private boolean setupAndRun = false;
private PythonJob pythonJob;
@Builder @Builder
@ -70,71 +74,70 @@ public class PythonTransform implements Transform {
String outputDict, String outputDict,
boolean returnAllInputs, boolean returnAllInputs,
boolean setupAndRun) { boolean setupAndRun) {
Preconditions.checkNotNull(code,"No code found to run!"); Preconditions.checkNotNull(code, "No code found to run!");
this.code = code; this.code = code;
this.returnAllVariables = returnAllInputs; this.returnAllVariables = returnAllInputs;
this.setupAndRun = setupAndRun; this.setupAndRun = setupAndRun;
if(inputs != null) if (inputs != null)
this.inputs = inputs; this.inputs = inputs;
if(outputs != null) if (outputs != null)
this.outputs = outputs; this.outputs = outputs;
if (name != null)
if(name != null)
this.name = name; this.name = name;
if (outputDict != null) { if (outputDict != null) {
this.outputDict = outputDict; this.outputDict = outputDict;
this.outputs = new PythonVariables(); this.outputs = new PythonVariables();
this.outputs.addDict(outputDict); this.outputs.addDict(outputDict);
String helpers;
try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) {
helpers = IOUtils.toString(is, Charset.defaultCharset());
}catch (IOException e){
throw new RuntimeException("Error reading python code");
}
this.code += "\n\n" + helpers;
this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")";
} }
try { try {
if(inputSchema != null) { if (inputSchema != null) {
this.inputSchema = inputSchema; this.inputSchema = inputSchema;
if(inputs == null || inputs.isEmpty()) { if (inputs == null || inputs.isEmpty()) {
this.inputs = schemaToPythonVariables(inputSchema); this.inputs = schemaToPythonVariables(inputSchema);
} }
} }
if(outputSchema != null) { if (outputSchema != null) {
this.outputSchema = outputSchema; this.outputSchema = outputSchema;
if(outputs == null || outputs.isEmpty()) { if (outputs == null || outputs.isEmpty()) {
this.outputs = schemaToPythonVariables(outputSchema); this.outputs = schemaToPythonVariables(outputSchema);
} }
} }
}catch(Exception e) { } catch (Exception e) {
throw new IllegalStateException(e); throw new IllegalStateException(e);
} }
try{
pythonJob = PythonJob.builder()
.name("a" + UUID.randomUUID().toString().replace("-", "_"))
.code(code)
.setupRunMode(setupAndRun)
.build();
}
catch(Exception e){
throw new IllegalStateException("Error creating python job: " + e);
}
} }
@Override @Override
public void setInputSchema(Schema inputSchema) { public void setInputSchema(Schema inputSchema) {
Preconditions.checkNotNull(inputSchema,"No input schema found!"); Preconditions.checkNotNull(inputSchema, "No input schema found!");
this.inputSchema = inputSchema; this.inputSchema = inputSchema;
try{ try {
inputs = schemaToPythonVariables(inputSchema); inputs = schemaToPythonVariables(inputSchema);
}catch (Exception e){ } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
if (outputSchema == null && outputDict == null){ if (outputSchema == null && outputDict == null) {
outputSchema = inputSchema; outputSchema = inputSchema;
} }
} }
@Override @Override
public Schema getInputSchema(){ public Schema getInputSchema() {
return inputSchema; return inputSchema;
} }
@ -158,67 +161,51 @@ public class PythonTransform implements Transform {
} }
@Override @Override
public List<Writable> map(List<Writable> writables) { public List<Writable> map(List<Writable> writables) {
PythonVariables pyInputs = getPyInputsFromWritables(writables); PythonVariables pyInputs = getPyInputsFromWritables(writables);
Preconditions.checkNotNull(pyInputs,"Inputs must not be null!"); Preconditions.checkNotNull(pyInputs, "Inputs must not be null!");
try {
try{
if (returnAllVariables) { if (returnAllVariables) {
if (setupAndRun){ return getWritablesFromPyOutputs(pythonJob.execAndReturnAllVariables(pyInputs));
return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs));
}
return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs));
} }
if (outputDict != null) { if (outputDict != null) {
if (setupAndRun) { pythonJob.exec(pyInputs, outputs);
PythonExecutioner.execWithSetupAndRun(this, pyInputs);
}else{
PythonExecutioner.exec(this, pyInputs);
}
PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict); PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict);
return getWritablesFromPyOutputs(out); return getWritablesFromPyOutputs(out);
} } else {
else { pythonJob.exec(pyInputs, outputs);
if (setupAndRun) {
PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs);
}else{
PythonExecutioner.exec(code, pyInputs, outputs);
}
return getWritablesFromPyOutputs(outputs); return getWritablesFromPyOutputs(outputs);
} }
} } catch (Exception e) {
catch (Exception e){
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
@Override @Override
public String[] outputColumnNames(){ public String[] outputColumnNames() {
return outputs.getVariables(); return outputs.getVariables();
} }
@Override @Override
public String outputColumnName(){ public String outputColumnName() {
return outputColumnNames()[0]; return outputColumnNames()[0];
} }
@Override @Override
public String[] columnNames(){ public String[] columnNames() {
return outputs.getVariables(); return outputs.getVariables();
} }
@Override @Override
public String columnName(){ public String columnName() {
return columnNames()[0]; return columnNames()[0];
} }
public Schema transform(Schema inputSchema){ public Schema transform(Schema inputSchema) {
return outputSchema; return outputSchema;
} }
@ -226,33 +213,33 @@ public class PythonTransform implements Transform {
private PythonVariables getPyInputsFromWritables(List<Writable> writables) { private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
PythonVariables ret = new PythonVariables(); PythonVariables ret = new PythonVariables();
for (String name: inputs.getVariables()) { for (String name : inputs.getVariables()) {
int colIdx = inputSchema.getIndexOfColumn(name); int colIdx = inputSchema.getIndexOfColumn(name);
Writable w = writables.get(colIdx); Writable w = writables.get(colIdx);
PythonVariables.Type pyType = inputs.getType(name); PythonType pyType = inputs.getType(name);
switch (pyType){ switch (pyType.getName()) {
case INT: case INT:
if (w instanceof LongWritable){ if (w instanceof LongWritable) {
ret.addInt(name, ((LongWritable)w).get()); ret.addInt(name, ((LongWritable) w).get());
} else {
ret.addInt(name, ((IntWritable) w).get());
} }
else{
ret.addInt(name, ((IntWritable)w).get());
}
break; break;
case FLOAT: case FLOAT:
if (w instanceof DoubleWritable) { if (w instanceof DoubleWritable) {
ret.addFloat(name, ((DoubleWritable)w).get()); ret.addFloat(name, ((DoubleWritable) w).get());
} } else {
else{ ret.addFloat(name, ((FloatWritable) w).get());
ret.addFloat(name, ((FloatWritable)w).get());
} }
break; break;
case STR: case STR:
ret.addStr(name, w.toString()); ret.addStr(name, w.toString());
break; break;
case NDARRAY: case NDARRAY:
ret.addNDArray(name,((NDArrayWritable)w).get()); ret.addNDArray(name, ((NDArrayWritable) w).get());
break;
case BOOL:
ret.addBool(name, ((BooleanWritable) w).get());
break; break;
default: default:
throw new RuntimeException("Unsupported input type:" + pyType); throw new RuntimeException("Unsupported input type:" + pyType);
@ -269,8 +256,8 @@ public class PythonTransform implements Transform {
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
for (int i = 0; i < varNames.length; i++) { for (int i = 0; i < varNames.length; i++) {
String name = varNames[i]; String name = varNames[i];
PythonVariables.Type pyType = pyOuts.getType(name); PythonType pyType = pyOuts.getType(name);
switch (pyType){ switch (pyType.getName()) {
case INT: case INT:
schemaBuilder.addColumnLong(name); schemaBuilder.addColumnLong(name);
break; break;
@ -283,11 +270,14 @@ public class PythonTransform implements Transform {
schemaBuilder.addColumnString(name); schemaBuilder.addColumnString(name);
break; break;
case NDARRAY: case NDARRAY:
NumpyArray arr = pyOuts.getNDArrayValue(name); INDArray arr = pyOuts.getNDArrayValue(name);
schemaBuilder.addColumnNDArray(name, arr.getShape()); schemaBuilder.addColumnNDArray(name, arr.shape());
break;
case BOOL:
schemaBuilder.addColumnBoolean(name);
break; break;
default: default:
throw new IllegalStateException("Unable to support type " + pyType.name()); throw new IllegalStateException("Unable to support type " + pyType.getName());
} }
} }
this.outputSchema = schemaBuilder.build(); this.outputSchema = schemaBuilder.build();
@ -295,9 +285,9 @@ public class PythonTransform implements Transform {
for (int i = 0; i < varNames.length; i++) { for (int i = 0; i < varNames.length; i++) {
String name = varNames[i]; String name = varNames[i];
PythonVariables.Type pyType = pyOuts.getType(name); PythonType pyType = pyOuts.getType(name);
switch (pyType){ switch (pyType.getName()) {
case INT: case INT:
out.add(new LongWritable(pyOuts.getIntValue(name))); out.add(new LongWritable(pyOuts.getIntValue(name)));
break; break;
@ -308,14 +298,14 @@ public class PythonTransform implements Transform {
out.add(new Text(pyOuts.getStrValue(name))); out.add(new Text(pyOuts.getStrValue(name)));
break; break;
case NDARRAY: case NDARRAY:
NumpyArray arr = pyOuts.getNDArrayValue(name); INDArray arr = pyOuts.getNDArrayValue(name);
out.add(new NDArrayWritable(arr.getNd4jArray())); out.add(new NDArrayWritable(arr));
break; break;
case DICT: case DICT:
Map<?, ?> dictValue = pyOuts.getDictValue(name); Map<?, ?> dictValue = pyOuts.getDictValue(name);
Map noNullValues = new java.util.HashMap<>(); Map noNullValues = new java.util.HashMap<>();
for(Map.Entry entry : dictValue.entrySet()) { for (Map.Entry entry : dictValue.entrySet()) {
if(entry.getValue() != org.json.JSONObject.NULL) { if (entry.getValue() != org.json.JSONObject.NULL) {
noNullValues.put(entry.getKey(), entry.getValue()); noNullValues.put(entry.getKey(), entry.getValue());
} }
} }
@ -327,21 +317,22 @@ public class PythonTransform implements Transform {
} }
break; break;
case LIST: case LIST:
Object[] listValue = pyOuts.getListValue(name); Object[] listValue = pyOuts.getListValue(name).toArray();
try { try {
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue))); out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue)));
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!"); throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
} }
break; break;
case BOOL:
out.add(new BooleanWritable(pyOuts.getBooleanValue(name)));
break;
default: default:
throw new IllegalStateException("Unable to support type " + pyType.name()); throw new IllegalStateException("Unable to support type " + pyType.getName());
} }
} }
return out; return out;
} }
} }

View File

@ -0,0 +1,238 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static org.datavec.python.Python.importModule;
/**
*
* @param <T> Corresponding Java type for the Python type
*/
public abstract class PythonType<T> {
public abstract T toJava(PythonObject pythonObject) throws PythonException;
private final TypeName typeName;
enum TypeName{
STR,
INT,
FLOAT,
BOOL,
LIST,
DICT,
NDARRAY,
BYTES
}
private PythonType(TypeName typeName){
this.typeName = typeName;
}
public TypeName getName(){return typeName;}
public String toString(){
return getName().name();
}
public static PythonType valueOf(String typeName) throws PythonException{
try{
typeName.valueOf(typeName);
} catch (IllegalArgumentException iae){
throw new PythonException("Invalid python type: " + typeName, iae);
}
try{
return (PythonType)PythonType.class.getField(typeName).get(null); // shouldn't fail
} catch (Exception e){
throw new RuntimeException(e);
}
}
public static PythonType valueOf(TypeName typeName){
try{
return valueOf(typeName.name()); // shouldn't fail
}catch (PythonException pe){
throw new RuntimeException(pe);
}
}
/**
* Since multiple java types can map to the same python type,
* this method "normalizes" all supported incoming objects to T
*
* @param object object to be converted to type T
* @return
*/
public T convert(Object object) throws PythonException {
return (T) object;
}
public static final PythonType<String> STR = new PythonType<String>(TypeName.STR) {
@Override
public String toJava(PythonObject pythonObject) throws PythonException {
if (!Python.isinstance(pythonObject, Python.strType())) {
throw new PythonException("Expected variable to be str, but was " + Python.type(pythonObject));
}
return pythonObject.toString();
}
@Override
public String convert(Object object) {
return object.toString();
}
};
public static final PythonType<Long> INT = new PythonType<Long>(TypeName.INT) {
@Override
public Long toJava(PythonObject pythonObject) throws PythonException {
if (!Python.isinstance(pythonObject, Python.intType())) {
throw new PythonException("Expected variable to be int, but was " + Python.type(pythonObject));
}
return pythonObject.toLong();
}
@Override
public Long convert(Object object) throws PythonException {
if (object instanceof Number) {
return ((Number) object).longValue();
}
throw new PythonException("Unable to cast " + object + " to Long.");
}
};
public static final PythonType<Double> FLOAT = new PythonType<Double>(TypeName.FLOAT) {
@Override
public Double toJava(PythonObject pythonObject) throws PythonException {
if (!Python.isinstance(pythonObject, Python.floatType())) {
throw new PythonException("Expected variable to be float, but was " + Python.type(pythonObject));
}
return pythonObject.toDouble();
}
@Override
public Double convert(Object object) throws PythonException {
if (object instanceof Number) {
return ((Number) object).doubleValue();
}
throw new PythonException("Unable to cast " + object + " to Double.");
}
};
public static final PythonType<Boolean> BOOL = new PythonType<Boolean>(TypeName.BOOL) {
@Override
public Boolean toJava(PythonObject pythonObject) throws PythonException {
if (!Python.isinstance(pythonObject, Python.boolType())) {
throw new PythonException("Expected variable to be bool, but was " + Python.type(pythonObject));
}
return pythonObject.toBoolean();
}
@Override
public Boolean convert(Object object) throws PythonException {
if (object instanceof Number) {
return ((Number) object).intValue() != 0;
} else if (object instanceof Boolean) {
return (Boolean) object;
}
throw new PythonException("Unable to cast " + object + " to Boolean.");
}
};
public static final PythonType<List> LIST = new PythonType<List>(TypeName.LIST) {
@Override
public List toJava(PythonObject pythonObject) throws PythonException {
if (!Python.isinstance(pythonObject, Python.listType())) {
throw new PythonException("Expected variable to be list, but was " + Python.type(pythonObject));
}
return pythonObject.toList();
}
@Override
public List convert(Object object) throws PythonException {
if (object instanceof java.util.List) {
return (List) object;
} else if (object instanceof org.json.JSONArray) {
org.json.JSONArray jsonArray = (org.json.JSONArray) object;
return jsonArray.toList();
} else if (object instanceof Object[]) {
return Arrays.asList((Object[]) object);
}
throw new PythonException("Unable to cast " + object + " to List.");
}
};
public static final PythonType<Map> DICT = new PythonType<Map>(TypeName.DICT) {
@Override
public Map toJava(PythonObject pythonObject) throws PythonException {
if (!Python.isinstance(pythonObject, Python.dictType())) {
throw new PythonException("Expected variable to be dict, but was " + Python.type(pythonObject));
}
return pythonObject.toMap();
}
@Override
public Map convert(Object object) throws PythonException {
if (object instanceof Map) {
return (Map) object;
}
throw new PythonException("Unable to cast " + object + " to Map.");
}
};
public static final PythonType<INDArray> NDARRAY = new PythonType<INDArray>(TypeName.NDARRAY) {
@Override
public INDArray toJava(PythonObject pythonObject) throws PythonException {
PythonObject np = importModule("numpy");
if (!Python.isinstance(pythonObject, np.attr("ndarray"), np.attr("generic"))) {
throw new PythonException("Expected variable to be numpy.ndarray, but was " + Python.type(pythonObject));
}
return pythonObject.toNumpy().getNd4jArray();
}
@Override
public INDArray convert(Object object) throws PythonException {
if (object instanceof INDArray) {
return (INDArray) object;
} else if (object instanceof NumpyArray) {
return ((NumpyArray) object).getNd4jArray();
}
throw new PythonException("Unable to cast " + object + " to INDArray.");
}
};
public static final PythonType<BytePointer> BYTES = new PythonType<BytePointer>(TypeName.BYTES) {
@Override
public BytePointer toJava(PythonObject pythonObject) throws PythonException {
return pythonObject.toBytePointer();
}
@Override
public BytePointer convert(Object object) throws PythonException {
if (object instanceof BytePointer) {
return (BytePointer) object;
} else if (object instanceof Pointer) {
return new BytePointer((Pointer) object);
}
throw new PythonException("Unable to cast " + object + " to BytePointer.");
}
};
}

View File

@ -24,28 +24,30 @@ public class PythonUtils {
* Create a {@link Schema} * Create a {@link Schema}
* from {@link PythonVariables}. * from {@link PythonVariables}.
* Types are mapped to types of the same name. * Types are mapped to types of the same name.
*
* @param input the input {@link PythonVariables} * @param input the input {@link PythonVariables}
* @return the output {@link Schema} * @return the output {@link Schema}
*/ */
public static Schema fromPythonVariables(PythonVariables input) { public static Schema fromPythonVariables(PythonVariables input) {
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0,"Input must have variables. Found none."); Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0, "Input must have variables. Found none.");
for(Map.Entry<String,PythonVariables.Type> entry : input.getVars().entrySet()) { for (String varName: input.getVariables()) {
switch(entry.getValue()) {
switch (input.getType(varName).getName()) {
case INT: case INT:
schemaBuilder.addColumnInteger(entry.getKey()); schemaBuilder.addColumnInteger(varName);
break; break;
case STR: case STR:
schemaBuilder.addColumnString(entry.getKey()); schemaBuilder.addColumnString(varName);
break; break;
case FLOAT: case FLOAT:
schemaBuilder.addColumnFloat(entry.getKey()); schemaBuilder.addColumnFloat(varName);
break; break;
case NDARRAY: case NDARRAY:
schemaBuilder.addColumnNDArray(entry.getKey(),null); schemaBuilder.addColumnNDArray(varName, null);
break; break;
case BOOL: case BOOL:
schemaBuilder.addColumn(new BooleanMetaData(entry.getKey())); schemaBuilder.addColumn(new BooleanMetaData(varName));
} }
} }
@ -56,34 +58,36 @@ public class PythonUtils {
* Create a {@link Schema} from an input * Create a {@link Schema} from an input
* {@link PythonVariables} * {@link PythonVariables}
* Types are mapped to types of the same name * Types are mapped to types of the same name
*
* @param input the input schema * @param input the input schema
* @return the output python variables. * @return the output python variables.
*/ */
public static PythonVariables fromSchema(Schema input) { public static PythonVariables fromSchema(Schema input) {
PythonVariables ret = new PythonVariables(); PythonVariables ret = new PythonVariables();
for(int i = 0; i < input.numColumns(); i++) { for (int i = 0; i < input.numColumns(); i++) {
String currColumnName = input.getName(i); String currColumnName = input.getName(i);
ColumnType columnType = input.getType(i); ColumnType columnType = input.getType(i);
switch(columnType) { switch (columnType) {
case NDArray: case NDArray:
ret.add(currColumnName, PythonVariables.Type.NDARRAY); ret.add(currColumnName, PythonType.NDARRAY);
break; break;
case Boolean: case Boolean:
ret.add(currColumnName, PythonVariables.Type.BOOL); ret.add(currColumnName, PythonType.BOOL);
break; break;
case Categorical: case Categorical:
case String: case String:
ret.add(currColumnName, PythonVariables.Type.STR); ret.add(currColumnName, PythonType.STR);
break; break;
case Double: case Double:
case Float: case Float:
ret.add(currColumnName, PythonVariables.Type.FLOAT); ret.add(currColumnName, PythonType.FLOAT);
break; break;
case Integer: case Integer:
case Long: case Long:
ret.add(currColumnName, PythonVariables.Type.INT); ret.add(currColumnName, PythonType.INT);
break; break;
case Bytes: case Bytes:
ret.add(currColumnName, PythonType.BYTES);
break; break;
case Time: case Time:
throw new UnsupportedOperationException("Unable to process dates with python yet."); throw new UnsupportedOperationException("Unable to process dates with python yet.");
@ -92,9 +96,11 @@ public class PythonUtils {
return ret; return ret;
} }
/** /**
* Convert a {@link Schema} * Convert a {@link Schema}
* to {@link PythonVariables} * to {@link PythonVariables}
*
* @param schema the input schema * @param schema the input schema
* @return the output {@link PythonVariables} where each * @return the output {@link PythonVariables} where each
* name in the map is associated with a column name in the schema. * name in the map is associated with a column name in the schema.
@ -107,7 +113,7 @@ public class PythonUtils {
for (int i = 0; i < numCols; i++) { for (int i = 0; i < numCols; i++) {
String colName = schema.getName(i); String colName = schema.getName(i);
ColumnType colType = schema.getType(i); ColumnType colType = schema.getType(i);
switch (colType){ switch (colType) {
case Long: case Long:
case Integer: case Integer:
pyVars.addInt(colName); pyVars.addInt(colName);
@ -122,6 +128,9 @@ public class PythonUtils {
case NDArray: case NDArray:
pyVars.addNDArray(colName); pyVars.addNDArray(colName);
break; break;
case Boolean:
pyVars.addBool(colName);
break;
default: default:
throw new Exception("Unsupported python input type: " + colType.toString()); throw new Exception("Unsupported python input type: " + colType.toString());
} }
@ -131,117 +140,104 @@ public class PythonUtils {
} }
public static NumpyArray mapToNumpyArray(Map map){ public static NumpyArray mapToNumpyArray(Map map) {
String dtypeName = (String)map.get("dtype"); String dtypeName = (String) map.get("dtype");
DataType dtype; DataType dtype;
if (dtypeName.equals("float64")){ if (dtypeName.equals("float64")) {
dtype = DataType.DOUBLE; dtype = DataType.DOUBLE;
} } else if (dtypeName.equals("float32")) {
else if (dtypeName.equals("float32")){
dtype = DataType.FLOAT; dtype = DataType.FLOAT;
} } else if (dtypeName.equals("int16")) {
else if (dtypeName.equals("int16")){
dtype = DataType.SHORT; dtype = DataType.SHORT;
} } else if (dtypeName.equals("int32")) {
else if (dtypeName.equals("int32")){
dtype = DataType.INT; dtype = DataType.INT;
} } else if (dtypeName.equals("int64")) {
else if (dtypeName.equals("int64")){
dtype = DataType.LONG; dtype = DataType.LONG;
} } else {
else{
throw new RuntimeException("Unsupported array type " + dtypeName + "."); throw new RuntimeException("Unsupported array type " + dtypeName + ".");
} }
List shapeList = (List)map.get("shape"); List shapeList = (List) map.get("shape");
long[] shape = new long[shapeList.size()]; long[] shape = new long[shapeList.size()];
for (int i = 0; i < shape.length; i++) { for (int i = 0; i < shape.length; i++) {
shape[i] = (Long)shapeList.get(i); shape[i] = (Long) shapeList.get(i);
} }
List strideList = (List)map.get("shape"); List strideList = (List) map.get("shape");
long[] stride = new long[strideList.size()]; long[] stride = new long[strideList.size()];
for (int i = 0; i < stride.length; i++) { for (int i = 0; i < stride.length; i++) {
stride[i] = (Long)strideList.get(i); stride[i] = (Long) strideList.get(i);
} }
long address = (Long)map.get("address"); long address = (Long) map.get("address");
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype); NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true);
return numpyArray; return numpyArray;
} }
public static PythonVariables expandInnerDict(PythonVariables pyvars, String key){ public static PythonVariables expandInnerDict(PythonVariables pyvars, String key) {
Map dict = pyvars.getDictValue(key); Map dict = pyvars.getDictValue(key);
String[] keys = (String[])dict.keySet().toArray(new String[dict.keySet().size()]); String[] keys = (String[]) dict.keySet().toArray(new String[dict.keySet().size()]);
PythonVariables pyvars2 = new PythonVariables(); PythonVariables pyvars2 = new PythonVariables();
for (String subkey: keys){ for (String subkey : keys) {
Object value = dict.get(subkey); Object value = dict.get(subkey);
if (value instanceof Map){ if (value instanceof Map) {
Map map = (Map)value; Map map = (Map) value;
if (map.containsKey("_is_numpy_array")){ if (map.containsKey("_is_numpy_array")) {
pyvars2.addNDArray(subkey, mapToNumpyArray(map)); pyvars2.addNDArray(subkey, mapToNumpyArray(map));
} } else {
else{ pyvars2.addDict(subkey, (Map) value);
pyvars2.addDict(subkey, (Map)value);
} }
} } else if (value instanceof List) {
else if (value instanceof List){
pyvars2.addList(subkey, ((List) value).toArray()); pyvars2.addList(subkey, ((List) value).toArray());
} } else if (value instanceof String) {
else if (value instanceof String){ System.out.println((String) value);
System.out.println((String)value);
pyvars2.addStr(subkey, (String) value); pyvars2.addStr(subkey, (String) value);
} } else if (value instanceof Integer || value instanceof Long) {
else if (value instanceof Integer || value instanceof Long) {
Number number = (Number) value; Number number = (Number) value;
pyvars2.addInt(subkey, number.intValue()); pyvars2.addInt(subkey, number.intValue());
} } else if (value instanceof Float || value instanceof Double) {
else if (value instanceof Float || value instanceof Double) {
Number number = (Number) value; Number number = (Number) value;
pyvars2.addFloat(subkey, number.doubleValue()); pyvars2.addFloat(subkey, number.doubleValue());
} } else if (value instanceof NumpyArray) {
else if (value instanceof NumpyArray){ pyvars2.addNDArray(subkey, (NumpyArray) value);
pyvars2.addNDArray(subkey, (NumpyArray)value); } else if (value == null) {
}
else if (value == null){
pyvars2.addStr(subkey, "None"); // FixMe pyvars2.addStr(subkey, "None"); // FixMe
} } else {
else{
throw new RuntimeException("Unsupported type!" + value); throw new RuntimeException("Unsupported type!" + value);
} }
} }
return pyvars2; return pyvars2;
} }
public static long[] jsonArrayToLongArray(JSONArray jsonArray){ public static long[] jsonArrayToLongArray(JSONArray jsonArray) {
long[] longs = new long[jsonArray.length()]; long[] longs = new long[jsonArray.length()];
for (int i=0; i<longs.length; i++){ for (int i = 0; i < longs.length; i++) {
longs[i] = jsonArray.getLong(i); longs[i] = jsonArray.getLong(i);
} }
return longs; return longs;
} }
public static Map<String, Object> toMap(JSONObject jsonobj) { public static Map<String, Object> toMap(JSONObject jsonobj) {
Map<String, Object> map = new HashMap<>(); Map<String, Object> map = new HashMap<>();
String[] keys = (String[])jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]); String[] keys = (String[]) jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]);
for (String key: keys){ for (String key : keys) {
Object value = jsonobj.get(key); Object value = jsonobj.get(key);
if (value instanceof JSONArray) { if (value instanceof JSONArray) {
value = toList((JSONArray) value); value = toList((JSONArray) value);
} else if (value instanceof JSONObject) { } else if (value instanceof JSONObject) {
JSONObject jsonobj2 = (JSONObject)value; JSONObject jsonobj2 = (JSONObject) value;
if (jsonobj2.has("_is_numpy_array")){ if (jsonobj2.has("_is_numpy_array")) {
value = jsonToNumpyArray(jsonobj2); value = jsonToNumpyArray(jsonobj2);
} } else {
else{
value = toMap(jsonobj2); value = toMap(jsonobj2);
} }
} }
map.put(key, value); map.put(key, value);
} return map; }
return map;
} }
@ -265,40 +261,35 @@ public class PythonUtils {
} }
private static NumpyArray jsonToNumpyArray(JSONObject map){ private static NumpyArray jsonToNumpyArray(JSONObject map) {
String dtypeName = (String)map.get("dtype"); String dtypeName = (String) map.get("dtype");
DataType dtype; DataType dtype;
if (dtypeName.equals("float64")){ if (dtypeName.equals("float64")) {
dtype = DataType.DOUBLE; dtype = DataType.DOUBLE;
} } else if (dtypeName.equals("float32")) {
else if (dtypeName.equals("float32")){
dtype = DataType.FLOAT; dtype = DataType.FLOAT;
} } else if (dtypeName.equals("int16")) {
else if (dtypeName.equals("int16")){
dtype = DataType.SHORT; dtype = DataType.SHORT;
} } else if (dtypeName.equals("int32")) {
else if (dtypeName.equals("int32")){
dtype = DataType.INT; dtype = DataType.INT;
} } else if (dtypeName.equals("int64")) {
else if (dtypeName.equals("int64")){
dtype = DataType.LONG; dtype = DataType.LONG;
} } else {
else{
throw new RuntimeException("Unsupported array type " + dtypeName + "."); throw new RuntimeException("Unsupported array type " + dtypeName + ".");
} }
List shapeList = (List)map.get("shape"); List shapeList = map.getJSONArray("shape").toList();
long[] shape = new long[shapeList.size()]; long[] shape = new long[shapeList.size()];
for (int i = 0; i < shape.length; i++) { for (int i = 0; i < shape.length; i++) {
shape[i] = (Long)shapeList.get(i); shape[i] = ((Number) shapeList.get(i)).longValue();
} }
List strideList = (List)map.get("shape"); List strideList = map.getJSONArray("shape").toList();
long[] stride = new long[strideList.size()]; long[] stride = new long[strideList.size()];
for (int i = 0; i < stride.length; i++) { for (int i = 0; i < stride.length; i++) {
stride[i] = (Long)strideList.get(i); stride[i] = ((Number) strideList.get(i)).longValue();
} }
long address = (Long)map.get("address"); long address = ((Number) map.get("address")).longValue();
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype); NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true);
return numpyArray; return numpyArray;
} }

View File

@ -17,13 +17,19 @@
package org.datavec.python; package org.datavec.python;
import lombok.Data; import lombok.Data;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.json.JSONObject; import org.json.JSONObject;
import org.json.JSONArray; import org.json.JSONArray;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.io.Serializable; import java.io.Serializable;
import java.nio.ByteBuffer;
import java.util.*; import java.util.*;
/** /**
* Holds python variable names, types and values. * Holds python variable names, types and values.
* Also handles mapping from java types to python types. * Also handles mapping from java types to python types.
@ -33,41 +39,31 @@ import java.util.*;
@lombok.Data @lombok.Data
public class PythonVariables implements java.io.Serializable { public class PythonVariables implements java.io.Serializable {
public enum Type{
BOOL,
STR,
INT,
FLOAT,
NDARRAY,
LIST,
FILE,
DICT
}
private java.util.Map<String, String> strVariables = new java.util.LinkedHashMap<>(); private java.util.Map<String, String> strVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, Long> intVariables = new java.util.LinkedHashMap<>(); private java.util.Map<String, Long> intVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, Double> floatVariables = new java.util.LinkedHashMap<>(); private java.util.Map<String, Double> floatVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, Boolean> boolVariables = new java.util.LinkedHashMap<>(); private java.util.Map<String, Boolean> boolVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, NumpyArray> ndVars = new java.util.LinkedHashMap<>(); private java.util.Map<String, INDArray> ndVars = new java.util.LinkedHashMap<>();
private java.util.Map<String, Object[]> listVariables = new java.util.LinkedHashMap<>(); private java.util.Map<String, List> listVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, String> fileVariables = new java.util.LinkedHashMap<>(); private java.util.Map<String, BytePointer> bytesVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, java.util.Map<?,?>> dictVariables = new java.util.LinkedHashMap<>(); private java.util.Map<String, java.util.Map<?, ?>> dictVariables = new java.util.LinkedHashMap<>();
private java.util.Map<String, Type> vars = new java.util.LinkedHashMap<>(); private java.util.Map<String, PythonType.TypeName> vars = new java.util.LinkedHashMap<>();
private java.util.Map<Type, java.util.Map> maps = new java.util.LinkedHashMap<>(); private java.util.Map<PythonType.TypeName, java.util.Map> maps = new java.util.LinkedHashMap<>();
/** /**
* Returns a copy of the variable * Returns a copy of the variable
* schema in this array without the values * schema in this array without the values
*
* @return an empty variables clone * @return an empty variables clone
* with no values * with no values
*/ */
public PythonVariables copySchema(){ public PythonVariables copySchema() {
PythonVariables ret = new PythonVariables(); PythonVariables ret = new PythonVariables();
for (String varName: getVariables()){ for (String varName : getVariables()) {
Type type = getType(varName); PythonType type = getType(varName);
ret.add(varName, type); ret.add(varName, type);
} }
return ret; return ret;
@ -77,21 +73,19 @@ public class PythonVariables implements java.io.Serializable {
* *
*/ */
public PythonVariables() { public PythonVariables() {
maps.put(PythonVariables.Type.BOOL, boolVariables); maps.put(PythonType.TypeName.BOOL, boolVariables);
maps.put(PythonVariables.Type.STR, strVariables); maps.put(PythonType.TypeName.STR, strVariables);
maps.put(PythonVariables.Type.INT, intVariables); maps.put(PythonType.TypeName.INT, intVariables);
maps.put(PythonVariables.Type.FLOAT, floatVariables); maps.put(PythonType.TypeName.FLOAT, floatVariables);
maps.put(PythonVariables.Type.NDARRAY, ndVars); maps.put(PythonType.TypeName.NDARRAY, ndVars);
maps.put(PythonVariables.Type.LIST, listVariables); maps.put(PythonType.TypeName.LIST, listVariables);
maps.put(PythonVariables.Type.FILE, fileVariables); maps.put(PythonType.TypeName.DICT, dictVariables);
maps.put(PythonVariables.Type.DICT, dictVariables); maps.put(PythonType.TypeName.BYTES, bytesVariables);
} }
/** /**
*
* @return true if there are no variables. * @return true if there are no variables.
*/ */
public boolean isEmpty() { public boolean isEmpty() {
@ -100,12 +94,11 @@ public class PythonVariables implements java.io.Serializable {
/** /**
*
* @param name Name of the variable * @param name Name of the variable
* @param type Type of the variable * @param type Type of the variable
*/ */
public void add(String name, Type type){ public void add(String name, PythonType type) {
switch (type){ switch (type.getName()) {
case BOOL: case BOOL:
addBool(name); addBool(name);
break; break;
@ -124,21 +117,21 @@ public class PythonVariables implements java.io.Serializable {
case LIST: case LIST:
addList(name); addList(name);
break; break;
case FILE:
addFile(name);
break;
case DICT: case DICT:
addDict(name); addDict(name);
break;
case BYTES:
addBytes(name);
break;
} }
} }
/** /**
* * @param name name of the variable
* @param name name of the variable * @param type type of the variable
* @param type type of the variable
* @param value value of the variable (must be instance of expected type) * @param value value of the variable (must be instance of expected type)
*/ */
public void add(String name, Type type, Object value) { public void add(String name, PythonType type, Object value) throws PythonException {
add(name, type); add(name, type);
setValue(name, value); setValue(name, value);
} }
@ -148,21 +141,23 @@ public class PythonVariables implements java.io.Serializable {
* Add a null variable to * Add a null variable to
* the set of variables * the set of variables
* to describe the type but no value * to describe the type but no value
*
* @param name the field to add * @param name the field to add
*/ */
public void addDict(String name) { public void addDict(String name) {
vars.put(name, PythonVariables.Type.DICT); vars.put(name, PythonType.TypeName.DICT);
dictVariables.put(name,null); dictVariables.put(name, null);
} }
/** /**
* Add a null variable to * Add a null variable to
* the set of variables * the set of variables
* to describe the type but no value * to describe the type but no value
*
* @param name the field to add * @param name the field to add
*/ */
public void addBool(String name){ public void addBool(String name) {
vars.put(name, PythonVariables.Type.BOOL); vars.put(name, PythonType.TypeName.BOOL);
boolVariables.put(name, null); boolVariables.put(name, null);
} }
@ -170,10 +165,11 @@ public class PythonVariables implements java.io.Serializable {
* Add a null variable to * Add a null variable to
* the set of variables * the set of variables
* to describe the type but no value * to describe the type but no value
*
* @param name the field to add * @param name the field to add
*/ */
public void addStr(String name){ public void addStr(String name) {
vars.put(name, PythonVariables.Type.STR); vars.put(name, PythonType.TypeName.STR);
strVariables.put(name, null); strVariables.put(name, null);
} }
@ -181,10 +177,11 @@ public class PythonVariables implements java.io.Serializable {
* Add a null variable to * Add a null variable to
* the set of variables * the set of variables
* to describe the type but no value * to describe the type but no value
*
* @param name the field to add * @param name the field to add
*/ */
public void addInt(String name){ public void addInt(String name) {
vars.put(name, PythonVariables.Type.INT); vars.put(name, PythonType.TypeName.INT);
intVariables.put(name, null); intVariables.put(name, null);
} }
@ -192,10 +189,11 @@ public class PythonVariables implements java.io.Serializable {
* Add a null variable to * Add a null variable to
* the set of variables * the set of variables
* to describe the type but no value * to describe the type but no value
*
* @param name the field to add * @param name the field to add
*/ */
public void addFloat(String name){ public void addFloat(String name) {
vars.put(name, PythonVariables.Type.FLOAT); vars.put(name, PythonType.TypeName.FLOAT);
floatVariables.put(name, null); floatVariables.put(name, null);
} }
@ -203,10 +201,11 @@ public class PythonVariables implements java.io.Serializable {
* Add a null variable to * Add a null variable to
* the set of variables * the set of variables
* to describe the type but no value * to describe the type but no value
*
* @param name the field to add * @param name the field to add
*/ */
public void addNDArray(String name){ public void addNDArray(String name) {
vars.put(name, PythonVariables.Type.NDARRAY); vars.put(name, PythonType.TypeName.NDARRAY);
ndVars.put(name, null); ndVars.put(name, null);
} }
@ -214,99 +213,109 @@ public class PythonVariables implements java.io.Serializable {
* Add a null variable to * Add a null variable to
* the set of variables * the set of variables
* to describe the type but no value * to describe the type but no value
*
* @param name the field to add * @param name the field to add
*/ */
public void addList(String name){ public void addList(String name) {
vars.put(name, PythonVariables.Type.LIST); vars.put(name, PythonType.TypeName.LIST);
listVariables.put(name, null); listVariables.put(name, null);
} }
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
*/
public void addFile(String name){
vars.put(name, PythonVariables.Type.FILE);
fileVariables.put(name, null);
}
/** /**
* Add a boolean variable to * Add a boolean variable to
* the set of variables * the set of variables
* @param name the field to add *
* @param name the field to add
* @param value the value to add * @param value the value to add
*/ */
public void addBool(String name, boolean value) { public void addBool(String name, boolean value) {
vars.put(name, PythonVariables.Type.BOOL); vars.put(name, PythonType.TypeName.BOOL);
boolVariables.put(name, value); boolVariables.put(name, value);
} }
/** /**
* Add a string variable to * Add a string variable to
* the set of variables * the set of variables
* @param name the field to add *
* @param name the field to add
* @param value the value to add * @param value the value to add
*/ */
public void addStr(String name, String value) { public void addStr(String name, String value) {
vars.put(name, PythonVariables.Type.STR); vars.put(name, PythonType.TypeName.STR);
strVariables.put(name, value); strVariables.put(name, value);
} }
/** /**
* Add an int variable to * Add an int variable to
* the set of variables * the set of variables
* @param name the field to add *
* @param name the field to add
* @param value the value to add * @param value the value to add
*/ */
public void addInt(String name, int value) { public void addInt(String name, int value) {
vars.put(name, PythonVariables.Type.INT); vars.put(name, PythonType.TypeName.INT);
intVariables.put(name, (long)value); intVariables.put(name, (long) value);
} }
/** /**
* Add a long variable to * Add a long variable to
* the set of variables * the set of variables
* @param name the field to add *
* @param name the field to add
* @param value the value to add * @param value the value to add
*/ */
public void addInt(String name, long value) { public void addInt(String name, long value) {
vars.put(name, PythonVariables.Type.INT); vars.put(name, PythonType.TypeName.INT);
intVariables.put(name, value); intVariables.put(name, value);
} }
/** /**
* Add a double variable to * Add a double variable to
* the set of variables * the set of variables
* @param name the field to add *
* @param name the field to add
* @param value the value to add * @param value the value to add
*/ */
public void addFloat(String name, double value) { public void addFloat(String name, double value) {
vars.put(name, PythonVariables.Type.FLOAT); vars.put(name, PythonType.TypeName.FLOAT);
floatVariables.put(name, value); floatVariables.put(name, value);
} }
/** /**
* Add a float variable to * Add a float variable to
* the set of variables * the set of variables
* @param name the field to add *
* @param name the field to add
* @param value the value to add * @param value the value to add
*/ */
public void addFloat(String name, float value) { public void addFloat(String name, float value) {
vars.put(name, PythonVariables.Type.FLOAT); vars.put(name, PythonType.TypeName.FLOAT);
floatVariables.put(name, (double)value); floatVariables.put(name, (double) value);
} }
/** /**
* Add a null variable to * Add a null variable to
* the set of variables * the set of variables
* to describe the type but no value * to describe the type but no value
* @param name the field to add *
* @param name the field to add
* @param value the value to add * @param value the value to add
*/ */
public void addNDArray(String name, NumpyArray value) { public void addNDArray(String name, NumpyArray value) {
vars.put(name, PythonVariables.Type.NDARRAY); vars.put(name, PythonType.TypeName.NDARRAY);
ndVars.put(name, value.getNd4jArray());
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
*
* @param name the field to add
* @param value the value to add
*/
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
vars.put(name, PythonType.TypeName.NDARRAY);
ndVars.put(name, value); ndVars.put(name, value);
} }
@ -314,117 +323,63 @@ public class PythonVariables implements java.io.Serializable {
* Add a null variable to * Add a null variable to
* the set of variables * the set of variables
* to describe the type but no value * to describe the type but no value
* @param name the field to add *
* @param value the value to add * @param name the field to add
*/
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
vars.put(name, PythonVariables.Type.NDARRAY);
ndVars.put(name, new NumpyArray(value));
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add * @param value the value to add
*/ */
public void addList(String name, Object[] value) { public void addList(String name, Object[] value) {
vars.put(name, PythonVariables.Type.LIST); vars.put(name, PythonType.TypeName.LIST);
listVariables.put(name, value); listVariables.put(name, Arrays.asList(value));
} }
/** /**
* Add a null variable to * Add a null variable to
* the set of variables * the set of variables
* to describe the type but no value * to describe the type but no value
* @param name the field to add *
* @param value the value to add * @param name the field to add
*/
public void addFile(String name, String value) {
vars.put(name, PythonVariables.Type.FILE);
fileVariables.put(name, value);
}
/**
* Add a null variable to
* the set of variables
* to describe the type but no value
* @param name the field to add
* @param value the value to add * @param value the value to add
*/ */
public void addDict(String name, java.util.Map value) { public void addDict(String name, java.util.Map value) {
vars.put(name, PythonVariables.Type.DICT); vars.put(name, PythonType.TypeName.DICT);
dictVariables.put(name, value); dictVariables.put(name, value);
} }
public void addBytes(String name){
vars.put(name, PythonType.TypeName.BYTES);
bytesVariables.put(name, null);
}
public void addBytes(String name, BytePointer value){
vars.put(name, PythonType.TypeName.BYTES);
bytesVariables.put(name, value);
}
// public void addBytes(String name, ByteBuffer value){
// Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress((value.address());
// BytePointer bp = new BytePointer(ptr);
// addBytes(name, bp);
// }
/** /**
* * @param name name of the variable
* @param name name of the variable
* @param value new value for the variable * @param value new value for the variable
*/ */
public void setValue(String name, Object value) { public void setValue(String name, Object value) throws PythonException {
Type type = vars.get(name); PythonType.TypeName type = vars.get(name);
if (type == PythonVariables.Type.BOOL){ maps.get(type).put(name, PythonType.valueOf(type).convert(value));
boolVariables.put(name, (Boolean)value);
}
else if (type == PythonVariables.Type.INT){
Number number = (Number) value;
intVariables.put(name, number.longValue());
}
else if (type == PythonVariables.Type.FLOAT){
Number number = (Number) value;
floatVariables.put(name, number.doubleValue());
}
else if (type == PythonVariables.Type.NDARRAY){
if (value instanceof NumpyArray){
ndVars.put(name, (NumpyArray)value);
}
else if (value instanceof org.nd4j.linalg.api.ndarray.INDArray) {
ndVars.put(name, new NumpyArray((org.nd4j.linalg.api.ndarray.INDArray) value));
}
else{
throw new RuntimeException("Unsupported type: " + value.getClass().toString());
}
}
else if (type == PythonVariables.Type.LIST) {
if (value instanceof java.util.List) {
value = ((java.util.List) value).toArray();
listVariables.put(name, (Object[]) value);
}
else if(value instanceof org.json.JSONArray) {
org.json.JSONArray jsonArray = (org.json.JSONArray) value;
Object[] copyArr = new Object[jsonArray.length()];
for(int i = 0; i < copyArr.length; i++) {
copyArr[i] = jsonArray.get(i);
}
listVariables.put(name, copyArr);
}
else {
listVariables.put(name, (Object[]) value);
}
}
else if(type == PythonVariables.Type.DICT) {
dictVariables.put(name,(java.util.Map<?,?>) value);
}
else if (type == PythonVariables.Type.FILE){
fileVariables.put(name, (String)value);
}
else{
strVariables.put(name, (String)value);
}
} }
/** /**
* Do a general object lookup. * Do a general object lookup.
* The look up will happen relative to the {@link Type} * The look up will happen relative to the {@link PythonType}
* of variable is described in the * of variable is described in the
*
* @param name the name of the variable to get * @param name the name of the variable to get
* @return teh value for the variable with the given name * @return teh value for the variable with the given name
*/ */
public Object getValue(String name) { public Object getValue(String name) {
Type type = vars.get(name); PythonType.TypeName type = vars.get(name);
java.util.Map map = maps.get(type); java.util.Map map = maps.get(type);
return map.get(name); return map.get(name);
} }
@ -432,6 +387,7 @@ public class PythonVariables implements java.io.Serializable {
/** /**
* Returns a boolean variable with the given name. * Returns a boolean variable with the given name.
*
* @param name the variable name to get the value for * @param name the variable name to get the value for
* @return the retrieved boolean value * @return the retrieved boolean value
*/ */
@ -440,80 +396,78 @@ public class PythonVariables implements java.io.Serializable {
} }
/** /**
*
* @param name the variable name * @param name the variable name
* @return the dictionary value * @return the dictionary value
*/ */
public java.util.Map<?,?> getDictValue(String name) { public java.util.Map<?, ?> getDictValue(String name) {
return dictVariables.get(name); return dictVariables.get(name);
} }
/** /**
/** * /**
* *
* @param name the variable name * @param name the variable name
* @return the string value * @return the string value
*/ */
public String getStrValue(String name){ public String getStrValue(String name) {
return strVariables.get(name); return strVariables.get(name);
} }
/** /**
*
* @param name the variable name * @param name the variable name
* @return the long value * @return the long value
*/ */
public Long getIntValue(String name){ public Long getIntValue(String name) {
return intVariables.get(name); return intVariables.get(name);
} }
/** /**
*
* @param name the variable name * @param name the variable name
* @return the float value * @return the float value
*/ */
public Double getFloatValue(String name){ public Double getFloatValue(String name) {
return floatVariables.get(name); return floatVariables.get(name);
} }
/** /**
*
* @param name the variable name * @param name the variable name
* @return the numpy array value * @return the numpy array value
*/ */
public NumpyArray getNDArrayValue(String name){ public INDArray getNDArrayValue(String name) {
return ndVars.get(name); return ndVars.get(name);
} }
/** /**
*
* @param name the variable name * @param name the variable name
* @return the list value as an object array * @return the list value as an object array
*/ */
public Object[] getListValue(String name){ public List getListValue(String name) {
return listVariables.get(name); return listVariables.get(name);
} }
/** /**
*
* @param name the variable name * @param name the variable name
* @return the value of the given file name * @return the bytes value as a BytePointer
*/ */
public String getFileValue(String name){ public BytePointer getBytesValue(String name){return bytesVariables.get(name);}
return fileVariables.get(name);
}
/** /**
* Returns the type for the given variable name * Returns the type for the given variable name
*
* @param name the name of the variable to get the type for * @param name the name of the variable to get the type for
* @return the type for the given variable * @return the type for the given variable
*/ */
public Type getType(String name){ public PythonType getType(String name){
return vars.get(name); try{
return PythonType.valueOf(vars.get(name)); // will never fail
}catch (Exception e)
{
throw new RuntimeException(e);
}
} }
/** /**
* Get all the variables present as a string array * Get all the variables present as a string array
*
* @return the variable names for this variable sset * @return the variable names for this variable sset
*/ */
public String[] getVariables() { public String[] getVariables() {
@ -524,11 +478,12 @@ public class PythonVariables implements java.io.Serializable {
/** /**
* This variables set as its json representation (an array of json objects) * This variables set as its json representation (an array of json objects)
*
* @return the json array output * @return the json array output
*/ */
public org.json.JSONArray toJSON(){ public org.json.JSONArray toJSON() {
org.json.JSONArray arr = new org.json.JSONArray(); org.json.JSONArray arr = new org.json.JSONArray();
for (String varName: getVariables()){ for (String varName : getVariables()) {
org.json.JSONObject var = new org.json.JSONObject(); org.json.JSONObject var = new org.json.JSONObject();
var.put("name", varName); var.put("name", varName);
String varType = getType(varName).toString(); String varType = getType(varName).toString();
@ -542,13 +497,14 @@ public class PythonVariables implements java.io.Serializable {
* Create a schema from a map. * Create a schema from a map.
* This is an empty PythonVariables * This is an empty PythonVariables
* that just contains names and types with no values * that just contains names and types with no values
*
* @param inputTypes the input types to convert * @param inputTypes the input types to convert
* @return the schema from the given map * @return the schema from the given map
*/ */
public static PythonVariables schemaFromMap(java.util.Map<String,String> inputTypes) { public static PythonVariables schemaFromMap(java.util.Map<String, String> inputTypes) throws Exception{
PythonVariables ret = new PythonVariables(); PythonVariables ret = new PythonVariables();
for(java.util.Map.Entry<String,String> entry : inputTypes.entrySet()) { for (java.util.Map.Entry<String, String> entry : inputTypes.entrySet()) {
ret.add(entry.getKey(), PythonVariables.Type.valueOf(entry.getValue())); ret.add(entry.getKey(), PythonType.valueOf(entry.getValue()));
} }
return ret; return ret;
@ -557,39 +513,17 @@ public class PythonVariables implements java.io.Serializable {
/** /**
* Get the python variable state relative to the * Get the python variable state relative to the
* input json array * input json array
*
* @param jsonArray the input json array * @param jsonArray the input json array
* @return the python variables based on the input json array * @return the python variables based on the input json array
*/ */
public static PythonVariables fromJSON(org.json.JSONArray jsonArray){ public static PythonVariables fromJSON(org.json.JSONArray jsonArray) {
PythonVariables pyvars = new PythonVariables(); PythonVariables pyvars = new PythonVariables();
for (int i = 0; i < jsonArray.length(); i++) { for (int i = 0; i < jsonArray.length(); i++) {
org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i); org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i);
String varName = (String)input.get("name"); String varName = (String) input.get("name");
String varType = (String)input.get("type"); String varType = (String) input.get("type");
if (varType.equals("BOOL")) { pyvars.maps.get(PythonType.TypeName.valueOf(varType)).put(varName, null);
pyvars.addBool(varName);
}
else if (varType.equals("INT")) {
pyvars.addInt(varName);
}
else if (varType.equals("FlOAT")){
pyvars.addFloat(varName);
}
else if (varType.equals("STR")) {
pyvars.addStr(varName);
}
else if (varType.equals("LIST")) {
pyvars.addList(varName);
}
else if (varType.equals("FILE")){
pyvars.addFile(varName);
}
else if (varType.equals("NDARRAY")) {
pyvars.addNDArray(varName);
}
else if(varType.equals("DICT")) {
pyvars.addDict(varName);
}
} }
return pyvars; return pyvars;

View File

@ -1,5 +0,0 @@
#See: https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script
import sys
this = sys.modules[__name__]
for n in dir():
if n[0]!='_': delattr(this, n)

View File

@ -1,20 +0,0 @@
def __is_numpy_array(x):
return str(type(x))== "<class 'numpy.ndarray'>"
def maybe_serialize_ndarray_metadata(x):
return serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
def serialize_ndarray_metadata(x):
return {"address": x.__array_interface__['data'][0],
"shape": x.shape,
"strides": x.strides,
"dtype": str(x.dtype),
"_is_numpy_array": True} if __is_numpy_array(x) else x
def is_json_ready(key, value):
return key is not 'f2' and not inspect.ismodule(value) \
and not hasattr(value, '__call__')

View File

@ -1,202 +0,0 @@
#patch
"""Implementation of __array_function__ overrides from NEP-18."""
import collections
import functools
import os
from numpy.core._multiarray_umath import (
add_docstring, implement_array_function, _get_implementing_args)
from numpy.compat._inspect import getargspec
ENABLE_ARRAY_FUNCTION = bool(
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
ARRAY_FUNCTION_ENABLED = ENABLE_ARRAY_FUNCTION # backward compat
_add_docstring = add_docstring
def add_docstring(*args):
try:
_add_docstring(*args)
except:
pass
add_docstring(
implement_array_function,
"""
Implement a function with checks for __array_function__ overrides.
All arguments are required, and can only be passed by position.
Arguments
---------
implementation : function
Function that implements the operation on NumPy array without
overrides when called like ``implementation(*args, **kwargs)``.
public_api : function
Function exposed by NumPy's public API originally called like
``public_api(*args, **kwargs)`` on which arguments are now being
checked.
relevant_args : iterable
Iterable of arguments to check for __array_function__ methods.
args : tuple
Arbitrary positional arguments originally passed into ``public_api``.
kwargs : dict
Arbitrary keyword arguments originally passed into ``public_api``.
Returns
-------
Result from calling ``implementation()`` or an ``__array_function__``
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
""")
# exposed for testing purposes; used internally by implement_array_function
add_docstring(
_get_implementing_args,
"""
Collect arguments on which to call __array_function__.
Parameters
----------
relevant_args : iterable of array-like
Iterable of possibly array-like arguments to check for
__array_function__ methods.
Returns
-------
Sequence of arguments with __array_function__ methods, in the order in
which they should be called.
""")
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
def verify_matching_signatures(implementation, dispatcher):
"""Verify that a dispatcher function has the right signature."""
implementation_spec = ArgSpec(*getargspec(implementation))
dispatcher_spec = ArgSpec(*getargspec(dispatcher))
if (implementation_spec.args != dispatcher_spec.args or
implementation_spec.varargs != dispatcher_spec.varargs or
implementation_spec.keywords != dispatcher_spec.keywords or
(bool(implementation_spec.defaults) !=
bool(dispatcher_spec.defaults)) or
(implementation_spec.defaults is not None and
len(implementation_spec.defaults) !=
len(dispatcher_spec.defaults))):
raise RuntimeError('implementation and dispatcher for %s have '
'different function signatures' % implementation)
if implementation_spec.defaults is not None:
if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
raise RuntimeError('dispatcher functions can only use None for '
'default argument values')
def set_module(module):
"""Decorator for overriding __module__ on a function or class.
Example usage::
@set_module('numpy')
def example():
pass
assert example.__module__ == 'numpy'
"""
def decorator(func):
if module is not None:
func.__module__ = module
return func
return decorator
def array_function_dispatch(dispatcher, module=None, verify=True,
docs_from_dispatcher=False):
"""Decorator for adding dispatch with the __array_function__ protocol.
See NEP-18 for example usage.
Parameters
----------
dispatcher : callable
Function that when called like ``dispatcher(*args, **kwargs)`` with
arguments from the NumPy function call returns an iterable of
array-like arguments to check for ``__array_function__``.
module : str, optional
__module__ attribute to set on new function, e.g., ``module='numpy'``.
By default, module is copied from the decorated function.
verify : bool, optional
If True, verify the that the signature of the dispatcher and decorated
function signatures match exactly: all required and optional arguments
should appear in order with the same names, but the default values for
all optional arguments should be ``None``. Only disable verification
if the dispatcher's signature needs to deviate for some particular
reason, e.g., because the function has a signature like
``func(*args, **kwargs)``.
docs_from_dispatcher : bool, optional
If True, copy docs from the dispatcher function onto the dispatched
function, rather than from the implementation. This is useful for
functions defined in C, which otherwise don't have docstrings.
Returns
-------
Function suitable for decorating the implementation of a NumPy function.
"""
if not ENABLE_ARRAY_FUNCTION:
# __array_function__ requires an explicit opt-in for now
def decorator(implementation):
if module is not None:
implementation.__module__ = module
if docs_from_dispatcher:
add_docstring(implementation, dispatcher.__doc__)
return implementation
return decorator
def decorator(implementation):
if verify:
verify_matching_signatures(implementation, dispatcher)
if docs_from_dispatcher:
add_docstring(implementation, dispatcher.__doc__)
@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
return implement_array_function(
implementation, public_api, relevant_args, args, kwargs)
if module is not None:
public_api.__module__ = module
# TODO: remove this when we drop Python 2 support (functools.wraps)
# adds __wrapped__ automatically in later versions)
public_api.__wrapped__ = implementation
return public_api
return decorator
def array_function_from_dispatcher(
implementation, module=None, verify=True, docs_from_dispatcher=True):
"""Like array_function_dispatcher, but with function arguments flipped."""
def decorator(dispatcher):
return array_function_dispatch(
dispatcher, module, verify=verify,
docs_from_dispatcher=docs_from_dispatcher)(implementation)
return decorator

View File

@ -1,172 +0,0 @@
#patch 1
"""
========================
Random Number Generation
========================
==================== =========================================================
Utility functions
==============================================================================
random_sample Uniformly distributed floats over ``[0, 1)``.
random Alias for `random_sample`.
bytes Uniformly distributed random bytes.
random_integers Uniformly distributed integers in a given range.
permutation Randomly permute a sequence / generate a random sequence.
shuffle Randomly permute a sequence in place.
seed Seed the random number generator.
choice Random sample from 1-D array.
==================== =========================================================
==================== =========================================================
Compatibility functions
==============================================================================
rand Uniformly distributed values.
randn Normally distributed values.
ranf Uniformly distributed floating point numbers.
randint Uniformly distributed integers in a given range.
==================== =========================================================
==================== =========================================================
Univariate distributions
==============================================================================
beta Beta distribution over ``[0, 1]``.
binomial Binomial distribution.
chisquare :math:`\\chi^2` distribution.
exponential Exponential distribution.
f F (Fisher-Snedecor) distribution.
gamma Gamma distribution.
geometric Geometric distribution.
gumbel Gumbel distribution.
hypergeometric Hypergeometric distribution.
laplace Laplace distribution.
logistic Logistic distribution.
lognormal Log-normal distribution.
logseries Logarithmic series distribution.
negative_binomial Negative binomial distribution.
noncentral_chisquare Non-central chi-square distribution.
noncentral_f Non-central F distribution.
normal Normal / Gaussian distribution.
pareto Pareto distribution.
poisson Poisson distribution.
power Power distribution.
rayleigh Rayleigh distribution.
triangular Triangular distribution.
uniform Uniform distribution.
vonmises Von Mises circular distribution.
wald Wald (inverse Gaussian) distribution.
weibull Weibull distribution.
zipf Zipf's distribution over ranked data.
==================== =========================================================
==================== =========================================================
Multivariate distributions
==============================================================================
dirichlet Multivariate generalization of Beta distribution.
multinomial Multivariate generalization of the binomial distribution.
multivariate_normal Multivariate generalization of the normal distribution.
==================== =========================================================
==================== =========================================================
Standard distributions
==============================================================================
standard_cauchy Standard Cauchy-Lorentz distribution.
standard_exponential Standard exponential distribution.
standard_gamma Standard Gamma distribution.
standard_normal Standard normal distribution.
standard_t Standard Student's t-distribution.
==================== =========================================================
==================== =========================================================
Internal functions
==============================================================================
get_state Get tuple representing internal state of generator.
set_state Set state of generator.
==================== =========================================================
"""
from __future__ import division, absolute_import, print_function
import warnings
__all__ = [
'beta',
'binomial',
'bytes',
'chisquare',
'choice',
'dirichlet',
'exponential',
'f',
'gamma',
'geometric',
'get_state',
'gumbel',
'hypergeometric',
'laplace',
'logistic',
'lognormal',
'logseries',
'multinomial',
'multivariate_normal',
'negative_binomial',
'noncentral_chisquare',
'noncentral_f',
'normal',
'pareto',
'permutation',
'poisson',
'power',
'rand',
'randint',
'randn',
'random_integers',
'random_sample',
'rayleigh',
'seed',
'set_state',
'shuffle',
'standard_cauchy',
'standard_exponential',
'standard_gamma',
'standard_normal',
'standard_t',
'triangular',
'uniform',
'vonmises',
'wald',
'weibull',
'zipf'
]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
try:
from .mtrand import *
# Some aliases:
ranf = random = sample = random_sample
__all__.extend(['ranf', 'random', 'sample'])
except:
warnings.warn("numpy.random is not available when using multiple interpreters!")
def __RandomState_ctor():
"""Return a RandomState instance.
This function exists solely to assist (un)pickling.
Note that the state of the RandomState returned here is irrelevant, as this function's
entire purpose is to return a newly allocated RandomState whose state pickle can set.
Consequently the RandomState returned by this function is a freshly allocated copy
with a seed=0.
See https://github.com/numpy/numpy/issues/4763 for a detailed discussion
"""
return RandomState(seed=0)
from numpy._pytesttester import PytestTester
test = PytestTester(__name__)
del PytestTester

View File

@ -3,13 +3,13 @@ import traceback
import json import json
import inspect import inspect
__python_exception__ = ""
try: try:
pass pass
sys.stdout.flush() sys.stdout.flush()
sys.stderr.flush() sys.stderr.flush()
except Exception as ex: except Exception as ex:
__python_exception__ = ex
try: try:
exc_info = sys.exc_info() exc_info = sys.exc_info()
finally: finally:

View File

@ -1,50 +0,0 @@
def __is_numpy_array(x):
return str(type(x))== "<class 'numpy.ndarray'>"
def __maybe_serialize_ndarray_metadata(x):
return __serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
def __serialize_ndarray_metadata(x):
return {"address": x.__array_interface__['data'][0],
"shape": x.shape,
"strides": x.strides,
"dtype": str(x.dtype),
"_is_numpy_array": True} if __is_numpy_array(x) else x
def __serialize_list(x):
import json
return json.dumps(__recursive_serialize_list(x))
def __serialize_dict(x):
import json
return json.dumps(__recursive_serialize_dict(x))
def __recursive_serialize_list(x):
out = []
for i in x:
if __is_numpy_array(i):
out.append(__serialize_ndarray_metadata(i))
elif isinstance(i, (list, tuple)):
out.append(__recursive_serialize_list(i))
elif isinstance(i, dict):
out.append(__recursive_serialize_dict(i))
else:
out.append(i)
return out
def __recursive_serialize_dict(x):
out = {}
for k in x:
v = x[k]
if __is_numpy_array(v):
out[k] = __serialize_ndarray_metadata(v)
elif isinstance(v, (list, tuple)):
out[k] = __recursive_serialize_list(v)
elif isinstance(v, dict):
out[k] = __recursive_serialize_dict(v)
else:
out[k] = v
return out

View File

@ -0,0 +1,87 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
@NotThreadSafe
public class TestPythonContextManager {
@Test
public void testInt() throws Exception{
Python.setContext("context1");
Python.exec("a = 1");
Python.setContext("context2");
Python.exec("a = 2");
Python.setContext("context3");
Python.exec("a = 3");
Python.setContext("context1");
Assert.assertEquals(1, PythonExecutioner.getVariable("a").toInt());
Python.setContext("context2");
Assert.assertEquals(2, PythonExecutioner.getVariable("a").toInt());
Python.setContext("context3");
Assert.assertEquals(3, PythonExecutioner.getVariable("a").toInt());
PythonContextManager.deleteNonMainContexts();
}
@Test
public void testNDArray() throws Exception{
Python.setContext("context1");
Python.exec("import numpy as np");
Python.exec("a = np.zeros((3,2)) + 1");
Python.setContext("context2");
Python.exec("import numpy as np");
Python.exec("a = np.zeros((3,2)) + 2");
Python.setContext("context3");
Python.exec("import numpy as np");
Python.exec("a = np.zeros((3,2)) + 3");
Python.setContext("context1");
Python.exec("a += 1");
Python.setContext("context2");
Python.exec("a += 2");
Python.setContext("context3");
Python.exec("a += 3");
INDArray arr = Nd4j.create(DataType.DOUBLE, 3, 2);
Python.setContext("context1");
Assert.assertEquals(arr.add(2), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray());
Python.setContext("context2");
Assert.assertEquals(arr.add(4), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray());
Python.setContext("context3");
Assert.assertEquals(arr.add(6), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray());
}
}

View File

@ -0,0 +1,64 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.var;
import org.json.JSONArray;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
@javax.annotation.concurrent.NotThreadSafe
public class TestPythonDict {
@Test
public void testPythonDictFromMap() throws Exception{
Map<Object, Object> map = new HashMap<>();
map.put("a", 1);
map.put("b", "a");
map.put("1", Arrays.asList(1, 2, 3, "4", Arrays.asList("x", 2.3)));
Map<Object, Object> innerMap = new HashMap<>();
innerMap.put("k", 32);
map.put("inner", innerMap);
map.put("ndarray", Nd4j.linspace(1, 4, 4));
innerMap.put("ndarray", Nd4j.linspace(5, 8, 4));
PythonObject dict = new PythonObject(map);
assertEquals(map.size(), Python.len(dict).toInt());
assertEquals("{'a': 1, '1': [1, 2, 3, '4', ['" +
"x', 2.3]], 'b': 'a', 'inner': {'k': 32," +
" 'ndarray': array([5., 6., 7., 8.], dty" +
"pe=float32)}, 'ndarray': array([1., 2., " +
"3., 4.], dtype=float32)}",
dict.toString());
Map map2 = dict.toMap();
PythonObject dict2 = new PythonObject(map2);
assertEquals(dict.toString(), dict2.toString());
}
}

View File

@ -1,75 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import org.junit.Assert;
import org.junit.Test;
@javax.annotation.concurrent.NotThreadSafe
public class TestPythonExecutionSandbox {
@Test
public void testInt(){
PythonExecutioner.setInterpreter("interp1");
PythonExecutioner.exec("a = 1");
PythonExecutioner.setInterpreter("interp2");
PythonExecutioner.exec("a = 2");
PythonExecutioner.setInterpreter("interp3");
PythonExecutioner.exec("a = 3");
PythonExecutioner.setInterpreter("interp1");
Assert.assertEquals(1, PythonExecutioner.evalInteger("a"));
PythonExecutioner.setInterpreter("interp2");
Assert.assertEquals(2, PythonExecutioner.evalInteger("a"));
PythonExecutioner.setInterpreter("interp3");
Assert.assertEquals(3, PythonExecutioner.evalInteger("a"));
}
@Test
public void testNDArray(){
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("import numpy as np");
PythonExecutioner.exec("a = np.zeros(5)");
PythonExecutioner.setInterpreter("main");
//PythonExecutioner.exec("import numpy as np");
PythonExecutioner.exec("a = np.zeros(5)");
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("a += 2");
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("a += 3");
PythonExecutioner.setInterpreter("main");
//PythonExecutioner.exec("import numpy as np");
// PythonExecutioner.exec("a = np.zeros(5)");
PythonExecutioner.setInterpreter("main");
Assert.assertEquals(25, PythonExecutioner.evalNdArray("a").getNd4jArray().sum().getDouble(), 1e-5);
}
@Test
public void testNumpyRandom(){
PythonExecutioner.setInterpreter("main");
PythonExecutioner.exec("import numpy as np; print(np.random.randint(5))");
}
}

View File

@ -15,13 +15,17 @@
******************************************************************************/ ******************************************************************************/
package org.datavec.python; package org.datavec.python;
import org.bytedeco.javacpp.BytePointer;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
@javax.annotation.concurrent.NotThreadSafe @javax.annotation.concurrent.NotThreadSafe
@ -29,12 +33,12 @@ public class TestPythonExecutioner {
@org.junit.Test @org.junit.Test
public void testPythonSysVersion() { public void testPythonSysVersion() throws PythonException {
PythonExecutioner.exec("import sys; print(sys.version)"); Python.exec("import sys; print(sys.version)");
} }
@Test @Test
public void testStr() throws Exception{ public void testStr() throws Exception {
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -46,7 +50,7 @@ public class TestPythonExecutioner {
String code = "z = x + ' ' + y"; String code = "z = x + ' ' + y";
PythonExecutioner.exec(code, pyInputs, pyOutputs); Python.exec(code, pyInputs, pyOutputs);
String z = pyOutputs.getStrValue("z"); String z = pyOutputs.getStrValue("z");
@ -56,7 +60,7 @@ public class TestPythonExecutioner {
} }
@Test @Test
public void testInt()throws Exception{ public void testInt() throws Exception {
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -68,7 +72,7 @@ public class TestPythonExecutioner {
pyOutputs.addInt("z"); pyOutputs.addInt("z");
PythonExecutioner.exec(code, pyInputs, pyOutputs); Python.exec(code, pyInputs, pyOutputs);
long z = pyOutputs.getIntValue("z"); long z = pyOutputs.getIntValue("z");
@ -77,7 +81,7 @@ public class TestPythonExecutioner {
} }
@Test @Test
public void testList() throws Exception{ public void testList() throws Exception {
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -92,30 +96,28 @@ public class TestPythonExecutioner {
pyOutputs.addList("z"); pyOutputs.addList("z");
PythonExecutioner.exec(code, pyInputs, pyOutputs); Python.exec(code, pyInputs, pyOutputs);
Object[] z = pyOutputs.getListValue("z"); Object[] z = pyOutputs.getListValue("z").toArray();
Assert.assertEquals(z.length, x.length + y.length); Assert.assertEquals(z.length, x.length + y.length);
for (int i = 0; i < x.length; i++) { for (int i = 0; i < x.length; i++) {
if(x[i] instanceof Number) { if (x[i] instanceof Number) {
Number xNum = (Number) x[i]; Number xNum = (Number) x[i];
Number zNum = (Number) z[i]; Number zNum = (Number) z[i];
Assert.assertEquals(xNum.intValue(), zNum.intValue()); Assert.assertEquals(xNum.intValue(), zNum.intValue());
} } else {
else {
Assert.assertEquals(x[i], z[i]); Assert.assertEquals(x[i], z[i]);
} }
} }
for (int i = 0; i < y.length; i++){ for (int i = 0; i < y.length; i++) {
if(y[i] instanceof Number) { if (y[i] instanceof Number) {
Number yNum = (Number) y[i]; Number yNum = (Number) y[i];
Number zNum = (Number) z[x.length + i]; Number zNum = (Number) z[x.length + i];
Assert.assertEquals(yNum.intValue(), zNum.intValue()); Assert.assertEquals(yNum.intValue(), zNum.intValue());
} } else {
else {
Assert.assertEquals(y[i], z[x.length + i]); Assert.assertEquals(y[i], z[x.length + i]);
} }
@ -125,7 +127,7 @@ public class TestPythonExecutioner {
} }
@Test @Test
public void testNDArrayFloat()throws Exception{ public void testNDArrayFloat() throws Exception {
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -135,8 +137,8 @@ public class TestPythonExecutioner {
String code = "z = x + y"; String code = "z = x + y";
PythonExecutioner.exec(code, pyInputs, pyOutputs); Python.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); INDArray z = pyOutputs.getNDArrayValue("z");
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
@ -144,12 +146,13 @@ public class TestPythonExecutioner {
} }
@Test @Test
public void testTensorflowCustomAnaconda() { @Ignore
PythonExecutioner.exec("import tensorflow as tf"); public void testTensorflowCustomAnaconda() throws PythonException {
Python.exec("import tensorflow as tf");
} }
@Test @Test
public void testNDArrayDouble()throws Exception { public void testNDArrayDouble() throws Exception {
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -159,14 +162,14 @@ public class TestPythonExecutioner {
String code = "z = x + y"; String code = "z = x + y";
PythonExecutioner.exec(code, pyInputs, pyOutputs); Python.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); INDArray z = pyOutputs.getNDArrayValue("z");
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
} }
@Test @Test
public void testNDArrayShort()throws Exception{ public void testNDArrayShort() throws Exception {
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -176,15 +179,15 @@ public class TestPythonExecutioner {
String code = "z = x + y"; String code = "z = x + y";
PythonExecutioner.exec(code, pyInputs, pyOutputs); Python.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); INDArray z = pyOutputs.getNDArrayValue("z");
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
} }
@Test @Test
public void testNDArrayInt()throws Exception{ public void testNDArrayInt() throws Exception {
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -194,15 +197,15 @@ public class TestPythonExecutioner {
String code = "z = x + y"; String code = "z = x + y";
PythonExecutioner.exec(code, pyInputs, pyOutputs); Python.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); INDArray z = pyOutputs.getNDArrayValue("z");
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
} }
@Test @Test
public void testNDArrayLong()throws Exception{ public void testNDArrayLong() throws Exception {
PythonVariables pyInputs = new PythonVariables(); PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables();
@ -212,12 +215,91 @@ public class TestPythonExecutioner {
String code = "z = x + y"; String code = "z = x + y";
PythonExecutioner.exec(code, pyInputs, pyOutputs); Python.exec(code, pyInputs, pyOutputs);
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); INDArray z = pyOutputs.getNDArrayValue("z");
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
}
@Test
public void testByteBufferInput() throws Exception{
//ByteBuffer buff = ByteBuffer.allocateDirect(3);
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
buff.putScalar(0, 97); // a
buff.putScalar(1, 98); // b
buff.putScalar(2, 99); // c
PythonVariables pyInputs = new PythonVariables();
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
PythonVariables pyOutputs= new PythonVariables();
pyOutputs.addStr("out");
String code = "out = buff.decode()";
Python.exec(code, pyInputs, pyOutputs);
Assert.assertEquals("abc", pyOutputs.getStrValue("out"));
}
@Test
public void testByteBufferOutputNoCopy() throws Exception{
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
buff.putScalar(0, 97); // a
buff.putScalar(1, 98); // b
buff.putScalar(2, 99); // c
PythonVariables pyInputs = new PythonVariables();
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
PythonVariables pyOutputs = new PythonVariables();
pyOutputs.addBytes("buff"); // same name as input, because inplace update
String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97";
Python.exec(code, pyInputs, pyOutputs);
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
}
@Test
public void testByteBufferOutputWithCopy() throws Exception{
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
buff.putScalar(0, 97); // a
buff.putScalar(1, 98); // b
buff.putScalar(2, 99); // c
PythonVariables pyInputs = new PythonVariables();
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
PythonVariables pyOutputs = new PythonVariables();
pyOutputs.addBytes("out");
String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97\nout=bytes(buff)";
Python.exec(code, pyInputs, pyOutputs);
Assert.assertEquals("cba", pyOutputs.getBytesValue("out").getString());
}
@Test
public void testBadCode() throws Exception{
Python.setContext("badcode");
PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables();
pyInputs.addNDArray("x", Nd4j.zeros(DataType.LONG, 2, 3));
pyInputs.addNDArray("y", Nd4j.ones(DataType.LONG, 2, 3));
pyOutputs.addNDArray("z");
String code = "z = x + a";
try{
Python.exec(code, pyInputs, pyOutputs);
fail("No exception thrown");
} catch (PythonException pe ){
Assert.assertEquals("NameError: name 'a' is not defined", pe.getMessage());
}
Python.setMainContext();
} }
} }

View File

@ -0,0 +1,326 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
@javax.annotation.concurrent.NotThreadSafe
public class TestPythonJob {
@Test
public void testPythonJobBasic() throws Exception{
PythonContextManager.deleteNonMainContexts();
String code = "c = a + b";
PythonJob job = new PythonJob("job1", code, false);
PythonVariables inputs = new PythonVariables();
inputs.addInt("a", 2);
inputs.addInt("b", 3);
PythonVariables outputs = new PythonVariables();
outputs.addInt("c");
job.exec(inputs, outputs);
assertEquals(5L, (long)outputs.getIntValue("c"));
inputs = new PythonVariables();
inputs.addFloat("a", 3.0);
inputs.addFloat("b", 4.0);
outputs = new PythonVariables();
outputs.addFloat("c");
job.exec(inputs, outputs);
assertEquals(7.0, outputs.getFloatValue("c"), 1e-5);
inputs = new PythonVariables();
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
outputs = new PythonVariables();
outputs.addNDArray("c");
job.exec(inputs, outputs);
assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c"));
}
@Test
public void testPythonJobReturnAllVariables()throws Exception{
PythonContextManager.deleteNonMainContexts();
String code = "c = a + b";
PythonJob job = new PythonJob("job1", code, false);
PythonVariables inputs = new PythonVariables();
inputs.addInt("a", 2);
inputs.addInt("b", 3);
PythonVariables outputs = job.execAndReturnAllVariables(inputs);
assertEquals(5L, (long)outputs.getIntValue("c"));
inputs = new PythonVariables();
inputs.addFloat("a", 3.0);
inputs.addFloat("b", 4.0);
outputs = job.execAndReturnAllVariables(inputs);
assertEquals(7.0, outputs.getFloatValue("c"), 1e-5);
inputs = new PythonVariables();
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
outputs = job.execAndReturnAllVariables(inputs);
assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c"));
}
@Test
public void testMultiplePythonJobsParallel()throws Exception{
PythonContextManager.deleteNonMainContexts();
String code1 = "c = a + b";
PythonJob job1 = new PythonJob("job1", code1, false);
String code2 = "c = a - b";
PythonJob job2 = new PythonJob("job2", code2, false);
PythonVariables inputs = new PythonVariables();
inputs.addInt("a", 2);
inputs.addInt("b", 3);
PythonVariables outputs = new PythonVariables();
outputs.addInt("c");
job1.exec(inputs, outputs);
assertEquals(5L, (long)outputs.getIntValue("c"));
job2.exec(inputs, outputs);
assertEquals(-1L, (long)outputs.getIntValue("c"));
inputs = new PythonVariables();
inputs.addFloat("a", 3.0);
inputs.addFloat("b", 4.0);
outputs = new PythonVariables();
outputs.addFloat("c");
job1.exec(inputs, outputs);
assertEquals(7.0, outputs.getFloatValue("c"), 1e-5);
job2.exec(inputs, outputs);
assertEquals(-1L, outputs.getFloatValue("c"), 1e-5);
inputs = new PythonVariables();
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
outputs = new PythonVariables();
outputs.addNDArray("c");
job1.exec(inputs, outputs);
assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c"));
job2.exec(inputs, outputs);
assertEquals(Nd4j.zeros(3, 2).sub(1), outputs.getNDArrayValue("c"));
}
@Test
public void testPythonJobSetupRun()throws Exception{
PythonContextManager.deleteNonMainContexts();
String code = "five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n";
PythonJob job = new PythonJob("job1", code, true);
PythonVariables inputs = new PythonVariables();
inputs.addInt("a", 2);
inputs.addInt("b", 3);
PythonVariables outputs = new PythonVariables();
outputs.addInt("c");
job.exec(inputs, outputs);
assertEquals(10L, (long)outputs.getIntValue("c"));
inputs = new PythonVariables();
inputs.addFloat("a", 3.0);
inputs.addFloat("b", 4.0);
outputs = new PythonVariables();
outputs.addFloat("c");
job.exec(inputs, outputs);
assertEquals(12.0, outputs.getFloatValue("c"), 1e-5);
inputs = new PythonVariables();
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
outputs = new PythonVariables();
outputs.addNDArray("c");
job.exec(inputs, outputs);
assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c"));
}
@Test
public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{
PythonContextManager.deleteNonMainContexts();
String code = "five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n";
PythonJob job = new PythonJob("job1", code, true);
PythonVariables inputs = new PythonVariables();
inputs.addInt("a", 2);
inputs.addInt("b", 3);
PythonVariables outputs = job.execAndReturnAllVariables(inputs);
assertEquals(10L, (long)outputs.getIntValue("c"));
inputs = new PythonVariables();
inputs.addFloat("a", 3.0);
inputs.addFloat("b", 4.0);
outputs = job.execAndReturnAllVariables(inputs);
assertEquals(12.0, outputs.getFloatValue("c"), 1e-5);
inputs = new PythonVariables();
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
outputs = job.execAndReturnAllVariables(inputs);
assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c"));
}
@Test
public void testMultiplePythonJobsSetupRunParallel()throws Exception{
PythonContextManager.deleteNonMainContexts();
String code1 = "five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n";
PythonJob job1 = new PythonJob("job1", code1, true);
String code2 = "five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b - five\n"+
" return {'c':c}\n\n";
PythonJob job2 = new PythonJob("job2", code2, true);
PythonVariables inputs = new PythonVariables();
inputs.addInt("a", 2);
inputs.addInt("b", 3);
PythonVariables outputs = new PythonVariables();
outputs.addInt("c");
job1.exec(inputs, outputs);
assertEquals(10L, (long)outputs.getIntValue("c"));
job2.exec(inputs, outputs);
assertEquals(0L, (long)outputs.getIntValue("c"));
inputs = new PythonVariables();
inputs.addFloat("a", 3.0);
inputs.addFloat("b", 4.0);
outputs = new PythonVariables();
outputs.addFloat("c");
job1.exec(inputs, outputs);
assertEquals(12.0, outputs.getFloatValue("c"), 1e-5);
job2.exec(inputs, outputs);
assertEquals(2L, outputs.getFloatValue("c"), 1e-5);
inputs = new PythonVariables();
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
outputs = new PythonVariables();
outputs.addNDArray("c");
job1.exec(inputs, outputs);
assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c"));
job2.exec(inputs, outputs);
assertEquals(Nd4j.zeros(3, 2).add(4), outputs.getNDArrayValue("c"));
}
}

View File

@ -0,0 +1,108 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.var;
import org.json.JSONArray;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
@javax.annotation.concurrent.NotThreadSafe
public class TestPythonList {
@Test
public void testPythonListFromIntArray() {
PythonObject pyList = new PythonObject(new Integer[]{1, 2, 3, 4, 5});
pyList.attr("append").call(6);
pyList.attr("append").call(7);
pyList.attr("append").call(8);
assertEquals(8, Python.len(pyList).toInt());
for (int i = 0; i < 8; i++) {
assertEquals(i + 1, pyList.get(i).toInt());
}
}
@Test
public void testPythonListFromLongArray() {
PythonObject pyList = new PythonObject(new Long[]{1L, 2L, 3L, 4L, 5L});
pyList.attr("append").call(6);
pyList.attr("append").call(7);
pyList.attr("append").call(8);
assertEquals(8, Python.len(pyList).toInt());
for (int i = 0; i < 8; i++) {
assertEquals(i + 1, pyList.get(i).toInt());
}
}
@Test
public void testPythonListFromDoubleArray() {
PythonObject pyList = new PythonObject(new Double[]{1., 2., 3., 4., 5.});
pyList.attr("append").call(6);
pyList.attr("append").call(7);
pyList.attr("append").call(8);
assertEquals(8, Python.len(pyList).toInt());
for (int i = 0; i < 8; i++) {
assertEquals(i + 1, pyList.get(i).toInt());
assertEquals((double) i + 1, pyList.get(i).toDouble(), 1e-5);
}
}
@Test
public void testPythonListFromStringArray() {
PythonObject pyList = new PythonObject(new String[]{"abcd", "efg"});
pyList.attr("append").call("hijk");
pyList.attr("append").call("lmnop");
assertEquals("abcdefghijklmnop", new PythonObject("").attr("join").call(pyList).toString());
}
@Test
public void testPythonListFromMixedArray()throws Exception {
Map<Object, Object> map = new HashMap<>();
map.put(1, "a");
map.put("a", Arrays.asList("a", "b", "c"));
map.put("arr", Nd4j.linspace(1, 4, 4));
Object[] objs = new Object[]{
1, 2, "a", 3f, 4L, 5.0, Arrays.asList(10,
20, "b", 30f, 40L, 50.0, map
), map
};
PythonObject pyList = new PythonObject(objs);
System.out.println(pyList.toString());
String expectedStr = "[1, 2, 'a', 3.0, 4, 5.0, [10" +
", 20, 'b', 30.0, 40, 50.0, {'arr': array([1.," +
" 2., 3., 4.], dtype=float32), 1: 'a', 'a': [" +
"'a', 'b', 'c']}], {'arr': array([1., 2., 3.," +
" 4.], dtype=float32), 1: 'a', 'a': ['a', 'b', 'c']}]";
assertEquals(expectedStr, pyList.toString());
List objs2 = pyList.toList();
PythonObject pyList2 = new PythonObject(objs2);
assertEquals(pyList.toString(), pyList2.toString());
}
}

View File

@ -1,27 +0,0 @@
package org.datavec.python;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@javax.annotation.concurrent.NotThreadSafe
public class TestPythonSetupAndRun {
@Test
public void testPythonWithSetupAndRun() throws Exception{
String code = "def setup():" +
"global counter;counter=0\n" +
"def run(step):" +
"global counter;" +
"counter+=step;" +
"return {\"counter\":counter}";
PythonVariables pyInputs = new PythonVariables();
pyInputs.addInt("step", 2);
PythonVariables pyOutputs = new PythonVariables();
pyOutputs.addInt("counter");
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
assertEquals((long)pyOutputs.getIntValue("counter"), 2L);
pyInputs.addInt("step", 3);
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
assertEquals((long)pyOutputs.getIntValue("counter"), 5L);
}
}

View File

@ -22,11 +22,14 @@
package org.datavec.python; package org.datavec.python;
import org.bytedeco.javacpp.BytePointer;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import static junit.framework.TestCase.assertNotNull; import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.assertNull; import static junit.framework.TestCase.assertNull;
@ -36,59 +39,50 @@ import static org.junit.Assert.assertTrue;
public class TestPythonVariables { public class TestPythonVariables {
@Test @Test
public void testImportNumpy(){ public void testDataAssociations() throws PythonException{
Nd4j.scalar(1.0);
System.out.println(System.getProperty("org.bytedeco.openblas.load"));
PythonExecutioner.exec("import numpy as np");
}
@Test
public void testDataAssociations() {
PythonVariables pythonVariables = new PythonVariables(); PythonVariables pythonVariables = new PythonVariables();
PythonVariables.Type[] types = { PythonType[] types = {
PythonVariables.Type.INT, PythonType.INT,
PythonVariables.Type.FLOAT, PythonType.FLOAT,
PythonVariables.Type.STR, PythonType.STR,
PythonVariables.Type.BOOL, PythonType.BOOL,
PythonVariables.Type.DICT, PythonType.DICT,
PythonVariables.Type.LIST, PythonType.LIST,
PythonVariables.Type.LIST, PythonType.LIST,
PythonVariables.Type.FILE, PythonType.NDARRAY,
PythonVariables.Type.NDARRAY PythonType.BYTES
}; };
NumpyArray npArr = new NumpyArray(Nd4j.scalar(1.0)); INDArray arr = Nd4j.scalar(1.0);
BytePointer bp = new BytePointer(arr.data().pointer());
Object[] values = { Object[] values = {
1L,1.0,"1",true, Collections.singletonMap("1",1), 1L,1.0,"1",true, Collections.singletonMap("1",1),
new Object[]{1}, Arrays.asList(1),"type", npArr new Object[]{1}, Arrays.asList(1), arr, bp
}; };
Object[] expectedValues = { Object[] expectedValues = {
1L,1.0,"1",true, Collections.singletonMap("1",1), 1L,1.0,"1",true, Collections.singletonMap("1",1),
new Object[]{1}, new Object[]{1},"type", npArr Arrays.asList(1), Arrays.asList(1), arr, bp
}; };
for(int i = 0; i < types.length; i++) { for(int i = 0; i < types.length; i++) {
testInsertGet(pythonVariables,types[i].name() + i,values[i],types[i],expectedValues[i]); testInsertGet(pythonVariables,types[i].getName().name() + i,values[i],types[i],expectedValues[i]);
} }
assertEquals(types.length,pythonVariables.getVariables().length); assertEquals(types.length,pythonVariables.getVariables().length);
} }
private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonVariables.Type type,Object expectedValue) { private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonType type,Object expectedValue) throws PythonException{
pythonVariables.add(key, type); pythonVariables.add(key, type);
assertNull(pythonVariables.getValue(key)); assertNull(pythonVariables.getValue(key));
pythonVariables.setValue(key,value); pythonVariables.setValue(key,value);
assertNotNull(pythonVariables.getValue(key)); assertNotNull(pythonVariables.getValue(key));
Object actualValue = pythonVariables.getValue(key); Object actualValue = pythonVariables.getValue(key);
if (expectedValue instanceof Object[]){ if (expectedValue instanceof Object[]){
assertTrue(actualValue instanceof Object[]); assertTrue(actualValue instanceof List);
Object[] actualArr = (Object[])actualValue; Object[] actualArr = ((List)actualValue).toArray();
Object[] expectedArr = (Object[])expectedValue; Object[] expectedArr = (Object[])expectedValue;
assertArrayEquals(expectedArr, actualArr); assertArrayEquals(expectedArr, actualArr);
} }

View File

@ -44,7 +44,7 @@ public class TestSerde {
String yaml = y.serialize(t); String yaml = y.serialize(t);
String json = j.serialize(t); String json = j.serialize(t);
Transform t2 = y.deserializeTransform(json); Transform t2 = y.deserializeTransform(yaml);
Transform t3 = j.deserializeTransform(json); Transform t3 = j.deserializeTransform(json);
assertEquals(t, t2); assertEquals(t, t2);
assertEquals(t, t3); assertEquals(t, t3);