################################################################################ # Copyright (c) 2015-2018 Skymind, Inc. # # This program and the accompanying materials are made available under the # terms of the Apache License, Version 2.0 which is available at # https://www.apache.org/licenses/LICENSE-2.0. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. # # SPDX-License-Identifier: Apache-2.0 ################################################################################ from .op import op from ..java_classes import Nd4j from ..ndarray import _nparray, ndarray, _indarray # Array manipulation routines # https://docs.scipy.org/doc/numpy-1.13.0/reference/routines.array-manipulation.html @op def reshape(arr, *args): if len(args) == 1 and type(args) in (list, tuple): args = tuple(args[0]) return arr.reshape(*args) @op def transpose(arr, *axis): if len(axis) == 0: return arr.transpose() else: if len(axis) == 1: axis = axis[0] assert set(axis) in [set(list(range(len(axis)))), set(list(range(len(arr.shape()))))] return arr.permute(*axis) @op def ravel(arr): return arr.ravel() @op def flatten(arr): return arr.ravel().dup() @op def moveaxis(arr, source, destination): assert type(source) == type( destination), 'source and destination should be of same type.' shape = arr.shape() ndim = len(shape) x = list(range(ndim)) if type(source) is int: if source < 0: source += ndim if destination < 0: destination += ndim z = x.pop(source) x.insert(destination, z) return arr.permute(*x) if type(source) in (list, tuple): source = list(source) destination = list(destination) assert len(source) == len(destination) for src, dst in zip(source, destination): if src < 0: src += ndim if dst < 0: dst += ndim z = x.pop(src) x.insert(dst, z) return arr.permute(*x) @op def permute(arr, *axis): if len(axis) == 1: axis = axis[0] assert set(axis) in [set(list(range(len(axis)))), set(list(range(len(arr.shape()))))] return arr.permute(*axis) @op def expand_dims(arr, axis): return Nd4j.expandDims(arr, axis) @op def squeeze(arr, axis): shape = arr.shape() if type(axis) in (list, tuple): shape = [shape[i] for i in range(len(shape)) if i not in axis] else: shape.pop(axis) return arr.reshape(*shape) @op def concatenate(arrs, axis=-1): return Nd4j.concat(axis, *arrs) @op def hstack(arrs): return Nd4j.hstack(arrs) @op def vstack(arrs): return Nd4j.vstack(arrs) @op def stack(arrs, axis): for i, arr in enumerate(arrs): shape = arr.shape() shape.insert(axis, 1) arrs[i] = arr.reshape(*shape) return Nd4j.concat(axis, *arrs) @op def tile(arr, reps): import numpy as np return _indarray(np.tile(_nparray(arr), reps)) if type(reps) is int: return Nd4j.tile(arr, reps) else: return Nd4j.tile(arr, *reps) @op def repeat(arr, repeats, axis=None): if type(repeats) is int: repeats = (repeats,) if axis is None: return arr.repeat(-1, *repeats).reshape(-1) else: return arr.repeat(axis, *repeats)