parent
237c137166
commit
c9ffb6cbec
|
@ -144,6 +144,14 @@ public class Python {
|
|||
return attr("bytearray");
|
||||
}
|
||||
|
||||
public static PythonObject memoryview(PythonObject pythonObject) {
|
||||
return attr("memoryview").call(pythonObject);
|
||||
}
|
||||
|
||||
public static PythonObject memoryviewType() {
|
||||
return attr("memoryview");
|
||||
}
|
||||
|
||||
public static PythonObject bytes(PythonObject pythonObject) {
|
||||
return attr("bytes").call(pythonObject);
|
||||
}
|
||||
|
@ -250,9 +258,6 @@ public class Python {
|
|||
public static void exec(String code)throws PythonException{
|
||||
PythonExecutioner.exec(code);
|
||||
}
|
||||
public static void exec(String code, PythonVariables inputs) throws PythonException{
|
||||
PythonExecutioner.exec(code, inputs);
|
||||
}
|
||||
public static void exec(String code, PythonVariables inputs, PythonVariables outputs) throws PythonException{
|
||||
PythonExecutioner.exec(code, inputs, outputs);
|
||||
}
|
||||
|
|
|
@ -20,23 +20,16 @@ package org.datavec.python;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.bytedeco.cpython.PyThreadState;
|
||||
import org.bytedeco.javacpp.BytePointer;
|
||||
import org.bytedeco.numpy.global.numpy;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import static org.bytedeco.cpython.global.python.*;
|
||||
import static org.bytedeco.cpython.global.python.PyThreadState_Get;
|
||||
import static org.datavec.python.Python.*;
|
||||
|
||||
/**
|
||||
|
@ -105,6 +98,7 @@ public class PythonExecutioner {
|
|||
init();
|
||||
}
|
||||
|
||||
|
||||
private static synchronized void init() {
|
||||
if (init.get()) {
|
||||
return;
|
||||
|
@ -204,6 +198,9 @@ public class PythonExecutioner {
|
|||
}
|
||||
|
||||
public static void getVariables(PythonVariables pyVars) throws PythonException {
|
||||
if (pyVars == null){
|
||||
return;
|
||||
}
|
||||
for (String varName : pyVars.getVariables()) {
|
||||
pyVars.setValue(varName, getVariable(varName, pyVars.getType(varName)));
|
||||
}
|
||||
|
@ -240,12 +237,6 @@ public class PythonExecutioner {
|
|||
throwIfExecutionFailed();
|
||||
}
|
||||
|
||||
public static void exec(String code, PythonVariables outputVariables)throws PythonException {
|
||||
simpleExec(getWrappedCode(code));
|
||||
throwIfExecutionFailed();
|
||||
getVariables(outputVariables);
|
||||
}
|
||||
|
||||
public static void exec(String code, PythonVariables inputVariables, PythonVariables outputVariables) throws PythonException {
|
||||
setVariables(inputVariables);
|
||||
simpleExec(getWrappedCode(code));
|
||||
|
@ -354,7 +345,6 @@ public class PythonExecutioner {
|
|||
log.info("Setting python path " + path);
|
||||
StringBuffer sb = new StringBuffer();
|
||||
File[] packages = numpy.cachePackages();
|
||||
|
||||
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
|
||||
switch (pathAppendValue) {
|
||||
case BEFORE:
|
||||
|
@ -395,4 +385,5 @@ public class PythonExecutioner {
|
|||
throw new IllegalStateException("Unable to reset python path. Already initialized.");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -69,7 +69,11 @@ public class PythonObject {
|
|||
}
|
||||
|
||||
public PythonObject(BytePointer bp){
|
||||
nativePythonObject = PyByteArray_FromStringAndSize(bp, bp.capacity());
|
||||
|
||||
long address = bp.address();
|
||||
long size = bp.capacity();
|
||||
NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.BYTE).build();
|
||||
nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject;
|
||||
}
|
||||
|
||||
public PythonObject(NumpyArray npArray) {
|
||||
|
@ -343,13 +347,28 @@ public class PythonObject {
|
|||
dtype = DataType.DOUBLE;
|
||||
} else if (dtypeName.equals("float32")) {
|
||||
dtype = DataType.FLOAT;
|
||||
} else if (dtypeName.equals("int16")) {
|
||||
} else if (dtypeName.equals("int8")){
|
||||
dtype = DataType.INT8;
|
||||
}else if (dtypeName.equals("int16")) {
|
||||
dtype = DataType.SHORT;
|
||||
} else if (dtypeName.equals("int32")) {
|
||||
dtype = DataType.INT;
|
||||
} else if (dtypeName.equals("int64")) {
|
||||
dtype = DataType.LONG;
|
||||
} else {
|
||||
}
|
||||
else if (dtypeName.equals("uint8")){
|
||||
dtype = DataType.UINT8;
|
||||
}
|
||||
else if (dtypeName.equals("uint16")){
|
||||
dtype = DataType.UINT16;
|
||||
}
|
||||
else if (dtypeName.equals("uint32")){
|
||||
dtype = DataType.UINT32;
|
||||
}
|
||||
else if (dtypeName.equals("uint64")){
|
||||
dtype = DataType.UINT64;
|
||||
}
|
||||
else {
|
||||
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||
}
|
||||
return new NumpyArray(address, jshape, jstrides, dtype);
|
||||
|
@ -518,6 +537,22 @@ public class PythonObject {
|
|||
else if (Python.isinstance(this, Python.bytearrayType())){
|
||||
return PyByteArray_AsString(nativePythonObject);
|
||||
}
|
||||
else if (Python.isinstance(this, Python.memoryviewType())){
|
||||
|
||||
// PyObject np = PyImport_ImportModule("numpy");
|
||||
// PyObject array = PyObject_GetAttrString(np, "asarray");
|
||||
// PyObject npArr = PyObject_CallObject(array, nativePythonObject); // Doesn't work
|
||||
// Invoke interpreter:
|
||||
String tempContext = "temp" + UUID.randomUUID().toString().replace('-', '_');
|
||||
String originalContext = Python.getCurrentContext();
|
||||
Python.setContext(tempContext);
|
||||
PythonExecutioner.setVariable("memview", this);
|
||||
PythonExecutioner.exec("import numpy as np\narr = np.array(memview)");
|
||||
BytePointer ret = new BytePointer(PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray().data().pointer());
|
||||
Python.setContext(originalContext);
|
||||
Python.deleteContext(tempContext);
|
||||
return ret;
|
||||
}
|
||||
else{
|
||||
PyObject ctypes = PyImport_ImportModule("ctypes");
|
||||
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
|
||||
|
@ -542,7 +577,7 @@ public class PythonObject {
|
|||
return new BytePointer(ptr);
|
||||
}
|
||||
else{
|
||||
throw new PythonException("Expected bytes, bytearray or ctypesArray. Received " + Python.type(this).toString());
|
||||
throw new PythonException("Expected bytes, bytearray, memoryview or ctypesArray. Received " + Python.type(this).toString());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
|
@ -237,12 +238,13 @@ public class TestPythonExecutioner {
|
|||
PythonVariables pyOutputs= new PythonVariables();
|
||||
pyOutputs.addStr("out");
|
||||
|
||||
String code = "out = buff.decode()";
|
||||
String code = "out = bytes(buff).decode()";
|
||||
Python.exec(code, pyInputs, pyOutputs);
|
||||
Assert.assertEquals("abc", pyOutputs.getStrValue("out"));
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testByteBufferOutputNoCopy() throws Exception{
|
||||
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||
|
@ -262,6 +264,24 @@ public class TestPythonExecutioner {
|
|||
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testByteBufferInplace() throws Exception{
|
||||
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||
buff.putScalar(0, 97); // a
|
||||
buff.putScalar(1, 98); // b
|
||||
buff.putScalar(2, 99); // c
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||
String code = "buff[0]+=2\nbuff[2]-=2";
|
||||
Python.exec(code, pyInputs, null);
|
||||
Assert.assertEquals("cba", pyInputs.getBytesValue("buff").getString());
|
||||
INDArray expected = buff.dup();
|
||||
expected.putScalar(0, 99);
|
||||
expected.putScalar(2, 97);
|
||||
Assert.assertEquals(buff, expected);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testByteBufferOutputWithCopy() throws Exception{
|
||||
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||
|
@ -302,4 +322,5 @@ public class TestPythonExecutioner {
|
|||
Python.setMainContext();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue