Merge pull request #8790 from KonduitAI/master

Additional update
master
Alex Black 2020-03-19 00:51:22 +11:00 committed by GitHub
commit 7c05928185
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 381 additions and 158 deletions

View File

@ -1,29 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api.transform.analysis.counter;
import java.util.function.BiFunction;
/**
* Created by Alex on 5/03/2016.
*/
public class StringAnalysisMergeFunction
implements BiFunction<StringAnalysisCounter, StringAnalysisCounter, StringAnalysisCounter> {
public StringAnalysisCounter apply(StringAnalysisCounter v1, StringAnalysisCounter v2) {
return v1.merge(v2);
}
}

View File

@ -19,6 +19,7 @@ package org.datavec.python;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.apache.commons.lang3.ArrayUtils;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
@ -29,6 +30,10 @@ import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.linalg.api.buffer.DataType;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import static org.nd4j.linalg.api.buffer.DataType.FLOAT;
@ -42,6 +47,7 @@ import static org.nd4j.linalg.api.buffer.DataType.FLOAT;
public class NumpyArray {
private static NativeOps nativeOps;
private static Map<String, INDArray> arrayCache; // Avoids re-allocation of device buffer
private long address;
private long[] shape;
private long[] strides;
@ -52,6 +58,7 @@ public class NumpyArray {
//initialize
Nd4j.scalar(1.0);
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
arrayCache = new HashMap<>();
}
@Builder
@ -84,24 +91,42 @@ public class NumpyArray {
private void setND4JArray() {
long size = 1;
for (long d : shape) {
size *= d;
}
Pointer ptr = nativeOps.pointerForAddress(address);
ptr = ptr.limit(size);
ptr = ptr.capacity(size);
DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype);
int elemSize = buff.getElementSize();
long[] nd4jStrides = new long[strides.length];
for (int i = 0; i < strides.length; i++) {
nd4jStrides[i] = strides[i] / elemSize;
}
nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
String cacheKey = address + "_" + size + "_" + dtype + "_" + ArrayUtils.toString(strides);
nd4jArray = arrayCache.get(cacheKey);
if (nd4jArray == null) {
Pointer ptr = nativeOps.pointerForAddress(address);
ptr = ptr.limit(size);
ptr = ptr.capacity(size);
DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype);
int elemSize = buff.getElementSize();
long[] nd4jStrides = new long[strides.length];
for (int i = 0; i < strides.length; i++) {
nd4jStrides[i] = strides[i] / elemSize;
}
nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
arrayCache.put(cacheKey, nd4jArray);
}
else{
if (!Arrays.equals(nd4jArray.shape(), shape)){
nd4jArray = nd4jArray.reshape(shape);
}
}
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
}
public INDArray getNd4jArray(){
Nd4j.getAffinityManager().tagLocation(nd4jArray, AffinityManager.Location.HOST);
return nd4jArray;
}
public NumpyArray(INDArray nd4jArray) {
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
DataBuffer buff = nd4jArray.data();
@ -115,6 +140,8 @@ public class NumpyArray {
}
dtype = nd4jArray.dataType();
this.nd4jArray = nd4jArray;
String cacheKey = address + "_" + nd4jArray.length() + "_" + dtype + "_" + ArrayUtils.toString(strides);
arrayCache.put(cacheKey, nd4jArray);
}
}

View File

