Fixing python object for obtaining scalars (#330)
* Fixing python object for obtaining scalars Signed-off-by: shams <shamsazeem20@gmail.com> * Fix variable name for stridePtr Signed-off-by: shams <shamsazeem20@gmail.com> * Fix variable name for stridePtr Signed-off-by: shams <shamsazeem20@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
3cbba49518
commit
4e8f3a025f
|
@ -77,7 +77,7 @@ public class PythonObject {
|
||||||
|
|
||||||
long address = bp.address();
|
long address = bp.address();
|
||||||
long size = bp.capacity();
|
long size = bp.capacity();
|
||||||
NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.BYTE).build();
|
NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.INT8).build();
|
||||||
nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject;
|
nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,20 +320,23 @@ public class PythonObject {
|
||||||
public NumpyArray toNumpy() throws PythonException{
|
public NumpyArray toNumpy() throws PythonException{
|
||||||
PyObject np = PyImport_ImportModule("numpy");
|
PyObject np = PyImport_ImportModule("numpy");
|
||||||
PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
|
PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
|
||||||
if (PyObject_IsInstance(nativePythonObject, ndarray) == 0){
|
if (PyObject_IsInstance(nativePythonObject, ndarray) != 1){
|
||||||
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
|
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
|
||||||
}
|
}
|
||||||
Py_DecRef(ndarray);
|
Py_DecRef(ndarray);
|
||||||
Py_DecRef(np);
|
Py_DecRef(np);
|
||||||
|
|
||||||
Pointer objPtr = new Pointer(nativePythonObject);
|
Pointer objPtr = new Pointer(nativePythonObject);
|
||||||
PyArrayObject npArr = new PyArrayObject(objPtr);
|
PyArrayObject npArr = new PyArrayObject(objPtr);
|
||||||
Pointer ptr = PyArray_DATA(npArr);
|
Pointer ptr = PyArray_DATA(npArr);
|
||||||
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
|
|
||||||
long[] shape = new long[PyArray_NDIM(npArr)];
|
long[] shape = new long[PyArray_NDIM(npArr)];
|
||||||
shapePtr.get(shape, 0, shape.length);
|
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
|
||||||
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
|
if (shapePtr != null)
|
||||||
|
shapePtr.get(shape, 0, shape.length);
|
||||||
long[] strides = new long[shape.length];
|
long[] strides = new long[shape.length];
|
||||||
stridesPtr.get(strides, 0, strides.length);
|
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
|
||||||
|
if (stridesPtr != null)
|
||||||
|
stridesPtr.get(strides, 0, strides.length);
|
||||||
int npdtype = PyArray_TYPE(npArr);
|
int npdtype = PyArray_TYPE(npArr);
|
||||||
|
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
|
@ -345,28 +348,27 @@ public class PythonObject {
|
||||||
case NPY_SHORT:
|
case NPY_SHORT:
|
||||||
dtype = DataType.SHORT; break;
|
dtype = DataType.SHORT; break;
|
||||||
case NPY_INT:
|
case NPY_INT:
|
||||||
dtype = DataType.INT; break;
|
dtype = DataType.INT32; break;
|
||||||
case NPY_LONG:
|
case NPY_LONG:
|
||||||
dtype = DataType.LONG; break;
|
dtype = DataType.LONG; break;
|
||||||
case NPY_UINT:
|
case NPY_UINT:
|
||||||
dtype = DataType.UINT32; break;
|
dtype = DataType.UINT32; break;
|
||||||
case NPY_BYTE:
|
case NPY_BYTE:
|
||||||
dtype = DataType.BYTE; break;
|
dtype = DataType.INT8; break;
|
||||||
case NPY_UBYTE:
|
case NPY_UBYTE:
|
||||||
dtype = DataType.UBYTE; break;
|
dtype = DataType.UINT8; break;
|
||||||
case NPY_BOOL:
|
case NPY_BOOL:
|
||||||
dtype = DataType.BOOL; break;
|
dtype = DataType.BOOL; break;
|
||||||
case NPY_HALF:
|
case NPY_HALF:
|
||||||
dtype = DataType.HALF; break;
|
dtype = DataType.FLOAT16; break;
|
||||||
case NPY_LONGLONG:
|
case NPY_LONGLONG:
|
||||||
dtype = DataType.INT64; break;
|
dtype = DataType.INT64; break;
|
||||||
case NPY_USHORT:
|
case NPY_USHORT:
|
||||||
dtype = DataType.UINT16; break;
|
dtype = DataType.UINT16; break;
|
||||||
case NPY_ULONG:
|
case NPY_ULONG:
|
||||||
dtype = DataType.UINT64; break;
|
|
||||||
case NPY_ULONGLONG:
|
case NPY_ULONGLONG:
|
||||||
dtype = DataType.UINT64; break;
|
dtype = DataType.UINT64; break;
|
||||||
default:
|
default:
|
||||||
throw new PythonException("Unsupported array data type: " + npdtype);
|
throw new PythonException("Unsupported array data type: " + npdtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static junit.framework.TestCase.assertEquals;
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
public class ScalarAndArrayTest {
|
||||||
|
|
||||||
|
@Parameterized.Parameters(name = "{index}: Testing with INDArray={0}")
|
||||||
|
public static INDArray[] data() {
|
||||||
|
return new INDArray[]{
|
||||||
|
Nd4j.scalar(10),
|
||||||
|
Nd4j.ones(10, 10, 10, 10)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private INDArray indArray;
|
||||||
|
|
||||||
|
public ScalarAndArrayTest(INDArray indArray) {
|
||||||
|
this.indArray = indArray;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testINDArray() throws PythonException {
|
||||||
|
assertEquals(indArray, new PythonObject(indArray).toNumpy().getNd4jArray());
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue