Python executioner 2.0 (#134)
* wrapper * builtins * context mgr * direct place * ret all var * call fix * fix ndarray serde * jobs * try-with gil management * cleanup * exec tests passing * list tests * transforms test passing * all pass * headers * dict fixes+test * python path * bool isinstance * job tests * nits * transform fix+test * transform tests * leak fixes * more mem leak fixes * more fixes * nits for adam * PythonJob lombok builder * checked exceptions * more nits * small leak fix * more nits * pythonexceptions * fix jvm crash when bad python code * Exception->PythonException * Add support for boolean types in arrow records and ability to cast from float, double to int for TypeConversion (#178) * nits for alex * update tests * fix test * all pass * refacc * rem old code * dtypes * bytes working+exception pass through+cleanup (#209) * more bytes tests * header * rem dummy test * rem bad import * alex nits + refacc * Small error fixes (wrong type in msg) + minor formatting Signed-off-by: AlexDBlack <blacka101@gmail.com> * use actual python type names (dictionary->dict, boolean->bool) Co-authored-by: Shams Ul Azeem <shamsazeem20@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
7ea07de76b
commit
f6b3032def
|
@ -45,7 +45,7 @@ public class TypeConversion {
|
||||||
}
|
}
|
||||||
|
|
||||||
public int convertInt(String o) {
|
public int convertInt(String o) {
|
||||||
return Integer.parseInt(o);
|
return (int) Double.parseDouble(o);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double convertDouble(Writable writable) {
|
public double convertDouble(Writable writable) {
|
||||||
|
|
|
@ -725,7 +725,6 @@ public class ArrowConverter {
|
||||||
case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
|
case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
|
||||||
case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
|
case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
|
||||||
default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i));
|
default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i));
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -802,8 +801,13 @@ public class ArrowConverter {
|
||||||
//for proper offsets
|
//for proper offsets
|
||||||
ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get());
|
ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get());
|
||||||
nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity());
|
nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity());
|
||||||
|
case Boolean:
|
||||||
|
BitVector bitVector = (BitVector) fieldVector;
|
||||||
|
if(value instanceof Boolean)
|
||||||
|
bitVector.set(row, (boolean) value ? 1 : 0);
|
||||||
|
else
|
||||||
|
bitVector.set(row, ((BooleanWritable) value).get() ? 1 : 0);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
}
|
}
|
||||||
}catch(Exception e) {
|
}catch(Exception e) {
|
||||||
log.warn("Unable to set value at row " + row);
|
log.warn("Unable to set value at row " + row);
|
||||||
|
|
|
@ -315,7 +315,7 @@ public class TestPythonTransformProcess {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNumpyTransform() throws Exception {
|
public void testNumpyTransform() {
|
||||||
PythonTransform pythonTransform = PythonTransform.builder()
|
PythonTransform pythonTransform = PythonTransform.builder()
|
||||||
.code("a += 2; b = 'hello world'")
|
.code("a += 2; b = 'hello world'")
|
||||||
.returnAllInputs(true)
|
.returnAllInputs(true)
|
||||||
|
@ -334,7 +334,42 @@ public class TestPythonTransformProcess {
|
||||||
assertFalse(execute.isEmpty());
|
assertFalse(execute.isEmpty());
|
||||||
assertNotNull(execute.get(0));
|
assertNotNull(execute.get(0));
|
||||||
assertNotNull(execute.get(0).get(0));
|
assertNotNull(execute.get(0).get(0));
|
||||||
assertEquals("hello world",execute.get(0).get(0).toString());
|
assertNotNull(execute.get(0).get(1));
|
||||||
|
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
|
||||||
|
assertEquals("hello world",execute.get(0).get(1).toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWithSetupRun() throws Exception {
|
||||||
|
|
||||||
|
PythonTransform pythonTransform = PythonTransform.builder()
|
||||||
|
.code("five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b + five\n"+
|
||||||
|
" return {'c':c}\n\n")
|
||||||
|
.returnAllInputs(true)
|
||||||
|
.setupAndRun(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<List<Writable>> inputs = new ArrayList<>();
|
||||||
|
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
|
||||||
|
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
|
||||||
|
Schema inputSchema = new Builder()
|
||||||
|
.addColumnNDArray("a",new long[]{1,1})
|
||||||
|
.addColumnNDArray("b", new long[]{1, 1})
|
||||||
|
.build();
|
||||||
|
|
||||||
|
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||||
|
.transform(pythonTransform)
|
||||||
|
.build();
|
||||||
|
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||||
|
assertFalse(execute.isEmpty());
|
||||||
|
assertNotNull(execute.get(0));
|
||||||
|
assertNotNull(execute.get(0).get(0));
|
||||||
|
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -29,6 +29,7 @@ import org.nd4j.nativeblas.NativeOps;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -46,55 +47,45 @@ public class NumpyArray {
|
||||||
private long[] strides;
|
private long[] strides;
|
||||||
private DataType dtype;
|
private DataType dtype;
|
||||||
private INDArray nd4jArray;
|
private INDArray nd4jArray;
|
||||||
|
|
||||||
static {
|
static {
|
||||||
//initialize
|
//initialize
|
||||||
Nd4j.scalar(1.0);
|
Nd4j.scalar(1.0);
|
||||||
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) {
|
public NumpyArray(long address, long[] shape, long strides[], DataType dtype, boolean copy) {
|
||||||
this.address = address;
|
this.address = address;
|
||||||
this.shape = shape;
|
this.shape = shape;
|
||||||
this.strides = strides;
|
this.strides = strides;
|
||||||
this.dtype = dtype;
|
this.dtype = dtype;
|
||||||
setND4JArray();
|
setND4JArray();
|
||||||
if (copy){
|
if (copy) {
|
||||||
nd4jArray = nd4jArray.dup();
|
nd4jArray = nd4jArray.dup();
|
||||||
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||||
this.address = nd4jArray.data().address();
|
this.address = nd4jArray.data().address();
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public NumpyArray copy(){
|
|
||||||
|
|
||||||
|
public NumpyArray copy() {
|
||||||
return new NumpyArray(nd4jArray.dup());
|
return new NumpyArray(nd4jArray.dup());
|
||||||
}
|
}
|
||||||
|
|
||||||
public NumpyArray(long address, long[] shape, long strides[]){
|
public NumpyArray(long address, long[] shape, long strides[]) {
|
||||||
this(address, shape, strides, false,DataType.FLOAT);
|
this(address, shape, strides, FLOAT, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public NumpyArray(long address, long[] shape, long strides[], DataType dtype){
|
public NumpyArray(long address, long[] shape, long strides[], DataType dtype) {
|
||||||
this(address, shape, strides, dtype, false);
|
this(address, shape, strides, dtype, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public NumpyArray(long address, long[] shape, long strides[], DataType dtype, boolean copy){
|
|
||||||
this.address = address;
|
|
||||||
this.shape = shape;
|
|
||||||
this.strides = strides;
|
|
||||||
this.dtype = dtype;
|
|
||||||
setND4JArray();
|
|
||||||
if (copy){
|
|
||||||
nd4jArray = nd4jArray.dup();
|
|
||||||
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
|
||||||
this.address = nd4jArray.data().address();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void setND4JArray() {
|
private void setND4JArray() {
|
||||||
long size = 1;
|
long size = 1;
|
||||||
for(long d: shape) {
|
for (long d : shape) {
|
||||||
size *= d;
|
size *= d;
|
||||||
}
|
}
|
||||||
Pointer ptr = nativeOps.pointerForAddress(address);
|
Pointer ptr = nativeOps.pointerForAddress(address);
|
||||||
|
@ -107,11 +98,11 @@ public class NumpyArray {
|
||||||
nd4jStrides[i] = strides[i] / elemSize;
|
nd4jStrides[i] = strides[i] / elemSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype);
|
nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
|
||||||
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||||
}
|
}
|
||||||
|
|
||||||
public NumpyArray(INDArray nd4jArray){
|
public NumpyArray(INDArray nd4jArray) {
|
||||||
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||||
DataBuffer buff = nd4jArray.data();
|
DataBuffer buff = nd4jArray.data();
|
||||||
address = buff.pointer().address();
|
address = buff.pointer().address();
|
||||||
|
@ -119,7 +110,7 @@ public class NumpyArray {
|
||||||
long[] nd4jStrides = nd4jArray.stride();
|
long[] nd4jStrides = nd4jArray.stride();
|
||||||
strides = new long[nd4jStrides.length];
|
strides = new long[nd4jStrides.length];
|
||||||
int elemSize = buff.getElementSize();
|
int elemSize = buff.getElementSize();
|
||||||
for(int i=0; i<strides.length; i++){
|
for (int i = 0; i < strides.length; i++) {
|
||||||
strides[i] = nd4jStrides[i] * elemSize;
|
strides[i] = nd4jStrides[i] * elemSize;
|
||||||
}
|
}
|
||||||
dtype = nd4jArray.dataType();
|
dtype = nd4jArray.dataType();
|
||||||
|
|
|
@ -0,0 +1,265 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import org.bytedeco.cpython.PyObject;
|
||||||
|
|
||||||
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Swift like python wrapper for Java
|
||||||
|
*
|
||||||
|
* @author Fariz Rahman
|
||||||
|
*/
|
||||||
|
|
||||||
|
public class Python {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Imports a python module, similar to python import statement.
|
||||||
|
* @param moduleName name of the module to be imported
|
||||||
|
* @return reference to the module object
|
||||||
|
* @throws PythonException
|
||||||
|
*/
|
||||||
|
public static PythonObject importModule(String moduleName) throws PythonException{
|
||||||
|
PythonObject module = new PythonObject(PyImport_ImportModule(moduleName));
|
||||||
|
if (module.isNone()) {
|
||||||
|
throw new PythonException("Error importing module: " + moduleName);
|
||||||
|
}
|
||||||
|
return module;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject attr(String attrName) {
|
||||||
|
return builtins().attr(attrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject len(PythonObject pythonObject) {
|
||||||
|
return attr("len").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject str(PythonObject pythonObject) {
|
||||||
|
return attr("str").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject str() {
|
||||||
|
return attr("str").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject strType() {
|
||||||
|
return attr("str");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject float_(PythonObject pythonObject) {
|
||||||
|
return attr("float").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject float_() {
|
||||||
|
return attr("float").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject floatType() {
|
||||||
|
return attr("float");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bool(PythonObject pythonObject) {
|
||||||
|
return attr("bool").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bool() {
|
||||||
|
return attr("bool").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject boolType() {
|
||||||
|
return attr("bool");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject int_(PythonObject pythonObject) {
|
||||||
|
return attr("int").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject int_() {
|
||||||
|
return attr("int").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject intType() {
|
||||||
|
return attr("int");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject list(PythonObject pythonObject) {
|
||||||
|
return attr("list").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject list() {
|
||||||
|
return attr("list").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject listType() {
|
||||||
|
return attr("list");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject dict(PythonObject pythonObject) {
|
||||||
|
return attr("dict").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject dict() {
|
||||||
|
return attr("dict").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject dictType() {
|
||||||
|
return attr("dict");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject set(PythonObject pythonObject) {
|
||||||
|
return attr("set").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject set() {
|
||||||
|
return attr("set").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytearray(PythonObject pythonObject) {
|
||||||
|
return attr("bytearray").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytearray() {
|
||||||
|
return attr("bytearray").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytearrayType() {
|
||||||
|
return attr("bytearray");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytes(PythonObject pythonObject) {
|
||||||
|
return attr("bytes").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytes() {
|
||||||
|
return attr("bytes").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytesType() {
|
||||||
|
return attr("bytes");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject tuple(PythonObject pythonObject) {
|
||||||
|
return attr("tuple").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject tuple() {
|
||||||
|
return attr("tuple").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static PythonObject Exception(PythonObject pythonObject) {
|
||||||
|
return attr("Exception").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject Exception() {
|
||||||
|
return attr("Exception").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject ExceptionType() {
|
||||||
|
return attr("Exception");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static PythonObject tupleType() {
|
||||||
|
return attr("tuple");
|
||||||
|
}
|
||||||
|
public static PythonObject globals() {
|
||||||
|
return new PythonObject(PyModule_GetDict(PyImport_ImportModule("__main__")));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject type(PythonObject obj) {
|
||||||
|
return attr("type").call(obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean isinstance(PythonObject obj, PythonObject... type) {
|
||||||
|
return PyObject_IsInstance(obj.getNativePythonObject(),
|
||||||
|
PyList_AsTuple(new PythonObject(type).getNativePythonObject())) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject eval(String code) {
|
||||||
|
PyObject compiledCode = Py_CompileString(code, "", Py_eval_input);
|
||||||
|
PyObject globals = globals().getNativePythonObject();
|
||||||
|
PyObject locals = Python.dict().getNativePythonObject();
|
||||||
|
return new PythonObject(PyEval_EvalCode(compiledCode, globals, locals));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static PythonObject builtins(){
|
||||||
|
try{
|
||||||
|
return importModule("builtins");
|
||||||
|
}catch (PythonException pe){
|
||||||
|
throw new IllegalStateException("Unable to import builtins: " + pe); // this should never happen
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject None() {
|
||||||
|
return dict().attr("get").call(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject True() {
|
||||||
|
return boolType().call(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject False() {
|
||||||
|
return boolType().call(0);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean callable(PythonObject pythonObject) {
|
||||||
|
return PyCallable_Check(pythonObject.getNativePythonObject()) == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static void setContext(String context) throws PythonException{
|
||||||
|
PythonContextManager.setContext(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getCurrentContext(){
|
||||||
|
return PythonContextManager.getCurrentContext();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void deleteContext(String context) throws PythonException{
|
||||||
|
PythonContextManager.deleteContext(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void deleteNonMainContexts(){
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setMainContext(){PythonContextManager.setMainContext();}
|
||||||
|
|
||||||
|
public static void exec(String code)throws PythonException{
|
||||||
|
PythonExecutioner.exec(code);
|
||||||
|
}
|
||||||
|
public static void exec(String code, PythonVariables inputs) throws PythonException{
|
||||||
|
PythonExecutioner.exec(code, inputs);
|
||||||
|
}
|
||||||
|
public static void exec(String code, PythonVariables inputs, PythonVariables outputs) throws PythonException{
|
||||||
|
PythonExecutioner.exec(code, inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonGIL lock(){
|
||||||
|
return PythonGIL.lock();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -16,14 +16,14 @@
|
||||||
|
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnType;
|
|
||||||
import org.datavec.api.transform.condition.Condition;
|
import org.datavec.api.transform.condition.Condition;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
|
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
|
||||||
|
import static org.nd4j.base.Preconditions.checkNotNull;
|
||||||
|
import static org.nd4j.base.Preconditions.checkState;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Lets a condition be defined as a python method f that takes no arguments
|
* Lets a condition be defined as a python method f that takes no arguments
|
||||||
|
@ -41,18 +41,16 @@ public class PythonCondition implements Condition {
|
||||||
|
|
||||||
|
|
||||||
public PythonCondition(String pythonCode) {
|
public PythonCondition(String pythonCode) {
|
||||||
org.nd4j.base.Preconditions.checkNotNull("Python code must not be null!",pythonCode);
|
checkNotNull("Python code must not be null!", pythonCode);
|
||||||
org.nd4j.base.Preconditions.checkState(pythonCode.length() >= 1,"Python code must not be empty!");
|
checkState(!pythonCode.isEmpty(), "Python code must not be empty!");
|
||||||
code = pythonCode;
|
code = pythonCode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setInputSchema(Schema inputSchema) {
|
public void setInputSchema(Schema inputSchema) {
|
||||||
this.inputSchema = inputSchema;
|
this.inputSchema = inputSchema;
|
||||||
try{
|
try {
|
||||||
pyInputs = schemaToPythonVariables(inputSchema);
|
pyInputs = schemaToPythonVariables(inputSchema);
|
||||||
PythonVariables pyOuts = new PythonVariables();
|
PythonVariables pyOuts = new PythonVariables();
|
||||||
pyOuts.addInt("out");
|
pyOuts.addInt("out");
|
||||||
|
@ -62,17 +60,15 @@ public class PythonCondition implements Condition {
|
||||||
.outputs(pyOuts)
|
.outputs(pyOuts)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
}
|
} catch (Exception e) {
|
||||||
catch (Exception e){
|
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Schema getInputSchema(){
|
public Schema getInputSchema() {
|
||||||
return inputSchema;
|
return inputSchema;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,40 +80,39 @@ public class PythonCondition implements Condition {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String outputColumnName(){
|
public String outputColumnName() {
|
||||||
return outputColumnNames()[0];
|
return outputColumnNames()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] columnNames(){
|
public String[] columnNames() {
|
||||||
return outputColumnNames();
|
return outputColumnNames();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String columnName(){
|
public String columnName() {
|
||||||
return outputColumnName();
|
return outputColumnName();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Schema transform(Schema inputSchema){
|
public Schema transform(Schema inputSchema) {
|
||||||
return inputSchema;
|
return inputSchema;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean condition(List<Writable> list) {
|
public boolean condition(List<Writable> list) {
|
||||||
PythonVariables inputs = getPyInputsFromWritables(list);
|
PythonVariables inputs = getPyInputsFromWritables(list);
|
||||||
try{
|
try {
|
||||||
PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs());
|
pythonTransform.getPythonJob().exec(inputs, pythonTransform.getOutputs());
|
||||||
boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0;
|
boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
} catch (Exception e) {
|
||||||
catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean condition(Object input){
|
public boolean condition(Object input) {
|
||||||
return condition(input);
|
return condition(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,28 +130,27 @@ public class PythonCondition implements Condition {
|
||||||
private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
|
private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
|
||||||
PythonVariables ret = new PythonVariables();
|
PythonVariables ret = new PythonVariables();
|
||||||
|
|
||||||
for (int i = 0; i < inputSchema.numColumns(); i++){
|
for (int i = 0; i < inputSchema.numColumns(); i++) {
|
||||||
String name = inputSchema.getName(i);
|
String name = inputSchema.getName(i);
|
||||||
Writable w = writables.get(i);
|
Writable w = writables.get(i);
|
||||||
PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i));
|
PythonType pyType = pyInputs.getType(inputSchema.getName(i));
|
||||||
switch (pyType){
|
switch (pyType.getName()) {
|
||||||
case INT:
|
case INT:
|
||||||
if (w instanceof LongWritable) {
|
if (w instanceof LongWritable) {
|
||||||
ret.addInt(name, ((LongWritable)w).get());
|
ret.addInt(name, ((LongWritable) w).get());
|
||||||
}
|
} else {
|
||||||
else {
|
ret.addInt(name, ((IntWritable) w).get());
|
||||||
ret.addInt(name, ((IntWritable)w).get());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
ret.addFloat(name, ((DoubleWritable)w).get());
|
ret.addFloat(name, ((DoubleWritable) w).get());
|
||||||
break;
|
break;
|
||||||
case STR:
|
case STR:
|
||||||
ret.addStr(name, w.toString());
|
ret.addStr(name, w.toString());
|
||||||
break;
|
break;
|
||||||
case NDARRAY:
|
case NDARRAY:
|
||||||
ret.addNDArray(name,((NDArrayWritable)w).get());
|
ret.addNDArray(name, ((NDArrayWritable) w).get());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,188 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Emulates multiples interpreters in a single interpreter.
|
||||||
|
* This works by simply obfuscating/de-obfuscating variable names
|
||||||
|
* such that only the required subset of the global namespace is "visible"
|
||||||
|
* at any given time.
|
||||||
|
* By default, there exists a "main" context emulating the default interpreter
|
||||||
|
* and cannot be deleted.
|
||||||
|
* @author Fariz Rahman
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
public class PythonContextManager {
|
||||||
|
|
||||||
|
private static Set<String> contexts = new HashSet<>();
|
||||||
|
private static AtomicBoolean init = new AtomicBoolean(false);
|
||||||
|
private static String currentContext;
|
||||||
|
private static final String MAIN_CONTEXT = "main";
|
||||||
|
static {
|
||||||
|
init();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void init() {
|
||||||
|
if (init.get()) return;
|
||||||
|
new PythonExecutioner();
|
||||||
|
init.set(true);
|
||||||
|
currentContext = MAIN_CONTEXT;
|
||||||
|
contexts.add(currentContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static void addContext(String contextName) throws PythonException {
|
||||||
|
if (!validateContextName(contextName)) {
|
||||||
|
throw new PythonException("Invalid context name: " + contextName);
|
||||||
|
}
|
||||||
|
contexts.add(contextName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean hasContext(String contextName) {
|
||||||
|
return contexts.contains(contextName);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static boolean validateContextName(String s) {
|
||||||
|
if (s.length() == 0) return false;
|
||||||
|
if (!Character.isJavaIdentifierStart(s.charAt(0))) return false;
|
||||||
|
for (int i = 1; i < s.length(); i++)
|
||||||
|
if (!Character.isJavaIdentifierPart(s.charAt(i)))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String getContextPrefix(String contextName) {
|
||||||
|
return "__collapsed__" + contextName + "__";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String getCollapsedVarNameForContext(String varName, String contextName) {
|
||||||
|
return getContextPrefix(contextName) + varName;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String expandCollapsedVarName(String varName, String contextName) {
|
||||||
|
String prefix = "__collapsed__" + contextName + "__";
|
||||||
|
return varName.substring(prefix.length());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void collapseContext(String contextName) {
|
||||||
|
PythonObject globals = Python.globals();
|
||||||
|
PythonObject keysList = Python.list(globals.attr("keys").call());
|
||||||
|
int numKeys = Python.len(keysList).toInt();
|
||||||
|
for (int i = 0; i < numKeys; i++) {
|
||||||
|
PythonObject key = keysList.get(i);
|
||||||
|
String keyStr = key.toString();
|
||||||
|
if (!((keyStr.startsWith("__") && keyStr.endsWith("__")) || keyStr.startsWith("__collapsed_"))) {
|
||||||
|
String collapsedKey = getCollapsedVarNameForContext(keyStr, contextName);
|
||||||
|
PythonObject val = globals.attr("pop").call(key);
|
||||||
|
globals.set(new PythonObject(collapsedKey), val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void expandContext(String contextName) {
|
||||||
|
String prefix = getContextPrefix(contextName);
|
||||||
|
PythonObject globals = Python.globals();
|
||||||
|
PythonObject keysList = Python.list(globals.attr("keys").call());
|
||||||
|
int numKeys = Python.len(keysList).toInt();
|
||||||
|
for (int i = 0; i < numKeys; i++) {
|
||||||
|
PythonObject key = keysList.get(i);
|
||||||
|
String keyStr = key.toString();
|
||||||
|
if (keyStr.startsWith(prefix)) {
|
||||||
|
String expandedKey = expandCollapsedVarName(keyStr, contextName);
|
||||||
|
PythonObject val = globals.attr("pop").call(key);
|
||||||
|
globals.set(new PythonObject(expandedKey), val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setContext(String contextName) throws PythonException{
|
||||||
|
if (contextName.equals(currentContext)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!hasContext(contextName)) {
|
||||||
|
addContext(contextName);
|
||||||
|
}
|
||||||
|
collapseContext(currentContext);
|
||||||
|
expandContext(contextName);
|
||||||
|
currentContext = contextName;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setMainContext() {
|
||||||
|
try{
|
||||||
|
setContext(MAIN_CONTEXT);
|
||||||
|
}
|
||||||
|
catch (PythonException pe){
|
||||||
|
throw new RuntimeException(pe);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getCurrentContext() {
|
||||||
|
return currentContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void deleteContext(String contextName) throws PythonException {
|
||||||
|
if (contextName.equals(MAIN_CONTEXT)) {
|
||||||
|
throw new PythonException("Can not delete main context!");
|
||||||
|
}
|
||||||
|
if (contextName.equals(currentContext)) {
|
||||||
|
throw new PythonException("Can not delete current context!");
|
||||||
|
}
|
||||||
|
String prefix = getContextPrefix(contextName);
|
||||||
|
PythonObject globals = Python.globals();
|
||||||
|
PythonObject keysList = Python.list(globals.attr("keys").call());
|
||||||
|
int numKeys = Python.len(keysList).toInt();
|
||||||
|
for (int i = 0; i < numKeys; i++) {
|
||||||
|
PythonObject key = keysList.get(i);
|
||||||
|
String keyStr = key.toString();
|
||||||
|
if (keyStr.startsWith(prefix)) {
|
||||||
|
globals.attr("__delitem__").call(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contexts.remove(contextName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void deleteNonMainContexts() {
|
||||||
|
try{
|
||||||
|
setContext(MAIN_CONTEXT); // will never fail
|
||||||
|
for (String c : contexts.toArray(new String[0])) {
|
||||||
|
if (!c.equals(MAIN_CONTEXT)) {
|
||||||
|
deleteContext(c); // will never fail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}catch(Exception e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public String[] getContexts() {
|
||||||
|
return contexts.toArray(new String[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Thrown when an exception occurs in python land
|
||||||
|
*/
|
||||||
|
public class PythonException extends Exception {
|
||||||
|
public PythonException(String message){
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
private static String getExceptionString(PythonObject exception){
|
||||||
|
if (Python.isinstance(exception, Python.ExceptionType())){
|
||||||
|
String exceptionClass = Python.type(exception).attr("__name__").toString();
|
||||||
|
String message = exception.toString();
|
||||||
|
return exceptionClass + ": " + message;
|
||||||
|
}
|
||||||
|
return exception.toString();
|
||||||
|
}
|
||||||
|
public PythonException(PythonObject exception){
|
||||||
|
this(getExceptionString(exception));
|
||||||
|
}
|
||||||
|
public PythonException(String message, Throwable cause){
|
||||||
|
super(message, cause);
|
||||||
|
}
|
||||||
|
public PythonException(Throwable cause){
|
||||||
|
super(cause);
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,68 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.bytedeco.cpython.PyThreadState;
|
||||||
|
|
||||||
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
|
import static org.bytedeco.cpython.global.python.PyEval_RestoreThread;
|
||||||
|
import static org.bytedeco.cpython.global.python.PyEval_SaveThread;
|
||||||
|
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class PythonGIL implements AutoCloseable {
|
||||||
|
private static PyThreadState mainThreadState;
|
||||||
|
|
||||||
|
static {
|
||||||
|
log.debug("CPython: PyThreadState_Get()");
|
||||||
|
mainThreadState = PyThreadState_Get();
|
||||||
|
}
|
||||||
|
|
||||||
|
private PythonGIL() {
|
||||||
|
acquire();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
release();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonGIL lock() {
|
||||||
|
return new PythonGIL();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static synchronized void acquire() {
|
||||||
|
log.debug("acquireGIL()");
|
||||||
|
log.debug("CPython: PyEval_SaveThread()");
|
||||||
|
mainThreadState = PyEval_SaveThread();
|
||||||
|
log.debug("CPython: PyThreadState_New()");
|
||||||
|
PyThreadState ts = PyThreadState_New(mainThreadState.interp());
|
||||||
|
log.debug("CPython: PyEval_RestoreThread()");
|
||||||
|
PyEval_RestoreThread(ts);
|
||||||
|
log.debug("CPython: PyThreadState_Swap()");
|
||||||
|
PyThreadState_Swap(ts);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static synchronized void release() {
|
||||||
|
log.debug("CPython: PyEval_SaveThread()");
|
||||||
|
PyEval_SaveThread();
|
||||||
|
log.debug("CPython: PyEval_RestoreThread()");
|
||||||
|
PyEval_RestoreThread(mainThreadState);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,171 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import javax.annotation.Nonnull;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
/**
|
||||||
|
* PythonJob is the right abstraction for executing multiple python scripts
|
||||||
|
* in a multi thread stateful environment. The setup-and-run mode allows your
|
||||||
|
* "setup" code (imports, model loading etc) to be executed only once.
|
||||||
|
*/
|
||||||
|
public class PythonJob {
|
||||||
|
|
||||||
|
private String code;
|
||||||
|
private String name;
|
||||||
|
private String context;
|
||||||
|
private boolean setupRunMode;
|
||||||
|
private PythonObject runF;
|
||||||
|
|
||||||
|
static {
|
||||||
|
new PythonExecutioner();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
/**
|
||||||
|
* @param name Name for the python job.
|
||||||
|
* @param code Python code.
|
||||||
|
* @param setupRunMode If true, the python code is expected to have two methods: setup(), which takes no arguments,
|
||||||
|
* and run() which takes some or no arguments. setup() method is executed once,
|
||||||
|
* and the run() method is called with the inputs(if any) per transaction, and is expected to return a dictionary
|
||||||
|
* mapping from output variable names (str) to output values.
|
||||||
|
* If false, the full script is run on each transaction and the output variables are obtained from the global namespace
|
||||||
|
* after execution.
|
||||||
|
*/
|
||||||
|
public PythonJob(@Nonnull String name, @Nonnull String code, boolean setupRunMode) throws Exception {
|
||||||
|
this.name = name;
|
||||||
|
this.code = code;
|
||||||
|
this.setupRunMode = setupRunMode;
|
||||||
|
context = "__job_" + name;
|
||||||
|
if (PythonContextManager.hasContext(context)) {
|
||||||
|
throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!");
|
||||||
|
}
|
||||||
|
if (setupRunMode) setup();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clears all variables in current context and calls setup()
|
||||||
|
*/
|
||||||
|
public void clearState() throws Exception {
|
||||||
|
String context = this.context;
|
||||||
|
PythonContextManager.setContext("main");
|
||||||
|
PythonContextManager.deleteContext(context);
|
||||||
|
this.context = context;
|
||||||
|
setup();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setup() throws Exception {
|
||||||
|
try (PythonGIL gil = PythonGIL.lock()) {
|
||||||
|
PythonContextManager.setContext(context);
|
||||||
|
PythonObject runF = PythonExecutioner.getVariable("run");
|
||||||
|
if (runF.isNone() || !Python.callable(runF)) {
|
||||||
|
PythonExecutioner.exec(code);
|
||||||
|
runF = PythonExecutioner.getVariable("run");
|
||||||
|
}
|
||||||
|
if (runF.isNone() || !Python.callable(runF)) {
|
||||||
|
throw new PythonException("run() method not found! " +
|
||||||
|
"If a PythonJob is created with 'setup and run' " +
|
||||||
|
"mode enabled, the associated python code is " +
|
||||||
|
"expected to contain a run() method " +
|
||||||
|
"(with or without arguments).");
|
||||||
|
}
|
||||||
|
this.runF = runF;
|
||||||
|
PythonObject setupF = PythonExecutioner.getVariable("setup");
|
||||||
|
if (!setupF.isNone()) {
|
||||||
|
setupF.call();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void exec(PythonVariables inputs, PythonVariables outputs) throws Exception {
|
||||||
|
try (PythonGIL gil = PythonGIL.lock()) {
|
||||||
|
PythonContextManager.setContext(context);
|
||||||
|
if (!setupRunMode) {
|
||||||
|
PythonExecutioner.exec(code, inputs, outputs);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PythonExecutioner.setVariables(inputs);
|
||||||
|
|
||||||
|
PythonObject inspect = Python.importModule("inspect");
|
||||||
|
PythonObject getfullargspec = inspect.attr("getfullargspec");
|
||||||
|
PythonObject argspec = getfullargspec.call(runF);
|
||||||
|
PythonObject argsList = argspec.attr("args");
|
||||||
|
PythonObject runargs = Python.dict();
|
||||||
|
int argsCount = Python.len(argsList).toInt();
|
||||||
|
for (int i = 0; i < argsCount; i++) {
|
||||||
|
PythonObject arg = argsList.get(i);
|
||||||
|
PythonObject val = Python.globals().get(arg);
|
||||||
|
if (val.isNone()) {
|
||||||
|
throw new PythonException("Input value not received for run() argument: " + arg.toString());
|
||||||
|
}
|
||||||
|
runargs.set(arg, val);
|
||||||
|
}
|
||||||
|
PythonObject outDict = runF.callWithKwargs(runargs);
|
||||||
|
Python.globals().attr("update").call(outDict);
|
||||||
|
|
||||||
|
PythonExecutioner.getVariables(outputs);
|
||||||
|
inspect.del();
|
||||||
|
getfullargspec.del();
|
||||||
|
argspec.del();
|
||||||
|
runargs.del();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonVariables execAndReturnAllVariables(PythonVariables inputs) throws Exception {
|
||||||
|
try (PythonGIL gil = PythonGIL.lock()) {
|
||||||
|
PythonContextManager.setContext(context);
|
||||||
|
if (!setupRunMode) {
|
||||||
|
return PythonExecutioner.execAndReturnAllVariables(code, inputs);
|
||||||
|
}
|
||||||
|
PythonExecutioner.setVariables(inputs);
|
||||||
|
PythonObject inspect = Python.importModule("inspect");
|
||||||
|
PythonObject getfullargspec = inspect.attr("getfullargspec");
|
||||||
|
PythonObject argspec = getfullargspec.call(runF);
|
||||||
|
PythonObject argsList = argspec.attr("args");
|
||||||
|
PythonObject runargs = Python.dict();
|
||||||
|
int argsCount = Python.len(argsList).toInt();
|
||||||
|
for (int i = 0; i < argsCount; i++) {
|
||||||
|
PythonObject arg = argsList.get(i);
|
||||||
|
PythonObject val = Python.globals().get(arg);
|
||||||
|
if (val.isNone()) {
|
||||||
|
throw new PythonException("Input value not received for run() argument: " + arg.toString());
|
||||||
|
}
|
||||||
|
runargs.set(arg, val);
|
||||||
|
}
|
||||||
|
PythonObject outDict = runF.callWithKwargs(runargs);
|
||||||
|
Python.globals().attr("update").call(outDict);
|
||||||
|
inspect.del();
|
||||||
|
getfullargspec.del();
|
||||||
|
argspec.del();
|
||||||
|
runargs.del();
|
||||||
|
return PythonExecutioner.getAllVariables();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,554 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import org.bytedeco.cpython.PyObject;
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.json.JSONArray;
|
||||||
|
import org.json.JSONObject;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
|
import static org.bytedeco.cpython.global.python.PyObject_SetItem;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Swift like python wrapper for J
|
||||||
|
*
|
||||||
|
* @author Fariz Rahman
|
||||||
|
*/
|
||||||
|
|
||||||
|
public class PythonObject {
|
||||||
|
private PyObject nativePythonObject;
|
||||||
|
|
||||||
|
static {
|
||||||
|
new PythonExecutioner();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<String, PythonObject> _getNDArraySerializer() {
|
||||||
|
Map<String, PythonObject> ndarraySerializer = new HashMap<>();
|
||||||
|
PythonObject lambda = Python.eval(
|
||||||
|
"lambda x: " +
|
||||||
|
"{'address':" +
|
||||||
|
"x.__array_interface__['data'][0]," +
|
||||||
|
"'shape':x.shape,'strides':x.strides," +
|
||||||
|
"'dtype': str(x.dtype),'_is_numpy_array': True}" +
|
||||||
|
" if str(type(x))== \"<class 'numpy.ndarray'>\" else x");
|
||||||
|
ndarraySerializer.put("default",
|
||||||
|
lambda);
|
||||||
|
return ndarraySerializer;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(PyObject pyObject) {
|
||||||
|
nativePythonObject = pyObject;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(INDArray npArray) {
|
||||||
|
this(new NumpyArray(npArray));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(BytePointer bp){
|
||||||
|
nativePythonObject = PyByteArray_FromStringAndSize(bp, bp.capacity());
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(NumpyArray npArray) {
|
||||||
|
PyObject ctypes = PyImport_ImportModule("ctypes");
|
||||||
|
PyObject np = PyImport_ImportModule("numpy");
|
||||||
|
PyObject ctype;
|
||||||
|
switch (npArray.getDtype()) {
|
||||||
|
case DOUBLE:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_double");
|
||||||
|
break;
|
||||||
|
case FLOAT:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_float");
|
||||||
|
break;
|
||||||
|
case LONG:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_int64");
|
||||||
|
break;
|
||||||
|
case INT:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_int32");
|
||||||
|
break;
|
||||||
|
case SHORT:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_int16");
|
||||||
|
break;
|
||||||
|
case UINT16:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_uint16");
|
||||||
|
break;
|
||||||
|
case UINT32:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_uint32");
|
||||||
|
break;
|
||||||
|
case UINT64:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_uint64");
|
||||||
|
break;
|
||||||
|
case BOOL:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_bool");
|
||||||
|
break;
|
||||||
|
case BYTE:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_byte");
|
||||||
|
break;
|
||||||
|
case UBYTE:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_ubyte");
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Unsupported dtype: " + npArray.getDtype());
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject ctypesPointer = PyObject_GetAttrString(ctypes, "POINTER");
|
||||||
|
PyObject argsTuple = PyTuple_New(1);
|
||||||
|
PyTuple_SetItem(argsTuple, 0, ctype);
|
||||||
|
PyObject ptrType = PyObject_Call(ctypesPointer, argsTuple, null);
|
||||||
|
|
||||||
|
PyObject cast = PyObject_GetAttrString(ctypes, "cast");
|
||||||
|
PyObject address = PyLong_FromLong(npArray.getAddress());
|
||||||
|
PyObject argsTuple2 = PyTuple_New(2);
|
||||||
|
PyTuple_SetItem(argsTuple2, 0, address);
|
||||||
|
PyTuple_SetItem(argsTuple2, 1, ptrType);
|
||||||
|
PyObject ptr = PyObject_Call(cast, argsTuple2, null);
|
||||||
|
PyObject shapeTuple = PyTuple_New(npArray.getShape().length);
|
||||||
|
for (int i = 0; i < npArray.getShape().length; i++) {
|
||||||
|
PyObject dim = PyLong_FromLong(npArray.getShape()[i]);
|
||||||
|
PyTuple_SetItem(shapeTuple, i, dim);
|
||||||
|
Py_DecRef(dim);
|
||||||
|
}
|
||||||
|
PyObject ctypesLib = PyObject_GetAttrString(np, "ctypeslib");
|
||||||
|
PyObject asArray = PyObject_GetAttrString(ctypesLib, "as_array");
|
||||||
|
PyObject argsTuple3 = PyTuple_New(2);
|
||||||
|
PyTuple_SetItem(argsTuple3, 0, ptr);
|
||||||
|
PyTuple_SetItem(argsTuple3, 1, shapeTuple);
|
||||||
|
nativePythonObject = PyObject_Call(asArray, argsTuple3, null);
|
||||||
|
|
||||||
|
Py_DecRef(ctypesPointer);
|
||||||
|
Py_DecRef(ctypesLib);
|
||||||
|
Py_DecRef(argsTuple);
|
||||||
|
Py_DecRef(argsTuple2);
|
||||||
|
Py_DecRef(argsTuple3);
|
||||||
|
Py_DecRef(cast);
|
||||||
|
Py_DecRef(asArray);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/*---primitve constructors---*/
|
||||||
|
public PyObject getNativePythonObject() {
|
||||||
|
return nativePythonObject;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(String data) {
|
||||||
|
nativePythonObject = PyUnicode_FromString(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(int data) {
|
||||||
|
nativePythonObject = PyLong_FromLong((long) data);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(long data) {
|
||||||
|
nativePythonObject = PyLong_FromLong(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(double data) {
|
||||||
|
nativePythonObject = PyFloat_FromDouble(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(boolean data) {
|
||||||
|
nativePythonObject = PyBool_FromLong(data ? 1 : 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static PythonObject j2pyObject(Object item) {
|
||||||
|
if (item instanceof PythonObject) {
|
||||||
|
return (PythonObject) item;
|
||||||
|
} else if (item instanceof PyObject) {
|
||||||
|
return new PythonObject((PyObject) item);
|
||||||
|
} else if (item instanceof INDArray) {
|
||||||
|
return new PythonObject((INDArray) item);
|
||||||
|
} else if (item instanceof NumpyArray) {
|
||||||
|
return new PythonObject((NumpyArray) item);
|
||||||
|
} else if (item instanceof List) {
|
||||||
|
return new PythonObject((List) item);
|
||||||
|
} else if (item instanceof Object[]) {
|
||||||
|
return new PythonObject((Object[]) item);
|
||||||
|
} else if (item instanceof Map) {
|
||||||
|
return new PythonObject((Map) item);
|
||||||
|
} else if (item instanceof String) {
|
||||||
|
return new PythonObject((String) item);
|
||||||
|
} else if (item instanceof Double) {
|
||||||
|
return new PythonObject((Double) item);
|
||||||
|
} else if (item instanceof Float) {
|
||||||
|
return new PythonObject((Float) item);
|
||||||
|
} else if (item instanceof Long) {
|
||||||
|
return new PythonObject((Long) item);
|
||||||
|
} else if (item instanceof Integer) {
|
||||||
|
return new PythonObject((Integer) item);
|
||||||
|
} else if (item instanceof Boolean) {
|
||||||
|
return new PythonObject((Boolean) item);
|
||||||
|
} else if (item instanceof Pointer){
|
||||||
|
return new PythonObject(new BytePointer((Pointer)item));
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Unsupported item in list: " + item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(Object[] data) {
|
||||||
|
PyObject pyList = PyList_New((long) data.length);
|
||||||
|
for (int i = 0; i < data.length; i++) {
|
||||||
|
PyList_SetItem(pyList, i, j2pyObject(data[i]).nativePythonObject);
|
||||||
|
}
|
||||||
|
nativePythonObject = pyList;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(List data) {
|
||||||
|
PyObject pyList = PyList_New((long) data.size());
|
||||||
|
for (int i = 0; i < data.size(); i++) {
|
||||||
|
PyList_SetItem(pyList, i, j2pyObject(data.get(i)).nativePythonObject);
|
||||||
|
}
|
||||||
|
nativePythonObject = pyList;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(Map data) {
|
||||||
|
PyObject pyDict = PyDict_New();
|
||||||
|
for (Object k : data.keySet()) {
|
||||||
|
PythonObject pyKey;
|
||||||
|
if (k instanceof PythonObject) {
|
||||||
|
pyKey = (PythonObject) k;
|
||||||
|
} else if (k instanceof String) {
|
||||||
|
pyKey = new PythonObject((String) k);
|
||||||
|
} else if (k instanceof Double) {
|
||||||
|
pyKey = new PythonObject((Double) k);
|
||||||
|
} else if (k instanceof Float) {
|
||||||
|
pyKey = new PythonObject((Float) k);
|
||||||
|
} else if (k instanceof Long) {
|
||||||
|
pyKey = new PythonObject((Long) k);
|
||||||
|
} else if (k instanceof Integer) {
|
||||||
|
pyKey = new PythonObject((Integer) k);
|
||||||
|
} else if (k instanceof Boolean) {
|
||||||
|
pyKey = new PythonObject((Boolean) k);
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Unsupported key in map: " + k.getClass());
|
||||||
|
}
|
||||||
|
Object v = data.get(k);
|
||||||
|
PythonObject pyVal;
|
||||||
|
if (v instanceof PythonObject) {
|
||||||
|
pyVal = (PythonObject) v;
|
||||||
|
} else if (v instanceof PyObject) {
|
||||||
|
pyVal = new PythonObject((PyObject) v);
|
||||||
|
} else if (v instanceof INDArray) {
|
||||||
|
pyVal = new PythonObject((INDArray) v);
|
||||||
|
} else if (v instanceof NumpyArray) {
|
||||||
|
pyVal = new PythonObject((NumpyArray) v);
|
||||||
|
} else if (v instanceof Map) {
|
||||||
|
pyVal = new PythonObject((Map) v);
|
||||||
|
} else if (v instanceof List) {
|
||||||
|
pyVal = new PythonObject((List) v);
|
||||||
|
} else if (v instanceof String) {
|
||||||
|
pyVal = new PythonObject((String) v);
|
||||||
|
} else if (v instanceof Double) {
|
||||||
|
pyVal = new PythonObject((Double) v);
|
||||||
|
} else if (v instanceof Float) {
|
||||||
|
pyVal = new PythonObject((Float) v);
|
||||||
|
} else if (v instanceof Long) {
|
||||||
|
pyVal = new PythonObject((Long) v);
|
||||||
|
} else if (v instanceof Integer) {
|
||||||
|
pyVal = new PythonObject((Integer) v);
|
||||||
|
} else if (v instanceof Boolean) {
|
||||||
|
pyVal = new PythonObject((Boolean) v);
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Unsupported value in map: " + k.getClass());
|
||||||
|
}
|
||||||
|
|
||||||
|
PyDict_SetItem(pyDict, pyKey.nativePythonObject, pyVal.nativePythonObject);
|
||||||
|
|
||||||
|
}
|
||||||
|
nativePythonObject = pyDict;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*------*/
|
||||||
|
|
||||||
|
private static String pyObjectToString(PyObject pyObject) {
|
||||||
|
PyObject repr = PyObject_Str(pyObject);
|
||||||
|
PyObject str = PyUnicode_AsEncodedString(repr, "utf-8", "~E~");
|
||||||
|
String jstr = PyBytes_AsString(str).getString();
|
||||||
|
Py_DecRef(repr);
|
||||||
|
Py_DecRef(str);
|
||||||
|
return jstr;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String toString() {
|
||||||
|
return pyObjectToString(nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public double toDouble() {
|
||||||
|
return PyFloat_AsDouble(nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public float toFloat() {
|
||||||
|
return (float) PyFloat_AsDouble(nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int toInt() {
|
||||||
|
return (int) PyLong_AsLong(nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public long toLong() {
|
||||||
|
return PyLong_AsLong(nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean toBoolean() {
|
||||||
|
if (isNone()) return false;
|
||||||
|
return toInt() != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NumpyArray toNumpy() {
|
||||||
|
PyObject arrInterface = PyObject_GetAttrString(nativePythonObject, "__array_interface__"); // borrowed reference; DO NOT Py_DecRef() !
|
||||||
|
PyObject data = PyDict_GetItemString(arrInterface, "data");
|
||||||
|
PyObject pyAddress = PyTuple_GetItem(data, 0);
|
||||||
|
long address = PyLong_AsLong(pyAddress);
|
||||||
|
PyObject pyDtype = PyObject_GetAttrString(nativePythonObject, "dtype");
|
||||||
|
PyObject pyDtypeName = PyObject_GetAttrString(pyDtype, "name");
|
||||||
|
String dtypeName = pyObjectToString(pyDtypeName);
|
||||||
|
Py_DecRef(pyDtype);
|
||||||
|
Py_DecRef(pyDtypeName);
|
||||||
|
PyObject shape = PyObject_GetAttrString(nativePythonObject, "shape");
|
||||||
|
PyObject strides = PyObject_GetAttrString(nativePythonObject, "strides");
|
||||||
|
int ndim = (int) PyObject_Size(shape);
|
||||||
|
long[] jshape = new long[ndim];
|
||||||
|
long[] jstrides = new long[ndim];
|
||||||
|
for (int i = 0; i < ndim; i++) {
|
||||||
|
jshape[i] = PyLong_AsLong(PyTuple_GetItem(shape, i));
|
||||||
|
jstrides[i] = PyLong_AsLong(PyTuple_GetItem(strides, i));
|
||||||
|
}
|
||||||
|
Py_DecRef(shape);
|
||||||
|
Py_DecRef(strides);
|
||||||
|
DataType dtype;
|
||||||
|
if (dtypeName.equals("float64")) {
|
||||||
|
dtype = DataType.DOUBLE;
|
||||||
|
} else if (dtypeName.equals("float32")) {
|
||||||
|
dtype = DataType.FLOAT;
|
||||||
|
} else if (dtypeName.equals("int16")) {
|
||||||
|
dtype = DataType.SHORT;
|
||||||
|
} else if (dtypeName.equals("int32")) {
|
||||||
|
dtype = DataType.INT;
|
||||||
|
} else if (dtypeName.equals("int64")) {
|
||||||
|
dtype = DataType.LONG;
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||||
|
}
|
||||||
|
return new NumpyArray(address, jshape, jstrides, dtype);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject attr(String attr) {
|
||||||
|
|
||||||
|
return new PythonObject(PyObject_GetAttrString(nativePythonObject, attr));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject call(Object... args) {
|
||||||
|
if (args.length > 0 && args[args.length - 1] instanceof Map) {
|
||||||
|
List<Object> args2 = new ArrayList<>();
|
||||||
|
for (int i = 0; i < args.length - 1; i++) {
|
||||||
|
args2.add(args[i]);
|
||||||
|
}
|
||||||
|
return call(args2, (Map) args[args.length - 1]);
|
||||||
|
}
|
||||||
|
if (args.length == 0) {
|
||||||
|
return new PythonObject(PyObject_CallObject(nativePythonObject, null));
|
||||||
|
}
|
||||||
|
PyObject tuple = PyTuple_New(args.length); // leaky; tuple may contain borrowed references, so can not be de-allocated.
|
||||||
|
for (int i = 0; i < args.length; i++) {
|
||||||
|
PyTuple_SetItem(tuple, i, j2pyObject(args[i]).nativePythonObject);
|
||||||
|
}
|
||||||
|
PythonObject ret = new PythonObject(PyObject_Call(nativePythonObject, tuple, null));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject callWithArgs(PythonObject args) {
|
||||||
|
PyObject tuple = PyList_AsTuple(args.nativePythonObject);
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, null));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject callWithKwargs(PythonObject kwargs) {
|
||||||
|
PyObject tuple = PyTuple_New(0);
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, kwargs.nativePythonObject));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject callWithArgsAndKwargs(PythonObject args, PythonObject kwargs) {
|
||||||
|
PyObject tuple = PyList_AsTuple(args.nativePythonObject);
|
||||||
|
PyObject dict = kwargs.nativePythonObject;
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject call(Map kwargs) {
|
||||||
|
PyObject dict = new PythonObject(kwargs).nativePythonObject;
|
||||||
|
PyObject tuple = PyTuple_New(0);
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject call(List args) {
|
||||||
|
PyObject tuple = PyList_AsTuple(new PythonObject(args).nativePythonObject);
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, null));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject call(List args, Map kwargs) {
|
||||||
|
PyObject tuple = PyList_AsTuple(new PythonObject(args).nativePythonObject);
|
||||||
|
PyObject dict = new PythonObject(kwargs).nativePythonObject;
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict));
|
||||||
|
}
|
||||||
|
|
||||||
|
private PythonObject get(PyObject key) {
|
||||||
|
return new PythonObject(
|
||||||
|
PyObject_GetItem(nativePythonObject, key)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject get(PythonObject key) {
|
||||||
|
return get(key.nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public PythonObject get(int key) {
|
||||||
|
return get(PyLong_FromLong((long) key));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject get(long key) {
|
||||||
|
return new PythonObject(
|
||||||
|
PyObject_GetItem(nativePythonObject, PyLong_FromLong(key))
|
||||||
|
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject get(double key) {
|
||||||
|
return new PythonObject(
|
||||||
|
PyObject_GetItem(nativePythonObject, PyFloat_FromDouble(key))
|
||||||
|
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject get(String key) {
|
||||||
|
return get(new PythonObject(key));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void set(PythonObject key, PythonObject value) {
|
||||||
|
PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void del() {
|
||||||
|
Py_DecRef(nativePythonObject);
|
||||||
|
nativePythonObject = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public JSONArray toJSONArray() throws PythonException {
|
||||||
|
PythonObject json = Python.importModule("json");
|
||||||
|
PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer());
|
||||||
|
String jsonString = serialized.toString();
|
||||||
|
return new JSONArray(jsonString);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public JSONObject toJSONObject() throws PythonException {
|
||||||
|
PythonObject json = Python.importModule("json");
|
||||||
|
PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer());
|
||||||
|
String jsonString = serialized.toString();
|
||||||
|
return new JSONObject(jsonString);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List toList() throws PythonException{
|
||||||
|
List list = new ArrayList();
|
||||||
|
int n = Python.len(this).toInt();
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
PythonObject o = get(i);
|
||||||
|
if (Python.isinstance(o, Python.strType())) {
|
||||||
|
list.add(o.toString());
|
||||||
|
} else if (Python.isinstance(o, Python.intType())) {
|
||||||
|
list.add(o.toLong());
|
||||||
|
} else if (Python.isinstance(o, Python.floatType())) {
|
||||||
|
list.add(o.toDouble());
|
||||||
|
} else if (Python.isinstance(o, Python.boolType())) {
|
||||||
|
list.add(o);
|
||||||
|
} else if (Python.isinstance(o, Python.listType(), Python.tupleType())) {
|
||||||
|
list.add(o.toList());
|
||||||
|
} else if (Python.isinstance(o, Python.importModule("numpy").attr("ndarray"))) {
|
||||||
|
list.add(o.toNumpy().getNd4jArray());
|
||||||
|
} else if (Python.isinstance(o, Python.dictType())) {
|
||||||
|
list.add(o.toMap());
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Error while converting python" +
|
||||||
|
" list to java List: Unable to serialize python " +
|
||||||
|
"object of type " + Python.type(this).toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map toMap() throws PythonException{
|
||||||
|
Map map = new HashMap();
|
||||||
|
List keys = Python.list(attr("keys").call()).toList();
|
||||||
|
List values = Python.list(attr("values").call()).toList();
|
||||||
|
for (int i = 0; i < keys.size(); i++) {
|
||||||
|
map.put(keys.get(i), values.get(i));
|
||||||
|
}
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
public BytePointer toBytePointer() throws PythonException{
|
||||||
|
if (Python.isinstance(this, Python.bytesType())){
|
||||||
|
PyObject byteArray = PyByteArray_FromObject(nativePythonObject);
|
||||||
|
return PyByteArray_AsString(byteArray);
|
||||||
|
|
||||||
|
}
|
||||||
|
else if (Python.isinstance(this, Python.bytearrayType())){
|
||||||
|
return PyByteArray_AsString(nativePythonObject);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
PyObject ctypes = PyImport_ImportModule("ctypes");
|
||||||
|
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
|
||||||
|
if (PyObject_IsInstance(nativePythonObject, cArrType) != 0){
|
||||||
|
PyObject cVoidP = PyObject_GetAttrString(ctypes, "c_void_p");
|
||||||
|
PyObject cast = PyObject_GetAttrString(ctypes, "cast");
|
||||||
|
PyObject argsTuple = PyTuple_New(2);
|
||||||
|
PyTuple_SetItem(argsTuple, 0, nativePythonObject);
|
||||||
|
PyTuple_SetItem(argsTuple, 1, cVoidP);
|
||||||
|
PyObject voidPtr = PyObject_Call(cast, argsTuple, null);
|
||||||
|
PyObject pyAddress = PyObject_GetAttrString(voidPtr, "value");
|
||||||
|
long address = PyLong_AsLong(pyAddress);
|
||||||
|
long size = PyObject_Size(nativePythonObject);
|
||||||
|
Py_DecRef(ctypes);
|
||||||
|
Py_DecRef(cArrType);
|
||||||
|
Py_DecRef(argsTuple);
|
||||||
|
Py_DecRef(voidPtr);
|
||||||
|
Py_DecRef(pyAddress);
|
||||||
|
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address);
|
||||||
|
ptr = ptr.limit(size);
|
||||||
|
ptr = ptr.capacity(size);
|
||||||
|
return new BytePointer(ptr);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
throw new PythonException("Expected bytes, bytearray or ctypesArray. Received " + Python.type(this).toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
public boolean isNone() {
|
||||||
|
return nativePythonObject == null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -24,10 +24,13 @@ import org.datavec.api.transform.ColumnType;
|
||||||
import org.datavec.api.transform.Transform;
|
import org.datavec.api.transform.Transform;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
|
import org.json.JSONPropertyIgnore;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
|
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
@ -52,12 +55,13 @@ public class PythonTransform implements Transform {
|
||||||
private String code;
|
private String code;
|
||||||
private PythonVariables inputs;
|
private PythonVariables inputs;
|
||||||
private PythonVariables outputs;
|
private PythonVariables outputs;
|
||||||
private String name = UUID.randomUUID().toString();
|
private String name = UUID.randomUUID().toString();
|
||||||
private Schema inputSchema;
|
private Schema inputSchema;
|
||||||
private Schema outputSchema;
|
private Schema outputSchema;
|
||||||
private String outputDict;
|
private String outputDict;
|
||||||
private boolean returnAllVariables;
|
private boolean returnAllVariables;
|
||||||
private boolean setupAndRun = false;
|
private boolean setupAndRun = false;
|
||||||
|
private PythonJob pythonJob;
|
||||||
|
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
|
@ -70,71 +74,70 @@ public class PythonTransform implements Transform {
|
||||||
String outputDict,
|
String outputDict,
|
||||||
boolean returnAllInputs,
|
boolean returnAllInputs,
|
||||||
boolean setupAndRun) {
|
boolean setupAndRun) {
|
||||||
Preconditions.checkNotNull(code,"No code found to run!");
|
Preconditions.checkNotNull(code, "No code found to run!");
|
||||||
this.code = code;
|
this.code = code;
|
||||||
this.returnAllVariables = returnAllInputs;
|
this.returnAllVariables = returnAllInputs;
|
||||||
this.setupAndRun = setupAndRun;
|
this.setupAndRun = setupAndRun;
|
||||||
if(inputs != null)
|
if (inputs != null)
|
||||||
this.inputs = inputs;
|
this.inputs = inputs;
|
||||||
if(outputs != null)
|
if (outputs != null)
|
||||||
this.outputs = outputs;
|
this.outputs = outputs;
|
||||||
|
if (name != null)
|
||||||
if(name != null)
|
|
||||||
this.name = name;
|
this.name = name;
|
||||||
if (outputDict != null) {
|
if (outputDict != null) {
|
||||||
this.outputDict = outputDict;
|
this.outputDict = outputDict;
|
||||||
this.outputs = new PythonVariables();
|
this.outputs = new PythonVariables();
|
||||||
this.outputs.addDict(outputDict);
|
this.outputs.addDict(outputDict);
|
||||||
|
|
||||||
String helpers;
|
|
||||||
try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) {
|
|
||||||
helpers = IOUtils.toString(is, Charset.defaultCharset());
|
|
||||||
|
|
||||||
}catch (IOException e){
|
|
||||||
throw new RuntimeException("Error reading python code");
|
|
||||||
}
|
|
||||||
this.code += "\n\n" + helpers;
|
|
||||||
this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if(inputSchema != null) {
|
if (inputSchema != null) {
|
||||||
this.inputSchema = inputSchema;
|
this.inputSchema = inputSchema;
|
||||||
if(inputs == null || inputs.isEmpty()) {
|
if (inputs == null || inputs.isEmpty()) {
|
||||||
this.inputs = schemaToPythonVariables(inputSchema);
|
this.inputs = schemaToPythonVariables(inputSchema);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(outputSchema != null) {
|
if (outputSchema != null) {
|
||||||
this.outputSchema = outputSchema;
|
this.outputSchema = outputSchema;
|
||||||
if(outputs == null || outputs.isEmpty()) {
|
if (outputs == null || outputs.isEmpty()) {
|
||||||
this.outputs = schemaToPythonVariables(outputSchema);
|
this.outputs = schemaToPythonVariables(outputSchema);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}catch(Exception e) {
|
} catch (Exception e) {
|
||||||
throw new IllegalStateException(e);
|
throw new IllegalStateException(e);
|
||||||
}
|
}
|
||||||
|
try{
|
||||||
|
pythonJob = PythonJob.builder()
|
||||||
|
.name("a" + UUID.randomUUID().toString().replace("-", "_"))
|
||||||
|
.code(code)
|
||||||
|
.setupRunMode(setupAndRun)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
catch(Exception e){
|
||||||
|
throw new IllegalStateException("Error creating python job: " + e);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setInputSchema(Schema inputSchema) {
|
public void setInputSchema(Schema inputSchema) {
|
||||||
Preconditions.checkNotNull(inputSchema,"No input schema found!");
|
Preconditions.checkNotNull(inputSchema, "No input schema found!");
|
||||||
this.inputSchema = inputSchema;
|
this.inputSchema = inputSchema;
|
||||||
try{
|
try {
|
||||||
inputs = schemaToPythonVariables(inputSchema);
|
inputs = schemaToPythonVariables(inputSchema);
|
||||||
}catch (Exception e){
|
} catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
if (outputSchema == null && outputDict == null){
|
if (outputSchema == null && outputDict == null) {
|
||||||
outputSchema = inputSchema;
|
outputSchema = inputSchema;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Schema getInputSchema(){
|
public Schema getInputSchema() {
|
||||||
return inputSchema;
|
return inputSchema;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,67 +161,51 @@ public class PythonTransform implements Transform {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Writable> map(List<Writable> writables) {
|
public List<Writable> map(List<Writable> writables) {
|
||||||
PythonVariables pyInputs = getPyInputsFromWritables(writables);
|
PythonVariables pyInputs = getPyInputsFromWritables(writables);
|
||||||
Preconditions.checkNotNull(pyInputs,"Inputs must not be null!");
|
Preconditions.checkNotNull(pyInputs, "Inputs must not be null!");
|
||||||
|
try {
|
||||||
|
|
||||||
try{
|
|
||||||
if (returnAllVariables) {
|
if (returnAllVariables) {
|
||||||
if (setupAndRun){
|
return getWritablesFromPyOutputs(pythonJob.execAndReturnAllVariables(pyInputs));
|
||||||
return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs));
|
|
||||||
}
|
|
||||||
return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (outputDict != null) {
|
if (outputDict != null) {
|
||||||
if (setupAndRun) {
|
pythonJob.exec(pyInputs, outputs);
|
||||||
PythonExecutioner.execWithSetupAndRun(this, pyInputs);
|
|
||||||
}else{
|
|
||||||
PythonExecutioner.exec(this, pyInputs);
|
|
||||||
}
|
|
||||||
PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict);
|
PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict);
|
||||||
return getWritablesFromPyOutputs(out);
|
return getWritablesFromPyOutputs(out);
|
||||||
}
|
} else {
|
||||||
else {
|
pythonJob.exec(pyInputs, outputs);
|
||||||
if (setupAndRun) {
|
|
||||||
PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs);
|
|
||||||
}else{
|
|
||||||
PythonExecutioner.exec(code, pyInputs, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
return getWritablesFromPyOutputs(outputs);
|
return getWritablesFromPyOutputs(outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
} catch (Exception e) {
|
||||||
catch (Exception e){
|
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] outputColumnNames(){
|
public String[] outputColumnNames() {
|
||||||
return outputs.getVariables();
|
return outputs.getVariables();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String outputColumnName(){
|
public String outputColumnName() {
|
||||||
return outputColumnNames()[0];
|
return outputColumnNames()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] columnNames(){
|
public String[] columnNames() {
|
||||||
return outputs.getVariables();
|
return outputs.getVariables();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String columnName(){
|
public String columnName() {
|
||||||
return columnNames()[0];
|
return columnNames()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
public Schema transform(Schema inputSchema){
|
public Schema transform(Schema inputSchema) {
|
||||||
return outputSchema;
|
return outputSchema;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,33 +213,33 @@ public class PythonTransform implements Transform {
|
||||||
private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
|
private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
|
||||||
PythonVariables ret = new PythonVariables();
|
PythonVariables ret = new PythonVariables();
|
||||||
|
|
||||||
for (String name: inputs.getVariables()) {
|
for (String name : inputs.getVariables()) {
|
||||||
int colIdx = inputSchema.getIndexOfColumn(name);
|
int colIdx = inputSchema.getIndexOfColumn(name);
|
||||||
Writable w = writables.get(colIdx);
|
Writable w = writables.get(colIdx);
|
||||||
PythonVariables.Type pyType = inputs.getType(name);
|
PythonType pyType = inputs.getType(name);
|
||||||
switch (pyType){
|
switch (pyType.getName()) {
|
||||||
case INT:
|
case INT:
|
||||||
if (w instanceof LongWritable){
|
if (w instanceof LongWritable) {
|
||||||
ret.addInt(name, ((LongWritable)w).get());
|
ret.addInt(name, ((LongWritable) w).get());
|
||||||
|
} else {
|
||||||
|
ret.addInt(name, ((IntWritable) w).get());
|
||||||
}
|
}
|
||||||
else{
|
|
||||||
ret.addInt(name, ((IntWritable)w).get());
|
|
||||||
}
|
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
if (w instanceof DoubleWritable) {
|
if (w instanceof DoubleWritable) {
|
||||||
ret.addFloat(name, ((DoubleWritable)w).get());
|
ret.addFloat(name, ((DoubleWritable) w).get());
|
||||||
}
|
} else {
|
||||||
else{
|
ret.addFloat(name, ((FloatWritable) w).get());
|
||||||
ret.addFloat(name, ((FloatWritable)w).get());
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case STR:
|
case STR:
|
||||||
ret.addStr(name, w.toString());
|
ret.addStr(name, w.toString());
|
||||||
break;
|
break;
|
||||||
case NDARRAY:
|
case NDARRAY:
|
||||||
ret.addNDArray(name,((NDArrayWritable)w).get());
|
ret.addNDArray(name, ((NDArrayWritable) w).get());
|
||||||
|
break;
|
||||||
|
case BOOL:
|
||||||
|
ret.addBool(name, ((BooleanWritable) w).get());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException("Unsupported input type:" + pyType);
|
throw new RuntimeException("Unsupported input type:" + pyType);
|
||||||
|
@ -269,8 +256,8 @@ public class PythonTransform implements Transform {
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||||
for (int i = 0; i < varNames.length; i++) {
|
for (int i = 0; i < varNames.length; i++) {
|
||||||
String name = varNames[i];
|
String name = varNames[i];
|
||||||
PythonVariables.Type pyType = pyOuts.getType(name);
|
PythonType pyType = pyOuts.getType(name);
|
||||||
switch (pyType){
|
switch (pyType.getName()) {
|
||||||
case INT:
|
case INT:
|
||||||
schemaBuilder.addColumnLong(name);
|
schemaBuilder.addColumnLong(name);
|
||||||
break;
|
break;
|
||||||
|
@ -283,11 +270,14 @@ public class PythonTransform implements Transform {
|
||||||
schemaBuilder.addColumnString(name);
|
schemaBuilder.addColumnString(name);
|
||||||
break;
|
break;
|
||||||
case NDARRAY:
|
case NDARRAY:
|
||||||
NumpyArray arr = pyOuts.getNDArrayValue(name);
|
INDArray arr = pyOuts.getNDArrayValue(name);
|
||||||
schemaBuilder.addColumnNDArray(name, arr.getShape());
|
schemaBuilder.addColumnNDArray(name, arr.shape());
|
||||||
|
break;
|
||||||
|
case BOOL:
|
||||||
|
schemaBuilder.addColumnBoolean(name);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new IllegalStateException("Unable to support type " + pyType.name());
|
throw new IllegalStateException("Unable to support type " + pyType.getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this.outputSchema = schemaBuilder.build();
|
this.outputSchema = schemaBuilder.build();
|
||||||
|
@ -295,9 +285,9 @@ public class PythonTransform implements Transform {
|
||||||
|
|
||||||
for (int i = 0; i < varNames.length; i++) {
|
for (int i = 0; i < varNames.length; i++) {
|
||||||
String name = varNames[i];
|
String name = varNames[i];
|
||||||
PythonVariables.Type pyType = pyOuts.getType(name);
|
PythonType pyType = pyOuts.getType(name);
|
||||||
|
|
||||||
switch (pyType){
|
switch (pyType.getName()) {
|
||||||
case INT:
|
case INT:
|
||||||
out.add(new LongWritable(pyOuts.getIntValue(name)));
|
out.add(new LongWritable(pyOuts.getIntValue(name)));
|
||||||
break;
|
break;
|
||||||
|
@ -308,14 +298,14 @@ public class PythonTransform implements Transform {
|
||||||
out.add(new Text(pyOuts.getStrValue(name)));
|
out.add(new Text(pyOuts.getStrValue(name)));
|
||||||
break;
|
break;
|
||||||
case NDARRAY:
|
case NDARRAY:
|
||||||
NumpyArray arr = pyOuts.getNDArrayValue(name);
|
INDArray arr = pyOuts.getNDArrayValue(name);
|
||||||
out.add(new NDArrayWritable(arr.getNd4jArray()));
|
out.add(new NDArrayWritable(arr));
|
||||||
break;
|
break;
|
||||||
case DICT:
|
case DICT:
|
||||||
Map<?, ?> dictValue = pyOuts.getDictValue(name);
|
Map<?, ?> dictValue = pyOuts.getDictValue(name);
|
||||||
Map noNullValues = new java.util.HashMap<>();
|
Map noNullValues = new java.util.HashMap<>();
|
||||||
for(Map.Entry entry : dictValue.entrySet()) {
|
for (Map.Entry entry : dictValue.entrySet()) {
|
||||||
if(entry.getValue() != org.json.JSONObject.NULL) {
|
if (entry.getValue() != org.json.JSONObject.NULL) {
|
||||||
noNullValues.put(entry.getKey(), entry.getValue());
|
noNullValues.put(entry.getKey(), entry.getValue());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -327,21 +317,22 @@ public class PythonTransform implements Transform {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case LIST:
|
case LIST:
|
||||||
Object[] listValue = pyOuts.getListValue(name);
|
Object[] listValue = pyOuts.getListValue(name).toArray();
|
||||||
try {
|
try {
|
||||||
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue)));
|
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue)));
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
|
throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case BOOL:
|
||||||
|
out.add(new BooleanWritable(pyOuts.getBooleanValue(name)));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw new IllegalStateException("Unable to support type " + pyType.name());
|
throw new IllegalStateException("Unable to support type " + pyType.getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,238 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.datavec.python.Python.importModule;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param <T> Corresponding Java type for the Python type
|
||||||
|
*/
|
||||||
|
public abstract class PythonType<T> {
|
||||||
|
|
||||||
|
public abstract T toJava(PythonObject pythonObject) throws PythonException;
|
||||||
|
private final TypeName typeName;
|
||||||
|
|
||||||
|
enum TypeName{
|
||||||
|
STR,
|
||||||
|
INT,
|
||||||
|
FLOAT,
|
||||||
|
BOOL,
|
||||||
|
LIST,
|
||||||
|
DICT,
|
||||||
|
NDARRAY,
|
||||||
|
BYTES
|
||||||
|
}
|
||||||
|
private PythonType(TypeName typeName){
|
||||||
|
this.typeName = typeName;
|
||||||
|
}
|
||||||
|
public TypeName getName(){return typeName;}
|
||||||
|
public String toString(){
|
||||||
|
return getName().name();
|
||||||
|
}
|
||||||
|
public static PythonType valueOf(String typeName) throws PythonException{
|
||||||
|
try{
|
||||||
|
typeName.valueOf(typeName);
|
||||||
|
} catch (IllegalArgumentException iae){
|
||||||
|
throw new PythonException("Invalid python type: " + typeName, iae);
|
||||||
|
}
|
||||||
|
try{
|
||||||
|
return (PythonType)PythonType.class.getField(typeName).get(null); // shouldn't fail
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
public static PythonType valueOf(TypeName typeName){
|
||||||
|
try{
|
||||||
|
return valueOf(typeName.name()); // shouldn't fail
|
||||||
|
}catch (PythonException pe){
|
||||||
|
throw new RuntimeException(pe);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Since multiple java types can map to the same python type,
|
||||||
|
* this method "normalizes" all supported incoming objects to T
|
||||||
|
*
|
||||||
|
* @param object object to be converted to type T
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public T convert(Object object) throws PythonException {
|
||||||
|
return (T) object;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static final PythonType<String> STR = new PythonType<String>(TypeName.STR) {
|
||||||
|
@Override
|
||||||
|
public String toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.strType())) {
|
||||||
|
throw new PythonException("Expected variable to be str, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String convert(Object object) {
|
||||||
|
return object.toString();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<Long> INT = new PythonType<Long>(TypeName.INT) {
|
||||||
|
@Override
|
||||||
|
public Long toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.intType())) {
|
||||||
|
throw new PythonException("Expected variable to be int, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toLong();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Long convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof Number) {
|
||||||
|
return ((Number) object).longValue();
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to Long.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<Double> FLOAT = new PythonType<Double>(TypeName.FLOAT) {
|
||||||
|
@Override
|
||||||
|
public Double toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.floatType())) {
|
||||||
|
throw new PythonException("Expected variable to be float, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toDouble();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof Number) {
|
||||||
|
return ((Number) object).doubleValue();
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to Double.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<Boolean> BOOL = new PythonType<Boolean>(TypeName.BOOL) {
|
||||||
|
@Override
|
||||||
|
public Boolean toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.boolType())) {
|
||||||
|
throw new PythonException("Expected variable to be bool, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toBoolean();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Boolean convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof Number) {
|
||||||
|
return ((Number) object).intValue() != 0;
|
||||||
|
} else if (object instanceof Boolean) {
|
||||||
|
return (Boolean) object;
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to Boolean.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<List> LIST = new PythonType<List>(TypeName.LIST) {
|
||||||
|
@Override
|
||||||
|
public List toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.listType())) {
|
||||||
|
throw new PythonException("Expected variable to be list, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof java.util.List) {
|
||||||
|
return (List) object;
|
||||||
|
} else if (object instanceof org.json.JSONArray) {
|
||||||
|
org.json.JSONArray jsonArray = (org.json.JSONArray) object;
|
||||||
|
return jsonArray.toList();
|
||||||
|
|
||||||
|
} else if (object instanceof Object[]) {
|
||||||
|
return Arrays.asList((Object[]) object);
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to List.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<Map> DICT = new PythonType<Map>(TypeName.DICT) {
|
||||||
|
@Override
|
||||||
|
public Map toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.dictType())) {
|
||||||
|
throw new PythonException("Expected variable to be dict, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof Map) {
|
||||||
|
return (Map) object;
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to Map.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<INDArray> NDARRAY = new PythonType<INDArray>(TypeName.NDARRAY) {
|
||||||
|
@Override
|
||||||
|
public INDArray toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
PythonObject np = importModule("numpy");
|
||||||
|
if (!Python.isinstance(pythonObject, np.attr("ndarray"), np.attr("generic"))) {
|
||||||
|
throw new PythonException("Expected variable to be numpy.ndarray, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toNumpy().getNd4jArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof INDArray) {
|
||||||
|
return (INDArray) object;
|
||||||
|
} else if (object instanceof NumpyArray) {
|
||||||
|
return ((NumpyArray) object).getNd4jArray();
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to INDArray.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<BytePointer> BYTES = new PythonType<BytePointer>(TypeName.BYTES) {
|
||||||
|
@Override
|
||||||
|
public BytePointer toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
return pythonObject.toBytePointer();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytePointer convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof BytePointer) {
|
||||||
|
return (BytePointer) object;
|
||||||
|
} else if (object instanceof Pointer) {
|
||||||
|
return new BytePointer((Pointer) object);
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to BytePointer.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
|
@ -24,28 +24,30 @@ public class PythonUtils {
|
||||||
* Create a {@link Schema}
|
* Create a {@link Schema}
|
||||||
* from {@link PythonVariables}.
|
* from {@link PythonVariables}.
|
||||||
* Types are mapped to types of the same name.
|
* Types are mapped to types of the same name.
|
||||||
|
*
|
||||||
* @param input the input {@link PythonVariables}
|
* @param input the input {@link PythonVariables}
|
||||||
* @return the output {@link Schema}
|
* @return the output {@link Schema}
|
||||||
*/
|
*/
|
||||||
public static Schema fromPythonVariables(PythonVariables input) {
|
public static Schema fromPythonVariables(PythonVariables input) {
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||||
Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0,"Input must have variables. Found none.");
|
Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0, "Input must have variables. Found none.");
|
||||||
for(Map.Entry<String,PythonVariables.Type> entry : input.getVars().entrySet()) {
|
for (String varName: input.getVariables()) {
|
||||||
switch(entry.getValue()) {
|
|
||||||
|
switch (input.getType(varName).getName()) {
|
||||||
case INT:
|
case INT:
|
||||||
schemaBuilder.addColumnInteger(entry.getKey());
|
schemaBuilder.addColumnInteger(varName);
|
||||||
break;
|
break;
|
||||||
case STR:
|
case STR:
|
||||||
schemaBuilder.addColumnString(entry.getKey());
|
schemaBuilder.addColumnString(varName);
|
||||||
break;
|
break;
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
schemaBuilder.addColumnFloat(entry.getKey());
|
schemaBuilder.addColumnFloat(varName);
|
||||||
break;
|
break;
|
||||||
case NDARRAY:
|
case NDARRAY:
|
||||||
schemaBuilder.addColumnNDArray(entry.getKey(),null);
|
schemaBuilder.addColumnNDArray(varName, null);
|
||||||
break;
|
break;
|
||||||
case BOOL:
|
case BOOL:
|
||||||
schemaBuilder.addColumn(new BooleanMetaData(entry.getKey()));
|
schemaBuilder.addColumn(new BooleanMetaData(varName));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,34 +58,36 @@ public class PythonUtils {
|
||||||
* Create a {@link Schema} from an input
|
* Create a {@link Schema} from an input
|
||||||
* {@link PythonVariables}
|
* {@link PythonVariables}
|
||||||
* Types are mapped to types of the same name
|
* Types are mapped to types of the same name
|
||||||
|
*
|
||||||
* @param input the input schema
|
* @param input the input schema
|
||||||
* @return the output python variables.
|
* @return the output python variables.
|
||||||
*/
|
*/
|
||||||
public static PythonVariables fromSchema(Schema input) {
|
public static PythonVariables fromSchema(Schema input) {
|
||||||
PythonVariables ret = new PythonVariables();
|
PythonVariables ret = new PythonVariables();
|
||||||
for(int i = 0; i < input.numColumns(); i++) {
|
for (int i = 0; i < input.numColumns(); i++) {
|
||||||
String currColumnName = input.getName(i);
|
String currColumnName = input.getName(i);
|
||||||
ColumnType columnType = input.getType(i);
|
ColumnType columnType = input.getType(i);
|
||||||
switch(columnType) {
|
switch (columnType) {
|
||||||
case NDArray:
|
case NDArray:
|
||||||
ret.add(currColumnName, PythonVariables.Type.NDARRAY);
|
ret.add(currColumnName, PythonType.NDARRAY);
|
||||||
break;
|
break;
|
||||||
case Boolean:
|
case Boolean:
|
||||||
ret.add(currColumnName, PythonVariables.Type.BOOL);
|
ret.add(currColumnName, PythonType.BOOL);
|
||||||
break;
|
break;
|
||||||
case Categorical:
|
case Categorical:
|
||||||
case String:
|
case String:
|
||||||
ret.add(currColumnName, PythonVariables.Type.STR);
|
ret.add(currColumnName, PythonType.STR);
|
||||||
break;
|
break;
|
||||||
case Double:
|
case Double:
|
||||||
case Float:
|
case Float:
|
||||||
ret.add(currColumnName, PythonVariables.Type.FLOAT);
|
ret.add(currColumnName, PythonType.FLOAT);
|
||||||
break;
|
break;
|
||||||
case Integer:
|
case Integer:
|
||||||
case Long:
|
case Long:
|
||||||
ret.add(currColumnName, PythonVariables.Type.INT);
|
ret.add(currColumnName, PythonType.INT);
|
||||||
break;
|
break;
|
||||||
case Bytes:
|
case Bytes:
|
||||||
|
ret.add(currColumnName, PythonType.BYTES);
|
||||||
break;
|
break;
|
||||||
case Time:
|
case Time:
|
||||||
throw new UnsupportedOperationException("Unable to process dates with python yet.");
|
throw new UnsupportedOperationException("Unable to process dates with python yet.");
|
||||||
|
@ -92,9 +96,11 @@ public class PythonUtils {
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert a {@link Schema}
|
* Convert a {@link Schema}
|
||||||
* to {@link PythonVariables}
|
* to {@link PythonVariables}
|
||||||
|
*
|
||||||
* @param schema the input schema
|
* @param schema the input schema
|
||||||
* @return the output {@link PythonVariables} where each
|
* @return the output {@link PythonVariables} where each
|
||||||
* name in the map is associated with a column name in the schema.
|
* name in the map is associated with a column name in the schema.
|
||||||
|
@ -107,7 +113,7 @@ public class PythonUtils {
|
||||||
for (int i = 0; i < numCols; i++) {
|
for (int i = 0; i < numCols; i++) {
|
||||||
String colName = schema.getName(i);
|
String colName = schema.getName(i);
|
||||||
ColumnType colType = schema.getType(i);
|
ColumnType colType = schema.getType(i);
|
||||||
switch (colType){
|
switch (colType) {
|
||||||
case Long:
|
case Long:
|
||||||
case Integer:
|
case Integer:
|
||||||
pyVars.addInt(colName);
|
pyVars.addInt(colName);
|
||||||
|
@ -122,6 +128,9 @@ public class PythonUtils {
|
||||||
case NDArray:
|
case NDArray:
|
||||||
pyVars.addNDArray(colName);
|
pyVars.addNDArray(colName);
|
||||||
break;
|
break;
|
||||||
|
case Boolean:
|
||||||
|
pyVars.addBool(colName);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw new Exception("Unsupported python input type: " + colType.toString());
|
throw new Exception("Unsupported python input type: " + colType.toString());
|
||||||
}
|
}
|
||||||
|
@ -131,117 +140,104 @@ public class PythonUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static NumpyArray mapToNumpyArray(Map map){
|
public static NumpyArray mapToNumpyArray(Map map) {
|
||||||
String dtypeName = (String)map.get("dtype");
|
String dtypeName = (String) map.get("dtype");
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
if (dtypeName.equals("float64")){
|
if (dtypeName.equals("float64")) {
|
||||||
dtype = DataType.DOUBLE;
|
dtype = DataType.DOUBLE;
|
||||||
}
|
} else if (dtypeName.equals("float32")) {
|
||||||
else if (dtypeName.equals("float32")){
|
|
||||||
dtype = DataType.FLOAT;
|
dtype = DataType.FLOAT;
|
||||||
}
|
} else if (dtypeName.equals("int16")) {
|
||||||
else if (dtypeName.equals("int16")){
|
|
||||||
dtype = DataType.SHORT;
|
dtype = DataType.SHORT;
|
||||||
}
|
} else if (dtypeName.equals("int32")) {
|
||||||
else if (dtypeName.equals("int32")){
|
|
||||||
dtype = DataType.INT;
|
dtype = DataType.INT;
|
||||||
}
|
} else if (dtypeName.equals("int64")) {
|
||||||
else if (dtypeName.equals("int64")){
|
|
||||||
dtype = DataType.LONG;
|
dtype = DataType.LONG;
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||||
}
|
}
|
||||||
List shapeList = (List)map.get("shape");
|
List shapeList = (List) map.get("shape");
|
||||||
long[] shape = new long[shapeList.size()];
|
long[] shape = new long[shapeList.size()];
|
||||||
for (int i = 0; i < shape.length; i++) {
|
for (int i = 0; i < shape.length; i++) {
|
||||||
shape[i] = (Long)shapeList.get(i);
|
shape[i] = (Long) shapeList.get(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
List strideList = (List)map.get("shape");
|
List strideList = (List) map.get("shape");
|
||||||
long[] stride = new long[strideList.size()];
|
long[] stride = new long[strideList.size()];
|
||||||
for (int i = 0; i < stride.length; i++) {
|
for (int i = 0; i < stride.length; i++) {
|
||||||
stride[i] = (Long)strideList.get(i);
|
stride[i] = (Long) strideList.get(i);
|
||||||
}
|
}
|
||||||
long address = (Long)map.get("address");
|
long address = (Long) map.get("address");
|
||||||
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
|
NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true);
|
||||||
return numpyArray;
|
return numpyArray;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static PythonVariables expandInnerDict(PythonVariables pyvars, String key){
|
public static PythonVariables expandInnerDict(PythonVariables pyvars, String key) {
|
||||||
Map dict = pyvars.getDictValue(key);
|
Map dict = pyvars.getDictValue(key);
|
||||||
String[] keys = (String[])dict.keySet().toArray(new String[dict.keySet().size()]);
|
String[] keys = (String[]) dict.keySet().toArray(new String[dict.keySet().size()]);
|
||||||
PythonVariables pyvars2 = new PythonVariables();
|
PythonVariables pyvars2 = new PythonVariables();
|
||||||
for (String subkey: keys){
|
for (String subkey : keys) {
|
||||||
Object value = dict.get(subkey);
|
Object value = dict.get(subkey);
|
||||||
if (value instanceof Map){
|
if (value instanceof Map) {
|
||||||
Map map = (Map)value;
|
Map map = (Map) value;
|
||||||
if (map.containsKey("_is_numpy_array")){
|
if (map.containsKey("_is_numpy_array")) {
|
||||||
pyvars2.addNDArray(subkey, mapToNumpyArray(map));
|
pyvars2.addNDArray(subkey, mapToNumpyArray(map));
|
||||||
|
|
||||||
}
|
} else {
|
||||||
else{
|
pyvars2.addDict(subkey, (Map) value);
|
||||||
pyvars2.addDict(subkey, (Map)value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
} else if (value instanceof List) {
|
||||||
else if (value instanceof List){
|
|
||||||
pyvars2.addList(subkey, ((List) value).toArray());
|
pyvars2.addList(subkey, ((List) value).toArray());
|
||||||
}
|
} else if (value instanceof String) {
|
||||||
else if (value instanceof String){
|
System.out.println((String) value);
|
||||||
System.out.println((String)value);
|
|
||||||
pyvars2.addStr(subkey, (String) value);
|
pyvars2.addStr(subkey, (String) value);
|
||||||
}
|
} else if (value instanceof Integer || value instanceof Long) {
|
||||||
else if (value instanceof Integer || value instanceof Long) {
|
|
||||||
Number number = (Number) value;
|
Number number = (Number) value;
|
||||||
pyvars2.addInt(subkey, number.intValue());
|
pyvars2.addInt(subkey, number.intValue());
|
||||||
}
|
} else if (value instanceof Float || value instanceof Double) {
|
||||||
else if (value instanceof Float || value instanceof Double) {
|
|
||||||
Number number = (Number) value;
|
Number number = (Number) value;
|
||||||
pyvars2.addFloat(subkey, number.doubleValue());
|
pyvars2.addFloat(subkey, number.doubleValue());
|
||||||
}
|
} else if (value instanceof NumpyArray) {
|
||||||
else if (value instanceof NumpyArray){
|
pyvars2.addNDArray(subkey, (NumpyArray) value);
|
||||||
pyvars2.addNDArray(subkey, (NumpyArray)value);
|
} else if (value == null) {
|
||||||
}
|
|
||||||
else if (value == null){
|
|
||||||
pyvars2.addStr(subkey, "None"); // FixMe
|
pyvars2.addStr(subkey, "None"); // FixMe
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
throw new RuntimeException("Unsupported type!" + value);
|
throw new RuntimeException("Unsupported type!" + value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return pyvars2;
|
return pyvars2;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static long[] jsonArrayToLongArray(JSONArray jsonArray){
|
public static long[] jsonArrayToLongArray(JSONArray jsonArray) {
|
||||||
long[] longs = new long[jsonArray.length()];
|
long[] longs = new long[jsonArray.length()];
|
||||||
for (int i=0; i<longs.length; i++){
|
for (int i = 0; i < longs.length; i++) {
|
||||||
|
|
||||||
longs[i] = jsonArray.getLong(i);
|
longs[i] = jsonArray.getLong(i);
|
||||||
}
|
}
|
||||||
return longs;
|
return longs;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Map<String, Object> toMap(JSONObject jsonobj) {
|
public static Map<String, Object> toMap(JSONObject jsonobj) {
|
||||||
Map<String, Object> map = new HashMap<>();
|
Map<String, Object> map = new HashMap<>();
|
||||||
String[] keys = (String[])jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]);
|
String[] keys = (String[]) jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]);
|
||||||
for (String key: keys){
|
for (String key : keys) {
|
||||||
Object value = jsonobj.get(key);
|
Object value = jsonobj.get(key);
|
||||||
if (value instanceof JSONArray) {
|
if (value instanceof JSONArray) {
|
||||||
value = toList((JSONArray) value);
|
value = toList((JSONArray) value);
|
||||||
} else if (value instanceof JSONObject) {
|
} else if (value instanceof JSONObject) {
|
||||||
JSONObject jsonobj2 = (JSONObject)value;
|
JSONObject jsonobj2 = (JSONObject) value;
|
||||||
if (jsonobj2.has("_is_numpy_array")){
|
if (jsonobj2.has("_is_numpy_array")) {
|
||||||
value = jsonToNumpyArray(jsonobj2);
|
value = jsonToNumpyArray(jsonobj2);
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
value = toMap(jsonobj2);
|
value = toMap(jsonobj2);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
map.put(key, value);
|
map.put(key, value);
|
||||||
} return map;
|
}
|
||||||
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -265,40 +261,35 @@ public class PythonUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private static NumpyArray jsonToNumpyArray(JSONObject map){
|
private static NumpyArray jsonToNumpyArray(JSONObject map) {
|
||||||
String dtypeName = (String)map.get("dtype");
|
String dtypeName = (String) map.get("dtype");
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
if (dtypeName.equals("float64")){
|
if (dtypeName.equals("float64")) {
|
||||||
dtype = DataType.DOUBLE;
|
dtype = DataType.DOUBLE;
|
||||||
}
|
} else if (dtypeName.equals("float32")) {
|
||||||
else if (dtypeName.equals("float32")){
|
|
||||||
dtype = DataType.FLOAT;
|
dtype = DataType.FLOAT;
|
||||||
}
|
} else if (dtypeName.equals("int16")) {
|
||||||
else if (dtypeName.equals("int16")){
|
|
||||||
dtype = DataType.SHORT;
|
dtype = DataType.SHORT;
|
||||||
}
|
} else if (dtypeName.equals("int32")) {
|
||||||
else if (dtypeName.equals("int32")){
|
|
||||||
dtype = DataType.INT;
|
dtype = DataType.INT;
|
||||||
}
|
} else if (dtypeName.equals("int64")) {
|
||||||
else if (dtypeName.equals("int64")){
|
|
||||||
dtype = DataType.LONG;
|
dtype = DataType.LONG;
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||||
}
|
}
|
||||||
List shapeList = (List)map.get("shape");
|
List shapeList = map.getJSONArray("shape").toList();
|
||||||
long[] shape = new long[shapeList.size()];
|
long[] shape = new long[shapeList.size()];
|
||||||
for (int i = 0; i < shape.length; i++) {
|
for (int i = 0; i < shape.length; i++) {
|
||||||
shape[i] = (Long)shapeList.get(i);
|
shape[i] = ((Number) shapeList.get(i)).longValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
List strideList = (List)map.get("shape");
|
List strideList = map.getJSONArray("shape").toList();
|
||||||
long[] stride = new long[strideList.size()];
|
long[] stride = new long[strideList.size()];
|
||||||
for (int i = 0; i < stride.length; i++) {
|
for (int i = 0; i < stride.length; i++) {
|
||||||
stride[i] = (Long)strideList.get(i);
|
stride[i] = ((Number) strideList.get(i)).longValue();
|
||||||
}
|
}
|
||||||
long address = (Long)map.get("address");
|
long address = ((Number) map.get("address")).longValue();
|
||||||
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
|
NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true);
|
||||||
return numpyArray;
|
return numpyArray;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,13 +17,19 @@
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.json.JSONObject;
|
import org.json.JSONObject;
|
||||||
import org.json.JSONArray;
|
import org.json.JSONArray;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Holds python variable names, types and values.
|
* Holds python variable names, types and values.
|
||||||
* Also handles mapping from java types to python types.
|
* Also handles mapping from java types to python types.
|
||||||
|
@ -33,41 +39,31 @@ import java.util.*;
|
||||||
|
|
||||||
@lombok.Data
|
@lombok.Data
|
||||||
public class PythonVariables implements java.io.Serializable {
|
public class PythonVariables implements java.io.Serializable {
|
||||||
|
|
||||||
public enum Type{
|
|
||||||
BOOL,
|
|
||||||
STR,
|
|
||||||
INT,
|
|
||||||
FLOAT,
|
|
||||||
NDARRAY,
|
|
||||||
LIST,
|
|
||||||
FILE,
|
|
||||||
DICT
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private java.util.Map<String, String> strVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, String> strVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, Long> intVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, Long> intVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, Double> floatVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, Double> floatVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, Boolean> boolVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, Boolean> boolVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, NumpyArray> ndVars = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, INDArray> ndVars = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, Object[]> listVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, List> listVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, String> fileVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, BytePointer> bytesVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, java.util.Map<?,?>> dictVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, java.util.Map<?, ?>> dictVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, Type> vars = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, PythonType.TypeName> vars = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<Type, java.util.Map> maps = new java.util.LinkedHashMap<>();
|
private java.util.Map<PythonType.TypeName, java.util.Map> maps = new java.util.LinkedHashMap<>();
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a copy of the variable
|
* Returns a copy of the variable
|
||||||
* schema in this array without the values
|
* schema in this array without the values
|
||||||
|
*
|
||||||
* @return an empty variables clone
|
* @return an empty variables clone
|
||||||
* with no values
|
* with no values
|
||||||
*/
|
*/
|
||||||
public PythonVariables copySchema(){
|
public PythonVariables copySchema() {
|
||||||
PythonVariables ret = new PythonVariables();
|
PythonVariables ret = new PythonVariables();
|
||||||
for (String varName: getVariables()){
|
for (String varName : getVariables()) {
|
||||||
Type type = getType(varName);
|
PythonType type = getType(varName);
|
||||||
ret.add(varName, type);
|
ret.add(varName, type);
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -77,21 +73,19 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public PythonVariables() {
|
public PythonVariables() {
|
||||||
maps.put(PythonVariables.Type.BOOL, boolVariables);
|
maps.put(PythonType.TypeName.BOOL, boolVariables);
|
||||||
maps.put(PythonVariables.Type.STR, strVariables);
|
maps.put(PythonType.TypeName.STR, strVariables);
|
||||||
maps.put(PythonVariables.Type.INT, intVariables);
|
maps.put(PythonType.TypeName.INT, intVariables);
|
||||||
maps.put(PythonVariables.Type.FLOAT, floatVariables);
|
maps.put(PythonType.TypeName.FLOAT, floatVariables);
|
||||||
maps.put(PythonVariables.Type.NDARRAY, ndVars);
|
maps.put(PythonType.TypeName.NDARRAY, ndVars);
|
||||||
maps.put(PythonVariables.Type.LIST, listVariables);
|
maps.put(PythonType.TypeName.LIST, listVariables);
|
||||||
maps.put(PythonVariables.Type.FILE, fileVariables);
|
maps.put(PythonType.TypeName.DICT, dictVariables);
|
||||||
maps.put(PythonVariables.Type.DICT, dictVariables);
|
maps.put(PythonType.TypeName.BYTES, bytesVariables);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @return true if there are no variables.
|
* @return true if there are no variables.
|
||||||
*/
|
*/
|
||||||
public boolean isEmpty() {
|
public boolean isEmpty() {
|
||||||
|
@ -100,12 +94,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name Name of the variable
|
* @param name Name of the variable
|
||||||
* @param type Type of the variable
|
* @param type Type of the variable
|
||||||
*/
|
*/
|
||||||
public void add(String name, Type type){
|
public void add(String name, PythonType type) {
|
||||||
switch (type){
|
switch (type.getName()) {
|
||||||
case BOOL:
|
case BOOL:
|
||||||
addBool(name);
|
addBool(name);
|
||||||
break;
|
break;
|
||||||
|
@ -124,21 +117,21 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
case LIST:
|
case LIST:
|
||||||
addList(name);
|
addList(name);
|
||||||
break;
|
break;
|
||||||
case FILE:
|
|
||||||
addFile(name);
|
|
||||||
break;
|
|
||||||
case DICT:
|
case DICT:
|
||||||
addDict(name);
|
addDict(name);
|
||||||
|
break;
|
||||||
|
case BYTES:
|
||||||
|
addBytes(name);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
* @param name name of the variable
|
||||||
* @param name name of the variable
|
* @param type type of the variable
|
||||||
* @param type type of the variable
|
|
||||||
* @param value value of the variable (must be instance of expected type)
|
* @param value value of the variable (must be instance of expected type)
|
||||||
*/
|
*/
|
||||||
public void add(String name, Type type, Object value) {
|
public void add(String name, PythonType type, Object value) throws PythonException {
|
||||||
add(name, type);
|
add(name, type);
|
||||||
setValue(name, value);
|
setValue(name, value);
|
||||||
}
|
}
|
||||||
|
@ -148,21 +141,23 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addDict(String name) {
|
public void addDict(String name) {
|
||||||
vars.put(name, PythonVariables.Type.DICT);
|
vars.put(name, PythonType.TypeName.DICT);
|
||||||
dictVariables.put(name,null);
|
dictVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addBool(String name){
|
public void addBool(String name) {
|
||||||
vars.put(name, PythonVariables.Type.BOOL);
|
vars.put(name, PythonType.TypeName.BOOL);
|
||||||
boolVariables.put(name, null);
|
boolVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,10 +165,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addStr(String name){
|
public void addStr(String name) {
|
||||||
vars.put(name, PythonVariables.Type.STR);
|
vars.put(name, PythonType.TypeName.STR);
|
||||||
strVariables.put(name, null);
|
strVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -181,10 +177,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addInt(String name){
|
public void addInt(String name) {
|
||||||
vars.put(name, PythonVariables.Type.INT);
|
vars.put(name, PythonType.TypeName.INT);
|
||||||
intVariables.put(name, null);
|
intVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -192,10 +189,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addFloat(String name){
|
public void addFloat(String name) {
|
||||||
vars.put(name, PythonVariables.Type.FLOAT);
|
vars.put(name, PythonType.TypeName.FLOAT);
|
||||||
floatVariables.put(name, null);
|
floatVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,10 +201,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addNDArray(String name){
|
public void addNDArray(String name) {
|
||||||
vars.put(name, PythonVariables.Type.NDARRAY);
|
vars.put(name, PythonType.TypeName.NDARRAY);
|
||||||
ndVars.put(name, null);
|
ndVars.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,99 +213,109 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addList(String name){
|
public void addList(String name) {
|
||||||
vars.put(name, PythonVariables.Type.LIST);
|
vars.put(name, PythonType.TypeName.LIST);
|
||||||
listVariables.put(name, null);
|
listVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a null variable to
|
|
||||||
* the set of variables
|
|
||||||
* to describe the type but no value
|
|
||||||
* @param name the field to add
|
|
||||||
*/
|
|
||||||
public void addFile(String name){
|
|
||||||
vars.put(name, PythonVariables.Type.FILE);
|
|
||||||
fileVariables.put(name, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a boolean variable to
|
* Add a boolean variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* @param name the field to add
|
*
|
||||||
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addBool(String name, boolean value) {
|
public void addBool(String name, boolean value) {
|
||||||
vars.put(name, PythonVariables.Type.BOOL);
|
vars.put(name, PythonType.TypeName.BOOL);
|
||||||
boolVariables.put(name, value);
|
boolVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a string variable to
|
* Add a string variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* @param name the field to add
|
*
|
||||||
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addStr(String name, String value) {
|
public void addStr(String name, String value) {
|
||||||
vars.put(name, PythonVariables.Type.STR);
|
vars.put(name, PythonType.TypeName.STR);
|
||||||
strVariables.put(name, value);
|
strVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add an int variable to
|
* Add an int variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* @param name the field to add
|
*
|
||||||
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addInt(String name, int value) {
|
public void addInt(String name, int value) {
|
||||||
vars.put(name, PythonVariables.Type.INT);
|
vars.put(name, PythonType.TypeName.INT);
|
||||||
intVariables.put(name, (long)value);
|
intVariables.put(name, (long) value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a long variable to
|
* Add a long variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* @param name the field to add
|
*
|
||||||
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addInt(String name, long value) {
|
public void addInt(String name, long value) {
|
||||||
vars.put(name, PythonVariables.Type.INT);
|
vars.put(name, PythonType.TypeName.INT);
|
||||||
intVariables.put(name, value);
|
intVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a double variable to
|
* Add a double variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* @param name the field to add
|
*
|
||||||
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addFloat(String name, double value) {
|
public void addFloat(String name, double value) {
|
||||||
vars.put(name, PythonVariables.Type.FLOAT);
|
vars.put(name, PythonType.TypeName.FLOAT);
|
||||||
floatVariables.put(name, value);
|
floatVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a float variable to
|
* Add a float variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* @param name the field to add
|
*
|
||||||
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addFloat(String name, float value) {
|
public void addFloat(String name, float value) {
|
||||||
vars.put(name, PythonVariables.Type.FLOAT);
|
vars.put(name, PythonType.TypeName.FLOAT);
|
||||||
floatVariables.put(name, (double)value);
|
floatVariables.put(name, (double) value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
* @param name the field to add
|
*
|
||||||
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addNDArray(String name, NumpyArray value) {
|
public void addNDArray(String name, NumpyArray value) {
|
||||||
vars.put(name, PythonVariables.Type.NDARRAY);
|
vars.put(name, PythonType.TypeName.NDARRAY);
|
||||||
|
ndVars.put(name, value.getNd4jArray());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
*
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
|
||||||
|
vars.put(name, PythonType.TypeName.NDARRAY);
|
||||||
ndVars.put(name, value);
|
ndVars.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -314,117 +323,63 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
* @param name the field to add
|
*
|
||||||
* @param value the value to add
|
* @param name the field to add
|
||||||
*/
|
|
||||||
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
|
|
||||||
vars.put(name, PythonVariables.Type.NDARRAY);
|
|
||||||
ndVars.put(name, new NumpyArray(value));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a null variable to
|
|
||||||
* the set of variables
|
|
||||||
* to describe the type but no value
|
|
||||||
* @param name the field to add
|
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addList(String name, Object[] value) {
|
public void addList(String name, Object[] value) {
|
||||||
vars.put(name, PythonVariables.Type.LIST);
|
vars.put(name, PythonType.TypeName.LIST);
|
||||||
listVariables.put(name, value);
|
listVariables.put(name, Arrays.asList(value));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
* @param name the field to add
|
*
|
||||||
* @param value the value to add
|
* @param name the field to add
|
||||||
*/
|
|
||||||
public void addFile(String name, String value) {
|
|
||||||
vars.put(name, PythonVariables.Type.FILE);
|
|
||||||
fileVariables.put(name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a null variable to
|
|
||||||
* the set of variables
|
|
||||||
* to describe the type but no value
|
|
||||||
* @param name the field to add
|
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addDict(String name, java.util.Map value) {
|
public void addDict(String name, java.util.Map value) {
|
||||||
vars.put(name, PythonVariables.Type.DICT);
|
vars.put(name, PythonType.TypeName.DICT);
|
||||||
dictVariables.put(name, value);
|
dictVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public void addBytes(String name){
|
||||||
|
vars.put(name, PythonType.TypeName.BYTES);
|
||||||
|
bytesVariables.put(name, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addBytes(String name, BytePointer value){
|
||||||
|
vars.put(name, PythonType.TypeName.BYTES);
|
||||||
|
bytesVariables.put(name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// public void addBytes(String name, ByteBuffer value){
|
||||||
|
// Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress((value.address());
|
||||||
|
// BytePointer bp = new BytePointer(ptr);
|
||||||
|
// addBytes(name, bp);
|
||||||
|
// }
|
||||||
/**
|
/**
|
||||||
*
|
* @param name name of the variable
|
||||||
* @param name name of the variable
|
|
||||||
* @param value new value for the variable
|
* @param value new value for the variable
|
||||||
*/
|
*/
|
||||||
public void setValue(String name, Object value) {
|
public void setValue(String name, Object value) throws PythonException {
|
||||||
Type type = vars.get(name);
|
PythonType.TypeName type = vars.get(name);
|
||||||
if (type == PythonVariables.Type.BOOL){
|
maps.get(type).put(name, PythonType.valueOf(type).convert(value));
|
||||||
boolVariables.put(name, (Boolean)value);
|
|
||||||
}
|
|
||||||
else if (type == PythonVariables.Type.INT){
|
|
||||||
Number number = (Number) value;
|
|
||||||
intVariables.put(name, number.longValue());
|
|
||||||
}
|
|
||||||
else if (type == PythonVariables.Type.FLOAT){
|
|
||||||
Number number = (Number) value;
|
|
||||||
floatVariables.put(name, number.doubleValue());
|
|
||||||
}
|
|
||||||
else if (type == PythonVariables.Type.NDARRAY){
|
|
||||||
if (value instanceof NumpyArray){
|
|
||||||
ndVars.put(name, (NumpyArray)value);
|
|
||||||
}
|
|
||||||
else if (value instanceof org.nd4j.linalg.api.ndarray.INDArray) {
|
|
||||||
ndVars.put(name, new NumpyArray((org.nd4j.linalg.api.ndarray.INDArray) value));
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
throw new RuntimeException("Unsupported type: " + value.getClass().toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (type == PythonVariables.Type.LIST) {
|
|
||||||
if (value instanceof java.util.List) {
|
|
||||||
value = ((java.util.List) value).toArray();
|
|
||||||
listVariables.put(name, (Object[]) value);
|
|
||||||
}
|
|
||||||
else if(value instanceof org.json.JSONArray) {
|
|
||||||
org.json.JSONArray jsonArray = (org.json.JSONArray) value;
|
|
||||||
Object[] copyArr = new Object[jsonArray.length()];
|
|
||||||
for(int i = 0; i < copyArr.length; i++) {
|
|
||||||
copyArr[i] = jsonArray.get(i);
|
|
||||||
}
|
|
||||||
listVariables.put(name, copyArr);
|
|
||||||
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
listVariables.put(name, (Object[]) value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if(type == PythonVariables.Type.DICT) {
|
|
||||||
dictVariables.put(name,(java.util.Map<?,?>) value);
|
|
||||||
}
|
|
||||||
else if (type == PythonVariables.Type.FILE){
|
|
||||||
fileVariables.put(name, (String)value);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
strVariables.put(name, (String)value);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Do a general object lookup.
|
* Do a general object lookup.
|
||||||
* The look up will happen relative to the {@link Type}
|
* The look up will happen relative to the {@link PythonType}
|
||||||
* of variable is described in the
|
* of variable is described in the
|
||||||
|
*
|
||||||
* @param name the name of the variable to get
|
* @param name the name of the variable to get
|
||||||
* @return teh value for the variable with the given name
|
* @return teh value for the variable with the given name
|
||||||
*/
|
*/
|
||||||
public Object getValue(String name) {
|
public Object getValue(String name) {
|
||||||
Type type = vars.get(name);
|
PythonType.TypeName type = vars.get(name);
|
||||||
java.util.Map map = maps.get(type);
|
java.util.Map map = maps.get(type);
|
||||||
return map.get(name);
|
return map.get(name);
|
||||||
}
|
}
|
||||||
|
@ -432,6 +387,7 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a boolean variable with the given name.
|
* Returns a boolean variable with the given name.
|
||||||
|
*
|
||||||
* @param name the variable name to get the value for
|
* @param name the variable name to get the value for
|
||||||
* @return the retrieved boolean value
|
* @return the retrieved boolean value
|
||||||
*/
|
*/
|
||||||
|
@ -440,80 +396,78 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the dictionary value
|
* @return the dictionary value
|
||||||
*/
|
*/
|
||||||
public java.util.Map<?,?> getDictValue(String name) {
|
public java.util.Map<?, ?> getDictValue(String name) {
|
||||||
return dictVariables.get(name);
|
return dictVariables.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
/**
|
* /**
|
||||||
*
|
*
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the string value
|
* @return the string value
|
||||||
*/
|
*/
|
||||||
public String getStrValue(String name){
|
public String getStrValue(String name) {
|
||||||
return strVariables.get(name);
|
return strVariables.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the long value
|
* @return the long value
|
||||||
*/
|
*/
|
||||||
public Long getIntValue(String name){
|
public Long getIntValue(String name) {
|
||||||
return intVariables.get(name);
|
return intVariables.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the float value
|
* @return the float value
|
||||||
*/
|
*/
|
||||||
public Double getFloatValue(String name){
|
public Double getFloatValue(String name) {
|
||||||
return floatVariables.get(name);
|
return floatVariables.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the numpy array value
|
* @return the numpy array value
|
||||||
*/
|
*/
|
||||||
public NumpyArray getNDArrayValue(String name){
|
public INDArray getNDArrayValue(String name) {
|
||||||
return ndVars.get(name);
|
return ndVars.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the list value as an object array
|
* @return the list value as an object array
|
||||||
*/
|
*/
|
||||||
public Object[] getListValue(String name){
|
public List getListValue(String name) {
|
||||||
return listVariables.get(name);
|
return listVariables.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the value of the given file name
|
* @return the bytes value as a BytePointer
|
||||||
*/
|
*/
|
||||||
public String getFileValue(String name){
|
public BytePointer getBytesValue(String name){return bytesVariables.get(name);}
|
||||||
return fileVariables.get(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the type for the given variable name
|
* Returns the type for the given variable name
|
||||||
|
*
|
||||||
* @param name the name of the variable to get the type for
|
* @param name the name of the variable to get the type for
|
||||||
* @return the type for the given variable
|
* @return the type for the given variable
|
||||||
*/
|
*/
|
||||||
public Type getType(String name){
|
public PythonType getType(String name){
|
||||||
return vars.get(name);
|
try{
|
||||||
|
return PythonType.valueOf(vars.get(name)); // will never fail
|
||||||
|
}catch (Exception e)
|
||||||
|
{
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get all the variables present as a string array
|
* Get all the variables present as a string array
|
||||||
|
*
|
||||||
* @return the variable names for this variable sset
|
* @return the variable names for this variable sset
|
||||||
*/
|
*/
|
||||||
public String[] getVariables() {
|
public String[] getVariables() {
|
||||||
|
@ -524,11 +478,12 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This variables set as its json representation (an array of json objects)
|
* This variables set as its json representation (an array of json objects)
|
||||||
|
*
|
||||||
* @return the json array output
|
* @return the json array output
|
||||||
*/
|
*/
|
||||||
public org.json.JSONArray toJSON(){
|
public org.json.JSONArray toJSON() {
|
||||||
org.json.JSONArray arr = new org.json.JSONArray();
|
org.json.JSONArray arr = new org.json.JSONArray();
|
||||||
for (String varName: getVariables()){
|
for (String varName : getVariables()) {
|
||||||
org.json.JSONObject var = new org.json.JSONObject();
|
org.json.JSONObject var = new org.json.JSONObject();
|
||||||
var.put("name", varName);
|
var.put("name", varName);
|
||||||
String varType = getType(varName).toString();
|
String varType = getType(varName).toString();
|
||||||
|
@ -542,13 +497,14 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Create a schema from a map.
|
* Create a schema from a map.
|
||||||
* This is an empty PythonVariables
|
* This is an empty PythonVariables
|
||||||
* that just contains names and types with no values
|
* that just contains names and types with no values
|
||||||
|
*
|
||||||
* @param inputTypes the input types to convert
|
* @param inputTypes the input types to convert
|
||||||
* @return the schema from the given map
|
* @return the schema from the given map
|
||||||
*/
|
*/
|
||||||
public static PythonVariables schemaFromMap(java.util.Map<String,String> inputTypes) {
|
public static PythonVariables schemaFromMap(java.util.Map<String, String> inputTypes) throws Exception{
|
||||||
PythonVariables ret = new PythonVariables();
|
PythonVariables ret = new PythonVariables();
|
||||||
for(java.util.Map.Entry<String,String> entry : inputTypes.entrySet()) {
|
for (java.util.Map.Entry<String, String> entry : inputTypes.entrySet()) {
|
||||||
ret.add(entry.getKey(), PythonVariables.Type.valueOf(entry.getValue()));
|
ret.add(entry.getKey(), PythonType.valueOf(entry.getValue()));
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -557,39 +513,17 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
/**
|
/**
|
||||||
* Get the python variable state relative to the
|
* Get the python variable state relative to the
|
||||||
* input json array
|
* input json array
|
||||||
|
*
|
||||||
* @param jsonArray the input json array
|
* @param jsonArray the input json array
|
||||||
* @return the python variables based on the input json array
|
* @return the python variables based on the input json array
|
||||||
*/
|
*/
|
||||||
public static PythonVariables fromJSON(org.json.JSONArray jsonArray){
|
public static PythonVariables fromJSON(org.json.JSONArray jsonArray) {
|
||||||
PythonVariables pyvars = new PythonVariables();
|
PythonVariables pyvars = new PythonVariables();
|
||||||
for (int i = 0; i < jsonArray.length(); i++) {
|
for (int i = 0; i < jsonArray.length(); i++) {
|
||||||
org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i);
|
org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i);
|
||||||
String varName = (String)input.get("name");
|
String varName = (String) input.get("name");
|
||||||
String varType = (String)input.get("type");
|
String varType = (String) input.get("type");
|
||||||
if (varType.equals("BOOL")) {
|
pyvars.maps.get(PythonType.TypeName.valueOf(varType)).put(varName, null);
|
||||||
pyvars.addBool(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("INT")) {
|
|
||||||
pyvars.addInt(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("FlOAT")){
|
|
||||||
pyvars.addFloat(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("STR")) {
|
|
||||||
pyvars.addStr(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("LIST")) {
|
|
||||||
pyvars.addList(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("FILE")){
|
|
||||||
pyvars.addFile(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("NDARRAY")) {
|
|
||||||
pyvars.addNDArray(varName);
|
|
||||||
}
|
|
||||||
else if(varType.equals("DICT")) {
|
|
||||||
pyvars.addDict(varName);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return pyvars;
|
return pyvars;
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
#See: https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script
|
|
||||||
import sys
|
|
||||||
this = sys.modules[__name__]
|
|
||||||
for n in dir():
|
|
||||||
if n[0]!='_': delattr(this, n)
|
|
|
@ -1 +0,0 @@
|
||||||
loc = {}
|
|
|
@ -1,20 +0,0 @@
|
||||||
|
|
||||||
def __is_numpy_array(x):
|
|
||||||
return str(type(x))== "<class 'numpy.ndarray'>"
|
|
||||||
|
|
||||||
def maybe_serialize_ndarray_metadata(x):
|
|
||||||
return serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_ndarray_metadata(x):
|
|
||||||
return {"address": x.__array_interface__['data'][0],
|
|
||||||
"shape": x.shape,
|
|
||||||
"strides": x.strides,
|
|
||||||
"dtype": str(x.dtype),
|
|
||||||
"_is_numpy_array": True} if __is_numpy_array(x) else x
|
|
||||||
|
|
||||||
|
|
||||||
def is_json_ready(key, value):
|
|
||||||
return key is not 'f2' and not inspect.ismodule(value) \
|
|
||||||
and not hasattr(value, '__call__')
|
|
||||||
|
|
|
@ -1,202 +0,0 @@
|
||||||
#patch
|
|
||||||
|
|
||||||
"""Implementation of __array_function__ overrides from NEP-18."""
|
|
||||||
import collections
|
|
||||||
import functools
|
|
||||||
import os
|
|
||||||
|
|
||||||
from numpy.core._multiarray_umath import (
|
|
||||||
add_docstring, implement_array_function, _get_implementing_args)
|
|
||||||
from numpy.compat._inspect import getargspec
|
|
||||||
|
|
||||||
|
|
||||||
ENABLE_ARRAY_FUNCTION = bool(
|
|
||||||
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
|
|
||||||
|
|
||||||
|
|
||||||
ARRAY_FUNCTION_ENABLED = ENABLE_ARRAY_FUNCTION # backward compat
|
|
||||||
|
|
||||||
|
|
||||||
_add_docstring = add_docstring
|
|
||||||
|
|
||||||
|
|
||||||
def add_docstring(*args):
|
|
||||||
try:
|
|
||||||
_add_docstring(*args)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
add_docstring(
|
|
||||||
implement_array_function,
|
|
||||||
"""
|
|
||||||
Implement a function with checks for __array_function__ overrides.
|
|
||||||
|
|
||||||
All arguments are required, and can only be passed by position.
|
|
||||||
|
|
||||||
Arguments
|
|
||||||
---------
|
|
||||||
implementation : function
|
|
||||||
Function that implements the operation on NumPy array without
|
|
||||||
overrides when called like ``implementation(*args, **kwargs)``.
|
|
||||||
public_api : function
|
|
||||||
Function exposed by NumPy's public API originally called like
|
|
||||||
``public_api(*args, **kwargs)`` on which arguments are now being
|
|
||||||
checked.
|
|
||||||
relevant_args : iterable
|
|
||||||
Iterable of arguments to check for __array_function__ methods.
|
|
||||||
args : tuple
|
|
||||||
Arbitrary positional arguments originally passed into ``public_api``.
|
|
||||||
kwargs : dict
|
|
||||||
Arbitrary keyword arguments originally passed into ``public_api``.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Result from calling ``implementation()`` or an ``__array_function__``
|
|
||||||
method, as appropriate.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
TypeError : if no implementation is found.
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
# exposed for testing purposes; used internally by implement_array_function
|
|
||||||
add_docstring(
|
|
||||||
_get_implementing_args,
|
|
||||||
"""
|
|
||||||
Collect arguments on which to call __array_function__.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
relevant_args : iterable of array-like
|
|
||||||
Iterable of possibly array-like arguments to check for
|
|
||||||
__array_function__ methods.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Sequence of arguments with __array_function__ methods, in the order in
|
|
||||||
which they should be called.
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
|
|
||||||
|
|
||||||
|
|
||||||
def verify_matching_signatures(implementation, dispatcher):
|
|
||||||
"""Verify that a dispatcher function has the right signature."""
|
|
||||||
implementation_spec = ArgSpec(*getargspec(implementation))
|
|
||||||
dispatcher_spec = ArgSpec(*getargspec(dispatcher))
|
|
||||||
|
|
||||||
if (implementation_spec.args != dispatcher_spec.args or
|
|
||||||
implementation_spec.varargs != dispatcher_spec.varargs or
|
|
||||||
implementation_spec.keywords != dispatcher_spec.keywords or
|
|
||||||
(bool(implementation_spec.defaults) !=
|
|
||||||
bool(dispatcher_spec.defaults)) or
|
|
||||||
(implementation_spec.defaults is not None and
|
|
||||||
len(implementation_spec.defaults) !=
|
|
||||||
len(dispatcher_spec.defaults))):
|
|
||||||
raise RuntimeError('implementation and dispatcher for %s have '
|
|
||||||
'different function signatures' % implementation)
|
|
||||||
|
|
||||||
if implementation_spec.defaults is not None:
|
|
||||||
if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
|
|
||||||
raise RuntimeError('dispatcher functions can only use None for '
|
|
||||||
'default argument values')
|
|
||||||
|
|
||||||
|
|
||||||
def set_module(module):
|
|
||||||
"""Decorator for overriding __module__ on a function or class.
|
|
||||||
|
|
||||||
Example usage::
|
|
||||||
|
|
||||||
@set_module('numpy')
|
|
||||||
def example():
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert example.__module__ == 'numpy'
|
|
||||||
"""
|
|
||||||
def decorator(func):
|
|
||||||
if module is not None:
|
|
||||||
func.__module__ = module
|
|
||||||
return func
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def array_function_dispatch(dispatcher, module=None, verify=True,
|
|
||||||
docs_from_dispatcher=False):
|
|
||||||
"""Decorator for adding dispatch with the __array_function__ protocol.
|
|
||||||
|
|
||||||
See NEP-18 for example usage.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
dispatcher : callable
|
|
||||||
Function that when called like ``dispatcher(*args, **kwargs)`` with
|
|
||||||
arguments from the NumPy function call returns an iterable of
|
|
||||||
array-like arguments to check for ``__array_function__``.
|
|
||||||
module : str, optional
|
|
||||||
__module__ attribute to set on new function, e.g., ``module='numpy'``.
|
|
||||||
By default, module is copied from the decorated function.
|
|
||||||
verify : bool, optional
|
|
||||||
If True, verify the that the signature of the dispatcher and decorated
|
|
||||||
function signatures match exactly: all required and optional arguments
|
|
||||||
should appear in order with the same names, but the default values for
|
|
||||||
all optional arguments should be ``None``. Only disable verification
|
|
||||||
if the dispatcher's signature needs to deviate for some particular
|
|
||||||
reason, e.g., because the function has a signature like
|
|
||||||
``func(*args, **kwargs)``.
|
|
||||||
docs_from_dispatcher : bool, optional
|
|
||||||
If True, copy docs from the dispatcher function onto the dispatched
|
|
||||||
function, rather than from the implementation. This is useful for
|
|
||||||
functions defined in C, which otherwise don't have docstrings.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Function suitable for decorating the implementation of a NumPy function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not ENABLE_ARRAY_FUNCTION:
|
|
||||||
# __array_function__ requires an explicit opt-in for now
|
|
||||||
def decorator(implementation):
|
|
||||||
if module is not None:
|
|
||||||
implementation.__module__ = module
|
|
||||||
if docs_from_dispatcher:
|
|
||||||
add_docstring(implementation, dispatcher.__doc__)
|
|
||||||
return implementation
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def decorator(implementation):
|
|
||||||
if verify:
|
|
||||||
verify_matching_signatures(implementation, dispatcher)
|
|
||||||
|
|
||||||
if docs_from_dispatcher:
|
|
||||||
add_docstring(implementation, dispatcher.__doc__)
|
|
||||||
|
|
||||||
@functools.wraps(implementation)
|
|
||||||
def public_api(*args, **kwargs):
|
|
||||||
relevant_args = dispatcher(*args, **kwargs)
|
|
||||||
return implement_array_function(
|
|
||||||
implementation, public_api, relevant_args, args, kwargs)
|
|
||||||
|
|
||||||
if module is not None:
|
|
||||||
public_api.__module__ = module
|
|
||||||
|
|
||||||
# TODO: remove this when we drop Python 2 support (functools.wraps)
|
|
||||||
# adds __wrapped__ automatically in later versions)
|
|
||||||
public_api.__wrapped__ = implementation
|
|
||||||
|
|
||||||
return public_api
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def array_function_from_dispatcher(
|
|
||||||
implementation, module=None, verify=True, docs_from_dispatcher=True):
|
|
||||||
"""Like array_function_dispatcher, but with function arguments flipped."""
|
|
||||||
|
|
||||||
def decorator(dispatcher):
|
|
||||||
return array_function_dispatch(
|
|
||||||
dispatcher, module, verify=verify,
|
|
||||||
docs_from_dispatcher=docs_from_dispatcher)(implementation)
|
|
||||||
return decorator
|
|
|
@ -1,172 +0,0 @@
|
||||||
#patch 1
|
|
||||||
|
|
||||||
"""
|
|
||||||
========================
|
|
||||||
Random Number Generation
|
|
||||||
========================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Utility functions
|
|
||||||
==============================================================================
|
|
||||||
random_sample Uniformly distributed floats over ``[0, 1)``.
|
|
||||||
random Alias for `random_sample`.
|
|
||||||
bytes Uniformly distributed random bytes.
|
|
||||||
random_integers Uniformly distributed integers in a given range.
|
|
||||||
permutation Randomly permute a sequence / generate a random sequence.
|
|
||||||
shuffle Randomly permute a sequence in place.
|
|
||||||
seed Seed the random number generator.
|
|
||||||
choice Random sample from 1-D array.
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Compatibility functions
|
|
||||||
==============================================================================
|
|
||||||
rand Uniformly distributed values.
|
|
||||||
randn Normally distributed values.
|
|
||||||
ranf Uniformly distributed floating point numbers.
|
|
||||||
randint Uniformly distributed integers in a given range.
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Univariate distributions
|
|
||||||
==============================================================================
|
|
||||||
beta Beta distribution over ``[0, 1]``.
|
|
||||||
binomial Binomial distribution.
|
|
||||||
chisquare :math:`\\chi^2` distribution.
|
|
||||||
exponential Exponential distribution.
|
|
||||||
f F (Fisher-Snedecor) distribution.
|
|
||||||
gamma Gamma distribution.
|
|
||||||
geometric Geometric distribution.
|
|
||||||
gumbel Gumbel distribution.
|
|
||||||
hypergeometric Hypergeometric distribution.
|
|
||||||
laplace Laplace distribution.
|
|
||||||
logistic Logistic distribution.
|
|
||||||
lognormal Log-normal distribution.
|
|
||||||
logseries Logarithmic series distribution.
|
|
||||||
negative_binomial Negative binomial distribution.
|
|
||||||
noncentral_chisquare Non-central chi-square distribution.
|
|
||||||
noncentral_f Non-central F distribution.
|
|
||||||
normal Normal / Gaussian distribution.
|
|
||||||
pareto Pareto distribution.
|
|
||||||
poisson Poisson distribution.
|
|
||||||
power Power distribution.
|
|
||||||
rayleigh Rayleigh distribution.
|
|
||||||
triangular Triangular distribution.
|
|
||||||
uniform Uniform distribution.
|
|
||||||
vonmises Von Mises circular distribution.
|
|
||||||
wald Wald (inverse Gaussian) distribution.
|
|
||||||
weibull Weibull distribution.
|
|
||||||
zipf Zipf's distribution over ranked data.
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Multivariate distributions
|
|
||||||
==============================================================================
|
|
||||||
dirichlet Multivariate generalization of Beta distribution.
|
|
||||||
multinomial Multivariate generalization of the binomial distribution.
|
|
||||||
multivariate_normal Multivariate generalization of the normal distribution.
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Standard distributions
|
|
||||||
==============================================================================
|
|
||||||
standard_cauchy Standard Cauchy-Lorentz distribution.
|
|
||||||
standard_exponential Standard exponential distribution.
|
|
||||||
standard_gamma Standard Gamma distribution.
|
|
||||||
standard_normal Standard normal distribution.
|
|
||||||
standard_t Standard Student's t-distribution.
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Internal functions
|
|
||||||
==============================================================================
|
|
||||||
get_state Get tuple representing internal state of generator.
|
|
||||||
set_state Set state of generator.
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import division, absolute_import, print_function
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'beta',
|
|
||||||
'binomial',
|
|
||||||
'bytes',
|
|
||||||
'chisquare',
|
|
||||||
'choice',
|
|
||||||
'dirichlet',
|
|
||||||
'exponential',
|
|
||||||
'f',
|
|
||||||
'gamma',
|
|
||||||
'geometric',
|
|
||||||
'get_state',
|
|
||||||
'gumbel',
|
|
||||||
'hypergeometric',
|
|
||||||
'laplace',
|
|
||||||
'logistic',
|
|
||||||
'lognormal',
|
|
||||||
'logseries',
|
|
||||||
'multinomial',
|
|
||||||
'multivariate_normal',
|
|
||||||
'negative_binomial',
|
|
||||||
'noncentral_chisquare',
|
|
||||||
'noncentral_f',
|
|
||||||
'normal',
|
|
||||||
'pareto',
|
|
||||||
'permutation',
|
|
||||||
'poisson',
|
|
||||||
'power',
|
|
||||||
'rand',
|
|
||||||
'randint',
|
|
||||||
'randn',
|
|
||||||
'random_integers',
|
|
||||||
'random_sample',
|
|
||||||
'rayleigh',
|
|
||||||
'seed',
|
|
||||||
'set_state',
|
|
||||||
'shuffle',
|
|
||||||
'standard_cauchy',
|
|
||||||
'standard_exponential',
|
|
||||||
'standard_gamma',
|
|
||||||
'standard_normal',
|
|
||||||
'standard_t',
|
|
||||||
'triangular',
|
|
||||||
'uniform',
|
|
||||||
'vonmises',
|
|
||||||
'wald',
|
|
||||||
'weibull',
|
|
||||||
'zipf'
|
|
||||||
]
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
|
|
||||||
try:
|
|
||||||
from .mtrand import *
|
|
||||||
# Some aliases:
|
|
||||||
ranf = random = sample = random_sample
|
|
||||||
__all__.extend(['ranf', 'random', 'sample'])
|
|
||||||
except:
|
|
||||||
warnings.warn("numpy.random is not available when using multiple interpreters!")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __RandomState_ctor():
|
|
||||||
"""Return a RandomState instance.
|
|
||||||
|
|
||||||
This function exists solely to assist (un)pickling.
|
|
||||||
|
|
||||||
Note that the state of the RandomState returned here is irrelevant, as this function's
|
|
||||||
entire purpose is to return a newly allocated RandomState whose state pickle can set.
|
|
||||||
Consequently the RandomState returned by this function is a freshly allocated copy
|
|
||||||
with a seed=0.
|
|
||||||
|
|
||||||
See https://github.com/numpy/numpy/issues/4763 for a detailed discussion
|
|
||||||
|
|
||||||
"""
|
|
||||||
return RandomState(seed=0)
|
|
||||||
|
|
||||||
from numpy._pytesttester import PytestTester
|
|
||||||
test = PytestTester(__name__)
|
|
||||||
del PytestTester
|
|
|
@ -3,13 +3,13 @@ import traceback
|
||||||
import json
|
import json
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
__python_exception__ = ""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
pass
|
pass
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
__python_exception__ = ex
|
||||||
try:
|
try:
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -1,50 +0,0 @@
|
||||||
def __is_numpy_array(x):
|
|
||||||
return str(type(x))== "<class 'numpy.ndarray'>"
|
|
||||||
|
|
||||||
def __maybe_serialize_ndarray_metadata(x):
|
|
||||||
return __serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
|
|
||||||
|
|
||||||
|
|
||||||
def __serialize_ndarray_metadata(x):
|
|
||||||
return {"address": x.__array_interface__['data'][0],
|
|
||||||
"shape": x.shape,
|
|
||||||
"strides": x.strides,
|
|
||||||
"dtype": str(x.dtype),
|
|
||||||
"_is_numpy_array": True} if __is_numpy_array(x) else x
|
|
||||||
|
|
||||||
|
|
||||||
def __serialize_list(x):
|
|
||||||
import json
|
|
||||||
return json.dumps(__recursive_serialize_list(x))
|
|
||||||
|
|
||||||
|
|
||||||
def __serialize_dict(x):
|
|
||||||
import json
|
|
||||||
return json.dumps(__recursive_serialize_dict(x))
|
|
||||||
|
|
||||||
def __recursive_serialize_list(x):
|
|
||||||
out = []
|
|
||||||
for i in x:
|
|
||||||
if __is_numpy_array(i):
|
|
||||||
out.append(__serialize_ndarray_metadata(i))
|
|
||||||
elif isinstance(i, (list, tuple)):
|
|
||||||
out.append(__recursive_serialize_list(i))
|
|
||||||
elif isinstance(i, dict):
|
|
||||||
out.append(__recursive_serialize_dict(i))
|
|
||||||
else:
|
|
||||||
out.append(i)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __recursive_serialize_dict(x):
|
|
||||||
out = {}
|
|
||||||
for k in x:
|
|
||||||
v = x[k]
|
|
||||||
if __is_numpy_array(v):
|
|
||||||
out[k] = __serialize_ndarray_metadata(v)
|
|
||||||
elif isinstance(v, (list, tuple)):
|
|
||||||
out[k] = __recursive_serialize_list(v)
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
out[k] = __recursive_serialize_dict(v)
|
|
||||||
else:
|
|
||||||
out[k] = v
|
|
||||||
return out
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
|
|
||||||
|
@NotThreadSafe
|
||||||
|
public class TestPythonContextManager {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInt() throws Exception{
|
||||||
|
Python.setContext("context1");
|
||||||
|
Python.exec("a = 1");
|
||||||
|
Python.setContext("context2");
|
||||||
|
Python.exec("a = 2");
|
||||||
|
Python.setContext("context3");
|
||||||
|
Python.exec("a = 3");
|
||||||
|
|
||||||
|
|
||||||
|
Python.setContext("context1");
|
||||||
|
Assert.assertEquals(1, PythonExecutioner.getVariable("a").toInt());
|
||||||
|
|
||||||
|
Python.setContext("context2");
|
||||||
|
Assert.assertEquals(2, PythonExecutioner.getVariable("a").toInt());
|
||||||
|
|
||||||
|
Python.setContext("context3");
|
||||||
|
Assert.assertEquals(3, PythonExecutioner.getVariable("a").toInt());
|
||||||
|
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNDArray() throws Exception{
|
||||||
|
Python.setContext("context1");
|
||||||
|
Python.exec("import numpy as np");
|
||||||
|
Python.exec("a = np.zeros((3,2)) + 1");
|
||||||
|
|
||||||
|
Python.setContext("context2");
|
||||||
|
Python.exec("import numpy as np");
|
||||||
|
Python.exec("a = np.zeros((3,2)) + 2");
|
||||||
|
|
||||||
|
Python.setContext("context3");
|
||||||
|
Python.exec("import numpy as np");
|
||||||
|
Python.exec("a = np.zeros((3,2)) + 3");
|
||||||
|
|
||||||
|
Python.setContext("context1");
|
||||||
|
Python.exec("a += 1");
|
||||||
|
|
||||||
|
Python.setContext("context2");
|
||||||
|
Python.exec("a += 2");
|
||||||
|
|
||||||
|
Python.setContext("context3");
|
||||||
|
Python.exec("a += 3");
|
||||||
|
|
||||||
|
INDArray arr = Nd4j.create(DataType.DOUBLE, 3, 2);
|
||||||
|
Python.setContext("context1");
|
||||||
|
Assert.assertEquals(arr.add(2), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray());
|
||||||
|
|
||||||
|
Python.setContext("context2");
|
||||||
|
Assert.assertEquals(arr.add(4), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray());
|
||||||
|
|
||||||
|
Python.setContext("context3");
|
||||||
|
Assert.assertEquals(arr.add(6), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,64 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.var;
|
||||||
|
import org.json.JSONArray;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
|
public class TestPythonDict {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonDictFromMap() throws Exception{
|
||||||
|
Map<Object, Object> map = new HashMap<>();
|
||||||
|
map.put("a", 1);
|
||||||
|
map.put("b", "a");
|
||||||
|
map.put("1", Arrays.asList(1, 2, 3, "4", Arrays.asList("x", 2.3)));
|
||||||
|
Map<Object, Object> innerMap = new HashMap<>();
|
||||||
|
innerMap.put("k", 32);
|
||||||
|
map.put("inner", innerMap);
|
||||||
|
map.put("ndarray", Nd4j.linspace(1, 4, 4));
|
||||||
|
innerMap.put("ndarray", Nd4j.linspace(5, 8, 4));
|
||||||
|
PythonObject dict = new PythonObject(map);
|
||||||
|
assertEquals(map.size(), Python.len(dict).toInt());
|
||||||
|
assertEquals("{'a': 1, '1': [1, 2, 3, '4', ['" +
|
||||||
|
"x', 2.3]], 'b': 'a', 'inner': {'k': 32," +
|
||||||
|
" 'ndarray': array([5., 6., 7., 8.], dty" +
|
||||||
|
"pe=float32)}, 'ndarray': array([1., 2., " +
|
||||||
|
"3., 4.], dtype=float32)}",
|
||||||
|
dict.toString());
|
||||||
|
Map map2 = dict.toMap();
|
||||||
|
PythonObject dict2 = new PythonObject(map2);
|
||||||
|
assertEquals(dict.toString(), dict2.toString());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,75 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.python;
|
|
||||||
|
|
||||||
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
@javax.annotation.concurrent.NotThreadSafe
|
|
||||||
public class TestPythonExecutionSandbox {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testInt(){
|
|
||||||
PythonExecutioner.setInterpreter("interp1");
|
|
||||||
PythonExecutioner.exec("a = 1");
|
|
||||||
PythonExecutioner.setInterpreter("interp2");
|
|
||||||
PythonExecutioner.exec("a = 2");
|
|
||||||
PythonExecutioner.setInterpreter("interp3");
|
|
||||||
PythonExecutioner.exec("a = 3");
|
|
||||||
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("interp1");
|
|
||||||
Assert.assertEquals(1, PythonExecutioner.evalInteger("a"));
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("interp2");
|
|
||||||
Assert.assertEquals(2, PythonExecutioner.evalInteger("a"));
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("interp3");
|
|
||||||
Assert.assertEquals(3, PythonExecutioner.evalInteger("a"));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNDArray(){
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
PythonExecutioner.exec("import numpy as np");
|
|
||||||
PythonExecutioner.exec("a = np.zeros(5)");
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
//PythonExecutioner.exec("import numpy as np");
|
|
||||||
PythonExecutioner.exec("a = np.zeros(5)");
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
PythonExecutioner.exec("a += 2");
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
PythonExecutioner.exec("a += 3");
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
//PythonExecutioner.exec("import numpy as np");
|
|
||||||
// PythonExecutioner.exec("a = np.zeros(5)");
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
Assert.assertEquals(25, PythonExecutioner.evalNdArray("a").getNd4jArray().sum().getDouble(), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNumpyRandom(){
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
PythonExecutioner.exec("import numpy as np; print(np.random.randint(5))");
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -15,13 +15,17 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
|
|
||||||
@javax.annotation.concurrent.NotThreadSafe
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
|
@ -29,12 +33,12 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
|
|
||||||
@org.junit.Test
|
@org.junit.Test
|
||||||
public void testPythonSysVersion() {
|
public void testPythonSysVersion() throws PythonException {
|
||||||
PythonExecutioner.exec("import sys; print(sys.version)");
|
Python.exec("import sys; print(sys.version)");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStr() throws Exception{
|
public void testStr() throws Exception {
|
||||||
|
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
@ -46,7 +50,7 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + ' ' + y";
|
String code = "z = x + ' ' + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
|
||||||
String z = pyOutputs.getStrValue("z");
|
String z = pyOutputs.getStrValue("z");
|
||||||
|
|
||||||
|
@ -56,7 +60,7 @@ public class TestPythonExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testInt()throws Exception{
|
public void testInt() throws Exception {
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
|
||||||
|
@ -68,7 +72,7 @@ public class TestPythonExecutioner {
|
||||||
pyOutputs.addInt("z");
|
pyOutputs.addInt("z");
|
||||||
|
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
|
||||||
long z = pyOutputs.getIntValue("z");
|
long z = pyOutputs.getIntValue("z");
|
||||||
|
|
||||||
|
@ -77,7 +81,7 @@ public class TestPythonExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testList() throws Exception{
|
public void testList() throws Exception {
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
|
||||||
|
@ -92,30 +96,28 @@ public class TestPythonExecutioner {
|
||||||
pyOutputs.addList("z");
|
pyOutputs.addList("z");
|
||||||
|
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
|
||||||
Object[] z = pyOutputs.getListValue("z");
|
Object[] z = pyOutputs.getListValue("z").toArray();
|
||||||
|
|
||||||
Assert.assertEquals(z.length, x.length + y.length);
|
Assert.assertEquals(z.length, x.length + y.length);
|
||||||
|
|
||||||
for (int i = 0; i < x.length; i++) {
|
for (int i = 0; i < x.length; i++) {
|
||||||
if(x[i] instanceof Number) {
|
if (x[i] instanceof Number) {
|
||||||
Number xNum = (Number) x[i];
|
Number xNum = (Number) x[i];
|
||||||
Number zNum = (Number) z[i];
|
Number zNum = (Number) z[i];
|
||||||
Assert.assertEquals(xNum.intValue(), zNum.intValue());
|
Assert.assertEquals(xNum.intValue(), zNum.intValue());
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
Assert.assertEquals(x[i], z[i]);
|
Assert.assertEquals(x[i], z[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
for (int i = 0; i < y.length; i++){
|
for (int i = 0; i < y.length; i++) {
|
||||||
if(y[i] instanceof Number) {
|
if (y[i] instanceof Number) {
|
||||||
Number yNum = (Number) y[i];
|
Number yNum = (Number) y[i];
|
||||||
Number zNum = (Number) z[x.length + i];
|
Number zNum = (Number) z[x.length + i];
|
||||||
Assert.assertEquals(yNum.intValue(), zNum.intValue());
|
Assert.assertEquals(yNum.intValue(), zNum.intValue());
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
Assert.assertEquals(y[i], z[x.length + i]);
|
Assert.assertEquals(y[i], z[x.length + i]);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -125,7 +127,7 @@ public class TestPythonExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayFloat()throws Exception{
|
public void testNDArrayFloat() throws Exception {
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
|
||||||
|
@ -135,8 +137,8 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + y";
|
String code = "z = x + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z");
|
||||||
|
|
||||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
|
|
||||||
|
@ -144,12 +146,13 @@ public class TestPythonExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTensorflowCustomAnaconda() {
|
@Ignore
|
||||||
PythonExecutioner.exec("import tensorflow as tf");
|
public void testTensorflowCustomAnaconda() throws PythonException {
|
||||||
|
Python.exec("import tensorflow as tf");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayDouble()throws Exception {
|
public void testNDArrayDouble() throws Exception {
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
|
||||||
|
@ -159,14 +162,14 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + y";
|
String code = "z = x + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z");
|
||||||
|
|
||||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayShort()throws Exception{
|
public void testNDArrayShort() throws Exception {
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
|
||||||
|
@ -176,15 +179,15 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + y";
|
String code = "z = x + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z");
|
||||||
|
|
||||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayInt()throws Exception{
|
public void testNDArrayInt() throws Exception {
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
|
||||||
|
@ -194,15 +197,15 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + y";
|
String code = "z = x + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z");
|
||||||
|
|
||||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayLong()throws Exception{
|
public void testNDArrayLong() throws Exception {
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
|
||||||
|
@ -212,12 +215,91 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + y";
|
String code = "z = x + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z");
|
||||||
|
|
||||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testByteBufferInput() throws Exception{
|
||||||
|
//ByteBuffer buff = ByteBuffer.allocateDirect(3);
|
||||||
|
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||||
|
buff.putScalar(0, 97); // a
|
||||||
|
buff.putScalar(1, 98); // b
|
||||||
|
buff.putScalar(2, 99); // c
|
||||||
|
|
||||||
|
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||||
|
|
||||||
|
PythonVariables pyOutputs= new PythonVariables();
|
||||||
|
pyOutputs.addStr("out");
|
||||||
|
|
||||||
|
String code = "out = buff.decode()";
|
||||||
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
Assert.assertEquals("abc", pyOutputs.getStrValue("out"));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testByteBufferOutputNoCopy() throws Exception{
|
||||||
|
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||||
|
buff.putScalar(0, 97); // a
|
||||||
|
buff.putScalar(1, 98); // b
|
||||||
|
buff.putScalar(2, 99); // c
|
||||||
|
|
||||||
|
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||||
|
|
||||||
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
pyOutputs.addBytes("buff"); // same name as input, because inplace update
|
||||||
|
|
||||||
|
String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97";
|
||||||
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testByteBufferOutputWithCopy() throws Exception{
|
||||||
|
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||||
|
buff.putScalar(0, 97); // a
|
||||||
|
buff.putScalar(1, 98); // b
|
||||||
|
buff.putScalar(2, 99); // c
|
||||||
|
|
||||||
|
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||||
|
|
||||||
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
pyOutputs.addBytes("out");
|
||||||
|
|
||||||
|
String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97\nout=bytes(buff)";
|
||||||
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
Assert.assertEquals("cba", pyOutputs.getBytesValue("out").getString());
|
||||||
|
}
|
||||||
|
@Test
|
||||||
|
public void testBadCode() throws Exception{
|
||||||
|
Python.setContext("badcode");
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
|
||||||
|
pyInputs.addNDArray("x", Nd4j.zeros(DataType.LONG, 2, 3));
|
||||||
|
pyInputs.addNDArray("y", Nd4j.ones(DataType.LONG, 2, 3));
|
||||||
|
pyOutputs.addNDArray("z");
|
||||||
|
|
||||||
|
String code = "z = x + a";
|
||||||
|
|
||||||
|
try{
|
||||||
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
fail("No exception thrown");
|
||||||
|
} catch (PythonException pe ){
|
||||||
|
Assert.assertEquals("NameError: name 'a' is not defined", pe.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
Python.setMainContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,326 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
|
||||||
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
|
public class TestPythonJob {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonJobBasic() throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code = "c = a + b";
|
||||||
|
PythonJob job = new PythonJob("job1", code, false);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
PythonVariables outputs = new PythonVariables();
|
||||||
|
outputs.addInt("c");
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(5L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addFloat("c");
|
||||||
|
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(7.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addNDArray("c");
|
||||||
|
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonJobReturnAllVariables()throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code = "c = a + b";
|
||||||
|
PythonJob job = new PythonJob("job1", code, false);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
|
||||||
|
PythonVariables outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(5L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(7.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMultiplePythonJobsParallel()throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code1 = "c = a + b";
|
||||||
|
PythonJob job1 = new PythonJob("job1", code1, false);
|
||||||
|
|
||||||
|
String code2 = "c = a - b";
|
||||||
|
PythonJob job2 = new PythonJob("job2", code2, false);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
PythonVariables outputs = new PythonVariables();
|
||||||
|
outputs.addInt("c");
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(5L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(-1L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addFloat("c");
|
||||||
|
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(7.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(-1L, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addNDArray("c");
|
||||||
|
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c"));
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).sub(1), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
@Test
|
||||||
|
public void testPythonJobSetupRun()throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code = "five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b + five\n"+
|
||||||
|
" return {'c':c}\n\n";
|
||||||
|
PythonJob job = new PythonJob("job1", code, true);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
PythonVariables outputs = new PythonVariables();
|
||||||
|
outputs.addInt("c");
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(10L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addFloat("c");
|
||||||
|
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(12.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addNDArray("c");
|
||||||
|
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
@Test
|
||||||
|
public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code = "five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b + five\n"+
|
||||||
|
" return {'c':c}\n\n";
|
||||||
|
PythonJob job = new PythonJob("job1", code, true);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
|
||||||
|
PythonVariables outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(10L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(12.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMultiplePythonJobsSetupRunParallel()throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code1 = "five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b + five\n"+
|
||||||
|
" return {'c':c}\n\n";
|
||||||
|
PythonJob job1 = new PythonJob("job1", code1, true);
|
||||||
|
|
||||||
|
String code2 = "five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b - five\n"+
|
||||||
|
" return {'c':c}\n\n";
|
||||||
|
PythonJob job2 = new PythonJob("job2", code2, true);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
PythonVariables outputs = new PythonVariables();
|
||||||
|
outputs.addInt("c");
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(10L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(0L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addFloat("c");
|
||||||
|
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(12.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(2L, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addNDArray("c");
|
||||||
|
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c"));
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(4), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,108 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.var;
|
||||||
|
import org.json.JSONArray;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
|
public class TestPythonList {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonListFromIntArray() {
|
||||||
|
PythonObject pyList = new PythonObject(new Integer[]{1, 2, 3, 4, 5});
|
||||||
|
pyList.attr("append").call(6);
|
||||||
|
pyList.attr("append").call(7);
|
||||||
|
pyList.attr("append").call(8);
|
||||||
|
assertEquals(8, Python.len(pyList).toInt());
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
assertEquals(i + 1, pyList.get(i).toInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonListFromLongArray() {
|
||||||
|
PythonObject pyList = new PythonObject(new Long[]{1L, 2L, 3L, 4L, 5L});
|
||||||
|
pyList.attr("append").call(6);
|
||||||
|
pyList.attr("append").call(7);
|
||||||
|
pyList.attr("append").call(8);
|
||||||
|
assertEquals(8, Python.len(pyList).toInt());
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
assertEquals(i + 1, pyList.get(i).toInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonListFromDoubleArray() {
|
||||||
|
PythonObject pyList = new PythonObject(new Double[]{1., 2., 3., 4., 5.});
|
||||||
|
pyList.attr("append").call(6);
|
||||||
|
pyList.attr("append").call(7);
|
||||||
|
pyList.attr("append").call(8);
|
||||||
|
assertEquals(8, Python.len(pyList).toInt());
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
assertEquals(i + 1, pyList.get(i).toInt());
|
||||||
|
assertEquals((double) i + 1, pyList.get(i).toDouble(), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonListFromStringArray() {
|
||||||
|
PythonObject pyList = new PythonObject(new String[]{"abcd", "efg"});
|
||||||
|
pyList.attr("append").call("hijk");
|
||||||
|
pyList.attr("append").call("lmnop");
|
||||||
|
assertEquals("abcdefghijklmnop", new PythonObject("").attr("join").call(pyList).toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonListFromMixedArray()throws Exception {
|
||||||
|
Map<Object, Object> map = new HashMap<>();
|
||||||
|
map.put(1, "a");
|
||||||
|
map.put("a", Arrays.asList("a", "b", "c"));
|
||||||
|
map.put("arr", Nd4j.linspace(1, 4, 4));
|
||||||
|
Object[] objs = new Object[]{
|
||||||
|
1, 2, "a", 3f, 4L, 5.0, Arrays.asList(10,
|
||||||
|
20, "b", 30f, 40L, 50.0, map
|
||||||
|
|
||||||
|
), map
|
||||||
|
};
|
||||||
|
PythonObject pyList = new PythonObject(objs);
|
||||||
|
System.out.println(pyList.toString());
|
||||||
|
String expectedStr = "[1, 2, 'a', 3.0, 4, 5.0, [10" +
|
||||||
|
", 20, 'b', 30.0, 40, 50.0, {'arr': array([1.," +
|
||||||
|
" 2., 3., 4.], dtype=float32), 1: 'a', 'a': [" +
|
||||||
|
"'a', 'b', 'c']}], {'arr': array([1., 2., 3.," +
|
||||||
|
" 4.], dtype=float32), 1: 'a', 'a': ['a', 'b', 'c']}]";
|
||||||
|
assertEquals(expectedStr, pyList.toString());
|
||||||
|
List objs2 = pyList.toList();
|
||||||
|
PythonObject pyList2 = new PythonObject(objs2);
|
||||||
|
assertEquals(pyList.toString(), pyList2.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,27 +0,0 @@
|
||||||
package org.datavec.python;
|
|
||||||
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
@javax.annotation.concurrent.NotThreadSafe
|
|
||||||
public class TestPythonSetupAndRun {
|
|
||||||
@Test
|
|
||||||
public void testPythonWithSetupAndRun() throws Exception{
|
|
||||||
String code = "def setup():" +
|
|
||||||
"global counter;counter=0\n" +
|
|
||||||
"def run(step):" +
|
|
||||||
"global counter;" +
|
|
||||||
"counter+=step;" +
|
|
||||||
"return {\"counter\":counter}";
|
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
|
||||||
pyInputs.addInt("step", 2);
|
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
|
||||||
pyOutputs.addInt("counter");
|
|
||||||
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
|
|
||||||
assertEquals((long)pyOutputs.getIntValue("counter"), 2L);
|
|
||||||
pyInputs.addInt("step", 3);
|
|
||||||
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
|
|
||||||
assertEquals((long)pyOutputs.getIntValue("counter"), 5L);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -22,11 +22,14 @@
|
||||||
|
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertNotNull;
|
import static junit.framework.TestCase.assertNotNull;
|
||||||
import static junit.framework.TestCase.assertNull;
|
import static junit.framework.TestCase.assertNull;
|
||||||
|
@ -36,59 +39,50 @@ import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class TestPythonVariables {
|
public class TestPythonVariables {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testImportNumpy(){
|
public void testDataAssociations() throws PythonException{
|
||||||
Nd4j.scalar(1.0);
|
|
||||||
System.out.println(System.getProperty("org.bytedeco.openblas.load"));
|
|
||||||
PythonExecutioner.exec("import numpy as np");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testDataAssociations() {
|
|
||||||
PythonVariables pythonVariables = new PythonVariables();
|
PythonVariables pythonVariables = new PythonVariables();
|
||||||
PythonVariables.Type[] types = {
|
PythonType[] types = {
|
||||||
PythonVariables.Type.INT,
|
PythonType.INT,
|
||||||
PythonVariables.Type.FLOAT,
|
PythonType.FLOAT,
|
||||||
PythonVariables.Type.STR,
|
PythonType.STR,
|
||||||
PythonVariables.Type.BOOL,
|
PythonType.BOOL,
|
||||||
PythonVariables.Type.DICT,
|
PythonType.DICT,
|
||||||
PythonVariables.Type.LIST,
|
PythonType.LIST,
|
||||||
PythonVariables.Type.LIST,
|
PythonType.LIST,
|
||||||
PythonVariables.Type.FILE,
|
PythonType.NDARRAY,
|
||||||
PythonVariables.Type.NDARRAY
|
PythonType.BYTES
|
||||||
};
|
};
|
||||||
|
|
||||||
NumpyArray npArr = new NumpyArray(Nd4j.scalar(1.0));
|
INDArray arr = Nd4j.scalar(1.0);
|
||||||
|
BytePointer bp = new BytePointer(arr.data().pointer());
|
||||||
Object[] values = {
|
Object[] values = {
|
||||||
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
||||||
new Object[]{1}, Arrays.asList(1),"type", npArr
|
new Object[]{1}, Arrays.asList(1), arr, bp
|
||||||
};
|
};
|
||||||
|
|
||||||
Object[] expectedValues = {
|
Object[] expectedValues = {
|
||||||
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
||||||
new Object[]{1}, new Object[]{1},"type", npArr
|
Arrays.asList(1), Arrays.asList(1), arr, bp
|
||||||
};
|
};
|
||||||
|
|
||||||
for(int i = 0; i < types.length; i++) {
|
for(int i = 0; i < types.length; i++) {
|
||||||
testInsertGet(pythonVariables,types[i].name() + i,values[i],types[i],expectedValues[i]);
|
testInsertGet(pythonVariables,types[i].getName().name() + i,values[i],types[i],expectedValues[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(types.length,pythonVariables.getVariables().length);
|
assertEquals(types.length,pythonVariables.getVariables().length);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonVariables.Type type,Object expectedValue) {
|
private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonType type,Object expectedValue) throws PythonException{
|
||||||
pythonVariables.add(key, type);
|
pythonVariables.add(key, type);
|
||||||
assertNull(pythonVariables.getValue(key));
|
assertNull(pythonVariables.getValue(key));
|
||||||
pythonVariables.setValue(key,value);
|
pythonVariables.setValue(key,value);
|
||||||
assertNotNull(pythonVariables.getValue(key));
|
assertNotNull(pythonVariables.getValue(key));
|
||||||
Object actualValue = pythonVariables.getValue(key);
|
Object actualValue = pythonVariables.getValue(key);
|
||||||
if (expectedValue instanceof Object[]){
|
if (expectedValue instanceof Object[]){
|
||||||
assertTrue(actualValue instanceof Object[]);
|
assertTrue(actualValue instanceof List);
|
||||||
Object[] actualArr = (Object[])actualValue;
|
Object[] actualArr = ((List)actualValue).toArray();
|
||||||
Object[] expectedArr = (Object[])expectedValue;
|
Object[] expectedArr = (Object[])expectedValue;
|
||||||
assertArrayEquals(expectedArr, actualArr);
|
assertArrayEquals(expectedArr, actualArr);
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,7 +44,7 @@ public class TestSerde {
|
||||||
String yaml = y.serialize(t);
|
String yaml = y.serialize(t);
|
||||||
String json = j.serialize(t);
|
String json = j.serialize(t);
|
||||||
|
|
||||||
Transform t2 = y.deserializeTransform(json);
|
Transform t2 = y.deserializeTransform(yaml);
|
||||||
Transform t3 = j.deserializeTransform(json);
|
Transform t3 = j.deserializeTransform(json);
|
||||||
assertEquals(t, t2);
|
assertEquals(t, t2);
|
||||||
assertEquals(t, t3);
|
assertEquals(t, t3);
|
||||||
|
|
Loading…
Reference in New Issue