@ -21,6 +21,7 @@ package org.datavec.python;
import org.bytedeco.cpython.PyObject;
import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.numpy.global.numpy.PyArray_EnsureArray;
/**
* Swift like python wrapper for Java
@ -232,6 +233,10 @@ public class Python {
}
public static PythonObject ndarray(PythonObject pythonObject){
return new PythonObject(PyArray_EnsureArray(pythonObject.getNativePythonObject()));
}
public static boolean callable(PythonObject pythonObject) {
return PyCallable_Check(pythonObject.getNativePythonObject()) == 1;
}

View File

@ -21,6 +21,9 @@ package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;

View File

@ -18,11 +18,15 @@
package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.cpython.PyObject;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.numpy.PyArrayObject;
import org.json.JSONArray;
import org.json.JSONObject;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.nativeblas.NativeOpsHolder;
@ -30,7 +34,7 @@ import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.*;
import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.cpython.global.python.PyObject_SetItem;
import static org.bytedeco.numpy.global.numpy.*;
/**
* Swift like python wrapper for J
@ -38,6 +42,7 @@ import static org.bytedeco.cpython.global.python.PyObject_SetItem;
* @author Fariz Rahman
*/
@Slf4j
public class PythonObject {
private PyObject nativePythonObject;
@ -77,78 +82,69 @@ public class PythonObject {
}
public PythonObject(NumpyArray npArray) {
PyObject ctypes = PyImport_ImportModule("ctypes");
PyObject np = PyImport_ImportModule("numpy");
PyObject ctype;
switch (npArray.getDtype()) {
int numpyType;
INDArray indArray = npArray.getNd4jArray();
DataType dataType = indArray.dataType();
switch (dataType) {
case DOUBLE:
ctype = PyObject_GetAttrString(ctypes, "c_double");
numpyType = NPY_DOUBLE;
break;
case FLOAT:
ctype = PyObject_GetAttrString(ctypes, "c_float");
break;
case LONG:
ctype = PyObject_GetAttrString(ctypes, "c_int64");
break;
case INT:
ctype = PyObject_GetAttrString(ctypes, "c_int32");
case BFLOAT16:
numpyType = NPY_FLOAT;
break;
case SHORT:
ctype = PyObject_GetAttrString(ctypes, "c_int16");
numpyType = NPY_SHORT;
break;
case INT:
numpyType = NPY_INT;
break;
case LONG:
numpyType = NPY_INT64;
break;
case UINT16:
ctype = PyObject_GetAttrString(ctypes, "c_uint16");
numpyType = NPY_USHORT;
break;
case UINT32:
ctype = PyObject_GetAttrString(ctypes, "c_uint32");
numpyType = NPY_UINT;
break;
case UINT64:
ctype = PyObject_GetAttrString(ctypes, "c_uint64");
numpyType = NPY_UINT64;
break;
case BOOL:
ctype = PyObject_GetAttrString(ctypes, "c_bool");
numpyType = NPY_BOOL;
break;
case BYTE:
ctype = PyObject_GetAttrString(ctypes, "c_byte");
numpyType = NPY_BYTE;
break;
case UBYTE:
ctype = PyObject_GetAttrString(ctypes, "c_ubyte");
numpyType = NPY_UBYTE;
break;
case HALF:
numpyType = NPY_HALF;
break;
default:
throw new RuntimeException("Unsupported dtype: " + npArray.getDtype());
}
PyObject ctypesPointer = PyObject_GetAttrString(ctypes, "POINTER");
PyObject argsTuple = PyTuple_New(1);
PyTuple_SetItem(argsTuple, 0, ctype);
PyObject ptrType = PyObject_Call(ctypesPointer, argsTuple, null);
PyObject cast = PyObject_GetAttrString(ctypes, "cast");
PyObject address = PyLong_FromLong(npArray.getAddress());
PyObject argsTuple2 = PyTuple_New(2);
PyTuple_SetItem(argsTuple2, 0, address);
PyTuple_SetItem(argsTuple2, 1, ptrType);
PyObject ptr = PyObject_Call(cast, argsTuple2, null);
PyObject shapeTuple = PyTuple_New(npArray.getShape().length);
for (int i = 0; i < npArray.getShape().length; i++) {
PyObject dim = PyLong_FromLong(npArray.getShape()[i]);
PyTuple_SetItem(shapeTuple, i, dim);
Py_DecRef(dim);
long[] shape = indArray.shape();
INDArray inputArray = indArray;
if(dataType == DataType.BFLOAT16) {
log.warn("\n\nThe given nd4j array \n\n{}\n\n is of BFLOAT16 datatype. " +
"Casting a copy of it to FLOAT and creating the respective numpy array from it.\n", indArray);
inputArray = indArray.castTo(DataType.FLOAT);
}
PyObject ctypesLib = PyObject_GetAttrString(np, "ctypeslib");
PyObject asArray = PyObject_GetAttrString(ctypesLib, "as_array");
PyObject argsTuple3 = PyTuple_New(2);
PyTuple_SetItem(argsTuple3, 0, ptr);
PyTuple_SetItem(argsTuple3, 1, shapeTuple);
nativePythonObject = PyObject_Call(asArray, argsTuple3, null);
Py_DecRef(ctypesPointer);
Py_DecRef(ctypesLib);
Py_DecRef(argsTuple);
Py_DecRef(argsTuple2);
Py_DecRef(argsTuple3);
Py_DecRef(cast);
Py_DecRef(asArray);
//Sync to host memory in the case of CUDA, before passing the host memory pointer to Python
if(inputArray.data() instanceof BaseDataBuffer){
((BaseDataBuffer)inputArray.data()).syncToPrimary();
}
nativePythonObject = PyArray_New(PyArray_Type(), shape.length, new SizeTPointer(shape),
numpyType, null,
inputArray.data().addressPointer(),
0, NPY_ARRAY_CARRAY, null);
}
@ -321,57 +317,60 @@ public class PythonObject {
return toInt() != 0;
}
public NumpyArray toNumpy() {
PyObject arrInterface = PyObject_GetAttrString(nativePythonObject, "__array_interface__"); // borrowed reference; DO NOT Py_DecRef() !
PyObject data = PyDict_GetItemString(arrInterface, "data");
PyObject pyAddress = PyTuple_GetItem(data, 0);
long address = PyLong_AsLong(pyAddress);
PyObject pyDtype = PyObject_GetAttrString(nativePythonObject, "dtype");
PyObject pyDtypeName = PyObject_GetAttrString(pyDtype, "name");
String dtypeName = pyObjectToString(pyDtypeName);
Py_DecRef(pyDtype);
Py_DecRef(pyDtypeName);
PyObject shape = PyObject_GetAttrString(nativePythonObject, "shape");
PyObject strides = PyObject_GetAttrString(nativePythonObject, "strides");
int ndim = (int) PyObject_Size(shape);
long[] jshape = new long[ndim];
long[] jstrides = new long[ndim];
for (int i = 0; i < ndim; i++) {
jshape[i] = PyLong_AsLong(PyTuple_GetItem(shape, i));
jstrides[i] = PyLong_AsLong(PyTuple_GetItem(strides, i));
public NumpyArray toNumpy() throws PythonException{
PyObject np = PyImport_ImportModule("numpy");
PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
if (PyObject_IsInstance(nativePythonObject, ndarray) == 0){
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
}
Py_DecRef(shape);
Py_DecRef(strides);
Py_DecRef(ndarray);
Py_DecRef(np);
Pointer objPtr = new Pointer(nativePythonObject);
PyArrayObject npArr = new PyArrayObject(objPtr);
Pointer ptr = PyArray_DATA(npArr);
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
long[] shape = new long[PyArray_NDIM(npArr)];
shapePtr.get(shape, 0, shape.length);
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
long[] strides = new long[shape.length];
stridesPtr.get(strides, 0, strides.length);
int npdtype = PyArray_TYPE(npArr);
DataType dtype;
if (dtypeName.equals("float64")) {
dtype = DataType.DOUBLE;
} else if (dtypeName.equals("float32")) {
dtype = DataType.FLOAT;
} else if (dtypeName.equals("int8")){
dtype = DataType.INT8;
}else if (dtypeName.equals("int16")) {
dtype = DataType.SHORT;
} else if (dtypeName.equals("int32")) {
dtype = DataType.INT;
} else if (dtypeName.equals("int64")) {
dtype = DataType.LONG;
switch (npdtype){
case NPY_DOUBLE:
dtype = DataType.DOUBLE; break;
case NPY_FLOAT:
dtype = DataType.FLOAT; break;
case NPY_SHORT:
dtype = DataType.SHORT; break;
case NPY_INT:
dtype = DataType.INT; break;
case NPY_LONG:
dtype = DataType.LONG; break;
case NPY_UINT:
dtype = DataType.UINT32; break;
case NPY_BYTE:
dtype = DataType.BYTE; break;
case NPY_UBYTE:
dtype = DataType.UBYTE; break;
case NPY_BOOL:
dtype = DataType.BOOL; break;
case NPY_HALF:
dtype = DataType.HALF; break;
case NPY_LONGLONG:
dtype = DataType.INT64; break;
case NPY_USHORT:
dtype = DataType.UINT16; break;
case NPY_ULONG:
dtype = DataType.UINT64; break;
case NPY_ULONGLONG:
dtype = DataType.UINT64; break;
default:
throw new PythonException("Unsupported array data type: " + npdtype);
}
else if (dtypeName.equals("uint8")){
dtype = DataType.UINT8;
}
else if (dtypeName.equals("uint16")){
dtype = DataType.UINT16;
}
else if (dtypeName.equals("uint32")){
dtype = DataType.UINT32;
}
else if (dtypeName.equals("uint64")){
dtype = DataType.UINT64;
}
else {
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
}
return new NumpyArray(address, jshape, jstrides, dtype);
return new NumpyArray(ptr.address(), shape, strides, dtype);
}
@ -442,7 +441,6 @@ public class PythonObject {
return get(key.nativePythonObject);
}
public PythonObject get(int key) {
return get(PyLong_FromLong((long) key));
}
@ -450,14 +448,12 @@ public class PythonObject {
public PythonObject get(long key) {
return new PythonObject(
PyObject_GetItem(nativePythonObject, PyLong_FromLong(key))
);
}
public PythonObject get(double key) {
return new PythonObject(
PyObject_GetItem(nativePythonObject, PyFloat_FromDouble(key))
);
}
@ -479,7 +475,6 @@ public class PythonObject {
PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer());
String jsonString = serialized.toString();
return new JSONArray(jsonString);
}
public JSONObject toJSONObject() throws PythonException {
@ -547,13 +542,16 @@ public class PythonObject {
String originalContext = Python.getCurrentContext();
Python.setContext(tempContext);
PythonExecutioner.setVariable("memview", this);
PythonExecutioner.exec("import numpy as np\narr = np.array(memview)");
BytePointer ret = new BytePointer(PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray().data().pointer());
PythonExecutioner.exec("import numpy as np\narr = np.frombuffer(memview, dtype='int8')");
INDArray arr = PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray();
if(arr.data() instanceof BaseDataBuffer){
((BaseDataBuffer)arr.data()).syncToPrimary();
}
BytePointer ret = new BytePointer(arr.data().pointer());
Python.setContext(originalContext);
Python.deleteContext(tempContext);
return ret;
}
else{
} else {
PyObject ctypes = PyImport_ImportModule("ctypes");
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
if (PyObject_IsInstance(nativePythonObject, cArrType) != 0){
@ -579,13 +577,10 @@ public class PythonObject {
else{
throw new PythonException("Expected bytes, bytearray, memoryview or ctypesArray. Received " + Python.type(this).toString());
}
}
}
public boolean isNone() {
return (nativePythonObject == null)||
(toString().equals("None") && Python.type(this).toString().equals("<class 'NoneType'>"));
}
}

View File

@ -314,7 +314,7 @@ public class PythonVariables implements java.io.Serializable {
* @param name the field to add
* @param value the value to add
*/
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
public void addNDArray(String name, INDArray value) {
vars.put(name, PythonType.TypeName.NDARRAY);
ndVars.put(name, value);
}

View File

@ -0,0 +1,76 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static junit.framework.TestCase.assertEquals;
@RunWith(Parameterized.class)
public class PythonNumpyTest {
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
public static DataType[] data() {
return new DataType[] {
DataType.BOOL,
DataType.FLOAT16,
DataType.BFLOAT16,
DataType.FLOAT,
DataType.DOUBLE,
DataType.INT8,
DataType.INT16,
DataType.INT32,
DataType.INT64,
DataType.UINT8,
DataType.UINT16,
DataType.UINT32,
DataType.UINT64
};
}
private DataType dataType;
public PythonNumpyTest(DataType dataType) {
this.dataType = dataType;
}
@Test
public void numpyAndNd4jConversions() throws Exception {
INDArray input = Nd4j.ones(dataType, 2, 2, 2);
PythonVariables inputs = new PythonVariables();
inputs.addNDArray("x", input);
PythonVariables outputs = new PythonVariables();
outputs.addNDArray("y");
PythonJob pythonJob = new PythonJob(String.format("job_%s", dataType.name()) + dataType.name(), "y = x", false);
pythonJob.exec(inputs, outputs);
INDArray output = outputs.getNDArrayValue("y");
// As numpy doesn't support BFLOAT16 we'll convert it to FLOAT
assertEquals(dataType == DataType.BFLOAT16 ? input.castTo(DataType.FLOAT) : input,
output);
}
}

View File

@ -20,10 +20,13 @@ import org.bytedeco.javacpp.BytePointer;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import java.lang.reflect.Method;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
@ -223,6 +226,35 @@ public class TestPythonExecutioner {
}
@Test
public void testNDArrayNoCopy() throws Exception{
PythonVariables pyInputs = new PythonVariables();
PythonVariables pyOutputs = new PythonVariables();
INDArray arr = Nd4j.rand(3, 2);
((BaseDataBuffer)arr.data()).syncToPrimary();
pyInputs.addNDArray("x", arr);
pyOutputs.addNDArray("x");
INDArray expected = arr.mul(2.3);
String code = "x *= 2.3";
Python.exec(code, pyInputs, pyOutputs);
Assert.assertEquals(pyInputs.getNDArrayValue("x"), pyOutputs.getNDArrayValue("x"));
Assert.assertEquals(expected, pyOutputs.getNDArrayValue("x"));
Assert.assertEquals(arr.data().address(), pyOutputs.getNDArrayValue("x").data().address());
}
@Test
public void testNDArrayInplace() throws Exception{
PythonVariables pyInputs = new PythonVariables();
INDArray arr = Nd4j.rand(3, 2);
((BaseDataBuffer)arr.data()).syncToPrimary();
pyInputs.addNDArray("x", arr);
INDArray expected = arr.mul(2.3);
String code = "x *= 2.3";
Python.exec(code, pyInputs, null);
Assert.assertEquals(expected, arr);
}
@Test
public void testByteBufferInput() throws Exception{
//ByteBuffer buff = ByteBuffer.allocateDirect(3);
@ -230,8 +262,7 @@ public class TestPythonExecutioner {
buff.putScalar(0, 97); // a
buff.putScalar(1, 98); // b
buff.putScalar(2, 99); // c
((BaseDataBuffer)buff.data()).syncToPrimary();
PythonVariables pyInputs = new PythonVariables();
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
@ -251,6 +282,7 @@ public class TestPythonExecutioner {
buff.putScalar(0, 97); // a
buff.putScalar(1, 98); // b
buff.putScalar(2, 99); // c
((BaseDataBuffer)buff.data()).syncToPrimary();
PythonVariables pyInputs = new PythonVariables();
@ -262,6 +294,7 @@ public class TestPythonExecutioner {
String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97";
Python.exec(code, pyInputs, pyOutputs);
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
Assert.assertEquals(buff.data().address(), pyOutputs.getBytesValue("buff").address());
}
@Test
@ -270,6 +303,8 @@ public class TestPythonExecutioner {
buff.putScalar(0, 97); // a
buff.putScalar(1, 98); // b
buff.putScalar(2, 99); // c
((BaseDataBuffer)buff.data()).syncToPrimary();
PythonVariables pyInputs = new PythonVariables();
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
String code = "buff[0]+=2\nbuff[2]-=2";
@ -288,6 +323,7 @@ public class TestPythonExecutioner {
buff.putScalar(0, 97); // a
buff.putScalar(1, 98); // b
buff.putScalar(2, 99); // c
((BaseDataBuffer)buff.data()).syncToPrimary();
PythonVariables pyInputs = new PythonVariables();
@ -300,6 +336,30 @@ public class TestPythonExecutioner {
Python.exec(code, pyInputs, pyOutputs);
Assert.assertEquals("cba", pyOutputs.getBytesValue("out").getString());
}
@Test
public void testDoubleDeviceAllocation() throws Exception{
if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
return;
}
// Test to make sure that multiple device buffers are not allocated
// for the same host buffer
INDArray arr = Nd4j.rand(3, 2);
((BaseDataBuffer)arr.data()).syncToPrimary();
long deviceAddress1 = getDeviceAddress(arr);
PythonVariables pyInputs = new PythonVariables();
pyInputs.addNDArray("arr", arr);
PythonVariables pyOutputs = new PythonVariables();
pyOutputs.addNDArray("arr");
String code = "arr += 2";
Python.exec(code, pyInputs, pyOutputs);
INDArray arr2 = pyOutputs.getNDArrayValue("arr");
long deviceAddress2 = getDeviceAddress(arr2);
Assert.assertEquals(deviceAddress1, deviceAddress2);
}
@Test
public void testBadCode() throws Exception{
Python.setContext("badcode");
@ -333,5 +393,22 @@ public class TestPythonExecutioner {
Assert.assertEquals("y", notNone.toString());
}
private static long getDeviceAddress(INDArray array){
if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
throw new IllegalStateException("Cannot ge device pointer for non-CUDA device");
}
//Use reflection here as OpaqueDataBuffer is only available on BaseCudaDataBuffer and BaseCpuDataBuffer - not DataBuffer/BaseDataBuffer
// due to it being defined in nd4j-native-api, not nd4j-api
try {
Class<?> c = Class.forName("org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer");
Method m = c.getMethod("getOpaqueDataBuffer");
OpaqueDataBuffer db = (OpaqueDataBuffer) m.invoke(array.data());
long address = db.specialBuffer().address();
return address;
} catch (Throwable t){
throw new RuntimeException("Error getting OpaqueDataBuffer", t);
}
}
}

View File

@ -24,6 +24,7 @@ package org.datavec.python;
import org.bytedeco.javacpp.BytePointer;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -55,6 +56,7 @@ public class TestPythonVariables {
};
INDArray arr = Nd4j.scalar(1.0);
((BaseDataBuffer)arr.data()).syncToPrimary();
BytePointer bp = new BytePointer(arr.data().pointer());
Object[] values = {
1L,1.0,"1",true, Collections.singletonMap("1",1),

View File

@ -1945,4 +1945,17 @@ public abstract class BaseDataBuffer implements DataBuffer {
public boolean wasClosed() {
return released;
}
/**
* This method synchronizes host memory
*/
public abstract void syncToPrimary();
/**
* This method synchronizes device memory
*/
public abstract void syncToSpecial();
}

View File

@ -230,13 +230,14 @@ public enum DataType {
public static DataType fromNumpy(String numpyDtypeName){
switch (numpyDtypeName.toLowerCase()){
case "bool": return BOOL;
case "byte": return BYTE;
case "int8": return BYTE;
case "int16": return SHORT;
case "int32": return INT;
case "int64": return LONG;
case "uint8": return UBYTE;
case "float16": return HALF;
case "byte":
case "int8":
return INT8;
case "int16": return INT16;
case "int32": return INT32;
case "int64": return INT64;
case "uint8": return UINT8;
case "float16": return FLOAT16;
case "float32": return FLOAT;
case "float64": return DOUBLE;
case "uint16": return UINT16;

View File

@ -211,6 +211,16 @@ public class CompressedDataBuffer extends BaseDataBuffer {
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
}
@Override
public void syncToPrimary() {
//No-op
}
@Override
public void syncToSpecial() {
//No-op
}
@Override
protected double getDoubleUnsynced(long index) {
return super.getDouble(index);

View File

@ -51,7 +51,6 @@ public class OpaqueDataBuffer extends Pointer {
try {
// try to allocate data buffer
buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(numElements, dataType.toInt(), allocateBoth);
// check error code
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (ec != 0) {

View File

@ -1845,4 +1845,14 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
public int targetDevice() {
return AtomicAllocator.getInstance().getAllocationPoint(this).getDeviceId();
}
@Override
public void syncToPrimary(){
ptrDataBuffer.syncToPrimary();
}
@Override
public void syncToSpecial(){
ptrDataBuffer.syncToSpecial();
}
}

View File

@ -951,4 +951,13 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
return this;
}
@Override
public void syncToPrimary(){
ptrDataBuffer.syncToPrimary();
}
@Override
public void syncToSpecial(){
ptrDataBuffer.syncToSpecial();
}
}

View File

@ -32,6 +32,8 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.nativeblas.NativeOpsHolder;
import sun.nio.ch.DirectBuffer;
import java.nio.ByteBuffer;
@ -398,6 +400,29 @@ public class DataBufferTests extends BaseNd4jTest {
}
}
@Test
public void testEnsureLocation(){
//https://github.com/eclipse/deeplearning4j/issues/8783
Nd4j.create(1);
DirectBuffer bb = (DirectBuffer) ByteBuffer.allocateDirect(5);
System.out.println(bb.getClass());
System.out.println(bb.address());
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bb.address());
DataBuffer buff = Nd4j.createBuffer(ptr, 20, DataType.BYTE);
INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 1L, 'c', DataType.BYTE);
long before = arr2.data().pointer().address();
Nd4j.getAffinityManager().ensureLocation(arr2, AffinityManager.Location.HOST);
long after = arr2.data().pointer().address();
assertEquals(before, after);
}
@Override
public char ordering() {
return 'c';