diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/StringAnalysisMergeFunction.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/StringAnalysisMergeFunction.java deleted file mode 100644 index b57bf97cd..000000000 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/StringAnalysisMergeFunction.java +++ /dev/null @@ -1,29 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.api.transform.analysis.counter; - -import java.util.function.BiFunction; - -/** - * Created by Alex on 5/03/2016. - */ -public class StringAnalysisMergeFunction - implements BiFunction { - public StringAnalysisCounter apply(StringAnalysisCounter v1, StringAnalysisCounter v2) { - return v1.merge(v2); - } -} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java index dd1613d0a..708184de7 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java @@ -19,6 +19,7 @@ package org.datavec.python; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; +import org.apache.commons.lang3.ArrayUtils; import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.concurrency.AffinityManager; @@ -29,6 +30,10 @@ import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.linalg.api.buffer.DataType; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + import static org.nd4j.linalg.api.buffer.DataType.FLOAT; @@ -42,6 +47,7 @@ import static org.nd4j.linalg.api.buffer.DataType.FLOAT; public class NumpyArray { private static NativeOps nativeOps; + private static Map arrayCache; // Avoids re-allocation of device buffer private long address; private long[] shape; private long[] strides; @@ -52,6 +58,7 @@ public class NumpyArray { //initialize Nd4j.scalar(1.0); nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + arrayCache = new HashMap<>(); } @Builder @@ -84,24 +91,42 @@ public class NumpyArray { private void setND4JArray() { + long size = 1; for (long d : shape) { size *= d; } - Pointer ptr = nativeOps.pointerForAddress(address); - ptr = ptr.limit(size); - ptr = ptr.capacity(size); - DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype); - int elemSize = buff.getElementSize(); - long[] nd4jStrides = new long[strides.length]; - for (int i = 0; i < strides.length; i++) { - nd4jStrides[i] = strides[i] / elemSize; - } - nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype); + String cacheKey = address + "_" + size + "_" + dtype + "_" + ArrayUtils.toString(strides); + nd4jArray = arrayCache.get(cacheKey); + if (nd4jArray == null) { + Pointer ptr = nativeOps.pointerForAddress(address); + ptr = ptr.limit(size); + ptr = ptr.capacity(size); + DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype); + + int elemSize = buff.getElementSize(); + long[] nd4jStrides = new long[strides.length]; + for (int i = 0; i < strides.length; i++) { + nd4jStrides[i] = strides[i] / elemSize; + } + + nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype); + arrayCache.put(cacheKey, nd4jArray); + } + else{ + if (!Arrays.equals(nd4jArray.shape(), shape)){ + nd4jArray = nd4jArray.reshape(shape); + } + } Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); } + public INDArray getNd4jArray(){ + Nd4j.getAffinityManager().tagLocation(nd4jArray, AffinityManager.Location.HOST); + return nd4jArray; + } + public NumpyArray(INDArray nd4jArray) { Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); DataBuffer buff = nd4jArray.data(); @@ -115,6 +140,8 @@ public class NumpyArray { } dtype = nd4jArray.dataType(); this.nd4jArray = nd4jArray; + String cacheKey = address + "_" + nd4jArray.length() + "_" + dtype + "_" + ArrayUtils.toString(strides); + arrayCache.put(cacheKey, nd4jArray); } } \ No newline at end of file diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java index 9dabbef2d..98c9b964c 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java @@ -21,6 +21,7 @@ package org.datavec.python; import org.bytedeco.cpython.PyObject; import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.numpy.global.numpy.PyArray_EnsureArray; /** * Swift like python wrapper for Java @@ -232,6 +233,10 @@ public class Python { } + public static PythonObject ndarray(PythonObject pythonObject){ + return new PythonObject(PyArray_EnsureArray(pythonObject.getNativePythonObject())); + } + public static boolean callable(PythonObject pythonObject) { return PyCallable_Check(pythonObject.getNativePythonObject()) == 1; } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index 530dd0e02..3d08d3141 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -21,6 +21,9 @@ package org.datavec.python; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.bytedeco.numpy.global.numpy; +import org.nd4j.linalg.api.concurrency.AffinityManager; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java index 84dd16e73..0408e3a59 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java @@ -18,11 +18,15 @@ package org.datavec.python; +import lombok.extern.slf4j.Slf4j; import org.bytedeco.cpython.PyObject; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.SizeTPointer; +import org.bytedeco.numpy.PyArrayObject; import org.json.JSONArray; import org.json.JSONObject; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.nativeblas.NativeOpsHolder; @@ -30,7 +34,7 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.*; import static org.bytedeco.cpython.global.python.*; -import static org.bytedeco.cpython.global.python.PyObject_SetItem; +import static org.bytedeco.numpy.global.numpy.*; /** * Swift like python wrapper for J @@ -38,6 +42,7 @@ import static org.bytedeco.cpython.global.python.PyObject_SetItem; * @author Fariz Rahman */ +@Slf4j public class PythonObject { private PyObject nativePythonObject; @@ -77,78 +82,69 @@ public class PythonObject { } public PythonObject(NumpyArray npArray) { - PyObject ctypes = PyImport_ImportModule("ctypes"); - PyObject np = PyImport_ImportModule("numpy"); - PyObject ctype; - switch (npArray.getDtype()) { + int numpyType; + INDArray indArray = npArray.getNd4jArray(); + DataType dataType = indArray.dataType(); + + switch (dataType) { case DOUBLE: - ctype = PyObject_GetAttrString(ctypes, "c_double"); + numpyType = NPY_DOUBLE; break; case FLOAT: - ctype = PyObject_GetAttrString(ctypes, "c_float"); - break; - case LONG: - ctype = PyObject_GetAttrString(ctypes, "c_int64"); - break; - case INT: - ctype = PyObject_GetAttrString(ctypes, "c_int32"); + case BFLOAT16: + numpyType = NPY_FLOAT; break; case SHORT: - ctype = PyObject_GetAttrString(ctypes, "c_int16"); + numpyType = NPY_SHORT; + break; + case INT: + numpyType = NPY_INT; + break; + case LONG: + numpyType = NPY_INT64; break; case UINT16: - ctype = PyObject_GetAttrString(ctypes, "c_uint16"); + numpyType = NPY_USHORT; break; case UINT32: - ctype = PyObject_GetAttrString(ctypes, "c_uint32"); + numpyType = NPY_UINT; break; case UINT64: - ctype = PyObject_GetAttrString(ctypes, "c_uint64"); + numpyType = NPY_UINT64; break; case BOOL: - ctype = PyObject_GetAttrString(ctypes, "c_bool"); + numpyType = NPY_BOOL; break; case BYTE: - ctype = PyObject_GetAttrString(ctypes, "c_byte"); + numpyType = NPY_BYTE; break; case UBYTE: - ctype = PyObject_GetAttrString(ctypes, "c_ubyte"); + numpyType = NPY_UBYTE; + break; + case HALF: + numpyType = NPY_HALF; break; default: throw new RuntimeException("Unsupported dtype: " + npArray.getDtype()); } - PyObject ctypesPointer = PyObject_GetAttrString(ctypes, "POINTER"); - PyObject argsTuple = PyTuple_New(1); - PyTuple_SetItem(argsTuple, 0, ctype); - PyObject ptrType = PyObject_Call(ctypesPointer, argsTuple, null); - - PyObject cast = PyObject_GetAttrString(ctypes, "cast"); - PyObject address = PyLong_FromLong(npArray.getAddress()); - PyObject argsTuple2 = PyTuple_New(2); - PyTuple_SetItem(argsTuple2, 0, address); - PyTuple_SetItem(argsTuple2, 1, ptrType); - PyObject ptr = PyObject_Call(cast, argsTuple2, null); - PyObject shapeTuple = PyTuple_New(npArray.getShape().length); - for (int i = 0; i < npArray.getShape().length; i++) { - PyObject dim = PyLong_FromLong(npArray.getShape()[i]); - PyTuple_SetItem(shapeTuple, i, dim); - Py_DecRef(dim); + long[] shape = indArray.shape(); + INDArray inputArray = indArray; + if(dataType == DataType.BFLOAT16) { + log.warn("\n\nThe given nd4j array \n\n{}\n\n is of BFLOAT16 datatype. " + + "Casting a copy of it to FLOAT and creating the respective numpy array from it.\n", indArray); + inputArray = indArray.castTo(DataType.FLOAT); } - PyObject ctypesLib = PyObject_GetAttrString(np, "ctypeslib"); - PyObject asArray = PyObject_GetAttrString(ctypesLib, "as_array"); - PyObject argsTuple3 = PyTuple_New(2); - PyTuple_SetItem(argsTuple3, 0, ptr); - PyTuple_SetItem(argsTuple3, 1, shapeTuple); - nativePythonObject = PyObject_Call(asArray, argsTuple3, null); - Py_DecRef(ctypesPointer); - Py_DecRef(ctypesLib); - Py_DecRef(argsTuple); - Py_DecRef(argsTuple2); - Py_DecRef(argsTuple3); - Py_DecRef(cast); - Py_DecRef(asArray); + //Sync to host memory in the case of CUDA, before passing the host memory pointer to Python + if(inputArray.data() instanceof BaseDataBuffer){ + ((BaseDataBuffer)inputArray.data()).syncToPrimary(); + } + + nativePythonObject = PyArray_New(PyArray_Type(), shape.length, new SizeTPointer(shape), + numpyType, null, + inputArray.data().addressPointer(), + 0, NPY_ARRAY_CARRAY, null); } @@ -321,57 +317,60 @@ public class PythonObject { return toInt() != 0; } - public NumpyArray toNumpy() { - PyObject arrInterface = PyObject_GetAttrString(nativePythonObject, "__array_interface__"); // borrowed reference; DO NOT Py_DecRef() ! - PyObject data = PyDict_GetItemString(arrInterface, "data"); - PyObject pyAddress = PyTuple_GetItem(data, 0); - long address = PyLong_AsLong(pyAddress); - PyObject pyDtype = PyObject_GetAttrString(nativePythonObject, "dtype"); - PyObject pyDtypeName = PyObject_GetAttrString(pyDtype, "name"); - String dtypeName = pyObjectToString(pyDtypeName); - Py_DecRef(pyDtype); - Py_DecRef(pyDtypeName); - PyObject shape = PyObject_GetAttrString(nativePythonObject, "shape"); - PyObject strides = PyObject_GetAttrString(nativePythonObject, "strides"); - int ndim = (int) PyObject_Size(shape); - long[] jshape = new long[ndim]; - long[] jstrides = new long[ndim]; - for (int i = 0; i < ndim; i++) { - jshape[i] = PyLong_AsLong(PyTuple_GetItem(shape, i)); - jstrides[i] = PyLong_AsLong(PyTuple_GetItem(strides, i)); + public NumpyArray toNumpy() throws PythonException{ + PyObject np = PyImport_ImportModule("numpy"); + PyObject ndarray = PyObject_GetAttrString(np, "ndarray"); + if (PyObject_IsInstance(nativePythonObject, ndarray) == 0){ + throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array."); } - Py_DecRef(shape); - Py_DecRef(strides); + Py_DecRef(ndarray); + Py_DecRef(np); + Pointer objPtr = new Pointer(nativePythonObject); + PyArrayObject npArr = new PyArrayObject(objPtr); + Pointer ptr = PyArray_DATA(npArr); + SizeTPointer shapePtr = PyArray_SHAPE(npArr); + long[] shape = new long[PyArray_NDIM(npArr)]; + shapePtr.get(shape, 0, shape.length); + SizeTPointer stridesPtr = PyArray_STRIDES(npArr); + long[] strides = new long[shape.length]; + stridesPtr.get(strides, 0, strides.length); + int npdtype = PyArray_TYPE(npArr); + DataType dtype; - if (dtypeName.equals("float64")) { - dtype = DataType.DOUBLE; - } else if (dtypeName.equals("float32")) { - dtype = DataType.FLOAT; - } else if (dtypeName.equals("int8")){ - dtype = DataType.INT8; - }else if (dtypeName.equals("int16")) { - dtype = DataType.SHORT; - } else if (dtypeName.equals("int32")) { - dtype = DataType.INT; - } else if (dtypeName.equals("int64")) { - dtype = DataType.LONG; + switch (npdtype){ + case NPY_DOUBLE: + dtype = DataType.DOUBLE; break; + case NPY_FLOAT: + dtype = DataType.FLOAT; break; + case NPY_SHORT: + dtype = DataType.SHORT; break; + case NPY_INT: + dtype = DataType.INT; break; + case NPY_LONG: + dtype = DataType.LONG; break; + case NPY_UINT: + dtype = DataType.UINT32; break; + case NPY_BYTE: + dtype = DataType.BYTE; break; + case NPY_UBYTE: + dtype = DataType.UBYTE; break; + case NPY_BOOL: + dtype = DataType.BOOL; break; + case NPY_HALF: + dtype = DataType.HALF; break; + case NPY_LONGLONG: + dtype = DataType.INT64; break; + case NPY_USHORT: + dtype = DataType.UINT16; break; + case NPY_ULONG: + dtype = DataType.UINT64; break; + case NPY_ULONGLONG: + dtype = DataType.UINT64; break; + default: + throw new PythonException("Unsupported array data type: " + npdtype); } - else if (dtypeName.equals("uint8")){ - dtype = DataType.UINT8; - } - else if (dtypeName.equals("uint16")){ - dtype = DataType.UINT16; - } - else if (dtypeName.equals("uint32")){ - dtype = DataType.UINT32; - } - else if (dtypeName.equals("uint64")){ - dtype = DataType.UINT64; - } - else { - throw new RuntimeException("Unsupported array type " + dtypeName + "."); - } - return new NumpyArray(address, jshape, jstrides, dtype); + + return new NumpyArray(ptr.address(), shape, strides, dtype); } @@ -442,7 +441,6 @@ public class PythonObject { return get(key.nativePythonObject); } - public PythonObject get(int key) { return get(PyLong_FromLong((long) key)); } @@ -450,14 +448,12 @@ public class PythonObject { public PythonObject get(long key) { return new PythonObject( PyObject_GetItem(nativePythonObject, PyLong_FromLong(key)) - ); } public PythonObject get(double key) { return new PythonObject( PyObject_GetItem(nativePythonObject, PyFloat_FromDouble(key)) - ); } @@ -479,7 +475,6 @@ public class PythonObject { PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer()); String jsonString = serialized.toString(); return new JSONArray(jsonString); - } public JSONObject toJSONObject() throws PythonException { @@ -547,13 +542,16 @@ public class PythonObject { String originalContext = Python.getCurrentContext(); Python.setContext(tempContext); PythonExecutioner.setVariable("memview", this); - PythonExecutioner.exec("import numpy as np\narr = np.array(memview)"); - BytePointer ret = new BytePointer(PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray().data().pointer()); + PythonExecutioner.exec("import numpy as np\narr = np.frombuffer(memview, dtype='int8')"); + INDArray arr = PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray(); + if(arr.data() instanceof BaseDataBuffer){ + ((BaseDataBuffer)arr.data()).syncToPrimary(); + } + BytePointer ret = new BytePointer(arr.data().pointer()); Python.setContext(originalContext); Python.deleteContext(tempContext); return ret; - } - else{ + } else { PyObject ctypes = PyImport_ImportModule("ctypes"); PyObject cArrType = PyObject_GetAttrString(ctypes, "Array"); if (PyObject_IsInstance(nativePythonObject, cArrType) != 0){ @@ -579,13 +577,10 @@ public class PythonObject { else{ throw new PythonException("Expected bytes, bytearray, memoryview or ctypesArray. Received " + Python.type(this).toString()); } - } } public boolean isNone() { return (nativePythonObject == null)|| (toString().equals("None") && Python.type(this).toString().equals("")); - } - } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java index 9d8b5c2a1..ade9bdfa0 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java @@ -314,7 +314,7 @@ public class PythonVariables implements java.io.Serializable { * @param name the field to add * @param value the value to add */ - public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) { + public void addNDArray(String name, INDArray value) { vars.put(name, PythonType.TypeName.NDARRAY); ndVars.put(name, value); } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/PythonNumpyTest.java b/datavec/datavec-python/src/test/java/org/datavec/python/PythonNumpyTest.java new file mode 100644 index 000000000..89ea83552 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/PythonNumpyTest.java @@ -0,0 +1,76 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static junit.framework.TestCase.assertEquals; + +@RunWith(Parameterized.class) +public class PythonNumpyTest { + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") + public static DataType[] data() { + return new DataType[] { + DataType.BOOL, + DataType.FLOAT16, + DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, + DataType.UINT8, + DataType.UINT16, + DataType.UINT32, + DataType.UINT64 + }; + } + + private DataType dataType; + + public PythonNumpyTest(DataType dataType) { + this.dataType = dataType; + } + + @Test + public void numpyAndNd4jConversions() throws Exception { + INDArray input = Nd4j.ones(dataType, 2, 2, 2); + + PythonVariables inputs = new PythonVariables(); + inputs.addNDArray("x", input); + + PythonVariables outputs = new PythonVariables(); + outputs.addNDArray("y"); + + PythonJob pythonJob = new PythonJob(String.format("job_%s", dataType.name()) + dataType.name(), "y = x", false); + + pythonJob.exec(inputs, outputs); + + INDArray output = outputs.getNDArrayValue("y"); + + // As numpy doesn't support BFLOAT16 we'll convert it to FLOAT + assertEquals(dataType == DataType.BFLOAT16 ? input.castTo(DataType.FLOAT) : input, + output); + } +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java index 52e2aad56..027a534bc 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java @@ -20,10 +20,13 @@ import org.bytedeco.javacpp.BytePointer; import org.junit.Assert; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.nativeblas.OpaqueDataBuffer; +import java.lang.reflect.Method; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -223,6 +226,35 @@ public class TestPythonExecutioner { } + + @Test + public void testNDArrayNoCopy() throws Exception{ + PythonVariables pyInputs = new PythonVariables(); + PythonVariables pyOutputs = new PythonVariables(); + INDArray arr = Nd4j.rand(3, 2); + ((BaseDataBuffer)arr.data()).syncToPrimary(); + pyInputs.addNDArray("x", arr); + pyOutputs.addNDArray("x"); + INDArray expected = arr.mul(2.3); + String code = "x *= 2.3"; + Python.exec(code, pyInputs, pyOutputs); + Assert.assertEquals(pyInputs.getNDArrayValue("x"), pyOutputs.getNDArrayValue("x")); + Assert.assertEquals(expected, pyOutputs.getNDArrayValue("x")); + Assert.assertEquals(arr.data().address(), pyOutputs.getNDArrayValue("x").data().address()); + } + + @Test + public void testNDArrayInplace() throws Exception{ + PythonVariables pyInputs = new PythonVariables(); + INDArray arr = Nd4j.rand(3, 2); + ((BaseDataBuffer)arr.data()).syncToPrimary(); + pyInputs.addNDArray("x", arr); + INDArray expected = arr.mul(2.3); + String code = "x *= 2.3"; + Python.exec(code, pyInputs, null); + Assert.assertEquals(expected, arr); + } + @Test public void testByteBufferInput() throws Exception{ //ByteBuffer buff = ByteBuffer.allocateDirect(3); @@ -230,8 +262,7 @@ public class TestPythonExecutioner { buff.putScalar(0, 97); // a buff.putScalar(1, 98); // b buff.putScalar(2, 99); // c - - + ((BaseDataBuffer)buff.data()).syncToPrimary(); PythonVariables pyInputs = new PythonVariables(); pyInputs.addBytes("buff", new BytePointer(buff.data().pointer())); @@ -251,6 +282,7 @@ public class TestPythonExecutioner { buff.putScalar(0, 97); // a buff.putScalar(1, 98); // b buff.putScalar(2, 99); // c + ((BaseDataBuffer)buff.data()).syncToPrimary(); PythonVariables pyInputs = new PythonVariables(); @@ -262,6 +294,7 @@ public class TestPythonExecutioner { String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97"; Python.exec(code, pyInputs, pyOutputs); Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString()); + Assert.assertEquals(buff.data().address(), pyOutputs.getBytesValue("buff").address()); } @Test @@ -270,6 +303,8 @@ public class TestPythonExecutioner { buff.putScalar(0, 97); // a buff.putScalar(1, 98); // b buff.putScalar(2, 99); // c + ((BaseDataBuffer)buff.data()).syncToPrimary(); + PythonVariables pyInputs = new PythonVariables(); pyInputs.addBytes("buff", new BytePointer(buff.data().pointer())); String code = "buff[0]+=2\nbuff[2]-=2"; @@ -288,6 +323,7 @@ public class TestPythonExecutioner { buff.putScalar(0, 97); // a buff.putScalar(1, 98); // b buff.putScalar(2, 99); // c + ((BaseDataBuffer)buff.data()).syncToPrimary(); PythonVariables pyInputs = new PythonVariables(); @@ -300,6 +336,30 @@ public class TestPythonExecutioner { Python.exec(code, pyInputs, pyOutputs); Assert.assertEquals("cba", pyOutputs.getBytesValue("out").getString()); } + + @Test + public void testDoubleDeviceAllocation() throws Exception{ + if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){ + return; + } + // Test to make sure that multiple device buffers are not allocated + // for the same host buffer + INDArray arr = Nd4j.rand(3, 2); + ((BaseDataBuffer)arr.data()).syncToPrimary(); + long deviceAddress1 = getDeviceAddress(arr); + PythonVariables pyInputs = new PythonVariables(); + pyInputs.addNDArray("arr", arr); + PythonVariables pyOutputs = new PythonVariables(); + pyOutputs.addNDArray("arr"); + String code = "arr += 2"; + Python.exec(code, pyInputs, pyOutputs); + INDArray arr2 = pyOutputs.getNDArrayValue("arr"); + long deviceAddress2 = getDeviceAddress(arr2); + Assert.assertEquals(deviceAddress1, deviceAddress2); + + + } + @Test public void testBadCode() throws Exception{ Python.setContext("badcode"); @@ -333,5 +393,22 @@ public class TestPythonExecutioner { Assert.assertEquals("y", notNone.toString()); } + private static long getDeviceAddress(INDArray array){ + if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){ + throw new IllegalStateException("Cannot ge device pointer for non-CUDA device"); + } + + //Use reflection here as OpaqueDataBuffer is only available on BaseCudaDataBuffer and BaseCpuDataBuffer - not DataBuffer/BaseDataBuffer + // due to it being defined in nd4j-native-api, not nd4j-api + try { + Class c = Class.forName("org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer"); + Method m = c.getMethod("getOpaqueDataBuffer"); + OpaqueDataBuffer db = (OpaqueDataBuffer) m.invoke(array.data()); + long address = db.specialBuffer().address(); + return address; + } catch (Throwable t){ + throw new RuntimeException("Error getting OpaqueDataBuffer", t); + } + } } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java index 22f8ba230..20adb720f 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java @@ -24,6 +24,7 @@ package org.datavec.python; import org.bytedeco.javacpp.BytePointer; import org.junit.Test; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -55,6 +56,7 @@ public class TestPythonVariables { }; INDArray arr = Nd4j.scalar(1.0); + ((BaseDataBuffer)arr.data()).syncToPrimary(); BytePointer bp = new BytePointer(arr.data().pointer()); Object[] values = { 1L,1.0,"1",true, Collections.singletonMap("1",1), diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 12e27e1c2..6b226ce20 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -1945,4 +1945,17 @@ public abstract class BaseDataBuffer implements DataBuffer { public boolean wasClosed() { return released; } + + + /** + * This method synchronizes host memory + */ + public abstract void syncToPrimary(); + + /** + * This method synchronizes device memory + */ + public abstract void syncToSpecial(); + + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java index 94cfdca43..c48b7577c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java @@ -230,13 +230,14 @@ public enum DataType { public static DataType fromNumpy(String numpyDtypeName){ switch (numpyDtypeName.toLowerCase()){ case "bool": return BOOL; - case "byte": return BYTE; - case "int8": return BYTE; - case "int16": return SHORT; - case "int32": return INT; - case "int64": return LONG; - case "uint8": return UBYTE; - case "float16": return HALF; + case "byte": + case "int8": + return INT8; + case "int16": return INT16; + case "int32": return INT32; + case "int64": return INT64; + case "uint8": return UINT8; + case "float16": return FLOAT16; case "float32": return FLOAT; case "float64": return DOUBLE; case "uint16": return UINT16; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java index f1c9ed6d9..7cbfb0d70 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java @@ -211,6 +211,16 @@ public class CompressedDataBuffer extends BaseDataBuffer { throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer"); } + @Override + public void syncToPrimary() { + //No-op + } + + @Override + public void syncToSpecial() { + //No-op + } + @Override protected double getDoubleUnsynced(long index) { return super.getDouble(index); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueDataBuffer.java index 4d5edea0c..d7c2e0ac0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueDataBuffer.java @@ -51,7 +51,6 @@ public class OpaqueDataBuffer extends Pointer { try { // try to allocate data buffer buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(numElements, dataType.toInt(), allocateBoth); - // check error code ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode(); if (ec != 0) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index f944d20cf..f84f96384 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -1845,4 +1845,14 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda public int targetDevice() { return AtomicAllocator.getInstance().getAllocationPoint(this).getDeviceId(); } + + @Override + public void syncToPrimary(){ + ptrDataBuffer.syncToPrimary(); + } + + @Override + public void syncToSpecial(){ + ptrDataBuffer.syncToSpecial(); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java index a5ddc7aef..b3def0f71 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java @@ -951,4 +951,13 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo return this; } + @Override + public void syncToPrimary(){ + ptrDataBuffer.syncToPrimary(); + } + + @Override + public void syncToSpecial(){ + ptrDataBuffer.syncToSpecial(); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java index 00165545b..b271c7bff 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java @@ -32,6 +32,8 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.nativeblas.NativeOpsHolder; +import sun.nio.ch.DirectBuffer; import java.nio.ByteBuffer; @@ -398,6 +400,29 @@ public class DataBufferTests extends BaseNd4jTest { } } + + @Test + public void testEnsureLocation(){ + //https://github.com/eclipse/deeplearning4j/issues/8783 + Nd4j.create(1); + + DirectBuffer bb = (DirectBuffer) ByteBuffer.allocateDirect(5); + System.out.println(bb.getClass()); + System.out.println(bb.address()); + + Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bb.address()); + DataBuffer buff = Nd4j.createBuffer(ptr, 20, DataType.BYTE); + + + INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 1L, 'c', DataType.BYTE); + long before = arr2.data().pointer().address(); + Nd4j.getAffinityManager().ensureLocation(arr2, AffinityManager.Location.HOST); + long after = arr2.data().pointer().address(); + + assertEquals(before, after); + } + + @Override public char ordering() { return 'c';