Fixing python object for obtaining scalars (#330)
* Fixing python object for obtaining scalars Signed-off-by: shams <shamsazeem20@gmail.com> * Fix variable name for stridePtr Signed-off-by: shams <shamsazeem20@gmail.com> * Fix variable name for stridePtr Signed-off-by: shams <shamsazeem20@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
3cbba49518
commit
4e8f3a025f
|
@ -77,7 +77,7 @@ public class PythonObject {
|
|||
|
||||
long address = bp.address();
|
||||
long size = bp.capacity();
|
||||
NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.BYTE).build();
|
||||
NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.INT8).build();
|
||||
nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject;
|
||||
}
|
||||
|
||||
|
@ -320,20 +320,23 @@ public class PythonObject {
|
|||
public NumpyArray toNumpy() throws PythonException{
|
||||
PyObject np = PyImport_ImportModule("numpy");
|
||||
PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
|
||||
if (PyObject_IsInstance(nativePythonObject, ndarray) == 0){
|
||||
if (PyObject_IsInstance(nativePythonObject, ndarray) != 1){
|
||||
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
|
||||
}
|
||||
Py_DecRef(ndarray);
|
||||
Py_DecRef(np);
|
||||
|
||||
Pointer objPtr = new Pointer(nativePythonObject);
|
||||
PyArrayObject npArr = new PyArrayObject(objPtr);
|
||||
Pointer ptr = PyArray_DATA(npArr);
|
||||
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
|
||||
long[] shape = new long[PyArray_NDIM(npArr)];
|
||||
shapePtr.get(shape, 0, shape.length);
|
||||
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
|
||||
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
|
||||
if (shapePtr != null)
|
||||
shapePtr.get(shape, 0, shape.length);
|
||||
long[] strides = new long[shape.length];
|
||||
stridesPtr.get(strides, 0, strides.length);
|
||||
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
|
||||
if (stridesPtr != null)
|
||||
stridesPtr.get(strides, 0, strides.length);
|
||||
int npdtype = PyArray_TYPE(npArr);
|
||||
|
||||
DataType dtype;
|
||||
|
@ -345,28 +348,27 @@ public class PythonObject {
|
|||
case NPY_SHORT:
|
||||
dtype = DataType.SHORT; break;
|
||||
case NPY_INT:
|
||||
dtype = DataType.INT; break;
|
||||
dtype = DataType.INT32; break;
|
||||
case NPY_LONG:
|
||||
dtype = DataType.LONG; break;
|
||||
case NPY_UINT:
|
||||
dtype = DataType.UINT32; break;
|
||||
case NPY_BYTE:
|
||||
dtype = DataType.BYTE; break;
|
||||
dtype = DataType.INT8; break;
|
||||
case NPY_UBYTE:
|
||||
dtype = DataType.UBYTE; break;
|
||||
dtype = DataType.UINT8; break;
|
||||
case NPY_BOOL:
|
||||
dtype = DataType.BOOL; break;
|
||||
case NPY_HALF:
|
||||
dtype = DataType.HALF; break;
|
||||
dtype = DataType.FLOAT16; break;
|
||||
case NPY_LONGLONG:
|
||||
dtype = DataType.INT64; break;
|
||||
case NPY_USHORT:
|
||||
dtype = DataType.UINT16; break;
|
||||
case NPY_ULONG:
|
||||
dtype = DataType.UINT64; break;
|
||||
case NPY_ULONGLONG:
|
||||
dtype = DataType.UINT64; break;
|
||||
default:
|
||||
default:
|
||||
throw new PythonException("Unsupported array data type: " + npdtype);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.python;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static junit.framework.TestCase.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class ScalarAndArrayTest {
|
||||
|
||||
@Parameterized.Parameters(name = "{index}: Testing with INDArray={0}")
|
||||
public static INDArray[] data() {
|
||||
return new INDArray[]{
|
||||
Nd4j.scalar(10),
|
||||
Nd4j.ones(10, 10, 10, 10)
|
||||
};
|
||||
}
|
||||
|
||||
private INDArray indArray;
|
||||
|
||||
public ScalarAndArrayTest(INDArray indArray) {
|
||||
this.indArray = indArray;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testINDArray() throws PythonException {
|
||||
assertEquals(indArray, new PythonObject(indArray).toNumpy().getNd4jArray());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue