parent
237c137166
commit
c9ffb6cbec
|
@ -144,6 +144,14 @@ public class Python {
|
||||||
return attr("bytearray");
|
return attr("bytearray");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static PythonObject memoryview(PythonObject pythonObject) {
|
||||||
|
return attr("memoryview").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject memoryviewType() {
|
||||||
|
return attr("memoryview");
|
||||||
|
}
|
||||||
|
|
||||||
public static PythonObject bytes(PythonObject pythonObject) {
|
public static PythonObject bytes(PythonObject pythonObject) {
|
||||||
return attr("bytes").call(pythonObject);
|
return attr("bytes").call(pythonObject);
|
||||||
}
|
}
|
||||||
|
@ -250,9 +258,6 @@ public class Python {
|
||||||
public static void exec(String code)throws PythonException{
|
public static void exec(String code)throws PythonException{
|
||||||
PythonExecutioner.exec(code);
|
PythonExecutioner.exec(code);
|
||||||
}
|
}
|
||||||
public static void exec(String code, PythonVariables inputs) throws PythonException{
|
|
||||||
PythonExecutioner.exec(code, inputs);
|
|
||||||
}
|
|
||||||
public static void exec(String code, PythonVariables inputs, PythonVariables outputs) throws PythonException{
|
public static void exec(String code, PythonVariables inputs, PythonVariables outputs) throws PythonException{
|
||||||
PythonExecutioner.exec(code, inputs, outputs);
|
PythonExecutioner.exec(code, inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,23 +20,16 @@ package org.datavec.python;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.bytedeco.cpython.PyThreadState;
|
|
||||||
import org.bytedeco.javacpp.BytePointer;
|
|
||||||
import org.bytedeco.numpy.global.numpy;
|
import org.bytedeco.numpy.global.numpy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.nio.charset.Charset;
|
import java.nio.charset.Charset;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
import static org.bytedeco.cpython.global.python.*;
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
import static org.bytedeco.cpython.global.python.PyThreadState_Get;
|
|
||||||
import static org.datavec.python.Python.*;
|
import static org.datavec.python.Python.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -105,6 +98,7 @@ public class PythonExecutioner {
|
||||||
init();
|
init();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private static synchronized void init() {
|
private static synchronized void init() {
|
||||||
if (init.get()) {
|
if (init.get()) {
|
||||||
return;
|
return;
|
||||||
|
@ -204,6 +198,9 @@ public class PythonExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void getVariables(PythonVariables pyVars) throws PythonException {
|
public static void getVariables(PythonVariables pyVars) throws PythonException {
|
||||||
|
if (pyVars == null){
|
||||||
|
return;
|
||||||
|
}
|
||||||
for (String varName : pyVars.getVariables()) {
|
for (String varName : pyVars.getVariables()) {
|
||||||
pyVars.setValue(varName, getVariable(varName, pyVars.getType(varName)));
|
pyVars.setValue(varName, getVariable(varName, pyVars.getType(varName)));
|
||||||
}
|
}
|
||||||
|
@ -240,12 +237,6 @@ public class PythonExecutioner {
|
||||||
throwIfExecutionFailed();
|
throwIfExecutionFailed();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void exec(String code, PythonVariables outputVariables)throws PythonException {
|
|
||||||
simpleExec(getWrappedCode(code));
|
|
||||||
throwIfExecutionFailed();
|
|
||||||
getVariables(outputVariables);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void exec(String code, PythonVariables inputVariables, PythonVariables outputVariables) throws PythonException {
|
public static void exec(String code, PythonVariables inputVariables, PythonVariables outputVariables) throws PythonException {
|
||||||
setVariables(inputVariables);
|
setVariables(inputVariables);
|
||||||
simpleExec(getWrappedCode(code));
|
simpleExec(getWrappedCode(code));
|
||||||
|
@ -354,7 +345,6 @@ public class PythonExecutioner {
|
||||||
log.info("Setting python path " + path);
|
log.info("Setting python path " + path);
|
||||||
StringBuffer sb = new StringBuffer();
|
StringBuffer sb = new StringBuffer();
|
||||||
File[] packages = numpy.cachePackages();
|
File[] packages = numpy.cachePackages();
|
||||||
|
|
||||||
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
|
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
|
||||||
switch (pathAppendValue) {
|
switch (pathAppendValue) {
|
||||||
case BEFORE:
|
case BEFORE:
|
||||||
|
@ -395,4 +385,5 @@ public class PythonExecutioner {
|
||||||
throw new IllegalStateException("Unable to reset python path. Already initialized.");
|
throw new IllegalStateException("Unable to reset python path. Already initialized.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,7 +69,11 @@ public class PythonObject {
|
||||||
}
|
}
|
||||||
|
|
||||||
public PythonObject(BytePointer bp){
|
public PythonObject(BytePointer bp){
|
||||||
nativePythonObject = PyByteArray_FromStringAndSize(bp, bp.capacity());
|
|
||||||
|
long address = bp.address();
|
||||||
|
long size = bp.capacity();
|
||||||
|
NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.BYTE).build();
|
||||||
|
nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject;
|
||||||
}
|
}
|
||||||
|
|
||||||
public PythonObject(NumpyArray npArray) {
|
public PythonObject(NumpyArray npArray) {
|
||||||
|
@ -343,13 +347,28 @@ public class PythonObject {
|
||||||
dtype = DataType.DOUBLE;
|
dtype = DataType.DOUBLE;
|
||||||
} else if (dtypeName.equals("float32")) {
|
} else if (dtypeName.equals("float32")) {
|
||||||
dtype = DataType.FLOAT;
|
dtype = DataType.FLOAT;
|
||||||
} else if (dtypeName.equals("int16")) {
|
} else if (dtypeName.equals("int8")){
|
||||||
|
dtype = DataType.INT8;
|
||||||
|
}else if (dtypeName.equals("int16")) {
|
||||||
dtype = DataType.SHORT;
|
dtype = DataType.SHORT;
|
||||||
} else if (dtypeName.equals("int32")) {
|
} else if (dtypeName.equals("int32")) {
|
||||||
dtype = DataType.INT;
|
dtype = DataType.INT;
|
||||||
} else if (dtypeName.equals("int64")) {
|
} else if (dtypeName.equals("int64")) {
|
||||||
dtype = DataType.LONG;
|
dtype = DataType.LONG;
|
||||||
} else {
|
}
|
||||||
|
else if (dtypeName.equals("uint8")){
|
||||||
|
dtype = DataType.UINT8;
|
||||||
|
}
|
||||||
|
else if (dtypeName.equals("uint16")){
|
||||||
|
dtype = DataType.UINT16;
|
||||||
|
}
|
||||||
|
else if (dtypeName.equals("uint32")){
|
||||||
|
dtype = DataType.UINT32;
|
||||||
|
}
|
||||||
|
else if (dtypeName.equals("uint64")){
|
||||||
|
dtype = DataType.UINT64;
|
||||||
|
}
|
||||||
|
else {
|
||||||
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||||
}
|
}
|
||||||
return new NumpyArray(address, jshape, jstrides, dtype);
|
return new NumpyArray(address, jshape, jstrides, dtype);
|
||||||
|
@ -518,6 +537,22 @@ public class PythonObject {
|
||||||
else if (Python.isinstance(this, Python.bytearrayType())){
|
else if (Python.isinstance(this, Python.bytearrayType())){
|
||||||
return PyByteArray_AsString(nativePythonObject);
|
return PyByteArray_AsString(nativePythonObject);
|
||||||
}
|
}
|
||||||
|
else if (Python.isinstance(this, Python.memoryviewType())){
|
||||||
|
|
||||||
|
// PyObject np = PyImport_ImportModule("numpy");
|
||||||
|
// PyObject array = PyObject_GetAttrString(np, "asarray");
|
||||||
|
// PyObject npArr = PyObject_CallObject(array, nativePythonObject); // Doesn't work
|
||||||
|
// Invoke interpreter:
|
||||||
|
String tempContext = "temp" + UUID.randomUUID().toString().replace('-', '_');
|
||||||
|
String originalContext = Python.getCurrentContext();
|
||||||
|
Python.setContext(tempContext);
|
||||||
|
PythonExecutioner.setVariable("memview", this);
|
||||||
|
PythonExecutioner.exec("import numpy as np\narr = np.array(memview)");
|
||||||
|
BytePointer ret = new BytePointer(PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray().data().pointer());
|
||||||
|
Python.setContext(originalContext);
|
||||||
|
Python.deleteContext(tempContext);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
else{
|
else{
|
||||||
PyObject ctypes = PyImport_ImportModule("ctypes");
|
PyObject ctypes = PyImport_ImportModule("ctypes");
|
||||||
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
|
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
|
||||||
|
@ -542,7 +577,7 @@ public class PythonObject {
|
||||||
return new BytePointer(ptr);
|
return new BytePointer(ptr);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
throw new PythonException("Expected bytes, bytearray or ctypesArray. Received " + Python.type(this).toString());
|
throw new PythonException("Expected bytes, bytearray, memoryview or ctypesArray. Received " + Python.type(this).toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
|
@ -237,12 +238,13 @@ public class TestPythonExecutioner {
|
||||||
PythonVariables pyOutputs= new PythonVariables();
|
PythonVariables pyOutputs= new PythonVariables();
|
||||||
pyOutputs.addStr("out");
|
pyOutputs.addStr("out");
|
||||||
|
|
||||||
String code = "out = buff.decode()";
|
String code = "out = bytes(buff).decode()";
|
||||||
Python.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
Assert.assertEquals("abc", pyOutputs.getStrValue("out"));
|
Assert.assertEquals("abc", pyOutputs.getStrValue("out"));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testByteBufferOutputNoCopy() throws Exception{
|
public void testByteBufferOutputNoCopy() throws Exception{
|
||||||
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||||
|
@ -262,6 +264,24 @@ public class TestPythonExecutioner {
|
||||||
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
|
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testByteBufferInplace() throws Exception{
|
||||||
|
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||||
|
buff.putScalar(0, 97); // a
|
||||||
|
buff.putScalar(1, 98); // b
|
||||||
|
buff.putScalar(2, 99); // c
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||||
|
String code = "buff[0]+=2\nbuff[2]-=2";
|
||||||
|
Python.exec(code, pyInputs, null);
|
||||||
|
Assert.assertEquals("cba", pyInputs.getBytesValue("buff").getString());
|
||||||
|
INDArray expected = buff.dup();
|
||||||
|
expected.putScalar(0, 99);
|
||||||
|
expected.putScalar(2, 97);
|
||||||
|
Assert.assertEquals(buff, expected);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testByteBufferOutputWithCopy() throws Exception{
|
public void testByteBufferOutputWithCopy() throws Exception{
|
||||||
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||||
|
@ -302,4 +322,5 @@ public class TestPythonExecutioner {
|
||||||
Python.setMainContext();
|
Python.setMainContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue