Python: Use memoryview instead of bytearray (#225)

* memoryview

* cleanup
master
Fariz Rahman 2020-02-10 14:54:44 +04:00 committed by GitHub
parent 237c137166
commit c9ffb6cbec
4 changed files with 74 additions and 22 deletions

View File

@ -144,6 +144,14 @@ public class Python {
return attr("bytearray");
}
public static PythonObject memoryview(PythonObject pythonObject) {
return attr("memoryview").call(pythonObject);
}
public static PythonObject memoryviewType() {
return attr("memoryview");
}
public static PythonObject bytes(PythonObject pythonObject) {
return attr("bytes").call(pythonObject);
}
@ -250,9 +258,6 @@ public class Python {
public static void exec(String code)throws PythonException{
PythonExecutioner.exec(code);
}
public static void exec(String code, PythonVariables inputs) throws PythonException{
PythonExecutioner.exec(code, inputs);
}
public static void exec(String code, PythonVariables inputs, PythonVariables outputs) throws PythonException{
PythonExecutioner.exec(code, inputs, outputs);
}

View File

@ -20,23 +20,16 @@ package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.cpython.PyThreadState;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.cpython.global.python.PyThreadState_Get;
import static org.datavec.python.Python.*;
/**
@ -105,6 +98,7 @@ public class PythonExecutioner {
init();
}
private static synchronized void init() {
if (init.get()) {
return;
@ -204,6 +198,9 @@ public class PythonExecutioner {
}
public static void getVariables(PythonVariables pyVars) throws PythonException {
if (pyVars == null){
return;
}
for (String varName : pyVars.getVariables()) {
pyVars.setValue(varName, getVariable(varName, pyVars.getType(varName)));
}
@ -240,12 +237,6 @@ public class PythonExecutioner {
throwIfExecutionFailed();
}
public static void exec(String code, PythonVariables outputVariables)throws PythonException {
simpleExec(getWrappedCode(code));
throwIfExecutionFailed();
getVariables(outputVariables);
}
public static void exec(String code, PythonVariables inputVariables, PythonVariables outputVariables) throws PythonException {
setVariables(inputVariables);
simpleExec(getWrappedCode(code));
@ -354,7 +345,6 @@ public class PythonExecutioner {
log.info("Setting python path " + path);
StringBuffer sb = new StringBuffer();
File[] packages = numpy.cachePackages();
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
switch (pathAppendValue) {
case BEFORE:
@ -395,4 +385,5 @@ public class PythonExecutioner {
throw new IllegalStateException("Unable to reset python path. Already initialized.");
}
}
}

View File

@ -69,7 +69,11 @@ public class PythonObject {
}
public PythonObject(BytePointer bp){
nativePythonObject = PyByteArray_FromStringAndSize(bp, bp.capacity());
long address = bp.address();
long size = bp.capacity();
NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.BYTE).build();
nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject;
}
public PythonObject(NumpyArray npArray) {
@ -343,13 +347,28 @@ public class PythonObject {
dtype = DataType.DOUBLE;
} else if (dtypeName.equals("float32")) {
dtype = DataType.FLOAT;
} else if (dtypeName.equals("int8")){
dtype = DataType.INT8;
}else if (dtypeName.equals("int16")) {
dtype = DataType.SHORT;
} else if (dtypeName.equals("int32")) {
dtype = DataType.INT;
} else if (dtypeName.equals("int64")) {
dtype = DataType.LONG;
} else {
}
else if (dtypeName.equals("uint8")){
dtype = DataType.UINT8;
}
else if (dtypeName.equals("uint16")){
dtype = DataType.UINT16;
}
else if (dtypeName.equals("uint32")){
dtype = DataType.UINT32;
}
else if (dtypeName.equals("uint64")){
dtype = DataType.UINT64;
}
else {
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
}
return new NumpyArray(address, jshape, jstrides, dtype);
@ -518,6 +537,22 @@ public class PythonObject {
else if (Python.isinstance(this, Python.bytearrayType())){
return PyByteArray_AsString(nativePythonObject);
}
else if (Python.isinstance(this, Python.memoryviewType())){
// PyObject np = PyImport_ImportModule("numpy");
// PyObject array = PyObject_GetAttrString(np, "asarray");
// PyObject npArr = PyObject_CallObject(array, nativePythonObject); // Doesn't work
// Invoke interpreter:
String tempContext = "temp" + UUID.randomUUID().toString().replace('-', '_');
String originalContext = Python.getCurrentContext();
Python.setContext(tempContext);
PythonExecutioner.setVariable("memview", this);
PythonExecutioner.exec("import numpy as np\narr = np.array(memview)");
BytePointer ret = new BytePointer(PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray().data().pointer());
Python.setContext(originalContext);
Python.deleteContext(tempContext);
return ret;
}
else{
PyObject ctypes = PyImport_ImportModule("ctypes");
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
@ -542,7 +577,7 @@ public class PythonObject {
return new BytePointer(ptr);
}
else{
throw new PythonException("Expected bytes, bytearray or ctypesArray. Received " + Python.type(this).toString());
throw new PythonException("Expected bytes, bytearray, memoryview or ctypesArray. Received " + Python.type(this).toString());
}
}

View File

@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
@ -237,12 +238,13 @@ public class TestPythonExecutioner {
PythonVariables pyOutputs= new PythonVariables();
pyOutputs.addStr("out");
String code = "out = buff.decode()";
String code = "out = bytes(buff).decode()";
Python.exec(code, pyInputs, pyOutputs);
Assert.assertEquals("abc", pyOutputs.getStrValue("out"));
}
@Test
public void testByteBufferOutputNoCopy() throws Exception{
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
@ -262,6 +264,24 @@ public class TestPythonExecutioner {
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
}
@Test
public void testByteBufferInplace() throws Exception{
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
buff.putScalar(0, 97); // a
buff.putScalar(1, 98); // b
buff.putScalar(2, 99); // c
PythonVariables pyInputs = new PythonVariables();
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
String code = "buff[0]+=2\nbuff[2]-=2";
Python.exec(code, pyInputs, null);
Assert.assertEquals("cba", pyInputs.getBytesValue("buff").getString());
INDArray expected = buff.dup();
expected.putScalar(0, 99);
expected.putScalar(2, 97);
Assert.assertEquals(buff, expected);
}
@Test
public void testByteBufferOutputWithCopy() throws Exception{
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
@ -302,4 +322,5 @@ public class TestPythonExecutioner {
Python.setMainContext();
}
}