Support for more numpy datatypes (#241)
* Adding more datatypes support in datavec-python * Using numpy C API for creating numpy arrays * Adding parameterized tests * Adding support for BFLOAT16 (by converting it to FLOAT) * Cleanup * Using casting instead of creating an array * Giving out a warning while casting array from BFLOAT16 to FLOAT * Add syncToPrimary and syncToSpecial methods to BaseDataBuffer Signed-off-by: Alex Black <blacka101@gmail.com> * Python exec: sync to host before passing pointers Signed-off-by: Alex Black <blacka101@gmail.com> * Added copyright header * use np api (#267) * python exec / numpy - check object type before cast (#268) * use np api * verify object before cast * fix cong * cuda fix * inplace test + tiny fix * more test * fix double alloc * rem tags * fix cuda check * Fix implicit CUDA dependency in datavec-python tests; remove new method, add test Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com> Co-authored-by: Fariz Rahman <farizrahman4u@gmail.com>master
parent
5cd143611e
commit
9c77bfa85f
|
@ -1,29 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.api.transform.analysis.counter;
|
||||
|
||||
import java.util.function.BiFunction;
|
||||
|
||||
/**
|
||||
* Created by Alex on 5/03/2016.
|
||||
*/
|
||||
public class StringAnalysisMergeFunction
|
||||
implements BiFunction<StringAnalysisCounter, StringAnalysisCounter, StringAnalysisCounter> {
|
||||
public StringAnalysisCounter apply(StringAnalysisCounter v1, StringAnalysisCounter v2) {
|
||||
return v1.merge(v2);
|
||||
}
|
||||
}
|
|
@ -19,6 +19,7 @@ package org.datavec.python;
|
|||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||
|
@ -29,6 +30,10 @@ import org.nd4j.nativeblas.NativeOps;
|
|||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
||||
|
||||
|
||||
|
@ -42,6 +47,7 @@ import static org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
|||
public class NumpyArray {
|
||||
|
||||
private static NativeOps nativeOps;
|
||||
private static Map<String, INDArray> arrayCache; // Avoids re-allocation of device buffer
|
||||
private long address;
|
||||
private long[] shape;
|
||||
private long[] strides;
|
||||
|
@ -52,6 +58,7 @@ public class NumpyArray {
|
|||
//initialize
|
||||
Nd4j.scalar(1.0);
|
||||
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||
arrayCache = new HashMap<>();
|
||||
}
|
||||
|
||||
@Builder
|
||||
|
@ -84,24 +91,42 @@ public class NumpyArray {
|
|||
|
||||
|
||||
private void setND4JArray() {
|
||||
|
||||
long size = 1;
|
||||
for (long d : shape) {
|
||||
size *= d;
|
||||
}
|
||||
Pointer ptr = nativeOps.pointerForAddress(address);
|
||||
ptr = ptr.limit(size);
|
||||
ptr = ptr.capacity(size);
|
||||
DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype);
|
||||
int elemSize = buff.getElementSize();
|
||||
long[] nd4jStrides = new long[strides.length];
|
||||
for (int i = 0; i < strides.length; i++) {
|
||||
nd4jStrides[i] = strides[i] / elemSize;
|
||||
}
|
||||
|
||||
nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
|
||||
String cacheKey = address + "_" + size + "_" + dtype + "_" + ArrayUtils.toString(strides);
|
||||
nd4jArray = arrayCache.get(cacheKey);
|
||||
if (nd4jArray == null) {
|
||||
Pointer ptr = nativeOps.pointerForAddress(address);
|
||||
ptr = ptr.limit(size);
|
||||
ptr = ptr.capacity(size);
|
||||
DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype);
|
||||
|
||||
int elemSize = buff.getElementSize();
|
||||
long[] nd4jStrides = new long[strides.length];
|
||||
for (int i = 0; i < strides.length; i++) {
|
||||
nd4jStrides[i] = strides[i] / elemSize;
|
||||
}
|
||||
|
||||
nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
|
||||
arrayCache.put(cacheKey, nd4jArray);
|
||||
}
|
||||
else{
|
||||
if (!Arrays.equals(nd4jArray.shape(), shape)){
|
||||
nd4jArray = nd4jArray.reshape(shape);
|
||||
}
|
||||
}
|
||||
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||
}
|
||||
|
||||
public INDArray getNd4jArray(){
|
||||
Nd4j.getAffinityManager().tagLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||
return nd4jArray;
|
||||
}
|
||||
|
||||
public NumpyArray(INDArray nd4jArray) {
|
||||
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||
DataBuffer buff = nd4jArray.data();
|
||||
|
@ -115,6 +140,8 @@ public class NumpyArray {
|
|||
}
|
||||
dtype = nd4jArray.dataType();
|
||||
this.nd4jArray = nd4jArray;
|
||||
String cacheKey = address + "_" + nd4jArray.length() + "_" + dtype + "_" + ArrayUtils.toString(strides);
|
||||
arrayCache.put(cacheKey, nd4jArray);
|
||||
}
|
||||
|
||||
}
|
|
@ -21,6 +21,7 @@ package org.datavec.python;
|
|||
import org.bytedeco.cpython.PyObject;
|
||||
|
||||
import static org.bytedeco.cpython.global.python.*;
|
||||
import static org.bytedeco.numpy.global.numpy.PyArray_EnsureArray;
|
||||
|
||||
/**
|
||||
* Swift like python wrapper for Java
|
||||
|
@ -232,6 +233,10 @@ public class Python {
|
|||
|
||||
}
|
||||
|
||||
public static PythonObject ndarray(PythonObject pythonObject){
|
||||
return new PythonObject(PyArray_EnsureArray(pythonObject.getNativePythonObject()));
|
||||
}
|
||||
|
||||
public static boolean callable(PythonObject pythonObject) {
|
||||
return PyCallable_Check(pythonObject.getNativePythonObject()) == 1;
|
||||
}
|
||||
|
|
|
@ -21,6 +21,9 @@ package org.datavec.python;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.bytedeco.numpy.global.numpy;
|
||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
|
|
|
@ -18,11 +18,15 @@
|
|||
package org.datavec.python;
|
||||
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.bytedeco.cpython.PyObject;
|
||||
import org.bytedeco.javacpp.BytePointer;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.bytedeco.javacpp.SizeTPointer;
|
||||
import org.bytedeco.numpy.PyArrayObject;
|
||||
import org.json.JSONArray;
|
||||
import org.json.JSONObject;
|
||||
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
@ -30,7 +34,7 @@ import org.nd4j.nativeblas.NativeOpsHolder;
|
|||
import java.util.*;
|
||||
|
||||
import static org.bytedeco.cpython.global.python.*;
|
||||
import static org.bytedeco.cpython.global.python.PyObject_SetItem;
|
||||
import static org.bytedeco.numpy.global.numpy.*;
|
||||
|
||||
/**
|
||||
* Swift like python wrapper for J
|
||||
|
@ -38,6 +42,7 @@ import static org.bytedeco.cpython.global.python.PyObject_SetItem;
|
|||
* @author Fariz Rahman
|
||||
*/
|
||||
|
||||
@Slf4j
|
||||
public class PythonObject {
|
||||
private PyObject nativePythonObject;
|
||||
|
||||
|
@ -77,78 +82,69 @@ public class PythonObject {
|
|||
}
|
||||
|
||||
public PythonObject(NumpyArray npArray) {
|
||||
PyObject ctypes = PyImport_ImportModule("ctypes");
|
||||
PyObject np = PyImport_ImportModule("numpy");
|
||||
PyObject ctype;
|
||||
switch (npArray.getDtype()) {
|
||||
int numpyType;
|
||||
INDArray indArray = npArray.getNd4jArray();
|
||||
DataType dataType = indArray.dataType();
|
||||
|
||||
switch (dataType) {
|
||||
case DOUBLE:
|
||||
ctype = PyObject_GetAttrString(ctypes, "c_double");
|
||||
numpyType = NPY_DOUBLE;
|
||||
break;
|
||||
case FLOAT:
|
||||
ctype = PyObject_GetAttrString(ctypes, "c_float");
|
||||
break;
|
||||
case LONG:
|
||||
ctype = PyObject_GetAttrString(ctypes, "c_int64");
|
||||
break;
|
||||
case INT:
|
||||
ctype = PyObject_GetAttrString(ctypes, "c_int32");
|
||||
case BFLOAT16:
|
||||
numpyType = NPY_FLOAT;
|
||||
break;
|
||||
case SHORT:
|
||||
ctype = PyObject_GetAttrString(ctypes, "c_int16");
|
||||
numpyType = NPY_SHORT;
|
||||
break;
|
||||
case INT:
|
||||
numpyType = NPY_INT;
|
||||
break;
|
||||
case LONG:
|
||||
numpyType = NPY_INT64;
|
||||
break;
|
||||
case UINT16:
|
||||
ctype = PyObject_GetAttrString(ctypes, "c_uint16");
|
||||
numpyType = NPY_USHORT;
|
||||
break;
|
||||
case UINT32:
|
||||
ctype = PyObject_GetAttrString(ctypes, "c_uint32");
|
||||
numpyType = NPY_UINT;
|
||||
break;
|
||||
case UINT64:
|
||||
ctype = PyObject_GetAttrString(ctypes, "c_uint64");
|
||||
numpyType = NPY_UINT64;
|
||||
break;
|
||||
case BOOL:
|
||||
ctype = PyObject_GetAttrString(ctypes, "c_bool");
|
||||
numpyType = NPY_BOOL;
|
||||
break;
|
||||
case BYTE:
|
||||
ctype = PyObject_GetAttrString(ctypes, "c_byte");
|
||||
numpyType = NPY_BYTE;
|
||||
break;
|
||||
case UBYTE:
|
||||
ctype = PyObject_GetAttrString(ctypes, "c_ubyte");
|
||||
numpyType = NPY_UBYTE;
|
||||
break;
|
||||
case HALF:
|
||||
numpyType = NPY_HALF;
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Unsupported dtype: " + npArray.getDtype());
|
||||
}
|
||||
|
||||
PyObject ctypesPointer = PyObject_GetAttrString(ctypes, "POINTER");
|
||||
PyObject argsTuple = PyTuple_New(1);
|
||||
PyTuple_SetItem(argsTuple, 0, ctype);
|
||||
PyObject ptrType = PyObject_Call(ctypesPointer, argsTuple, null);
|
||||
|
||||
PyObject cast = PyObject_GetAttrString(ctypes, "cast");
|
||||
PyObject address = PyLong_FromLong(npArray.getAddress());
|
||||
PyObject argsTuple2 = PyTuple_New(2);
|
||||
PyTuple_SetItem(argsTuple2, 0, address);
|
||||
PyTuple_SetItem(argsTuple2, 1, ptrType);
|
||||
PyObject ptr = PyObject_Call(cast, argsTuple2, null);
|
||||
PyObject shapeTuple = PyTuple_New(npArray.getShape().length);
|
||||
for (int i = 0; i < npArray.getShape().length; i++) {
|
||||
PyObject dim = PyLong_FromLong(npArray.getShape()[i]);
|
||||
PyTuple_SetItem(shapeTuple, i, dim);
|
||||
Py_DecRef(dim);
|
||||
long[] shape = indArray.shape();
|
||||
INDArray inputArray = indArray;
|
||||
if(dataType == DataType.BFLOAT16) {
|
||||
log.warn("\n\nThe given nd4j array \n\n{}\n\n is of BFLOAT16 datatype. " +
|
||||
"Casting a copy of it to FLOAT and creating the respective numpy array from it.\n", indArray);
|
||||
inputArray = indArray.castTo(DataType.FLOAT);
|
||||
}
|
||||
PyObject ctypesLib = PyObject_GetAttrString(np, "ctypeslib");
|
||||
PyObject asArray = PyObject_GetAttrString(ctypesLib, "as_array");
|
||||
PyObject argsTuple3 = PyTuple_New(2);
|
||||
PyTuple_SetItem(argsTuple3, 0, ptr);
|
||||
PyTuple_SetItem(argsTuple3, 1, shapeTuple);
|
||||
nativePythonObject = PyObject_Call(asArray, argsTuple3, null);
|
||||
|
||||
Py_DecRef(ctypesPointer);
|
||||
Py_DecRef(ctypesLib);
|
||||
Py_DecRef(argsTuple);
|
||||
Py_DecRef(argsTuple2);
|
||||
Py_DecRef(argsTuple3);
|
||||
Py_DecRef(cast);
|
||||
Py_DecRef(asArray);
|
||||
//Sync to host memory in the case of CUDA, before passing the host memory pointer to Python
|
||||
if(inputArray.data() instanceof BaseDataBuffer){
|
||||
((BaseDataBuffer)inputArray.data()).syncToPrimary();
|
||||
}
|
||||
|
||||
nativePythonObject = PyArray_New(PyArray_Type(), shape.length, new SizeTPointer(shape),
|
||||
numpyType, null,
|
||||
inputArray.data().addressPointer(),
|
||||
0, NPY_ARRAY_CARRAY, null);
|
||||
|
||||
}
|
||||
|
||||
|
@ -321,57 +317,60 @@ public class PythonObject {
|
|||
return toInt() != 0;
|
||||
}
|
||||
|
||||
public NumpyArray toNumpy() {
|
||||
PyObject arrInterface = PyObject_GetAttrString(nativePythonObject, "__array_interface__"); // borrowed reference; DO NOT Py_DecRef() !
|
||||
PyObject data = PyDict_GetItemString(arrInterface, "data");
|
||||
PyObject pyAddress = PyTuple_GetItem(data, 0);
|
||||
long address = PyLong_AsLong(pyAddress);
|
||||
PyObject pyDtype = PyObject_GetAttrString(nativePythonObject, "dtype");
|
||||
PyObject pyDtypeName = PyObject_GetAttrString(pyDtype, "name");
|
||||
String dtypeName = pyObjectToString(pyDtypeName);
|
||||
Py_DecRef(pyDtype);
|
||||
Py_DecRef(pyDtypeName);
|
||||
PyObject shape = PyObject_GetAttrString(nativePythonObject, "shape");
|
||||
PyObject strides = PyObject_GetAttrString(nativePythonObject, "strides");
|
||||
int ndim = (int) PyObject_Size(shape);
|
||||
long[] jshape = new long[ndim];
|
||||
long[] jstrides = new long[ndim];
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
jshape[i] = PyLong_AsLong(PyTuple_GetItem(shape, i));
|
||||
jstrides[i] = PyLong_AsLong(PyTuple_GetItem(strides, i));
|
||||
public NumpyArray toNumpy() throws PythonException{
|
||||
PyObject np = PyImport_ImportModule("numpy");
|
||||
PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
|
||||
if (PyObject_IsInstance(nativePythonObject, ndarray) == 0){
|
||||
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
|
||||
}
|
||||
Py_DecRef(shape);
|
||||
Py_DecRef(strides);
|
||||
Py_DecRef(ndarray);
|
||||
Py_DecRef(np);
|
||||
Pointer objPtr = new Pointer(nativePythonObject);
|
||||
PyArrayObject npArr = new PyArrayObject(objPtr);
|
||||
Pointer ptr = PyArray_DATA(npArr);
|
||||
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
|
||||
long[] shape = new long[PyArray_NDIM(npArr)];
|
||||
shapePtr.get(shape, 0, shape.length);
|
||||
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
|
||||
long[] strides = new long[shape.length];
|
||||
stridesPtr.get(strides, 0, strides.length);
|
||||
int npdtype = PyArray_TYPE(npArr);
|
||||
|
||||
DataType dtype;
|
||||
if (dtypeName.equals("float64")) {
|
||||
dtype = DataType.DOUBLE;
|
||||
} else if (dtypeName.equals("float32")) {
|
||||
dtype = DataType.FLOAT;
|
||||
} else if (dtypeName.equals("int8")){
|
||||
dtype = DataType.INT8;
|
||||
}else if (dtypeName.equals("int16")) {
|
||||
dtype = DataType.SHORT;
|
||||
} else if (dtypeName.equals("int32")) {
|
||||
dtype = DataType.INT;
|
||||
} else if (dtypeName.equals("int64")) {
|
||||
dtype = DataType.LONG;
|
||||
switch (npdtype){
|
||||
case NPY_DOUBLE:
|
||||
dtype = DataType.DOUBLE; break;
|
||||
case NPY_FLOAT:
|
||||
dtype = DataType.FLOAT; break;
|
||||
case NPY_SHORT:
|
||||
dtype = DataType.SHORT; break;
|
||||
case NPY_INT:
|
||||
dtype = DataType.INT; break;
|
||||
case NPY_LONG:
|
||||
dtype = DataType.LONG; break;
|
||||
case NPY_UINT:
|
||||
dtype = DataType.UINT32; break;
|
||||
case NPY_BYTE:
|
||||
dtype = DataType.BYTE; break;
|
||||
case NPY_UBYTE:
|
||||
dtype = DataType.UBYTE; break;
|
||||
case NPY_BOOL:
|
||||
dtype = DataType.BOOL; break;
|
||||
case NPY_HALF:
|
||||
dtype = DataType.HALF; break;
|
||||
case NPY_LONGLONG:
|
||||
dtype = DataType.INT64; break;
|
||||
case NPY_USHORT:
|
||||
dtype = DataType.UINT16; break;
|
||||
case NPY_ULONG:
|
||||
dtype = DataType.UINT64; break;
|
||||
case NPY_ULONGLONG:
|
||||
dtype = DataType.UINT64; break;
|
||||
default:
|
||||
throw new PythonException("Unsupported array data type: " + npdtype);
|
||||
}
|
||||
else if (dtypeName.equals("uint8")){
|
||||
dtype = DataType.UINT8;
|
||||
}
|
||||
else if (dtypeName.equals("uint16")){
|
||||
dtype = DataType.UINT16;
|
||||
}
|
||||
else if (dtypeName.equals("uint32")){
|
||||
dtype = DataType.UINT32;
|
||||
}
|
||||
else if (dtypeName.equals("uint64")){
|
||||
dtype = DataType.UINT64;
|
||||
}
|
||||
else {
|
||||
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||
}
|
||||
return new NumpyArray(address, jshape, jstrides, dtype);
|
||||
|
||||
return new NumpyArray(ptr.address(), shape, strides, dtype);
|
||||
|
||||
}
|
||||
|
||||
|
@ -442,7 +441,6 @@ public class PythonObject {
|
|||
return get(key.nativePythonObject);
|
||||
}
|
||||
|
||||
|
||||
public PythonObject get(int key) {
|
||||
return get(PyLong_FromLong((long) key));
|
||||
}
|
||||
|
@ -450,14 +448,12 @@ public class PythonObject {
|
|||
public PythonObject get(long key) {
|
||||
return new PythonObject(
|
||||
PyObject_GetItem(nativePythonObject, PyLong_FromLong(key))
|
||||
|
||||
);
|
||||
}
|
||||
|
||||
public PythonObject get(double key) {
|
||||
return new PythonObject(
|
||||
PyObject_GetItem(nativePythonObject, PyFloat_FromDouble(key))
|
||||
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -479,7 +475,6 @@ public class PythonObject {
|
|||
PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer());
|
||||
String jsonString = serialized.toString();
|
||||
return new JSONArray(jsonString);
|
||||
|
||||
}
|
||||
|
||||
public JSONObject toJSONObject() throws PythonException {
|
||||
|
@ -547,13 +542,16 @@ public class PythonObject {
|
|||
String originalContext = Python.getCurrentContext();
|
||||
Python.setContext(tempContext);
|
||||
PythonExecutioner.setVariable("memview", this);
|
||||
PythonExecutioner.exec("import numpy as np\narr = np.array(memview)");
|
||||
BytePointer ret = new BytePointer(PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray().data().pointer());
|
||||
PythonExecutioner.exec("import numpy as np\narr = np.frombuffer(memview, dtype='int8')");
|
||||
INDArray arr = PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray();
|
||||
if(arr.data() instanceof BaseDataBuffer){
|
||||
((BaseDataBuffer)arr.data()).syncToPrimary();
|
||||
}
|
||||
BytePointer ret = new BytePointer(arr.data().pointer());
|
||||
Python.setContext(originalContext);
|
||||
Python.deleteContext(tempContext);
|
||||
return ret;
|
||||
}
|
||||
else{
|
||||
} else {
|
||||
PyObject ctypes = PyImport_ImportModule("ctypes");
|
||||
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
|
||||
if (PyObject_IsInstance(nativePythonObject, cArrType) != 0){
|
||||
|
@ -579,13 +577,10 @@ public class PythonObject {
|
|||
else{
|
||||
throw new PythonException("Expected bytes, bytearray, memoryview or ctypesArray. Received " + Python.type(this).toString());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
public boolean isNone() {
|
||||
return (nativePythonObject == null)||
|
||||
(toString().equals("None") && Python.type(this).toString().equals("<class 'NoneType'>"));
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -314,7 +314,7 @@ public class PythonVariables implements java.io.Serializable {
|
|||
* @param name the field to add
|
||||
* @param value the value to add
|
||||
*/
|
||||
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
|
||||
public void addNDArray(String name, INDArray value) {
|
||||
vars.put(name, PythonType.TypeName.NDARRAY);
|
||||
ndVars.put(name, value);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.python;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static junit.framework.TestCase.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class PythonNumpyTest {
|
||||
|
||||
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
|
||||
public static DataType[] data() {
|
||||
return new DataType[] {
|
||||
DataType.BOOL,
|
||||
DataType.FLOAT16,
|
||||
DataType.BFLOAT16,
|
||||
DataType.FLOAT,
|
||||
DataType.DOUBLE,
|
||||
DataType.INT8,
|
||||
DataType.INT16,
|
||||
DataType.INT32,
|
||||
DataType.INT64,
|
||||
DataType.UINT8,
|
||||
DataType.UINT16,
|
||||
DataType.UINT32,
|
||||
DataType.UINT64
|
||||
};
|
||||
}
|
||||
|
||||
private DataType dataType;
|
||||
|
||||
public PythonNumpyTest(DataType dataType) {
|
||||
this.dataType = dataType;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void numpyAndNd4jConversions() throws Exception {
|
||||
INDArray input = Nd4j.ones(dataType, 2, 2, 2);
|
||||
|
||||
PythonVariables inputs = new PythonVariables();
|
||||
inputs.addNDArray("x", input);
|
||||
|
||||
PythonVariables outputs = new PythonVariables();
|
||||
outputs.addNDArray("y");
|
||||
|
||||
PythonJob pythonJob = new PythonJob(String.format("job_%s", dataType.name()) + dataType.name(), "y = x", false);
|
||||
|
||||
pythonJob.exec(inputs, outputs);
|
||||
|
||||
INDArray output = outputs.getNDArrayValue("y");
|
||||
|
||||
// As numpy doesn't support BFLOAT16 we'll convert it to FLOAT
|
||||
assertEquals(dataType == DataType.BFLOAT16 ? input.castTo(DataType.FLOAT) : input,
|
||||
output);
|
||||
}
|
||||
}
|
|
@ -20,10 +20,13 @@ import org.bytedeco.javacpp.BytePointer;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.nativeblas.OpaqueDataBuffer;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
|
@ -223,6 +226,35 @@ public class TestPythonExecutioner {
|
|||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testNDArrayNoCopy() throws Exception{
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
PythonVariables pyOutputs = new PythonVariables();
|
||||
INDArray arr = Nd4j.rand(3, 2);
|
||||
((BaseDataBuffer)arr.data()).syncToPrimary();
|
||||
pyInputs.addNDArray("x", arr);
|
||||
pyOutputs.addNDArray("x");
|
||||
INDArray expected = arr.mul(2.3);
|
||||
String code = "x *= 2.3";
|
||||
Python.exec(code, pyInputs, pyOutputs);
|
||||
Assert.assertEquals(pyInputs.getNDArrayValue("x"), pyOutputs.getNDArrayValue("x"));
|
||||
Assert.assertEquals(expected, pyOutputs.getNDArrayValue("x"));
|
||||
Assert.assertEquals(arr.data().address(), pyOutputs.getNDArrayValue("x").data().address());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNDArrayInplace() throws Exception{
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
INDArray arr = Nd4j.rand(3, 2);
|
||||
((BaseDataBuffer)arr.data()).syncToPrimary();
|
||||
pyInputs.addNDArray("x", arr);
|
||||
INDArray expected = arr.mul(2.3);
|
||||
String code = "x *= 2.3";
|
||||
Python.exec(code, pyInputs, null);
|
||||
Assert.assertEquals(expected, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testByteBufferInput() throws Exception{
|
||||
//ByteBuffer buff = ByteBuffer.allocateDirect(3);
|
||||
|
@ -230,8 +262,7 @@ public class TestPythonExecutioner {
|
|||
buff.putScalar(0, 97); // a
|
||||
buff.putScalar(1, 98); // b
|
||||
buff.putScalar(2, 99); // c
|
||||
|
||||
|
||||
((BaseDataBuffer)buff.data()).syncToPrimary();
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||
|
||||
|
@ -251,6 +282,7 @@ public class TestPythonExecutioner {
|
|||
buff.putScalar(0, 97); // a
|
||||
buff.putScalar(1, 98); // b
|
||||
buff.putScalar(2, 99); // c
|
||||
((BaseDataBuffer)buff.data()).syncToPrimary();
|
||||
|
||||
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
|
@ -262,6 +294,7 @@ public class TestPythonExecutioner {
|
|||
String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97";
|
||||
Python.exec(code, pyInputs, pyOutputs);
|
||||
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
|
||||
Assert.assertEquals(buff.data().address(), pyOutputs.getBytesValue("buff").address());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -270,6 +303,8 @@ public class TestPythonExecutioner {
|
|||
buff.putScalar(0, 97); // a
|
||||
buff.putScalar(1, 98); // b
|
||||
buff.putScalar(2, 99); // c
|
||||
((BaseDataBuffer)buff.data()).syncToPrimary();
|
||||
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||
String code = "buff[0]+=2\nbuff[2]-=2";
|
||||
|
@ -288,6 +323,7 @@ public class TestPythonExecutioner {
|
|||
buff.putScalar(0, 97); // a
|
||||
buff.putScalar(1, 98); // b
|
||||
buff.putScalar(2, 99); // c
|
||||
((BaseDataBuffer)buff.data()).syncToPrimary();
|
||||
|
||||
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
|
@ -300,6 +336,30 @@ public class TestPythonExecutioner {
|
|||
Python.exec(code, pyInputs, pyOutputs);
|
||||
Assert.assertEquals("cba", pyOutputs.getBytesValue("out").getString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDoubleDeviceAllocation() throws Exception{
|
||||
if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
|
||||
return;
|
||||
}
|
||||
// Test to make sure that multiple device buffers are not allocated
|
||||
// for the same host buffer
|
||||
INDArray arr = Nd4j.rand(3, 2);
|
||||
((BaseDataBuffer)arr.data()).syncToPrimary();
|
||||
long deviceAddress1 = getDeviceAddress(arr);
|
||||
PythonVariables pyInputs = new PythonVariables();
|
||||
pyInputs.addNDArray("arr", arr);
|
||||
PythonVariables pyOutputs = new PythonVariables();
|
||||
pyOutputs.addNDArray("arr");
|
||||
String code = "arr += 2";
|
||||
Python.exec(code, pyInputs, pyOutputs);
|
||||
INDArray arr2 = pyOutputs.getNDArrayValue("arr");
|
||||
long deviceAddress2 = getDeviceAddress(arr2);
|
||||
Assert.assertEquals(deviceAddress1, deviceAddress2);
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBadCode() throws Exception{
|
||||
Python.setContext("badcode");
|
||||
|
@ -333,5 +393,22 @@ public class TestPythonExecutioner {
|
|||
Assert.assertEquals("y", notNone.toString());
|
||||
}
|
||||
|
||||
private static long getDeviceAddress(INDArray array){
|
||||
if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
|
||||
throw new IllegalStateException("Cannot ge device pointer for non-CUDA device");
|
||||
}
|
||||
|
||||
//Use reflection here as OpaqueDataBuffer is only available on BaseCudaDataBuffer and BaseCpuDataBuffer - not DataBuffer/BaseDataBuffer
|
||||
// due to it being defined in nd4j-native-api, not nd4j-api
|
||||
try {
|
||||
Class<?> c = Class.forName("org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer");
|
||||
Method m = c.getMethod("getOpaqueDataBuffer");
|
||||
OpaqueDataBuffer db = (OpaqueDataBuffer) m.invoke(array.data());
|
||||
long address = db.specialBuffer().address();
|
||||
return address;
|
||||
} catch (Throwable t){
|
||||
throw new RuntimeException("Error getting OpaqueDataBuffer", t);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ package org.datavec.python;
|
|||
|
||||
import org.bytedeco.javacpp.BytePointer;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -55,6 +56,7 @@ public class TestPythonVariables {
|
|||
};
|
||||
|
||||
INDArray arr = Nd4j.scalar(1.0);
|
||||
((BaseDataBuffer)arr.data()).syncToPrimary();
|
||||
BytePointer bp = new BytePointer(arr.data().pointer());
|
||||
Object[] values = {
|
||||
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
||||
|
|
|
@ -1945,4 +1945,17 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
public boolean wasClosed() {
|
||||
return released;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* This method synchronizes host memory
|
||||
*/
|
||||
public abstract void syncToPrimary();
|
||||
|
||||
/**
|
||||
* This method synchronizes device memory
|
||||
*/
|
||||
public abstract void syncToSpecial();
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -230,13 +230,14 @@ public enum DataType {
|
|||
public static DataType fromNumpy(String numpyDtypeName){
|
||||
switch (numpyDtypeName.toLowerCase()){
|
||||
case "bool": return BOOL;
|
||||
case "byte": return BYTE;
|
||||
case "int8": return BYTE;
|
||||
case "int16": return SHORT;
|
||||
case "int32": return INT;
|
||||
case "int64": return LONG;
|
||||
case "uint8": return UBYTE;
|
||||
case "float16": return HALF;
|
||||
case "byte":
|
||||
case "int8":
|
||||
return INT8;
|
||||
case "int16": return INT16;
|
||||
case "int32": return INT32;
|
||||
case "int64": return INT64;
|
||||
case "uint8": return UINT8;
|
||||
case "float16": return FLOAT16;
|
||||
case "float32": return FLOAT;
|
||||
case "float64": return DOUBLE;
|
||||
case "uint16": return UINT16;
|
||||
|
|
|
@ -211,6 +211,16 @@ public class CompressedDataBuffer extends BaseDataBuffer {
|
|||
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void syncToPrimary() {
|
||||
//No-op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void syncToSpecial() {
|
||||
//No-op
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double getDoubleUnsynced(long index) {
|
||||
return super.getDouble(index);
|
||||
|
|
|
@ -51,7 +51,6 @@ public class OpaqueDataBuffer extends Pointer {
|
|||
try {
|
||||
// try to allocate data buffer
|
||||
buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(numElements, dataType.toInt(), allocateBoth);
|
||||
|
||||
// check error code
|
||||
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||
if (ec != 0) {
|
||||
|
|
|
@ -1845,4 +1845,14 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
public int targetDevice() {
|
||||
return AtomicAllocator.getInstance().getAllocationPoint(this).getDeviceId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void syncToPrimary(){
|
||||
ptrDataBuffer.syncToPrimary();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void syncToSpecial(){
|
||||
ptrDataBuffer.syncToSpecial();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -951,4 +951,13 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void syncToPrimary(){
|
||||
ptrDataBuffer.syncToPrimary();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void syncToSpecial(){
|
||||
ptrDataBuffer.syncToSpecial();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,6 +32,8 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import sun.nio.ch.DirectBuffer;
|
||||
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -398,6 +400,29 @@ public class DataBufferTests extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEnsureLocation(){
|
||||
//https://github.com/eclipse/deeplearning4j/issues/8783
|
||||
Nd4j.create(1);
|
||||
|
||||
DirectBuffer bb = (DirectBuffer) ByteBuffer.allocateDirect(5);
|
||||
System.out.println(bb.getClass());
|
||||
System.out.println(bb.address());
|
||||
|
||||
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bb.address());
|
||||
DataBuffer buff = Nd4j.createBuffer(ptr, 20, DataType.BYTE);
|
||||
|
||||
|
||||
INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 1L, 'c', DataType.BYTE);
|
||||
long before = arr2.data().pointer().address();
|
||||
Nd4j.getAffinityManager().ensureLocation(arr2, AffinityManager.Location.HOST);
|
||||
long after = arr2.data().pointer().address();
|
||||
|
||||
assertEquals(before, after);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue