Python: Use memoryview instead of bytearray (#225)

* memoryview

* cleanup
master
Fariz Rahman 2020-02-10 14:54:44 +04:00 committed by GitHub
parent 237c137166
commit c9ffb6cbec
4 changed files with 74 additions and 22 deletions

View File

@ -144,6 +144,14 @@ public class Python {
return attr("bytearray"); return attr("bytearray");
} }
public static PythonObject memoryview(PythonObject pythonObject) {
return attr("memoryview").call(pythonObject);
}
public static PythonObject memoryviewType() {
return attr("memoryview");
}
public static PythonObject bytes(PythonObject pythonObject) { public static PythonObject bytes(PythonObject pythonObject) {
return attr("bytes").call(pythonObject); return attr("bytes").call(pythonObject);
} }
@ -250,9 +258,6 @@ public class Python {
public static void exec(String code)throws PythonException{ public static void exec(String code)throws PythonException{
PythonExecutioner.exec(code); PythonExecutioner.exec(code);
} }
public static void exec(String code, PythonVariables inputs) throws PythonException{
PythonExecutioner.exec(code, inputs);
}
public static void exec(String code, PythonVariables inputs, PythonVariables outputs) throws PythonException{ public static void exec(String code, PythonVariables inputs, PythonVariables outputs) throws PythonException{
PythonExecutioner.exec(code, inputs, outputs); PythonExecutioner.exec(code, inputs, outputs);
} }

View File

@ -20,23 +20,16 @@ package org.datavec.python;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.bytedeco.cpython.PyThreadState;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.numpy.global.numpy; import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import static org.bytedeco.cpython.global.python.*; import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.cpython.global.python.PyThreadState_Get;
import static org.datavec.python.Python.*; import static org.datavec.python.Python.*;
/** /**
@ -105,6 +98,7 @@ public class PythonExecutioner {
init(); init();
} }
private static synchronized void init() { private static synchronized void init() {
if (init.get()) { if (init.get()) {
return; return;
@ -204,6 +198,9 @@ public class PythonExecutioner {
} }
public static void getVariables(PythonVariables pyVars) throws PythonException { public static void getVariables(PythonVariables pyVars) throws PythonException {
if (pyVars == null){
return;
}
for (String varName : pyVars.getVariables()) { for (String varName : pyVars.getVariables()) {
pyVars.setValue(varName, getVariable(varName, pyVars.getType(varName))); pyVars.setValue(varName, getVariable(varName, pyVars.getType(varName)));
} }
@ -240,12 +237,6 @@ public class PythonExecutioner {
throwIfExecutionFailed(); throwIfExecutionFailed();
} }
public static void exec(String code, PythonVariables outputVariables)throws PythonException {
simpleExec(getWrappedCode(code));
throwIfExecutionFailed();
getVariables(outputVariables);
}
public static void exec(String code, PythonVariables inputVariables, PythonVariables outputVariables) throws PythonException { public static void exec(String code, PythonVariables inputVariables, PythonVariables outputVariables) throws PythonException {
setVariables(inputVariables); setVariables(inputVariables);
simpleExec(getWrappedCode(code)); simpleExec(getWrappedCode(code));
@ -354,7 +345,6 @@ public class PythonExecutioner {
log.info("Setting python path " + path); log.info("Setting python path " + path);
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
File[] packages = numpy.cachePackages(); File[] packages = numpy.cachePackages();
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase()); JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
switch (pathAppendValue) { switch (pathAppendValue) {
case BEFORE: case BEFORE:
@ -395,4 +385,5 @@ public class PythonExecutioner {
throw new IllegalStateException("Unable to reset python path. Already initialized."); throw new IllegalStateException("Unable to reset python path. Already initialized.");
} }
} }
} }

View File

@ -69,7 +69,11 @@ public class PythonObject {
} }
public PythonObject(BytePointer bp){ public PythonObject(BytePointer bp){
nativePythonObject = PyByteArray_FromStringAndSize(bp, bp.capacity());
long address = bp.address();
long size = bp.capacity();
NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.BYTE).build();
nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject;
} }
public PythonObject(NumpyArray npArray) { public PythonObject(NumpyArray npArray) {
@ -343,13 +347,28 @@ public class PythonObject {
dtype = DataType.DOUBLE; dtype = DataType.DOUBLE;
} else if (dtypeName.equals("float32")) { } else if (dtypeName.equals("float32")) {
dtype = DataType.FLOAT; dtype = DataType.FLOAT;
} else if (dtypeName.equals("int16")) { } else if (dtypeName.equals("int8")){
dtype = DataType.INT8;
}else if (dtypeName.equals("int16")) {
dtype = DataType.SHORT; dtype = DataType.SHORT;
} else if (dtypeName.equals("int32")) { } else if (dtypeName.equals("int32")) {
dtype = DataType.INT; dtype = DataType.INT;
} else if (dtypeName.equals("int64")) { } else if (dtypeName.equals("int64")) {
dtype = DataType.LONG; dtype = DataType.LONG;
} else { }
else if (dtypeName.equals("uint8")){
dtype = DataType.UINT8;
}
else if (dtypeName.equals("uint16")){
dtype = DataType.UINT16;
}
else if (dtypeName.equals("uint32")){
dtype = DataType.UINT32;
}
else if (dtypeName.equals("uint64")){
dtype = DataType.UINT64;
}
else {
throw new RuntimeException("Unsupported array type " + dtypeName + "."); throw new RuntimeException("Unsupported array type " + dtypeName + ".");
} }
return new NumpyArray(address, jshape, jstrides, dtype); return new NumpyArray(address, jshape, jstrides, dtype);
@ -518,6 +537,22 @@ public class PythonObject {
else if (Python.isinstance(this, Python.bytearrayType())){ else if (Python.isinstance(this, Python.bytearrayType())){
return PyByteArray_AsString(nativePythonObject); return PyByteArray_AsString(nativePythonObject);
} }
else if (Python.isinstance(this, Python.memoryviewType())){
// PyObject np = PyImport_ImportModule("numpy");
// PyObject array = PyObject_GetAttrString(np, "asarray");
// PyObject npArr = PyObject_CallObject(array, nativePythonObject); // Doesn't work
// Invoke interpreter:
String tempContext = "temp" + UUID.randomUUID().toString().replace('-', '_');
String originalContext = Python.getCurrentContext();
Python.setContext(tempContext);
PythonExecutioner.setVariable("memview", this);
PythonExecutioner.exec("import numpy as np\narr = np.array(memview)");
BytePointer ret = new BytePointer(PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray().data().pointer());
Python.setContext(originalContext);
Python.deleteContext(tempContext);
return ret;
}
else{ else{
PyObject ctypes = PyImport_ImportModule("ctypes"); PyObject ctypes = PyImport_ImportModule("ctypes");
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array"); PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
@ -542,7 +577,7 @@ public class PythonObject {
return new BytePointer(ptr); return new BytePointer(ptr);
} }
else{ else{
throw new PythonException("Expected bytes, bytearray or ctypesArray. Received " + Python.type(this).toString()); throw new PythonException("Expected bytes, bytearray, memoryview or ctypesArray. Received " + Python.type(this).toString());
} }
} }

View File

@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
@ -237,12 +238,13 @@ public class TestPythonExecutioner {
PythonVariables pyOutputs= new PythonVariables(); PythonVariables pyOutputs= new PythonVariables();
pyOutputs.addStr("out"); pyOutputs.addStr("out");
String code = "out = buff.decode()"; String code = "out = bytes(buff).decode()";
Python.exec(code, pyInputs, pyOutputs); Python.exec(code, pyInputs, pyOutputs);
Assert.assertEquals("abc", pyOutputs.getStrValue("out")); Assert.assertEquals("abc", pyOutputs.getStrValue("out"));
} }
@Test @Test
public void testByteBufferOutputNoCopy() throws Exception{ public void testByteBufferOutputNoCopy() throws Exception{
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE); INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
@ -262,6 +264,24 @@ public class TestPythonExecutioner {
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString()); Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
} }
@Test
public void testByteBufferInplace() throws Exception{
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
buff.putScalar(0, 97); // a
buff.putScalar(1, 98); // b
buff.putScalar(2, 99); // c
PythonVariables pyInputs = new PythonVariables();
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
String code = "buff[0]+=2\nbuff[2]-=2";
Python.exec(code, pyInputs, null);
Assert.assertEquals("cba", pyInputs.getBytesValue("buff").getString());
INDArray expected = buff.dup();
expected.putScalar(0, 99);
expected.putScalar(2, 97);
Assert.assertEquals(buff, expected);
}
@Test @Test
public void testByteBufferOutputWithCopy() throws Exception{ public void testByteBufferOutputWithCopy() throws Exception{
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE); INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
@ -302,4 +322,5 @@ public class TestPythonExecutioner {
Python.setMainContext(); Python.setMainContext();
} }
} }