513 lines
14 KiB
Python
513 lines
14 KiB
Python
################################################################################
|
|
# Copyright (c) 2015-2018 Skymind, Inc.
|
|
#
|
|
# This program and the accompanying materials are made available under the
|
|
# terms of the Apache License, Version 2.0 which is available at
|
|
# https://www.apache.org/licenses/LICENSE-2.0.
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
################################################################################
|
|
|
|
|
|
from .java_classes import *
|
|
import numpy as np
|
|
import ctypes
|
|
import warnings
|
|
|
|
|
|
native_ops = NativeOpsHolder.getInstance().getDeviceNativeOps()
|
|
|
|
|
|
# DATA TYPE MANAGEMENT
|
|
|
|
|
|
DOUBLE = DataType.DOUBLE
|
|
FLOAT = DataType.FLOAT
|
|
HALF = DataType.HALF
|
|
LONG = DataType.LONG
|
|
INT = DataType.INT
|
|
SHORT = DataType.SHORT
|
|
UBYTE = DataType.UBYTE
|
|
BYTE = DataType.BYTE
|
|
BOOL = DataType.BOOL
|
|
UTF8 = DataType.UTF8
|
|
COMPRESSED = DataType.COMPRESSED
|
|
UNKNOWN = DataType.UNKNOWN
|
|
|
|
SUPPORTED_JAVA_DTYPES = [
|
|
DOUBLE,
|
|
FLOAT,
|
|
HALF,
|
|
|
|
LONG,
|
|
INT,
|
|
SHORT,
|
|
|
|
BOOL
|
|
#UTF8
|
|
]
|
|
|
|
SUPPORTED_PYTHON_DTYPES = [
|
|
np.float64,
|
|
np.float32,
|
|
np.float16,
|
|
|
|
np.int64,
|
|
np.int32,
|
|
np.int16,
|
|
|
|
np.bool_
|
|
#np.str_
|
|
]
|
|
|
|
|
|
|
|
|
|
_PY2J = {SUPPORTED_PYTHON_DTYPES[i] : SUPPORTED_JAVA_DTYPES[i] for i in range(len(SUPPORTED_JAVA_DTYPES))}
|
|
_J2PY = {SUPPORTED_JAVA_DTYPES[i] : SUPPORTED_PYTHON_DTYPES[i] for i in range(len(SUPPORTED_JAVA_DTYPES))}
|
|
|
|
|
|
def _dtype_py2j(dtype):
|
|
if isinstance(dtype, str):
|
|
dtype = np.dtype(dtype).type
|
|
elif isinstance(dtype, np.dtype):
|
|
dtype = dtype.type
|
|
jtype = _PY2J.get(dtype)
|
|
if jtype is None:
|
|
raise NotImplementedError("Unsupported type: " + dtype.name)
|
|
return jtype
|
|
|
|
|
|
def _dtype_j2py(dtype):
|
|
pytype = _J2PY.get(dtype)
|
|
if pytype is None:
|
|
raise NotImplementedError("Unsupported type: " + (str(dtype)))
|
|
return pytype
|
|
|
|
|
|
def set_context_dtype(dtype):
|
|
'''
|
|
Sets the dtype for nd4j
|
|
# Arguments
|
|
dtype: 'float' or 'double'
|
|
'''
|
|
dtype_map = {
|
|
'float32': 'float',
|
|
'float64': 'double'
|
|
}
|
|
dtype = dtype_map.get(dtype, dtype)
|
|
if dtype not in ['float', 'double']:
|
|
raise ValueError("Invalid dtype '{}'. Available dtypes are 'float' and 'double'.".format(dtype))
|
|
dtype_ = DataTypeUtil.getDtypeFromContext(dtype)
|
|
DataTypeUtil.setDTypeForContext(dtype_)
|
|
if get_context_dtype() != dtype:
|
|
warnings.warn("Can not set context dtype now. Set it at the beginning of your program.")
|
|
|
|
|
|
def get_context_dtype():
|
|
'''
|
|
Returns the nd4j dtype
|
|
'''
|
|
dtype = DataTypeUtil.getDtypeFromContext()
|
|
return DataTypeUtil.getDTypeForName(dtype)
|
|
|
|
_refs = []
|
|
|
|
|
|
def _from_numpy(np_array):
|
|
'''
|
|
Convert numpy array to nd4j array
|
|
'''
|
|
pointer_address, _ = np_array.__array_interface__['data']
|
|
_refs.append(np_array)
|
|
pointer = native_ops.pointerForAddress(pointer_address)
|
|
size = np_array.size
|
|
pointer.limit(size)
|
|
jdtype = _dtype_py2j(np_array.dtype)
|
|
'''
|
|
mapping = {
|
|
DOUBLE: DoublePointer,
|
|
FLOAT: FloatPointer,
|
|
HALF: HalfPointer,
|
|
LONG: LongPointer,
|
|
INT: IntPointer,
|
|
SHORT: ShortPointer,
|
|
BOOL: BoolPointer
|
|
}
|
|
pc = mapping[jdtype]
|
|
#pointer = pc(pointer)
|
|
'''
|
|
buff = Nd4j.createBuffer(pointer, size, jdtype)
|
|
assert buff.address() == pointer_address
|
|
_refs.append(buff)
|
|
elem_size = buff.getElementSize()
|
|
assert elem_size == np_array.dtype.itemsize
|
|
strides = np_array.strides
|
|
strides = [dim / elem_size for dim in strides]
|
|
shape = np_array.shape
|
|
nd4j_array = Nd4j.create(buff, shape, strides, 0)
|
|
assert buff.address() == nd4j_array.data().address()
|
|
return nd4j_array
|
|
|
|
|
|
def _to_numpy(nd4j_array):
|
|
'''
|
|
Convert nd4j array to numpy array
|
|
'''
|
|
buff = nd4j_array.data()
|
|
address = buff.pointer().address()
|
|
dtype = nd4j_array.dataType().toString()
|
|
mapping = {
|
|
'DOUBLE': ctypes.c_double,
|
|
'FLOAT': ctypes.c_float,
|
|
'HALF': ctypes.c_short,
|
|
'LONG': ctypes.c_long,
|
|
'INT': ctypes.c_int,
|
|
'SHORT': ctypes.c_short,
|
|
'BOOL': ctypes.c_bool
|
|
}
|
|
Pointer = ctypes.POINTER(mapping[dtype])
|
|
pointer = ctypes.cast(address, Pointer)
|
|
np_array = np.ctypeslib.as_array(pointer, tuple(nd4j_array.shape()))
|
|
return np_array
|
|
|
|
|
|
def _indarray(x):
|
|
typ = type(x)
|
|
if typ is INDArray:
|
|
return x
|
|
elif typ is ndarray:
|
|
return x.array
|
|
elif 'numpy' in str(typ):
|
|
return _from_numpy(x)
|
|
elif typ in (list, tuple):
|
|
return _from_numpy(np.array(x))
|
|
elif typ in (int, float):
|
|
return Nd4j.scalar(x)
|
|
else:
|
|
raise Exception('Data type not understood :' + str(typ))
|
|
|
|
|
|
def _nparray(x):
|
|
typ = type(x)
|
|
if typ is INDArray:
|
|
return ndarray(x).numpy()
|
|
elif typ is ndarray:
|
|
return x.numpy()
|
|
elif 'numpy' in str(typ):
|
|
return x
|
|
elif typ in (list, tuple):
|
|
return np.array(x)
|
|
elif typ in (int, float):
|
|
return np.array(x)
|
|
else:
|
|
raise Exception('Data type not understood :' + str(typ))
|
|
|
|
|
|
def broadcast_like(y, x):
|
|
xs = x.shape()
|
|
ys = y.shape()
|
|
if xs == ys:
|
|
return y
|
|
_xs = tuple(xs)
|
|
_ys = tuple(ys)
|
|
nx = len(xs)
|
|
ny = len(ys)
|
|
if nx > ny:
|
|
diff = nx - ny
|
|
ys = ([1] * diff) + ys
|
|
y = y.reshape(ys)
|
|
ny = nx
|
|
elif ny > nx:
|
|
raise Exception('Unable to broadcast shapes ' + str(_xs) + ''
|
|
' and ' + str(_ys))
|
|
yt = []
|
|
rep_y = False
|
|
for xd, yd in zip(xs, ys):
|
|
if xd == yd:
|
|
yt.append(1)
|
|
elif xd == 1:
|
|
raise Exception('Unable to broadcast shapes ' + str(_xs) + ''
|
|
' and ' + str(_ys))
|
|
elif yd == 1:
|
|
yt.append(xd)
|
|
rep_y = True
|
|
else:
|
|
raise Exception('Unable to broadcast shapes ' + str(_xs) + ''
|
|
' and ' + str(_ys))
|
|
if rep_y:
|
|
y = y.repmat(*yt)
|
|
return y
|
|
|
|
|
|
def broadcast(x, y):
|
|
xs = x.shape()
|
|
ys = y.shape()
|
|
if xs == ys:
|
|
return x, y
|
|
_xs = tuple(xs)
|
|
_ys = tuple(ys)
|
|
nx = len(xs)
|
|
ny = len(ys)
|
|
if nx > ny:
|
|
diff = nx - ny
|
|
ys = ([1] * diff) + ys
|
|
y = y.reshape(*ys)
|
|
ny = nx
|
|
elif ny > nx:
|
|
diff = ny - nx
|
|
xs = ([1] * diff) + xs
|
|
x = x.reshape(*xs)
|
|
nx = ny
|
|
xt = []
|
|
yt = []
|
|
rep_x = False
|
|
rep_y = False
|
|
for xd, yd in zip(xs, ys):
|
|
if xd == yd:
|
|
xt.append(1)
|
|
yt.append(1)
|
|
elif xd == 1:
|
|
xt.append(yd)
|
|
yt.append(1)
|
|
rep_x = True
|
|
elif yd == 1:
|
|
xt.append(1)
|
|
yt.append(xd)
|
|
rep_y = True
|
|
else:
|
|
raise Exception('Unable to broadcast shapes ' + str(_xs) + ''
|
|
' and ' + str(_ys))
|
|
if rep_x:
|
|
x = Nd4j.tile(x, *xt)
|
|
if rep_y:
|
|
try:
|
|
y = Nd4j.tile(y, *yt)
|
|
except:
|
|
y = Nd4j.tile(y, *yt)
|
|
return x, y
|
|
|
|
|
|
class ndarray(object):
|
|
|
|
def __init__(self, data, dtype=None):
|
|
# we ignore dtype for now
|
|
typ = type(data)
|
|
if 'nd4j' in typ.__name__:
|
|
# Note that we don't make a copy here
|
|
self.array = data
|
|
elif typ is ndarray:
|
|
self.array = data.array.dup()
|
|
else:
|
|
if typ is not np.ndarray:
|
|
data = np.array(data)
|
|
self.array = _from_numpy(data)
|
|
|
|
def numpy(self):
|
|
try:
|
|
return self.np_array
|
|
except AttributeError:
|
|
self.np_array = _to_numpy(self.array)
|
|
return self.np_array
|
|
|
|
@property
|
|
def size(self):
|
|
return self.array.length()
|
|
|
|
@property
|
|
def shape(self):
|
|
return tuple(self.array.shape())
|
|
|
|
@shape.setter
|
|
def shape(self, value):
|
|
arr = self.reshape(value)
|
|
self.array = arr.array
|
|
|
|
@property
|
|
def ndim(self):
|
|
return len(self.array.shape())
|
|
|
|
def __getitem__(self, key):
|
|
return ndarray(self.numpy()[key])
|
|
if type(key) is int:
|
|
return ndarray(self.array.get(NDArrayIndex.point(key)))
|
|
if type(key) is slice:
|
|
start = key.start
|
|
stop = key.stop
|
|
step = key.step
|
|
if start is None:
|
|
start = 0
|
|
if stop is None:
|
|
shape = self.array.shape()
|
|
if shape[0] == 1:
|
|
stop = shape[1]
|
|
else:
|
|
stop = shape[0]
|
|
if stop - start <= 0:
|
|
return None
|
|
if step is None or step == 1:
|
|
return ndarray(self.array.get(NDArrayIndex.interval(start, stop)))
|
|
else:
|
|
return ndarray(self.array.get(NDArrayIndex.interval(start, step, stop)))
|
|
if type(key) is list:
|
|
raise NotImplementedError(
|
|
'Sorry, this type of indexing is not supported yet.')
|
|
if type(key) is tuple:
|
|
key = list(key)
|
|
shape = self.array.shape()
|
|
ndim = len(shape)
|
|
nk = len(key)
|
|
key += [slice(None)] * (ndim - nk)
|
|
args = []
|
|
for i, dim in enumerate(key):
|
|
if type(dim) is int:
|
|
args.append(NDArrayIndex.point(dim))
|
|
elif type(dim) is slice:
|
|
if dim == slice(None):
|
|
args.append(NDArrayIndex.all())
|
|
else:
|
|
start = dim.start
|
|
stop = dim.stop
|
|
step = dim.step
|
|
if start is None:
|
|
start = 0
|
|
if stop is None:
|
|
stop = shape[i]
|
|
if stop - start <= 0:
|
|
return None
|
|
if step is None or step == 1:
|
|
args.append(NDArrayIndex.interval(start, stop))
|
|
else:
|
|
args.append(NDArrayIndex.interval(
|
|
start, step, stop))
|
|
elif type(dim) in (list, tuple):
|
|
raise NotImplementedError(
|
|
'Sorry, this type of indexing is not supported yet.')
|
|
return ndarray(self.array.get(*args))
|
|
|
|
def __setitem__(self, key, other):
|
|
self.numpy()[key] = _nparray(other)
|
|
return
|
|
other = _indarray(other)
|
|
view = self[key]
|
|
if view is None:
|
|
return
|
|
view = view.array
|
|
other = broadcast_like(other, view)
|
|
view.assign(other)
|
|
|
|
def __add__(self, other):
|
|
return ndarray(self.numpy() + _nparray(other))
|
|
other = _indarray(other)
|
|
x, y = broadcast(self.array, other)
|
|
return ndarray(x.add(y))
|
|
|
|
def __sub__(self, other):
|
|
return ndarray(self.numpy() - _nparray(other))
|
|
other = _indarray(other)
|
|
x, y = broadcast(self.array, other)
|
|
return ndarray(x.sub(y))
|
|
|
|
def __mul__(self, other):
|
|
return ndarray(self.numpy() * _nparray(other))
|
|
other = _indarray(other)
|
|
x, y = broadcast(self.array, other)
|
|
return ndarray(x.mul(y))
|
|
|
|
def __div__(self, other):
|
|
return ndarray(self.numpy() / _nparray(other))
|
|
other = _indarray(other)
|
|
x, y = broadcast(self.array, other)
|
|
return ndarray(x.div(y))
|
|
|
|
def __pow__(self, other):
|
|
return ndarray(self.numpy() ** _nparray(other))
|
|
other = _indarray(other)
|
|
x, y = broadcast(self.array, other)
|
|
return ndarray(Transforms.pow(x, y))
|
|
|
|
def __iadd__(self, other):
|
|
self.numpy().__iadd__(_nparray(other))
|
|
return self
|
|
other = _indarray(other)
|
|
if self.array.shape() == other.shape():
|
|
self.array = self.array.addi(other)
|
|
else:
|
|
x, y = broadcast(self.array, other)
|
|
self.array = x.add(y)
|
|
return self
|
|
|
|
def __isub__(self, other):
|
|
self.numpy().__isub__(_nparray(other))
|
|
return self
|
|
other = _indarray(other)
|
|
if self.array.shape() == other.shape():
|
|
self.array = self.array.subi(other)
|
|
else:
|
|
x, y = broadcast(self.array, other)
|
|
self.array = x.sub(y)
|
|
return self
|
|
|
|
def __imul__(self, other):
|
|
self.numpy().__imul__(_nparray(other))
|
|
return self
|
|
other = _indarray(other)
|
|
if self.array.shape() == other.shape():
|
|
self.array = self.array.muli(other)
|
|
else:
|
|
x, y = broadcast(self.array, other)
|
|
self.array = x.mul(y)
|
|
return self
|
|
|
|
def __idiv__(self, other):
|
|
self.numpy().__idiv__(_nparray(other))
|
|
return self
|
|
other = _indarray(other)
|
|
if self.array.shape() == other.shape():
|
|
self.array = self.array.divi(other)
|
|
else:
|
|
x, y = broadcast(self.array, other)
|
|
self.array = x.div(y)
|
|
return self
|
|
|
|
def __ipow__(self, other):
|
|
self.numpy().__ipow__(_nparray(other))
|
|
return self
|
|
other = _indarray(other)
|
|
if self.array.shape() == other.shape():
|
|
self.array = self.array.divi(other)
|
|
else:
|
|
x, y = broadcast(self.array, other)
|
|
self.array = Transforms.pow(x, y)
|
|
return self
|
|
|
|
def __getattr__(self, attr):
|
|
import ops
|
|
f = getattr(ops, attr)
|
|
setattr(ndarray, attr, f)
|
|
return getattr(self, attr)
|
|
|
|
def __int__(self):
|
|
if self.array.length() == 1:
|
|
return self.array.getInt(0)
|
|
raise Exception('Applicable only for scalars')
|
|
|
|
def __float__(self):
|
|
if self.array.length() == 1:
|
|
return self.array.getDouble(0)
|
|
raise Exception('Applicable only for scalars')
|
|
|
|
@property
|
|
def T(self):
|
|
return self.transpose()
|
|
|
|
|
|
def array(*args, **kwargs):
|
|
return ndarray(*args, **kwargs)
|