151 lines
3.6 KiB
151 lines
3.6 KiB
# Copyright (c) 2015-2018 Skymind, Inc.
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# SPDX-License-Identifier: Apache-2.0
from .op import op
from ..java_classes import Nd4j
from ..ndarray import _nparray, ndarray, _indarray
# Array manipulation routines
# https://docs.scipy.org/doc/numpy-1.13.0/reference/routines.array-manipulation.html
def reshape(arr, *args):
if len(args) == 1 and type(args) in (list, tuple):
args = tuple(args[0])
return arr.reshape(*args)
def transpose(arr, *axis):
if len(axis) == 0:
return arr.transpose()
if len(axis) == 1:
axis = axis[0]
assert set(axis) in [set(list(range(len(axis)))),
return arr.permute(*axis)
def ravel(arr):
return arr.ravel()
def flatten(arr):
return arr.ravel().dup()
def moveaxis(arr, source, destination):
assert type(source) == type(
destination), 'source and destination should be of same type.'
shape = arr.shape()
ndim = len(shape)
x = list(range(ndim))
if type(source) is int:
if source < 0:
source += ndim
if destination < 0:
destination += ndim
z = x.pop(source)
x.insert(destination, z)
return arr.permute(*x)
if type(source) in (list, tuple):
source = list(source)
destination = list(destination)
assert len(source) == len(destination)
for src, dst in zip(source, destination):
if src < 0:
src += ndim
if dst < 0:
dst += ndim
z = x.pop(src)
x.insert(dst, z)
return arr.permute(*x)
def permute(arr, *axis):
if len(axis) == 1:
axis = axis[0]
assert set(axis) in [set(list(range(len(axis)))),
return arr.permute(*axis)
def expand_dims(arr, axis):
return Nd4j.expandDims(arr, axis)
def squeeze(arr, axis):
shape = arr.shape()
if type(axis) in (list, tuple):
shape = [shape[i] for i in range(len(shape)) if i not in axis]
return arr.reshape(*shape)
def concatenate(arrs, axis=-1):
return Nd4j.concat(axis, *arrs)
def hstack(arrs):
return Nd4j.hstack(arrs)
def vstack(arrs):
return Nd4j.vstack(arrs)
def stack(arrs, axis):
for i, arr in enumerate(arrs):
shape = arr.shape()
shape.insert(axis, 1)
arrs[i] = arr.reshape(*shape)
return Nd4j.concat(axis, *arrs)
def tile(arr, reps):
import numpy as np
return _indarray(np.tile(_nparray(arr), reps))
if type(reps) is int:
return Nd4j.tile(arr, reps)
return Nd4j.tile(arr, *reps)
def repeat(arr, repeats, axis=None):
if type(repeats) is int:
repeats = (repeats,)
if axis is None:
return arr.repeat(-1, *repeats).reshape(-1)
return arr.repeat(axis, *repeats)