Fixing python object for obtaining scalars (#330)

* Fixing python object for obtaining scalars

Signed-off-by: shams <shamsazeem20@gmail.com>

* Fix variable name for stridePtr

Signed-off-by: shams <shamsazeem20@gmail.com>

* Fix variable name for stridePtr

Signed-off-by: shams <shamsazeem20@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Shams Ul Azeem 2020-03-24 13:11:57 +05:00 committed by GitHub
parent 3cbba49518
commit 4e8f3a025f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 12 deletions

View File

@ -77,7 +77,7 @@ public class PythonObject {
long address = bp.address(); long address = bp.address();
long size = bp.capacity(); long size = bp.capacity();
NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.BYTE).build(); NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.INT8).build();
nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject; nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject;
} }
@ -320,19 +320,22 @@ public class PythonObject {
public NumpyArray toNumpy() throws PythonException{ public NumpyArray toNumpy() throws PythonException{
PyObject np = PyImport_ImportModule("numpy"); PyObject np = PyImport_ImportModule("numpy");
PyObject ndarray = PyObject_GetAttrString(np, "ndarray"); PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
if (PyObject_IsInstance(nativePythonObject, ndarray) == 0){ if (PyObject_IsInstance(nativePythonObject, ndarray) != 1){
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array."); throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
} }
Py_DecRef(ndarray); Py_DecRef(ndarray);
Py_DecRef(np); Py_DecRef(np);
Pointer objPtr = new Pointer(nativePythonObject); Pointer objPtr = new Pointer(nativePythonObject);
PyArrayObject npArr = new PyArrayObject(objPtr); PyArrayObject npArr = new PyArrayObject(objPtr);
Pointer ptr = PyArray_DATA(npArr); Pointer ptr = PyArray_DATA(npArr);
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
long[] shape = new long[PyArray_NDIM(npArr)]; long[] shape = new long[PyArray_NDIM(npArr)];
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
if (shapePtr != null)
shapePtr.get(shape, 0, shape.length); shapePtr.get(shape, 0, shape.length);
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
long[] strides = new long[shape.length]; long[] strides = new long[shape.length];
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
if (stridesPtr != null)
stridesPtr.get(strides, 0, strides.length); stridesPtr.get(strides, 0, strides.length);
int npdtype = PyArray_TYPE(npArr); int npdtype = PyArray_TYPE(npArr);
@ -345,25 +348,24 @@ public class PythonObject {
case NPY_SHORT: case NPY_SHORT:
dtype = DataType.SHORT; break; dtype = DataType.SHORT; break;
case NPY_INT: case NPY_INT:
dtype = DataType.INT; break; dtype = DataType.INT32; break;
case NPY_LONG: case NPY_LONG:
dtype = DataType.LONG; break; dtype = DataType.LONG; break;
case NPY_UINT: case NPY_UINT:
dtype = DataType.UINT32; break; dtype = DataType.UINT32; break;
case NPY_BYTE: case NPY_BYTE:
dtype = DataType.BYTE; break; dtype = DataType.INT8; break;
case NPY_UBYTE: case NPY_UBYTE:
dtype = DataType.UBYTE; break; dtype = DataType.UINT8; break;
case NPY_BOOL: case NPY_BOOL:
dtype = DataType.BOOL; break; dtype = DataType.BOOL; break;
case NPY_HALF: case NPY_HALF:
dtype = DataType.HALF; break; dtype = DataType.FLOAT16; break;
case NPY_LONGLONG: case NPY_LONGLONG:
dtype = DataType.INT64; break; dtype = DataType.INT64; break;
case NPY_USHORT: case NPY_USHORT:
dtype = DataType.UINT16; break; dtype = DataType.UINT16; break;
case NPY_ULONG: case NPY_ULONG:
dtype = DataType.UINT64; break;
case NPY_ULONGLONG: case NPY_ULONGLONG:
dtype = DataType.UINT64; break; dtype = DataType.UINT64; break;
default: default:

View File

@ -0,0 +1,48 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static junit.framework.TestCase.assertEquals;
@RunWith(Parameterized.class)
public class ScalarAndArrayTest {
@Parameterized.Parameters(name = "{index}: Testing with INDArray={0}")
public static INDArray[] data() {
return new INDArray[]{
Nd4j.scalar(10),
Nd4j.ones(10, 10, 10, 10)
};
}
private INDArray indArray;
public ScalarAndArrayTest(INDArray indArray) {
this.indArray = indArray;
}
@Test
public void testINDArray() throws PythonException {
assertEquals(indArray, new PythonObject(indArray).toNumpy().getNd4jArray());
}
}