commit
7c05928185
|
@ -1,29 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.api.transform.analysis.counter;
|
|
||||||
|
|
||||||
import java.util.function.BiFunction;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 5/03/2016.
|
|
||||||
*/
|
|
||||||
public class StringAnalysisMergeFunction
|
|
||||||
implements BiFunction<StringAnalysisCounter, StringAnalysisCounter, StringAnalysisCounter> {
|
|
||||||
public StringAnalysisCounter apply(StringAnalysisCounter v1, StringAnalysisCounter v2) {
|
|
||||||
return v1.merge(v2);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -19,6 +19,7 @@ package org.datavec.python;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
|
@ -29,6 +30,10 @@ import org.nd4j.nativeblas.NativeOps;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
import static org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +47,7 @@ import static org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
||||||
public class NumpyArray {
|
public class NumpyArray {
|
||||||
|
|
||||||
private static NativeOps nativeOps;
|
private static NativeOps nativeOps;
|
||||||
|
private static Map<String, INDArray> arrayCache; // Avoids re-allocation of device buffer
|
||||||
private long address;
|
private long address;
|
||||||
private long[] shape;
|
private long[] shape;
|
||||||
private long[] strides;
|
private long[] strides;
|
||||||
|
@ -52,6 +58,7 @@ public class NumpyArray {
|
||||||
//initialize
|
//initialize
|
||||||
Nd4j.scalar(1.0);
|
Nd4j.scalar(1.0);
|
||||||
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||||
|
arrayCache = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
|
@ -84,24 +91,42 @@ public class NumpyArray {
|
||||||
|
|
||||||
|
|
||||||
private void setND4JArray() {
|
private void setND4JArray() {
|
||||||
|
|
||||||
long size = 1;
|
long size = 1;
|
||||||
for (long d : shape) {
|
for (long d : shape) {
|
||||||
size *= d;
|
size *= d;
|
||||||
}
|
}
|
||||||
Pointer ptr = nativeOps.pointerForAddress(address);
|
|
||||||
ptr = ptr.limit(size);
|
|
||||||
ptr = ptr.capacity(size);
|
|
||||||
DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype);
|
|
||||||
int elemSize = buff.getElementSize();
|
|
||||||
long[] nd4jStrides = new long[strides.length];
|
|
||||||
for (int i = 0; i < strides.length; i++) {
|
|
||||||
nd4jStrides[i] = strides[i] / elemSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
|
String cacheKey = address + "_" + size + "_" + dtype + "_" + ArrayUtils.toString(strides);
|
||||||
|
nd4jArray = arrayCache.get(cacheKey);
|
||||||
|
if (nd4jArray == null) {
|
||||||
|
Pointer ptr = nativeOps.pointerForAddress(address);
|
||||||
|
ptr = ptr.limit(size);
|
||||||
|
ptr = ptr.capacity(size);
|
||||||
|
DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype);
|
||||||
|
|
||||||
|
int elemSize = buff.getElementSize();
|
||||||
|
long[] nd4jStrides = new long[strides.length];
|
||||||
|
for (int i = 0; i < strides.length; i++) {
|
||||||
|
nd4jStrides[i] = strides[i] / elemSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
|
||||||
|
arrayCache.put(cacheKey, nd4jArray);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
if (!Arrays.equals(nd4jArray.shape(), shape)){
|
||||||
|
nd4jArray = nd4jArray.reshape(shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public INDArray getNd4jArray(){
|
||||||
|
Nd4j.getAffinityManager().tagLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||||
|
return nd4jArray;
|
||||||
|
}
|
||||||
|
|
||||||
public NumpyArray(INDArray nd4jArray) {
|
public NumpyArray(INDArray nd4jArray) {
|
||||||
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||||
DataBuffer buff = nd4jArray.data();
|
DataBuffer buff = nd4jArray.data();
|
||||||
|
@ -115,6 +140,8 @@ public class NumpyArray {
|
||||||
}
|
}
|
||||||
dtype = nd4jArray.dataType();
|
dtype = nd4jArray.dataType();
|
||||||
this.nd4jArray = nd4jArray;
|
this.nd4jArray = nd4jArray;
|
||||||
|
String cacheKey = address + "_" + nd4jArray.length() + "_" + dtype + "_" + ArrayUtils.toString(strides);
|
||||||
|
arrayCache.put(cacheKey, nd4jArray);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -21,6 +21,7 @@ package org.datavec.python;
|
||||||
import org.bytedeco.cpython.PyObject;
|
import org.bytedeco.cpython.PyObject;
|
||||||
|
|
||||||
import static org.bytedeco.cpython.global.python.*;
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
|
import static org.bytedeco.numpy.global.numpy.PyArray_EnsureArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Swift like python wrapper for Java
|
* Swift like python wrapper for Java
|
||||||
|
@ -232,6 +233,10 @@ public class Python {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static PythonObject ndarray(PythonObject pythonObject){
|
||||||
|
return new PythonObject(PyArray_EnsureArray(pythonObject.getNativePythonObject()));
|
||||||
|
}
|
||||||
|
|
||||||
public static boolean callable(PythonObject pythonObject) {
|
public static boolean callable(PythonObject pythonObject) {
|
||||||
return PyCallable_Check(pythonObject.getNativePythonObject()) == 1;
|
return PyCallable_Check(pythonObject.getNativePythonObject()) == 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,9 @@ package org.datavec.python;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.bytedeco.numpy.global.numpy;
|
import org.bytedeco.numpy.global.numpy;
|
||||||
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
|
|
@ -18,11 +18,15 @@
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.bytedeco.cpython.PyObject;
|
import org.bytedeco.cpython.PyObject;
|
||||||
import org.bytedeco.javacpp.BytePointer;
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.bytedeco.javacpp.SizeTPointer;
|
||||||
|
import org.bytedeco.numpy.PyArrayObject;
|
||||||
import org.json.JSONArray;
|
import org.json.JSONArray;
|
||||||
import org.json.JSONObject;
|
import org.json.JSONObject;
|
||||||
|
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
@ -30,7 +34,7 @@ import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.bytedeco.cpython.global.python.*;
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
import static org.bytedeco.cpython.global.python.PyObject_SetItem;
|
import static org.bytedeco.numpy.global.numpy.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Swift like python wrapper for J
|
* Swift like python wrapper for J
|
||||||
|
@ -38,6 +42,7 @@ import static org.bytedeco.cpython.global.python.PyObject_SetItem;
|
||||||
* @author Fariz Rahman
|
* @author Fariz Rahman
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
public class PythonObject {
|
public class PythonObject {
|
||||||
private PyObject nativePythonObject;
|
private PyObject nativePythonObject;
|
||||||
|
|
||||||
|
@ -77,78 +82,69 @@ public class PythonObject {
|
||||||
}
|
}
|
||||||
|
|
||||||
public PythonObject(NumpyArray npArray) {
|
public PythonObject(NumpyArray npArray) {
|
||||||
PyObject ctypes = PyImport_ImportModule("ctypes");
|
int numpyType;
|
||||||
PyObject np = PyImport_ImportModule("numpy");
|
INDArray indArray = npArray.getNd4jArray();
|
||||||
PyObject ctype;
|
DataType dataType = indArray.dataType();
|
||||||
switch (npArray.getDtype()) {
|
|
||||||
|
switch (dataType) {
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
ctype = PyObject_GetAttrString(ctypes, "c_double");
|
numpyType = NPY_DOUBLE;
|
||||||
break;
|
break;
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
ctype = PyObject_GetAttrString(ctypes, "c_float");
|
case BFLOAT16:
|
||||||
break;
|
numpyType = NPY_FLOAT;
|
||||||
case LONG:
|
|
||||||
ctype = PyObject_GetAttrString(ctypes, "c_int64");
|
|
||||||
break;
|
|
||||||
case INT:
|
|
||||||
ctype = PyObject_GetAttrString(ctypes, "c_int32");
|
|
||||||
break;
|
break;
|
||||||
case SHORT:
|
case SHORT:
|
||||||
ctype = PyObject_GetAttrString(ctypes, "c_int16");
|
numpyType = NPY_SHORT;
|
||||||
|
break;
|
||||||
|
case INT:
|
||||||
|
numpyType = NPY_INT;
|
||||||
|
break;
|
||||||
|
case LONG:
|
||||||
|
numpyType = NPY_INT64;
|
||||||
break;
|
break;
|
||||||
case UINT16:
|
case UINT16:
|
||||||
ctype = PyObject_GetAttrString(ctypes, "c_uint16");
|
numpyType = NPY_USHORT;
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
ctype = PyObject_GetAttrString(ctypes, "c_uint32");
|
numpyType = NPY_UINT;
|
||||||
break;
|
break;
|
||||||
case UINT64:
|
case UINT64:
|
||||||
ctype = PyObject_GetAttrString(ctypes, "c_uint64");
|
numpyType = NPY_UINT64;
|
||||||
break;
|
break;
|
||||||
case BOOL:
|
case BOOL:
|
||||||
ctype = PyObject_GetAttrString(ctypes, "c_bool");
|
numpyType = NPY_BOOL;
|
||||||
break;
|
break;
|
||||||
case BYTE:
|
case BYTE:
|
||||||
ctype = PyObject_GetAttrString(ctypes, "c_byte");
|
numpyType = NPY_BYTE;
|
||||||
break;
|
break;
|
||||||
case UBYTE:
|
case UBYTE:
|
||||||
ctype = PyObject_GetAttrString(ctypes, "c_ubyte");
|
numpyType = NPY_UBYTE;
|
||||||
|
break;
|
||||||
|
case HALF:
|
||||||
|
numpyType = NPY_HALF;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException("Unsupported dtype: " + npArray.getDtype());
|
throw new RuntimeException("Unsupported dtype: " + npArray.getDtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject ctypesPointer = PyObject_GetAttrString(ctypes, "POINTER");
|
long[] shape = indArray.shape();
|
||||||
PyObject argsTuple = PyTuple_New(1);
|
INDArray inputArray = indArray;
|
||||||
PyTuple_SetItem(argsTuple, 0, ctype);
|
if(dataType == DataType.BFLOAT16) {
|
||||||
PyObject ptrType = PyObject_Call(ctypesPointer, argsTuple, null);
|
log.warn("\n\nThe given nd4j array \n\n{}\n\n is of BFLOAT16 datatype. " +
|
||||||
|
"Casting a copy of it to FLOAT and creating the respective numpy array from it.\n", indArray);
|
||||||
PyObject cast = PyObject_GetAttrString(ctypes, "cast");
|
inputArray = indArray.castTo(DataType.FLOAT);
|
||||||
PyObject address = PyLong_FromLong(npArray.getAddress());
|
|
||||||
PyObject argsTuple2 = PyTuple_New(2);
|
|
||||||
PyTuple_SetItem(argsTuple2, 0, address);
|
|
||||||
PyTuple_SetItem(argsTuple2, 1, ptrType);
|
|
||||||
PyObject ptr = PyObject_Call(cast, argsTuple2, null);
|
|
||||||
PyObject shapeTuple = PyTuple_New(npArray.getShape().length);
|
|
||||||
for (int i = 0; i < npArray.getShape().length; i++) {
|
|
||||||
PyObject dim = PyLong_FromLong(npArray.getShape()[i]);
|
|
||||||
PyTuple_SetItem(shapeTuple, i, dim);
|
|
||||||
Py_DecRef(dim);
|
|
||||||
}
|
}
|
||||||
PyObject ctypesLib = PyObject_GetAttrString(np, "ctypeslib");
|
|
||||||
PyObject asArray = PyObject_GetAttrString(ctypesLib, "as_array");
|
|
||||||
PyObject argsTuple3 = PyTuple_New(2);
|
|
||||||
PyTuple_SetItem(argsTuple3, 0, ptr);
|
|
||||||
PyTuple_SetItem(argsTuple3, 1, shapeTuple);
|
|
||||||
nativePythonObject = PyObject_Call(asArray, argsTuple3, null);
|
|
||||||
|
|
||||||
Py_DecRef(ctypesPointer);
|
//Sync to host memory in the case of CUDA, before passing the host memory pointer to Python
|
||||||
Py_DecRef(ctypesLib);
|
if(inputArray.data() instanceof BaseDataBuffer){
|
||||||
Py_DecRef(argsTuple);
|
((BaseDataBuffer)inputArray.data()).syncToPrimary();
|
||||||
Py_DecRef(argsTuple2);
|
}
|
||||||
Py_DecRef(argsTuple3);
|
|
||||||
Py_DecRef(cast);
|
nativePythonObject = PyArray_New(PyArray_Type(), shape.length, new SizeTPointer(shape),
|
||||||
Py_DecRef(asArray);
|
numpyType, null,
|
||||||
|
inputArray.data().addressPointer(),
|
||||||
|
0, NPY_ARRAY_CARRAY, null);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -321,57 +317,60 @@ public class PythonObject {
|
||||||
return toInt() != 0;
|
return toInt() != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
public NumpyArray toNumpy() {
|
public NumpyArray toNumpy() throws PythonException{
|
||||||
PyObject arrInterface = PyObject_GetAttrString(nativePythonObject, "__array_interface__"); // borrowed reference; DO NOT Py_DecRef() !
|
PyObject np = PyImport_ImportModule("numpy");
|
||||||
PyObject data = PyDict_GetItemString(arrInterface, "data");
|
PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
|
||||||
PyObject pyAddress = PyTuple_GetItem(data, 0);
|
if (PyObject_IsInstance(nativePythonObject, ndarray) == 0){
|
||||||
long address = PyLong_AsLong(pyAddress);
|
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
|
||||||
PyObject pyDtype = PyObject_GetAttrString(nativePythonObject, "dtype");
|
|
||||||
PyObject pyDtypeName = PyObject_GetAttrString(pyDtype, "name");
|
|
||||||
String dtypeName = pyObjectToString(pyDtypeName);
|
|
||||||
Py_DecRef(pyDtype);
|
|
||||||
Py_DecRef(pyDtypeName);
|
|
||||||
PyObject shape = PyObject_GetAttrString(nativePythonObject, "shape");
|
|
||||||
PyObject strides = PyObject_GetAttrString(nativePythonObject, "strides");
|
|
||||||
int ndim = (int) PyObject_Size(shape);
|
|
||||||
long[] jshape = new long[ndim];
|
|
||||||
long[] jstrides = new long[ndim];
|
|
||||||
for (int i = 0; i < ndim; i++) {
|
|
||||||
jshape[i] = PyLong_AsLong(PyTuple_GetItem(shape, i));
|
|
||||||
jstrides[i] = PyLong_AsLong(PyTuple_GetItem(strides, i));
|
|
||||||
}
|
}
|
||||||
Py_DecRef(shape);
|
Py_DecRef(ndarray);
|
||||||
Py_DecRef(strides);
|
Py_DecRef(np);
|
||||||
|
Pointer objPtr = new Pointer(nativePythonObject);
|
||||||
|
PyArrayObject npArr = new PyArrayObject(objPtr);
|
||||||
|
Pointer ptr = PyArray_DATA(npArr);
|
||||||
|
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
|
||||||
|
long[] shape = new long[PyArray_NDIM(npArr)];
|
||||||
|
shapePtr.get(shape, 0, shape.length);
|
||||||
|
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
|
||||||
|
long[] strides = new long[shape.length];
|
||||||
|
stridesPtr.get(strides, 0, strides.length);
|
||||||
|
int npdtype = PyArray_TYPE(npArr);
|
||||||
|
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
if (dtypeName.equals("float64")) {
|
switch (npdtype){
|
||||||
dtype = DataType.DOUBLE;
|
case NPY_DOUBLE:
|
||||||
} else if (dtypeName.equals("float32")) {
|
dtype = DataType.DOUBLE; break;
|
||||||
dtype = DataType.FLOAT;
|
case NPY_FLOAT:
|
||||||
} else if (dtypeName.equals("int8")){
|
dtype = DataType.FLOAT; break;
|
||||||
dtype = DataType.INT8;
|
case NPY_SHORT:
|
||||||
}else if (dtypeName.equals("int16")) {
|
dtype = DataType.SHORT; break;
|
||||||
dtype = DataType.SHORT;
|
case NPY_INT:
|
||||||
} else if (dtypeName.equals("int32")) {
|
dtype = DataType.INT; break;
|
||||||
dtype = DataType.INT;
|
case NPY_LONG:
|
||||||
} else if (dtypeName.equals("int64")) {
|
dtype = DataType.LONG; break;
|
||||||
dtype = DataType.LONG;
|
case NPY_UINT:
|
||||||
|
dtype = DataType.UINT32; break;
|
||||||
|
case NPY_BYTE:
|
||||||
|
dtype = DataType.BYTE; break;
|
||||||
|
case NPY_UBYTE:
|
||||||
|
dtype = DataType.UBYTE; break;
|
||||||
|
case NPY_BOOL:
|
||||||
|
dtype = DataType.BOOL; break;
|
||||||
|
case NPY_HALF:
|
||||||
|
dtype = DataType.HALF; break;
|
||||||
|
case NPY_LONGLONG:
|
||||||
|
dtype = DataType.INT64; break;
|
||||||
|
case NPY_USHORT:
|
||||||
|
dtype = DataType.UINT16; break;
|
||||||
|
case NPY_ULONG:
|
||||||
|
dtype = DataType.UINT64; break;
|
||||||
|
case NPY_ULONGLONG:
|
||||||
|
dtype = DataType.UINT64; break;
|
||||||
|
default:
|
||||||
|
throw new PythonException("Unsupported array data type: " + npdtype);
|
||||||
}
|
}
|
||||||
else if (dtypeName.equals("uint8")){
|
|
||||||
dtype = DataType.UINT8;
|
return new NumpyArray(ptr.address(), shape, strides, dtype);
|
||||||
}
|
|
||||||
else if (dtypeName.equals("uint16")){
|
|
||||||
dtype = DataType.UINT16;
|
|
||||||
}
|
|
||||||
else if (dtypeName.equals("uint32")){
|
|
||||||
dtype = DataType.UINT32;
|
|
||||||
}
|
|
||||||
else if (dtypeName.equals("uint64")){
|
|
||||||
dtype = DataType.UINT64;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
|
||||||
}
|
|
||||||
return new NumpyArray(address, jshape, jstrides, dtype);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -442,7 +441,6 @@ public class PythonObject {
|
||||||
return get(key.nativePythonObject);
|
return get(key.nativePythonObject);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public PythonObject get(int key) {
|
public PythonObject get(int key) {
|
||||||
return get(PyLong_FromLong((long) key));
|
return get(PyLong_FromLong((long) key));
|
||||||
}
|
}
|
||||||
|
@ -450,14 +448,12 @@ public class PythonObject {
|
||||||
public PythonObject get(long key) {
|
public PythonObject get(long key) {
|
||||||
return new PythonObject(
|
return new PythonObject(
|
||||||
PyObject_GetItem(nativePythonObject, PyLong_FromLong(key))
|
PyObject_GetItem(nativePythonObject, PyLong_FromLong(key))
|
||||||
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
public PythonObject get(double key) {
|
public PythonObject get(double key) {
|
||||||
return new PythonObject(
|
return new PythonObject(
|
||||||
PyObject_GetItem(nativePythonObject, PyFloat_FromDouble(key))
|
PyObject_GetItem(nativePythonObject, PyFloat_FromDouble(key))
|
||||||
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -479,7 +475,6 @@ public class PythonObject {
|
||||||
PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer());
|
PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer());
|
||||||
String jsonString = serialized.toString();
|
String jsonString = serialized.toString();
|
||||||
return new JSONArray(jsonString);
|
return new JSONArray(jsonString);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public JSONObject toJSONObject() throws PythonException {
|
public JSONObject toJSONObject() throws PythonException {
|
||||||
|
@ -547,13 +542,16 @@ public class PythonObject {
|
||||||
String originalContext = Python.getCurrentContext();
|
String originalContext = Python.getCurrentContext();
|
||||||
Python.setContext(tempContext);
|
Python.setContext(tempContext);
|
||||||
PythonExecutioner.setVariable("memview", this);
|
PythonExecutioner.setVariable("memview", this);
|
||||||
PythonExecutioner.exec("import numpy as np\narr = np.array(memview)");
|
PythonExecutioner.exec("import numpy as np\narr = np.frombuffer(memview, dtype='int8')");
|
||||||
BytePointer ret = new BytePointer(PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray().data().pointer());
|
INDArray arr = PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray();
|
||||||
|
if(arr.data() instanceof BaseDataBuffer){
|
||||||
|
((BaseDataBuffer)arr.data()).syncToPrimary();
|
||||||
|
}
|
||||||
|
BytePointer ret = new BytePointer(arr.data().pointer());
|
||||||
Python.setContext(originalContext);
|
Python.setContext(originalContext);
|
||||||
Python.deleteContext(tempContext);
|
Python.deleteContext(tempContext);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
PyObject ctypes = PyImport_ImportModule("ctypes");
|
PyObject ctypes = PyImport_ImportModule("ctypes");
|
||||||
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
|
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
|
||||||
if (PyObject_IsInstance(nativePythonObject, cArrType) != 0){
|
if (PyObject_IsInstance(nativePythonObject, cArrType) != 0){
|
||||||
|
@ -579,13 +577,10 @@ public class PythonObject {
|
||||||
else{
|
else{
|
||||||
throw new PythonException("Expected bytes, bytearray, memoryview or ctypesArray. Received " + Python.type(this).toString());
|
throw new PythonException("Expected bytes, bytearray, memoryview or ctypesArray. Received " + Python.type(this).toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
public boolean isNone() {
|
public boolean isNone() {
|
||||||
return (nativePythonObject == null)||
|
return (nativePythonObject == null)||
|
||||||
(toString().equals("None") && Python.type(this).toString().equals("<class 'NoneType'>"));
|
(toString().equals("None") && Python.type(this).toString().equals("<class 'NoneType'>"));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -314,7 +314,7 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
|
public void addNDArray(String name, INDArray value) {
|
||||||
vars.put(name, PythonType.TypeName.NDARRAY);
|
vars.put(name, PythonType.TypeName.NDARRAY);
|
||||||
ndVars.put(name, value);
|
ndVars.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static junit.framework.TestCase.assertEquals;
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
public class PythonNumpyTest {
|
||||||
|
|
||||||
|
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
|
||||||
|
public static DataType[] data() {
|
||||||
|
return new DataType[] {
|
||||||
|
DataType.BOOL,
|
||||||
|
DataType.FLOAT16,
|
||||||
|
DataType.BFLOAT16,
|
||||||
|
DataType.FLOAT,
|
||||||
|
DataType.DOUBLE,
|
||||||
|
DataType.INT8,
|
||||||
|
DataType.INT16,
|
||||||
|
DataType.INT32,
|
||||||
|
DataType.INT64,
|
||||||
|
DataType.UINT8,
|
||||||
|
DataType.UINT16,
|
||||||
|
DataType.UINT32,
|
||||||
|
DataType.UINT64
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private DataType dataType;
|
||||||
|
|
||||||
|
public PythonNumpyTest(DataType dataType) {
|
||||||
|
this.dataType = dataType;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void numpyAndNd4jConversions() throws Exception {
|
||||||
|
INDArray input = Nd4j.ones(dataType, 2, 2, 2);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("x", input);
|
||||||
|
|
||||||
|
PythonVariables outputs = new PythonVariables();
|
||||||
|
outputs.addNDArray("y");
|
||||||
|
|
||||||
|
PythonJob pythonJob = new PythonJob(String.format("job_%s", dataType.name()) + dataType.name(), "y = x", false);
|
||||||
|
|
||||||
|
pythonJob.exec(inputs, outputs);
|
||||||
|
|
||||||
|
INDArray output = outputs.getNDArrayValue("y");
|
||||||
|
|
||||||
|
// As numpy doesn't support BFLOAT16 we'll convert it to FLOAT
|
||||||
|
assertEquals(dataType == DataType.BFLOAT16 ? input.castTo(DataType.FLOAT) : input,
|
||||||
|
output);
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,10 +20,13 @@ import org.bytedeco.javacpp.BytePointer;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.nativeblas.OpaqueDataBuffer;
|
||||||
|
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
|
@ -223,6 +226,35 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNDArrayNoCopy() throws Exception{
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
INDArray arr = Nd4j.rand(3, 2);
|
||||||
|
((BaseDataBuffer)arr.data()).syncToPrimary();
|
||||||
|
pyInputs.addNDArray("x", arr);
|
||||||
|
pyOutputs.addNDArray("x");
|
||||||
|
INDArray expected = arr.mul(2.3);
|
||||||
|
String code = "x *= 2.3";
|
||||||
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
Assert.assertEquals(pyInputs.getNDArrayValue("x"), pyOutputs.getNDArrayValue("x"));
|
||||||
|
Assert.assertEquals(expected, pyOutputs.getNDArrayValue("x"));
|
||||||
|
Assert.assertEquals(arr.data().address(), pyOutputs.getNDArrayValue("x").data().address());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNDArrayInplace() throws Exception{
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
INDArray arr = Nd4j.rand(3, 2);
|
||||||
|
((BaseDataBuffer)arr.data()).syncToPrimary();
|
||||||
|
pyInputs.addNDArray("x", arr);
|
||||||
|
INDArray expected = arr.mul(2.3);
|
||||||
|
String code = "x *= 2.3";
|
||||||
|
Python.exec(code, pyInputs, null);
|
||||||
|
Assert.assertEquals(expected, arr);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testByteBufferInput() throws Exception{
|
public void testByteBufferInput() throws Exception{
|
||||||
//ByteBuffer buff = ByteBuffer.allocateDirect(3);
|
//ByteBuffer buff = ByteBuffer.allocateDirect(3);
|
||||||
|
@ -230,8 +262,7 @@ public class TestPythonExecutioner {
|
||||||
buff.putScalar(0, 97); // a
|
buff.putScalar(0, 97); // a
|
||||||
buff.putScalar(1, 98); // b
|
buff.putScalar(1, 98); // b
|
||||||
buff.putScalar(2, 99); // c
|
buff.putScalar(2, 99); // c
|
||||||
|
((BaseDataBuffer)buff.data()).syncToPrimary();
|
||||||
|
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||||
|
|
||||||
|
@ -251,6 +282,7 @@ public class TestPythonExecutioner {
|
||||||
buff.putScalar(0, 97); // a
|
buff.putScalar(0, 97); // a
|
||||||
buff.putScalar(1, 98); // b
|
buff.putScalar(1, 98); // b
|
||||||
buff.putScalar(2, 99); // c
|
buff.putScalar(2, 99); // c
|
||||||
|
((BaseDataBuffer)buff.data()).syncToPrimary();
|
||||||
|
|
||||||
|
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
@ -262,6 +294,7 @@ public class TestPythonExecutioner {
|
||||||
String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97";
|
String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97";
|
||||||
Python.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
|
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
|
||||||
|
Assert.assertEquals(buff.data().address(), pyOutputs.getBytesValue("buff").address());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -270,6 +303,8 @@ public class TestPythonExecutioner {
|
||||||
buff.putScalar(0, 97); // a
|
buff.putScalar(0, 97); // a
|
||||||
buff.putScalar(1, 98); // b
|
buff.putScalar(1, 98); // b
|
||||||
buff.putScalar(2, 99); // c
|
buff.putScalar(2, 99); // c
|
||||||
|
((BaseDataBuffer)buff.data()).syncToPrimary();
|
||||||
|
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||||
String code = "buff[0]+=2\nbuff[2]-=2";
|
String code = "buff[0]+=2\nbuff[2]-=2";
|
||||||
|
@ -288,6 +323,7 @@ public class TestPythonExecutioner {
|
||||||
buff.putScalar(0, 97); // a
|
buff.putScalar(0, 97); // a
|
||||||
buff.putScalar(1, 98); // b
|
buff.putScalar(1, 98); // b
|
||||||
buff.putScalar(2, 99); // c
|
buff.putScalar(2, 99); // c
|
||||||
|
((BaseDataBuffer)buff.data()).syncToPrimary();
|
||||||
|
|
||||||
|
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
@ -300,6 +336,30 @@ public class TestPythonExecutioner {
|
||||||
Python.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
Assert.assertEquals("cba", pyOutputs.getBytesValue("out").getString());
|
Assert.assertEquals("cba", pyOutputs.getBytesValue("out").getString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDoubleDeviceAllocation() throws Exception{
|
||||||
|
if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Test to make sure that multiple device buffers are not allocated
|
||||||
|
// for the same host buffer
|
||||||
|
INDArray arr = Nd4j.rand(3, 2);
|
||||||
|
((BaseDataBuffer)arr.data()).syncToPrimary();
|
||||||
|
long deviceAddress1 = getDeviceAddress(arr);
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
pyInputs.addNDArray("arr", arr);
|
||||||
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
pyOutputs.addNDArray("arr");
|
||||||
|
String code = "arr += 2";
|
||||||
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
INDArray arr2 = pyOutputs.getNDArrayValue("arr");
|
||||||
|
long deviceAddress2 = getDeviceAddress(arr2);
|
||||||
|
Assert.assertEquals(deviceAddress1, deviceAddress2);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBadCode() throws Exception{
|
public void testBadCode() throws Exception{
|
||||||
Python.setContext("badcode");
|
Python.setContext("badcode");
|
||||||
|
@ -333,5 +393,22 @@ public class TestPythonExecutioner {
|
||||||
Assert.assertEquals("y", notNone.toString());
|
Assert.assertEquals("y", notNone.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static long getDeviceAddress(INDArray array){
|
||||||
|
if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
|
||||||
|
throw new IllegalStateException("Cannot ge device pointer for non-CUDA device");
|
||||||
|
}
|
||||||
|
|
||||||
|
//Use reflection here as OpaqueDataBuffer is only available on BaseCudaDataBuffer and BaseCpuDataBuffer - not DataBuffer/BaseDataBuffer
|
||||||
|
// due to it being defined in nd4j-native-api, not nd4j-api
|
||||||
|
try {
|
||||||
|
Class<?> c = Class.forName("org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer");
|
||||||
|
Method m = c.getMethod("getOpaqueDataBuffer");
|
||||||
|
OpaqueDataBuffer db = (OpaqueDataBuffer) m.invoke(array.data());
|
||||||
|
long address = db.specialBuffer().address();
|
||||||
|
return address;
|
||||||
|
} catch (Throwable t){
|
||||||
|
throw new RuntimeException("Error getting OpaqueDataBuffer", t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ package org.datavec.python;
|
||||||
|
|
||||||
import org.bytedeco.javacpp.BytePointer;
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -55,6 +56,7 @@ public class TestPythonVariables {
|
||||||
};
|
};
|
||||||
|
|
||||||
INDArray arr = Nd4j.scalar(1.0);
|
INDArray arr = Nd4j.scalar(1.0);
|
||||||
|
((BaseDataBuffer)arr.data()).syncToPrimary();
|
||||||
BytePointer bp = new BytePointer(arr.data().pointer());
|
BytePointer bp = new BytePointer(arr.data().pointer());
|
||||||
Object[] values = {
|
Object[] values = {
|
||||||
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
||||||
|
|
|
@ -1945,4 +1945,17 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
public boolean wasClosed() {
|
public boolean wasClosed() {
|
||||||
return released;
|
return released;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method synchronizes host memory
|
||||||
|
*/
|
||||||
|
public abstract void syncToPrimary();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method synchronizes device memory
|
||||||
|
*/
|
||||||
|
public abstract void syncToSpecial();
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -230,13 +230,14 @@ public enum DataType {
|
||||||
public static DataType fromNumpy(String numpyDtypeName){
|
public static DataType fromNumpy(String numpyDtypeName){
|
||||||
switch (numpyDtypeName.toLowerCase()){
|
switch (numpyDtypeName.toLowerCase()){
|
||||||
case "bool": return BOOL;
|
case "bool": return BOOL;
|
||||||
case "byte": return BYTE;
|
case "byte":
|
||||||
case "int8": return BYTE;
|
case "int8":
|
||||||
case "int16": return SHORT;
|
return INT8;
|
||||||
case "int32": return INT;
|
case "int16": return INT16;
|
||||||
case "int64": return LONG;
|
case "int32": return INT32;
|
||||||
case "uint8": return UBYTE;
|
case "int64": return INT64;
|
||||||
case "float16": return HALF;
|
case "uint8": return UINT8;
|
||||||
|
case "float16": return FLOAT16;
|
||||||
case "float32": return FLOAT;
|
case "float32": return FLOAT;
|
||||||
case "float64": return DOUBLE;
|
case "float64": return DOUBLE;
|
||||||
case "uint16": return UINT16;
|
case "uint16": return UINT16;
|
||||||
|
|
|
@ -211,6 +211,16 @@ public class CompressedDataBuffer extends BaseDataBuffer {
|
||||||
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
|
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void syncToPrimary() {
|
||||||
|
//No-op
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void syncToSpecial() {
|
||||||
|
//No-op
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected double getDoubleUnsynced(long index) {
|
protected double getDoubleUnsynced(long index) {
|
||||||
return super.getDouble(index);
|
return super.getDouble(index);
|
||||||
|
|
|
@ -51,7 +51,6 @@ public class OpaqueDataBuffer extends Pointer {
|
||||||
try {
|
try {
|
||||||
// try to allocate data buffer
|
// try to allocate data buffer
|
||||||
buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(numElements, dataType.toInt(), allocateBoth);
|
buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(numElements, dataType.toInt(), allocateBoth);
|
||||||
|
|
||||||
// check error code
|
// check error code
|
||||||
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||||
if (ec != 0) {
|
if (ec != 0) {
|
||||||
|
|
|
@ -1845,4 +1845,14 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
public int targetDevice() {
|
public int targetDevice() {
|
||||||
return AtomicAllocator.getInstance().getAllocationPoint(this).getDeviceId();
|
return AtomicAllocator.getInstance().getAllocationPoint(this).getDeviceId();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void syncToPrimary(){
|
||||||
|
ptrDataBuffer.syncToPrimary();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void syncToSpecial(){
|
||||||
|
ptrDataBuffer.syncToSpecial();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -951,4 +951,13 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void syncToPrimary(){
|
||||||
|
ptrDataBuffer.syncToPrimary();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void syncToSpecial(){
|
||||||
|
ptrDataBuffer.syncToSpecial();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,8 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
import sun.nio.ch.DirectBuffer;
|
||||||
|
|
||||||
|
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
@ -398,6 +400,29 @@ public class DataBufferTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEnsureLocation(){
|
||||||
|
//https://github.com/eclipse/deeplearning4j/issues/8783
|
||||||
|
Nd4j.create(1);
|
||||||
|
|
||||||
|
DirectBuffer bb = (DirectBuffer) ByteBuffer.allocateDirect(5);
|
||||||
|
System.out.println(bb.getClass());
|
||||||
|
System.out.println(bb.address());
|
||||||
|
|
||||||
|
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bb.address());
|
||||||
|
DataBuffer buff = Nd4j.createBuffer(ptr, 20, DataType.BYTE);
|
||||||
|
|
||||||
|
|
||||||
|
INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 1L, 'c', DataType.BYTE);
|
||||||
|
long before = arr2.data().pointer().address();
|
||||||
|
Nd4j.getAffinityManager().ensureLocation(arr2, AffinityManager.Location.HOST);
|
||||||
|
long after = arr2.data().pointer().address();
|
||||||
|
|
||||||
|
assertEquals(before, after);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
Loading…
Reference in New Issue