2021-02-01 14:31:20 +09:00
|
|
|
/*
|
|
|
|
|
* ******************************************************************************
|
|
|
|
|
* *
|
|
|
|
|
* *
|
|
|
|
|
* * This program and the accompanying materials are made available under the
|
|
|
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
|
* *
|
2021-02-01 17:47:29 +09:00
|
|
|
* * See the NOTICE file distributed with this work for additional
|
|
|
|
|
* * information regarding copyright ownership.
|
2021-02-01 14:31:20 +09:00
|
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
|
* * License for the specific language governing permissions and limitations
|
|
|
|
|
* * under the License.
|
|
|
|
|
* *
|
|
|
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
* *****************************************************************************
|
|
|
|
|
*/
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
package org.nd4j.linalg.indexing;
|
|
|
|
|
|
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
|
import lombok.val;
|
2020-04-29 11:19:26 +10:00
|
|
|
import org.nd4j.common.base.Preconditions;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
|
|
|
import org.nd4j.linalg.api.shape.Shape;
|
2020-04-29 11:19:26 +10:00
|
|
|
import org.nd4j.common.util.ArrayUtil;
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
import java.util.ArrayList;
|
|
|
|
|
import java.util.Arrays;
|
|
|
|
|
import java.util.List;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* NDArray indexing
|
|
|
|
|
*
|
|
|
|
|
* @author Adam Gibson
|
|
|
|
|
*/
|
|
|
|
|
@Slf4j
|
|
|
|
|
public abstract class NDArrayIndex implements INDArrayIndex {
|
|
|
|
|
|
2022-10-21 15:19:32 +02:00
|
|
|
private final long[] indices;
|
|
|
|
|
private static final NewAxis NEW_AXIS = new NewAxis();
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Returns a point index
|
|
|
|
|
* @param point the point index
|
|
|
|
|
* @return the point index based
|
|
|
|
|
* on the specified point
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex point(long point) {
|
|
|
|
|
return new PointIndex(point);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Add indexes for the given shape
|
|
|
|
|
* @param shape the shape ot convert to indexes
|
|
|
|
|
* @return the indexes for the given shape
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex[] indexesFor(long... shape) {
|
|
|
|
|
INDArrayIndex[] ret = new INDArrayIndex[shape.length];
|
|
|
|
|
for (int i = 0; i < shape.length; i++) {
|
|
|
|
|
ret[i] = NDArrayIndex.point(shape[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Compute the offset given an array of offsets.
|
|
|
|
|
* The offset is computed(for both fortran an d c ordering) as:
|
|
|
|
|
* sum from i to n - 1 o[i] * s[i]
|
|
|
|
|
* where i is the index o is the offset and s is the stride
|
|
|
|
|
* Notice the -1 at the end.
|
|
|
|
|
* @param arr the array to compute the offset for
|
|
|
|
|
* @param offsets the offsets for each dimension
|
|
|
|
|
* @return the offset that should be used for indexing
|
|
|
|
|
*/
|
|
|
|
|
public static long offset(INDArray arr, long... offsets) {
|
|
|
|
|
return offset(arr.stride(), offsets);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Compute the offset given an array of offsets.
|
|
|
|
|
* The offset is computed(for both fortran an d c ordering) as:
|
|
|
|
|
* sum from i to n - 1 o[i] * s[i]
|
|
|
|
|
* where i is the index o is the offset and s is the stride
|
|
|
|
|
* Notice the -1 at the end.
|
|
|
|
|
* @param arr the array to compute the offset for
|
|
|
|
|
* @param indices the offsets for each dimension
|
|
|
|
|
* @return the offset that should be used for indexing
|
|
|
|
|
*/
|
|
|
|
|
public static long offset(INDArray arr, INDArrayIndex... indices) {
|
|
|
|
|
return offset(arr.stride(), Indices.offsets(arr.shape(), indices));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Compute the offset given an array of offsets.
|
|
|
|
|
* The offset is computed(for both fortran an d c ordering) as:
|
|
|
|
|
* sum from i to n - 1 o[i] * s[i]
|
|
|
|
|
* where i is the index o is the offset and s is the stride
|
|
|
|
|
* Notice the -1 at the end.
|
|
|
|
|
* @param strides the strides to compute the offset for
|
|
|
|
|
* @param offsets the offsets for each dimension
|
|
|
|
|
* @return the offset that should be used for indexing
|
|
|
|
|
*/
|
|
|
|
|
public static long offset(long[] strides, long[] offsets) {
|
|
|
|
|
int ret = 0;
|
|
|
|
|
|
|
|
|
|
if (ArrayUtil.prod(offsets) == 1) {
|
|
|
|
|
for (int i = 0; i < offsets.length; i++) {
|
|
|
|
|
ret += offsets[i] * strides[i];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < offsets.length; i++) {
|
|
|
|
|
ret += offsets[i] * strides[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public static long offset(int[] strides, long[] offsets) {
|
|
|
|
|
int ret = 0;
|
|
|
|
|
|
|
|
|
|
if (ArrayUtil.prodLong(offsets) == 1) {
|
|
|
|
|
for (int i = 0; i < offsets.length; i++) {
|
|
|
|
|
ret += offsets[i] * strides[i];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < offsets.length; i++) {
|
|
|
|
|
ret += offsets[i] * strides[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Repeat a copy of copy n times
|
|
|
|
|
* @param copy the ndarray index to copy
|
|
|
|
|
* @param n the number of times to copy
|
|
|
|
|
* @return an array of length n containing copies of
|
|
|
|
|
* the given ndarray index
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex[] nTimes(INDArrayIndex copy, int n) {
|
|
|
|
|
INDArrayIndex[] ret = new INDArrayIndex[n];
|
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
|
ret[i] = copy;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* NDArrayIndexing based on the given
|
|
|
|
|
* indexes
|
|
|
|
|
* @param indices
|
|
|
|
|
*/
|
|
|
|
|
public NDArrayIndex(long... indices) {
|
|
|
|
|
this.indices = indices;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Represents collecting all elements
|
|
|
|
|
*
|
|
|
|
|
* @return an ndarray index
|
|
|
|
|
* meaning collect
|
|
|
|
|
* all elements
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex all() {
|
|
|
|
|
return new NDArrayIndexAll();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Returns an instance of {@link SpecifiedIndex}.
|
|
|
|
|
* Note that SpecifiedIndex works differently than the other indexing options, in that it always returns a copy
|
|
|
|
|
* of the (subset of) the underlying array, for get operations. This means that INDArray.get(..., indices(x,y,z), ...)
|
|
|
|
|
* will be a copy of the relevant subset of the array.
|
|
|
|
|
* @param indices Indices to get
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex indices(long... indices){
|
|
|
|
|
return new SpecifiedIndex(indices);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Represents adding a new dimension
|
|
|
|
|
* @return the indexing for
|
|
|
|
|
* adding a new dimension
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex newAxis() {
|
|
|
|
|
return NEW_AXIS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Given an all index and
|
|
|
|
|
* the intended indexes, return an
|
|
|
|
|
* index array containing a combination of all elements
|
|
|
|
|
* for slicing and overriding particular indexes where necessary
|
|
|
|
|
* @param arr the array to resolve indexes for
|
|
|
|
|
* @param intendedIndexes the indexes specified by the user
|
|
|
|
|
* @return the resolved indexes (containing all where nothing is specified, and the intended index
|
|
|
|
|
* for a particular dimension otherwise)
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex[] resolve(INDArray arr, INDArrayIndex... intendedIndexes) {
|
|
|
|
|
return resolve(NDArrayIndex.allFor(arr), intendedIndexes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Number of point indexes
|
|
|
|
|
* @param indexes the indexes
|
|
|
|
|
* to count for points
|
|
|
|
|
* @return the number of point indexes
|
|
|
|
|
* in the array
|
|
|
|
|
*/
|
|
|
|
|
public static int numPoints(INDArrayIndex... indexes) {
|
|
|
|
|
int ret = 0;
|
|
|
|
|
for (int i = 0; i < indexes.length; i++)
|
|
|
|
|
if (indexes[i] instanceof PointIndex)
|
|
|
|
|
ret++;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Given an all index and
|
|
|
|
|
* the intended indexes, return an
|
|
|
|
|
* index array containing a combination of all elements
|
|
|
|
|
* for slicing and overriding particular indexes where necessary
|
|
|
|
|
* @param shapeInfo the index containing all elements
|
|
|
|
|
* @param intendedIndexes the indexes specified by the user
|
|
|
|
|
* @return the resolved indexes (containing all where nothing is specified, and the intended index
|
|
|
|
|
* for a particular dimension otherwise)
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex[] resolveLong(long[] shapeInfo, INDArrayIndex... intendedIndexes) {
|
|
|
|
|
int numSpecified = 0;
|
|
|
|
|
for (int i = 0; i < intendedIndexes.length; i++) {
|
|
|
|
|
if (intendedIndexes[i] instanceof SpecifiedIndex)
|
|
|
|
|
numSpecified++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (numSpecified > 0) {
|
|
|
|
|
val shape = Shape.shapeOf(shapeInfo);
|
|
|
|
|
INDArrayIndex[] ret = new INDArrayIndex[intendedIndexes.length];
|
|
|
|
|
for (int i = 0; i < intendedIndexes.length; i++) {
|
|
|
|
|
if (intendedIndexes[i] instanceof SpecifiedIndex)
|
|
|
|
|
ret[i] = intendedIndexes[i];
|
|
|
|
|
else {
|
|
|
|
|
if (intendedIndexes[i] instanceof NDArrayIndexAll) {
|
|
|
|
|
SpecifiedIndex specifiedIndex = new SpecifiedIndex(ArrayUtil.range(0L, shape[i]));
|
|
|
|
|
ret[i] = specifiedIndex;
|
|
|
|
|
} else if (intendedIndexes[i] instanceof IntervalIndex) {
|
|
|
|
|
IntervalIndex intervalIndex = (IntervalIndex) intendedIndexes[i];
|
|
|
|
|
ret[i] = new SpecifiedIndex(ArrayUtil.range(intervalIndex.begin, intervalIndex.end(),
|
|
|
|
|
intervalIndex.stride()));
|
|
|
|
|
} else if(intendedIndexes[i] instanceof PointIndex){
|
|
|
|
|
ret[i] = intendedIndexes[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* If it's a vector and index asking
|
|
|
|
|
* for a scalar just return the array
|
|
|
|
|
*/
|
|
|
|
|
int rank = Shape.rank(shapeInfo);
|
|
|
|
|
val shape = Shape.shapeOf(shapeInfo);
|
|
|
|
|
if (intendedIndexes.length >= rank || Shape.isVector(shapeInfo) && intendedIndexes.length == 1) {
|
|
|
|
|
if(Shape.rank(shapeInfo) == 1){
|
|
|
|
|
//1D edge case, with 1 index
|
|
|
|
|
return intendedIndexes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (Shape.isRowVectorShape(shapeInfo) && intendedIndexes.length == 1) {
|
|
|
|
|
INDArrayIndex[] ret = new INDArrayIndex[2];
|
|
|
|
|
ret[0] = NDArrayIndex.point(0);
|
|
|
|
|
long size;
|
|
|
|
|
if (1 == shape[0] && rank == 2)
|
|
|
|
|
size = shape[1];
|
|
|
|
|
else
|
|
|
|
|
size = shape[0];
|
|
|
|
|
ret[1] = validate(size, intendedIndexes[0]);
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
List<INDArrayIndex> retList = new ArrayList<>(intendedIndexes.length);
|
|
|
|
|
for (int i = 0; i < intendedIndexes.length; i++) {
|
|
|
|
|
if (i < rank)
|
|
|
|
|
retList.add(validate(shape[i], intendedIndexes[i]));
|
|
|
|
|
else
|
|
|
|
|
retList.add(intendedIndexes[i]);
|
|
|
|
|
}
|
|
|
|
|
return retList.toArray(new INDArrayIndex[retList.size()]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
List<INDArrayIndex> retList = new ArrayList<>(intendedIndexes.length + 1);
|
|
|
|
|
int numNewAxes = 0;
|
|
|
|
|
|
|
|
|
|
if (Shape.isMatrix(shape) && intendedIndexes.length == 1) {
|
|
|
|
|
retList.add(validate(shape[0], intendedIndexes[0]));
|
|
|
|
|
retList.add(NDArrayIndex.all());
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < intendedIndexes.length; i++) {
|
|
|
|
|
retList.add(validate(shape[i], intendedIndexes[i]));
|
|
|
|
|
if (intendedIndexes[i] instanceof NewAxis)
|
|
|
|
|
numNewAxes++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int length = rank + numNewAxes;
|
|
|
|
|
//fill the rest with all
|
|
|
|
|
while (retList.size() < length)
|
|
|
|
|
retList.add(NDArrayIndex.all());
|
|
|
|
|
|
|
|
|
|
return retList.toArray(new INDArrayIndex[retList.size()]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Given an all index and
|
|
|
|
|
* the intended indexes, return an
|
|
|
|
|
* index array containing a combination of all elements
|
|
|
|
|
* for slicing and overriding particular indexes where necessary
|
|
|
|
|
* @param shape the index containing all elements
|
|
|
|
|
* @param intendedIndexes the indexes specified by the user
|
|
|
|
|
* @return the resolved indexes (containing all where nothing is specified, and the intended index
|
|
|
|
|
* for a particular dimension otherwise)
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex[] resolve(int[] shape, INDArrayIndex... intendedIndexes) {
|
|
|
|
|
return resolve(ArrayUtil.toLongArray(shape), intendedIndexes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public static INDArrayIndex[] resolve(long[] shape, INDArrayIndex... intendedIndexes) {
|
|
|
|
|
/**
|
|
|
|
|
* If it's a vector and index asking for a scalar just return the array
|
|
|
|
|
*/
|
|
|
|
|
if (intendedIndexes.length >= shape.length || Shape.isVector(shape) && intendedIndexes.length == 1) {
|
|
|
|
|
if (Shape.isRowVectorShape(shape) && intendedIndexes.length == 1) {
|
|
|
|
|
INDArrayIndex[] ret = new INDArrayIndex[2];
|
|
|
|
|
ret[0] = NDArrayIndex.point(0);
|
|
|
|
|
long size;
|
|
|
|
|
if (1 == shape[0] && shape.length == 2)
|
|
|
|
|
size = shape[1];
|
|
|
|
|
else
|
|
|
|
|
size = shape[0];
|
|
|
|
|
ret[1] = validate(size, intendedIndexes[0]);
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
List<INDArrayIndex> retList = new ArrayList<>(intendedIndexes.length);
|
|
|
|
|
for (int i = 0; i < intendedIndexes.length; i++) {
|
|
|
|
|
if (i < shape.length)
|
|
|
|
|
retList.add(validate(shape[i], intendedIndexes[i]));
|
|
|
|
|
else
|
|
|
|
|
retList.add(intendedIndexes[i]);
|
|
|
|
|
}
|
|
|
|
|
return retList.toArray(new INDArrayIndex[retList.size()]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
List<INDArrayIndex> retList = new ArrayList<>(intendedIndexes.length + 1);
|
|
|
|
|
int numNewAxes = 0;
|
|
|
|
|
|
|
|
|
|
if (Shape.isMatrix(shape) && intendedIndexes.length == 1) {
|
|
|
|
|
retList.add(validate(shape[0], intendedIndexes[0]));
|
|
|
|
|
retList.add(NDArrayIndex.all());
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < intendedIndexes.length; i++) {
|
|
|
|
|
retList.add(validate(shape[i], intendedIndexes[i]));
|
|
|
|
|
if (intendedIndexes[i] instanceof NewAxis)
|
|
|
|
|
numNewAxes++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int length = shape.length + numNewAxes;
|
|
|
|
|
//fill the rest with all
|
|
|
|
|
while (retList.size() < length)
|
|
|
|
|
retList.add(NDArrayIndex.all());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return retList.toArray(new INDArrayIndex[retList.size()]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected static INDArrayIndex validate(long size, INDArrayIndex index) {
|
|
|
|
|
if ((index instanceof IntervalIndex || index instanceof PointIndex) && size <= index.offset())
|
|
|
|
|
throw new IllegalArgumentException("NDArrayIndex is out of range. Beginning index: " + index.offset()
|
|
|
|
|
+ " must be less than its size: " + size);
|
|
|
|
|
if (index instanceof IntervalIndex && index.end() > size)
|
|
|
|
|
throw new IllegalArgumentException("NDArrayIndex is out of range. End index: " + index.end()
|
|
|
|
|
+ " must be less than its size: " + size);
|
|
|
|
|
if (index instanceof IntervalIndex && size < index.end()) {
|
|
|
|
|
long begin = ((IntervalIndex) index).begin;
|
|
|
|
|
index = NDArrayIndex.interval(begin, index.stride(), size);
|
|
|
|
|
}
|
|
|
|
|
return index;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Given an all index and
|
|
|
|
|
* the intended indexes, return an
|
|
|
|
|
* index array containing a combination of all elements
|
|
|
|
|
* for slicing and overriding particular indexes where necessary
|
|
|
|
|
* @param allIndex the index containing all elements
|
|
|
|
|
* @param intendedIndexes the indexes specified by the user
|
|
|
|
|
* @return the resolved indexes (containing all where nothing is specified, and the intended index
|
|
|
|
|
* for a particular dimension otherwise)
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex[] resolve(INDArrayIndex[] allIndex, INDArrayIndex... intendedIndexes) {
|
|
|
|
|
|
|
|
|
|
int numNewAxes = numNewAxis(intendedIndexes);
|
|
|
|
|
INDArrayIndex[] all = new INDArrayIndex[allIndex.length + numNewAxes];
|
|
|
|
|
Arrays.fill(all, NDArrayIndex.all());
|
|
|
|
|
for (int i = 0; i < allIndex.length; i++) {
|
|
|
|
|
//collapse single length indexes in to point indexes
|
|
|
|
|
if (i >= intendedIndexes.length)
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
|
if (intendedIndexes[i] instanceof NDArrayIndex) {
|
|
|
|
|
NDArrayIndex idx = (NDArrayIndex) intendedIndexes[i];
|
|
|
|
|
if (idx.indices.length == 1)
|
|
|
|
|
intendedIndexes[i] = new PointIndex(idx.indices[0]);
|
|
|
|
|
}
|
|
|
|
|
all[i] = intendedIndexes[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return all;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Given an array of indexes
|
|
|
|
|
* return the number of new axis elements
|
|
|
|
|
* in teh array
|
|
|
|
|
* @param axes the indexes to get the number
|
|
|
|
|
* of new axes for
|
|
|
|
|
* @return the number of new axis elements in the given array
|
|
|
|
|
*/
|
|
|
|
|
public static int numNewAxis(INDArrayIndex... axes) {
|
|
|
|
|
int ret = 0;
|
|
|
|
|
for (INDArrayIndex index : axes)
|
|
|
|
|
if (index instanceof NewAxis)
|
|
|
|
|
ret++;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Generate an all index
|
|
|
|
|
* equal to the rank of the given array
|
|
|
|
|
* @param arr the array to generate the all index for
|
|
|
|
|
* @return an ndarray index array containing of length
|
|
|
|
|
* arr.rank() containing all elements
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex[] allFor(INDArray arr) {
|
|
|
|
|
INDArrayIndex[] ret = new INDArrayIndex[arr.rank()];
|
|
|
|
|
for (int i = 0; i < ret.length; i++)
|
|
|
|
|
ret[i] = NDArrayIndex.all();
|
|
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Creates an index covering the given shape
|
|
|
|
|
* (for each dimension 0,shape[i])
|
|
|
|
|
* @param shape the shape to cover
|
|
|
|
|
* @return the ndarray indexes to cover
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex[] createCoveringShape(int[] shape) {
|
|
|
|
|
INDArrayIndex[] ret = new INDArrayIndex[shape.length];
|
|
|
|
|
for (int i = 0; i < ret.length; i++) {
|
|
|
|
|
ret[i] = NDArrayIndex.interval(0, shape[i]);
|
|
|
|
|
}
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public static INDArrayIndex[] createCoveringShape(long[] shape) {
|
|
|
|
|
INDArrayIndex[] ret = new INDArrayIndex[shape.length];
|
|
|
|
|
for (int i = 0; i < ret.length; i++) {
|
|
|
|
|
ret[i] = NDArrayIndex.interval(0, shape[i]);
|
|
|
|
|
}
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Create a range based on the given indexes.
|
|
|
|
|
* This is similar to create covering shape in that it approximates
|
|
|
|
|
* the length of each dimension (ignoring elements) and
|
|
|
|
|
* reproduces an index of the same dimension and length.
|
|
|
|
|
*
|
|
|
|
|
* @param indexes the indexes to create the range for
|
|
|
|
|
* @return the index ranges.
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex[] rangeOfLength(INDArrayIndex[] indexes) {
|
|
|
|
|
INDArrayIndex[] indexesRet = new INDArrayIndex[indexes.length];
|
|
|
|
|
for (int i = 0; i < indexes.length; i++)
|
|
|
|
|
indexesRet[i] = NDArrayIndex.interval(0, indexes[i].length());
|
|
|
|
|
return indexesRet;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Generates an interval from begin (inclusive) to end (exclusive)
|
|
|
|
|
*
|
|
|
|
|
* @param begin the begin
|
|
|
|
|
* @param stride the stride at which to increment
|
|
|
|
|
* @param end the end index
|
|
|
|
|
* @param max the max length for this domain
|
|
|
|
|
* @return the interval
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex interval(long begin, long stride, long end,long max) {
|
|
|
|
|
if(begin < 0) {
|
|
|
|
|
begin += max;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if(end < 0) {
|
|
|
|
|
end += max;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (Math.abs(begin - end) < 1)
|
|
|
|
|
end++;
|
|
|
|
|
if (stride > 1 && Math.abs(begin - end) == 1) {
|
|
|
|
|
end *= stride;
|
|
|
|
|
}
|
|
|
|
|
return interval(begin, stride, end, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Generates an interval from begin (inclusive) to end (exclusive)
|
|
|
|
|
*
|
|
|
|
|
* @param begin the begin
|
|
|
|
|
* @param stride the stride at which to increment
|
|
|
|
|
* @param end the end index
|
|
|
|
|
* @return the interval
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex interval(long begin, long stride, long end) {
|
|
|
|
|
if (Math.abs(begin - end) < 1)
|
|
|
|
|
end++;
|
|
|
|
|
if (stride > 1 && Math.abs(begin - end) == 1) {
|
|
|
|
|
end *= stride;
|
|
|
|
|
}
|
|
|
|
|
return interval(begin, stride, end, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Generates an interval from begin (inclusive) to end (exclusive)
|
|
|
|
|
*
|
|
|
|
|
* @param begin the begin
|
|
|
|
|
* @param stride the stride at which to increment
|
|
|
|
|
* @param end the end index
|
|
|
|
|
* @param inclusive whether the end should be inclusive or not
|
|
|
|
|
* @return the interval
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex interval(int begin, int stride, int end, boolean inclusive) {
|
|
|
|
|
Preconditions.checkArgument(begin <= end, "Beginning index (%s) in range must be less than or equal to end (%s)", begin, end);
|
|
|
|
|
INDArrayIndex index = new IntervalIndex(inclusive, stride);
|
|
|
|
|
index.init(begin, end);
|
|
|
|
|
return index;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static INDArrayIndex interval(long begin, long stride, long end,long max, boolean inclusive) {
|
|
|
|
|
Preconditions.checkArgument(begin <= end, "Beginning index (%s) in range must be less than or equal to end (%s)", begin, end);
|
|
|
|
|
INDArrayIndex index = new IntervalIndex(inclusive, stride);
|
|
|
|
|
index.init(begin, end);
|
|
|
|
|
return index;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static INDArrayIndex interval(long begin, long stride, long end, boolean inclusive) {
|
|
|
|
|
Preconditions.checkArgument(begin <= end, "Beginning index (%s) in range must be less than or equal to end (%s)", begin, end);
|
|
|
|
|
INDArrayIndex index = new IntervalIndex(inclusive, stride);
|
|
|
|
|
index.init(begin, end);
|
|
|
|
|
return index;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Generates an interval from begin (inclusive) to end (exclusive)
|
|
|
|
|
*
|
|
|
|
|
* @param begin the begin
|
|
|
|
|
* @param end the end index
|
|
|
|
|
* @return the interval
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex interval(int begin, int end) {
|
|
|
|
|
return interval(begin, 1, end, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public static INDArrayIndex interval(long begin, long end) {
|
|
|
|
|
return interval(begin, 1, end, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Generates an interval from begin (inclusive) to end (exclusive)
|
|
|
|
|
*
|
|
|
|
|
* @param begin the begin
|
|
|
|
|
* @param end the end index
|
|
|
|
|
* @param inclusive whether the end should be inclusive or not
|
|
|
|
|
* @return the interval
|
|
|
|
|
*/
|
|
|
|
|
public static INDArrayIndex interval(long begin, long end, boolean inclusive) {
|
|
|
|
|
return interval(begin, 1, end, inclusive);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public long end() {
|
|
|
|
|
if (indices != null && indices.length > 0)
|
|
|
|
|
return indices[indices.length - 1];
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public long offset() {
|
|
|
|
|
if (indices.length < 1)
|
|
|
|
|
return 0;
|
|
|
|
|
return indices[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Returns the length of the indices
|
|
|
|
|
*
|
|
|
|
|
* @return the length of the range
|
|
|
|
|
*/
|
|
|
|
|
@Override
|
|
|
|
|
public long length() {
|
|
|
|
|
return indices.length;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public long stride() {
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public void reverse() {
|
|
|
|
|
ArrayUtil.reverse(indices);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public String toString() {
|
|
|
|
|
return "NDArrayIndex{" + "indices=" + Arrays.toString(indices) + '}';
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public boolean equals(Object o) {
|
|
|
|
|
if (this == o)
|
|
|
|
|
return true;
|
|
|
|
|
if (!(o instanceof INDArrayIndex))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
NDArrayIndex that = (NDArrayIndex) o;
|
|
|
|
|
|
2022-10-21 15:19:32 +02:00
|
|
|
return Arrays.equals(indices, that.indices);
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public int hashCode() {
|
|
|
|
|
return Arrays.hashCode(indices);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public void init(INDArray arr, long begin, int dimension) {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public void init(INDArray arr, int dimension) {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public void init(long begin, long end, long max) {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public void init(long begin, long end) {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|