Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-30 21:40:27 +09:00 committed by Alex Black
parent 62e96c9724
commit 54e320a255
6 changed files with 65 additions and 350 deletions

View File

@ -1711,16 +1711,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} }
/**
* Returns the element at the specified row/column
* This will throw an exception if the
*
* @param row the row of the element to return
* @param column the row of the element to return
* @return a scalar indarray of the element at this index
*/
@Override @Override
public INDArray getScalar(long row, long column) { public INDArray getScalar(long row, long column) {
return getScalar(new long[] {row, column}); return getScalar(new long[] {row, column});
@ -1885,27 +1875,20 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
/**
* Inserts the element at the specified index
*
* @param indices the indices to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override @Override
public INDArray put(int[] indices, INDArray element) { public INDArray put(int[] indices, INDArray element) {
Nd4j.getCompressor().autoDecompress(this); Nd4j.getCompressor().autoDecompress(this);
if (!element.isScalar()) if (!element.isScalar())
throw new IllegalArgumentException("Unable to insert anything but a scalar"); throw new IllegalArgumentException("Unable to insert anything but a scalar");
if (isRowVector() && indices[0] == 0 && indices.length == 2) { if (isRowVector() && indices[0] == 0 && indices.length == 2) {
int ix = 0; //Shape.offset(javaShapeInformation); int ix = 0;
for (int i = 1; i < indices.length; i++) for (int i = 1; i < indices.length; i++)
ix += indices[i] * stride(i); ix += indices[i] * stride(i);
if (ix >= data.length()) if (ix >= data.length())
throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices));
data.put(ix, element.getDouble(0)); data.put(ix, element.getDouble(0));
} else { } else {
int ix = 0; //Shape.offset(javaShapeInformation); int ix = 0;
for (int i = 0; i < indices.length; i++) for (int i = 0; i < indices.length; i++)
if (size(i) != 1) if (size(i) != 1)
ix += indices[i] * stride(i); ix += indices[i] * stride(i);
@ -1913,10 +1896,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices));
data.put(ix, element.getDouble(0)); data.put(ix, element.getDouble(0));
} }
return this; return this;
} }
@Override @Override
@ -1970,39 +1950,16 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return putWhereWithMask(mask,Nd4j.scalar(put)); return putWhereWithMask(mask,Nd4j.scalar(put));
} }
/**
* Inserts the element at the specified index
*
* @param i the row insert into
* @param j the column to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override @Override
public INDArray put(int i, int j, INDArray element) { public INDArray put(int i, int j, INDArray element) {
return put(new int[] {i, j}, element); return put(new int[] {i, j}, element);
} }
/**
* Inserts the element at the specified index
*
* @param i the row insert into
* @param j the column to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override @Override
public INDArray put(int i, int j, Number element) { public INDArray put(int i, int j, Number element) {
return putScalar(new int[] {i, j}, element.doubleValue()); return putScalar(new int[] {i, j}, element.doubleValue());
} }
/**
* Assigns the given matrix (put) to the specified slice
*
* @param slice the slice to assign
* @param put the slice to put
* @return this for chainability
*/
@Override @Override
public INDArray putSlice(int slice, INDArray put) { public INDArray putSlice(int slice, INDArray put) {
Nd4j.getCompressor().autoDecompress(this); Nd4j.getCompressor().autoDecompress(this);
@ -2102,10 +2059,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getStrides(shape, ordering); return Nd4j.getStrides(shape, ordering);
} }
/**
* Returns the square of the Euclidean distance.
*/
@Override @Override
public double squaredDistance(INDArray other) { public double squaredDistance(INDArray other) {
validateNumericalArray("squaredDistance", false); validateNumericalArray("squaredDistance", false);
@ -2113,9 +2066,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return d2 * d2; return d2 * d2;
} }
/**
* Returns the (euclidean) distance.
*/
@Override @Override
public double distance2(INDArray other) { public double distance2(INDArray other) {
validateNumericalArray("distance2", false); validateNumericalArray("distance2", false);
@ -2123,9 +2073,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this, other)).getFinalResult().doubleValue(); return Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this, other)).getFinalResult().doubleValue();
} }
/**
* Returns the (1-norm) distance.
*/
@Override @Override
public double distance1(INDArray other) { public double distance1(INDArray other) {
validateNumericalArray("distance1", false); validateNumericalArray("distance1", false);
@ -2133,8 +2080,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this, other)).getFinalResult().doubleValue(); return Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this, other)).getFinalResult().doubleValue();
} }
@Override @Override
public INDArray get(INDArray indices) { public INDArray get(INDArray indices) {
if(indices.rank() > 2) { if(indices.rank() > 2) {
@ -2190,49 +2135,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
@Override
public INDArray get(List<List<Integer>> indices) {
INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()];
for(int i = 0; i < indArrayIndices.length; i++) {
indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i)));
}
boolean hasNext = true;
Generator<List<List<Long>>> iterate = SpecifiedIndex.iterate(indArrayIndices);
List<INDArray> resultList = new ArrayList<>();
while(hasNext) {
try {
List<List<Long>> next = iterate.next();
int[][] nextArr = new int[next.size()][];
for(int i = 0; i < next.size(); i++) {
nextArr[i] = Ints.toArray(next.get(i));
}
int[] curr = Ints.concat(nextArr);
INDArray currSlice = this;
for(int j = 0; j < curr.length; j++) {
currSlice = currSlice.slice(curr[j]);
}
//slice drops the first dimension, this adds a 1 to match normal numpy behavior
currSlice = currSlice.reshape(Longs.concat(new long[]{1},currSlice.shape()));
resultList.add(currSlice);
}
catch(NoSuchElementException e) {
hasNext = false;
}
}
return Nd4j.concat(0,resultList.toArray(new INDArray[resultList.size()]));
}
@Override @Override
public INDArray put(INDArray indices, INDArray element) { public INDArray put(INDArray indices, INDArray element) {
if(indices.rank() > 2) { if(indices.rank() > 2) {
@ -2245,7 +2147,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next())); putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next()));
} }
} }
else { else {
List<INDArray> arrList = new ArrayList<>(); List<INDArray> arrList = new ArrayList<>();
@ -2258,8 +2159,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
Nd4j.getExecutioner().execAndReturn(new Assign(new INDArray[]{slice,element},new INDArray[]{slice})); Nd4j.getExecutioner().execAndReturn(new Assign(new INDArray[]{slice,element},new INDArray[]{slice}));
arrList.add(slice(row.getInt(j))); arrList.add(slice(row.getInt(j)));
} }
} }
} }
else if(indices.isRowVector()) { else if(indices.isRowVector()) {
@ -2267,15 +2166,10 @@ public abstract class BaseNDArray implements INDArray, Iterable {
arrList.add(slice(indices.getInt(i))); arrList.add(slice(indices.getInt(i)));
} }
} }
} }
return this; return this;
} }
@Override @Override
public INDArray put(INDArrayIndex[] indices, INDArray element) { public INDArray put(INDArrayIndex[] indices, INDArray element) {
Nd4j.getCompressor().autoDecompress(this); Nd4j.getCompressor().autoDecompress(this);
@ -2343,7 +2237,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return this; return this;
} }
/** /**
* Mainly here for people coming from numpy. * Mainly here for people coming from numpy.
* This is equivalent to a call to permute * This is equivalent to a call to permute
@ -2446,7 +2339,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return ret; return ret;
} }
protected void init(int[] shape, int[] stride) { protected void init(int[] shape, int[] stride) {
//null character //null character
if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') { if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') {
@ -2466,7 +2358,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
@Override @Override
public INDArray getScalar(long i) { public INDArray getScalar(long i) {
if (i >= this.length()) if (i >= this.length())
@ -2936,13 +2827,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return dup().rsubiRowVector(rowVector); return dup().rsubiRowVector(rowVector);
} }
/**
* Inserts the element at the specified index
*
* @param i the index insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override @Override
public INDArray put(int i, INDArray element) { public INDArray put(int i, INDArray element) {
Preconditions.checkArgument(element.isScalar(), "Element must be a scalar: element has shape %ndShape", element); Preconditions.checkArgument(element.isScalar(), "Element must be a scalar: element has shape %ndShape", element);
@ -3706,12 +3590,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return normmax(false, dimension); return normmax(false, dimension);
} }
/**
* Reverse division
*
* @param other the matrix to divide from
* @return
*/
@Override @Override
public INDArray rdiv(INDArray other) { public INDArray rdiv(INDArray other) {
validateNumericalArray("rdiv", false); validateNumericalArray("rdiv", false);
@ -3722,37 +3600,17 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
/**
* Reverse divsion (in place)
*
* @param other
* @return
*/
@Override @Override
public INDArray rdivi(INDArray other) { public INDArray rdivi(INDArray other) {
return rdivi(other, this); return rdivi(other, this);
} }
/**
* Reverse division
*
* @param other the matrix to subtract from
* @param result the result ndarray
* @return
*/
@Override @Override
public INDArray rdiv(INDArray other, INDArray result) { public INDArray rdiv(INDArray other, INDArray result) {
validateNumericalArray("rdiv", false); validateNumericalArray("rdiv", false);
return dup().rdivi(other, result); return dup().rdivi(other, result);
} }
/**
* Reverse division (in-place)
*
* @param other the other ndarray to subtract
* @param result the result ndarray
* @return the ndarray with the operation applied
*/
@Override @Override
public INDArray rdivi(INDArray other, INDArray result) { public INDArray rdivi(INDArray other, INDArray result) {
validateNumericalArray("rdivi", false); validateNumericalArray("rdivi", false);
@ -3761,23 +3619,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} }
/**
* Reverse subtraction
*
* @param other the matrix to subtract from
* @param result the result ndarray
* @return
*/
@Override @Override
public INDArray rsub(INDArray other, INDArray result) { public INDArray rsub(INDArray other, INDArray result) {
validateNumericalArray("rsub", false); validateNumericalArray("rsub", false);
return rsubi(other, result); return rsubi(other, result);
} }
/**
* @param other
* @return
*/
@Override @Override
public INDArray rsub(INDArray other) { public INDArray rsub(INDArray other) {
validateNumericalArray("rsub", false); validateNumericalArray("rsub", false);
@ -3788,22 +3635,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
/**
* @param other
* @return
*/
@Override @Override
public INDArray rsubi(INDArray other) { public INDArray rsubi(INDArray other) {
return rsubi(other, this); return rsubi(other, this);
} }
/**
* Reverse subtraction (in-place)
*
* @param other the other ndarray to subtract
* @param result the result ndarray
* @return the ndarray with the operation applied
*/
@Override @Override
public INDArray rsubi(INDArray other, INDArray result) { public INDArray rsubi(INDArray other, INDArray result) {
validateNumericalArray("rsubi", false); validateNumericalArray("rsubi", false);
@ -3812,12 +3648,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} }
/**
* Set the value of the ndarray to the specified value
*
* @param value the value to assign
* @return the ndarray with the values
*/
@Override @Override
public INDArray assign(Number value) { public INDArray assign(Number value) {
Preconditions.checkState(dataType() != DataType.BOOL || value.doubleValue() == 0.0 || value.doubleValue() == 1.0, "Only values 0 or 1 are allowed for scalar " + Preconditions.checkState(dataType() != DataType.BOOL || value.doubleValue() == 0.0 || value.doubleValue() == 1.0, "Only values 0 or 1 are allowed for scalar " +
@ -3826,7 +3656,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return this; return this;
} }
@Override @Override
public INDArray assign(boolean value) { public INDArray assign(boolean value) {
return assign(value ? 1 : 0); return assign(value ? 1 : 0);
@ -3846,7 +3675,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
@Override @Override
@Deprecated @Deprecated //TODO: Investigate. Not deprecated in the base interface.
public long linearIndex(long i) { public long linearIndex(long i) {
long idx = i; long idx = i;
for (int j = 0; j < jvmShapeInfo.rank - 1; j++) { for (int j = 0; j < jvmShapeInfo.rank - 1; j++) {
@ -4073,15 +3902,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return addi(n, this); return addi(n, this);
} }
/**
* Replicate and tile array to fill out to the given shape
* See:
* https://github.com/numpy/numpy/blob/master/numpy/matlib.py#L310-L358
* @param shape the new shape of this ndarray
* @return the shape to fill out to
*/
@Override @Override
public INDArray repmat(int[] shape) { public INDArray repmat(int[] shape) {
Nd4j.getCompressor().autoDecompress(this); Nd4j.getCompressor().autoDecompress(this);
@ -4109,16 +3929,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return out; return out;
} }
/**
* Insert a row in to this array
* Will throw an exception if this
* ndarray is not a matrix
*
* @param row the row insert into
* @param toPut the row to insert
* @return this
*/
@Override @Override
public INDArray putRow(long row, INDArray toPut) { public INDArray putRow(long row, INDArray toPut) {
if (isRowVector() && toPut.isVector()) { if (isRowVector() && toPut.isVector()) {
@ -4127,15 +3937,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return put(new INDArrayIndex[] {NDArrayIndex.point(row), NDArrayIndex.all()}, toPut); return put(new INDArrayIndex[] {NDArrayIndex.point(row), NDArrayIndex.all()}, toPut);
} }
/**
* Insert a column in to this array
* Will throw an exception if this
* ndarray is not a matrix
*
* @param column the column to insert
* @param toPut the array to put
* @return this
*/
@Override @Override
public INDArray putColumn(int column, INDArray toPut) { public INDArray putColumn(int column, INDArray toPut) {
Nd4j.getCompressor().autoDecompress(this); Nd4j.getCompressor().autoDecompress(this);
@ -4752,11 +4553,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return reshape(length()); return reshape(length());
} }
/**
* Flattens the array for linear indexing
*
* @return the flattened version of this array
*/
@Override @Override
public void sliceVectors(List<INDArray> list) { public void sliceVectors(List<INDArray> list) {
if (isVector()) if (isVector())
@ -4804,12 +4600,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return col.reshape(col.length(), 1); return col.reshape(col.length(), 1);
} }
/**
* Get whole rows from the passed indices.
*
* @param rindices
*/
@Override @Override
public INDArray getRows(int[] rindices) { public INDArray getRows(int[] rindices) {
Nd4j.getCompressor().autoDecompress(this); Nd4j.getCompressor().autoDecompress(this);
@ -4826,13 +4616,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
/**
* Returns a subset of this array based on the specified
* indexes
*
* @param indexes the indexes in to the array
* @return a view of the array with the specified indices
*/
@Override @Override
public INDArray get(INDArrayIndex... indexes) { public INDArray get(INDArrayIndex... indexes) {
Nd4j.getCompressor().autoDecompress(this); Nd4j.getCompressor().autoDecompress(this);
@ -5020,13 +4803,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return out; return out;
} }
/**
* Get whole columns
* from the passed indices.
*
* @param cindices
*/
@Override @Override
public INDArray getColumns(int... cindices) { public INDArray getColumns(int... cindices) {
if (!isMatrix() && !isVector()) if (!isMatrix() && !isVector())

View File

@ -153,7 +153,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
arrList.set(j,put); arrList.set(j,put);
} }
} }
} }
} }
else if(indices.isRowVector()) { else if(indices.isRowVector()) {
@ -161,12 +160,8 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
arrList.add(slice(indices.getInt(i))); arrList.add(slice(indices.getInt(i)));
} }
} }
return Nd4j.concat(0,arrList.toArray(new INDArray[arrList.size()])); return Nd4j.concat(0,arrList.toArray(new INDArray[arrList.size()]));
} }
} }
@Override @Override
@ -259,21 +254,13 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
return null; return null;
} }
@Override
public INDArray get(List<List<Integer>> indices) {
return null;
}
@Override @Override
public INDArray put(INDArray indices, INDArray element) { public INDArray put(INDArray indices, INDArray element) {
INDArrayIndex[] realIndices = new INDArrayIndex[indices.rank()]; INDArrayIndex[] realIndices = new INDArrayIndex[indices.rank()];
for(int i = 0; i < realIndices.length; i++) { for(int i = 0; i < realIndices.length; i++) {
realIndices[i] = new SpecifiedIndex(indices.slice(i).dup().data().asInt()); realIndices[i] = new SpecifiedIndex(indices.slice(i).dup().data().asInt());
} }
return put(realIndices,element); return put(realIndices,element);
} }
@Override @Override

View File

@ -568,13 +568,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
return this; return this;
} }
/**
* Returns a subset of this array based on the specified
* indexes
*
* @param indexes the indexes in to the array
* @return a view of the array with the specified indices
*/
@Override @Override
public INDArray get(INDArrayIndex... indexes) { public INDArray get(INDArrayIndex... indexes) {

View File

@ -124,14 +124,6 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
return this; return this;
} }
/**
* Returns a subset of this array based on the specified
* indexes
*
* @param indexes the indexes in to the array
* @return a view of the array with the specified indices
*/
@Override @Override
public INDArray get(INDArrayIndex... indexes) { public INDArray get(INDArrayIndex... indexes) {
//check for row/column vector and point index being 0 //check for row/column vector and point index being 0

View File

@ -458,7 +458,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray rsub(Number n); INDArray rsub(Number n);
/** /**
* Reverse subtraction in place - i.e., (n - thisArrayValues) * Reverse subtraction in place - i.e., (n - thisArrayValues)
* *
@ -467,7 +466,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray rsubi(Number n); INDArray rsubi(Number n);
/** /**
* Division by a number * Division by a number
* *
@ -484,7 +482,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray divi(Number n); INDArray divi(Number n);
/** /**
* Scalar multiplication (copy) * Scalar multiplication (copy)
* *
@ -501,7 +498,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray muli(Number n); INDArray muli(Number n);
/** /**
* Scalar subtraction (copied) * Scalar subtraction (copied)
* *
@ -510,7 +506,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray sub(Number n); INDArray sub(Number n);
/** /**
* In place scalar subtraction * In place scalar subtraction
* *
@ -535,7 +530,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray addi(Number n); INDArray addi(Number n);
/** /**
* Reverse division (number / ndarray) * Reverse division (number / ndarray)
* *
@ -545,7 +539,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray rdiv(Number n, INDArray result); INDArray rdiv(Number n, INDArray result);
/** /**
* Reverse in place division * Reverse in place division
* *
@ -573,11 +566,12 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray rsubi(Number n, INDArray result); INDArray rsubi(Number n, INDArray result);
/** /**
* @param n * Division if ndarray by number
* @param result *
* @return * @param n the number to divide by
* @param result the result ndarray
* @return the result ndarray
*/ */
INDArray div(Number n, INDArray result); INDArray div(Number n, INDArray result);
@ -586,24 +580,35 @@ public interface INDArray extends Serializable, AutoCloseable {
* *
* @param n the number to divide by * @param n the number to divide by
* @param result the result ndarray * @param result the result ndarray
* @return * @return the result ndarray
*/ */
INDArray divi(Number n, INDArray result); INDArray divi(Number n, INDArray result);
/**
* Multiplication of ndarray.
*
* @param n the number to multiply by
* @param result the result ndarray
* @return the result ndarray
*/
INDArray mul(Number n, INDArray result); INDArray mul(Number n, INDArray result);
/** /**
* In place multiplication of this ndarray * In place multiplication of this ndarray
* *
* @param n the number to divide by * @param n the number to divide by
* @param result the result ndarray * @param result the result ndarray
* @return * @return the result ndarray
*/ */
INDArray muli(Number n, INDArray result); INDArray muli(Number n, INDArray result);
/**
* Subtraction of this ndarray
*
* @param n the number to subtract by
* @param result the result ndarray
* @return the result ndarray
*/
INDArray sub(Number n, INDArray result); INDArray sub(Number n, INDArray result);
/** /**
@ -615,6 +620,12 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray subi(Number n, INDArray result); INDArray subi(Number n, INDArray result);
/**
* Addition of this ndarray.
* @param n the number to add
* @param result the result ndarray
* @return the result ndarray
*/
INDArray add(Number n, INDArray result); INDArray add(Number n, INDArray result);
/** /**
@ -626,25 +637,24 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray addi(Number n, INDArray result); INDArray addi(Number n, INDArray result);
/** /**
* Returns a subset of this array based on the specified * Returns a subset of this array based on the specified indexes
* indexes
* *
* @param indexes the indexes in to the array * @param indexes the indexes in to the array
* @return a view of the array with the specified indices * @return a view of the array with the specified indices
*/ */
INDArray get(INDArrayIndex... indexes); INDArray get(INDArrayIndex... indexes);
//TODO: revisit after #8166 is resolved.
/** /**
* Return a mask on whether each element * Return a mask on whether each element matches the given condition
* matches the given condition
* @param comp * @param comp
* @param condition * @param condition
* @return * @return
*/ */
INDArray match(INDArray comp,Condition condition); INDArray match(INDArray comp,Condition condition);
//TODO: revisit after #8166 is resolved.
/** /**
* Returns a mask * Returns a mask
* @param comp * @param comp
@ -673,54 +683,51 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray getWhere(Number comp,Condition condition); INDArray getWhere(Number comp,Condition condition);
//TODO: unused / untested method. (only used to forward calls from putWhere(Number,INDArray ,Condition).
/** /**
* Assign the element according * Assign the element according to the comparison array
* to the comparison array
* @param comp the comparison array * @param comp the comparison array
* @param put the elements to put * @param put the elements to put
* @param condition the condition for masking on * @param condition the condition for masking on
* @return * @return a copy of this array with the conditional assignments.
*/ */
INDArray putWhere(INDArray comp,INDArray put,Condition condition); INDArray putWhere(INDArray comp,INDArray put,Condition condition);
//TODO: unused / untested method.
/** /**
* Assign the element according * Assign the element according to the comparison array
* to the comparison array
* @param comp the comparison array * @param comp the comparison array
* @param put the elements to put * @param put the elements to put
* @param condition the condition for masking on * @param condition the condition for masking on
* @return * @return a copy of this array with the conditional assignments.
*/ */
INDArray putWhere(Number comp,INDArray put,Condition condition); INDArray putWhere(Number comp,INDArray put,Condition condition);
//TODO: unused / untested method. (only used to forward calls from other putWhereWithMask implementations.
/** /**
* Use a pre computed mask * Use a pre computed mask for assigning arrays
* for assigning arrays
* @param mask the mask to use * @param mask the mask to use
* @param put the array to put * @param put the array to put
* @return the resulting array * @return a copy of this array with the conditional assignments.
*/ */
INDArray putWhereWithMask(INDArray mask,INDArray put); INDArray putWhereWithMask(INDArray mask,INDArray put);
//TODO: unused / untested method.
/** /**
* Use a pre computed mask * Use a pre computed mask for assigning arrays
* for assigning arrays
* @param mask the mask to use * @param mask the mask to use
* @param put the array to put * @param put the array to put
* @return the resulting array * @return a copy of this array with the conditional assignments.
*/ */
INDArray putWhereWithMask(INDArray mask,Number put); INDArray putWhereWithMask(INDArray mask,Number put);
//TODO: unused / untested method.
/** /**
* Assign the element according * Assign the element according to the comparison array
* to the comparison array
* @param comp the comparison array * @param comp the comparison array
* @param put the elements to put * @param put the elements to put
* @param condition the condition for masking on * @param condition the condition for masking on
* @return * @return a copy of this array with the conditional assignments.
*/ */
INDArray putWhere(Number comp,Number put,Condition condition); INDArray putWhere(Number comp,Number put,Condition condition);
@ -731,14 +738,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray get(INDArray indices); INDArray get(INDArray indices);
/**
* Get the elements from this ndarray based on the specified indices
* @param indices an ndaray of the indices to get the elements for
* @return the elements to get the array for
*/
@Deprecated
INDArray get(List<List<Integer>> indices);
/** /**
* Get an INDArray comprised of the specified columns only. Copy operation. * Get an INDArray comprised of the specified columns only. Copy operation.
* *
@ -771,20 +770,20 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray rdivi(INDArray other); INDArray rdivi(INDArray other);
//TODO: unused / untested method.
/** /**
* Reverse division * Reverse division
* *
* @param other the matrix to subtract from * @param other the matrix to divide from
* @param result the result ndarray * @param result the result ndarray
* @return * @return the result ndarray
*/ */
INDArray rdiv(INDArray other, INDArray result); INDArray rdiv(INDArray other, INDArray result);
/** /**
* Reverse division (in-place) * Reverse division (in-place)
* *
* @param other the other ndarray to subtract * @param other the matrix to divide from
* @param result the result ndarray * @param result the result ndarray
* @return the ndarray with the operation applied * @return the ndarray with the operation applied
*/ */
@ -795,11 +794,10 @@ public interface INDArray extends Serializable, AutoCloseable {
* *
* @param other the matrix to subtract from * @param other the matrix to subtract from
* @param result the result ndarray * @param result the result ndarray
* @return * @return the result ndarray
*/ */
INDArray rsub(INDArray other, INDArray result); INDArray rsub(INDArray other, INDArray result);
/** /**
* Element-wise reverse subtraction (copy op). i.e., other - this * Element-wise reverse subtraction (copy op). i.e., other - this
* *
@ -842,22 +840,19 @@ public interface INDArray extends Serializable, AutoCloseable {
INDArray assign(boolean value); INDArray assign(boolean value);
/** /**
* Get the linear index of the data in to * Get the linear index of the data in to the array
* the array
* *
* @param i the index to getScalar * @param i the index to getScalar
* @return the linear index in to the data * @return the linear index in to the data
*/ */
long linearIndex(long i); long linearIndex(long i);
//TODO: unused / untested method. only used recursively.
/** /**
* * Flattens the array for linear indexing in list.
* @param list
*/ */
void sliceVectors(List<INDArray> list); void sliceVectors(List<INDArray> list);
/** /**
* Assigns the given matrix (put) to the specified slice * Assigns the given matrix (put) to the specified slice
* *
@ -875,16 +870,15 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray cond(Condition condition); INDArray cond(Condition condition);
/** /**
* Replicate and tile array to fill out to the given shape * Replicate and tile array to fill out to the given shape
* * See:
* https://github.com/numpy/numpy/blob/master/numpy/matlib.py#L310-L358
* @param shape the new shape of this ndarray * @param shape the new shape of this ndarray
* @return the shape to fill out to * @return the shape to fill out to
*/ */
INDArray repmat(int... shape); INDArray repmat(int... shape);
/** /**
* Repeat elements along a specified dimension. * Repeat elements along a specified dimension.
* *
@ -894,11 +888,9 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray repeat(int dimension, long... repeats); INDArray repeat(int dimension, long... repeats);
/** /**
* Insert a row in to this array * Insert a row in to this array
* Will throw an exception if this * Will throw an exception if this ndarray is not a matrix
* ndarray is not a matrix
* *
* @param row the row insert into * @param row the row insert into
* @param toPut the row to insert * @param toPut the row to insert
@ -908,8 +900,7 @@ public interface INDArray extends Serializable, AutoCloseable {
/** /**
* Insert a column in to this array * Insert a column in to this array
* Will throw an exception if this * Will throw an exception if this ndarray is not a matrix
* ndarray is not a matrix
* *
* @param column the column to insert * @param column the column to insert
* @param toPut the array to put * @param toPut the array to put
@ -919,7 +910,6 @@ public interface INDArray extends Serializable, AutoCloseable {
/** /**
* Returns the element at the specified row/column * Returns the element at the specified row/column
* This will throw an exception if the
* *
* @param row the row of the element to return * @param row the row of the element to return
* @param column the row of the element to return * @param column the row of the element to return
@ -950,26 +940,20 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
double distance1(INDArray other); double distance1(INDArray other);
/** /**
* Put element in to the indices denoted by * Put element in to the indices denoted by
* the indices ndarray. * the indices ndarray.
* This is equivalent to: * In numpy this is equivalent to:
* a[indices] = element * a[indices] = element
* *
* in numpy.
*
* @param indices the indices to put * @param indices the indices to put
* @param element the element array to put * @param element the element array to put
* @return this array * @return this array
*/ */
INDArray put(INDArray indices,INDArray element); INDArray put(INDArray indices,INDArray element);
/** /**
* Put the elements of the ndarray * Put the elements of the ndarray in to the specified indices
* in to the specified indices
* *
* @param indices the indices to put the ndarray in to * @param indices the indices to put the ndarray in to
* @param element the ndarray to put * @param element the ndarray to put
@ -978,8 +962,7 @@ public interface INDArray extends Serializable, AutoCloseable {
INDArray put(INDArrayIndex[] indices, INDArray element); INDArray put(INDArrayIndex[] indices, INDArray element);
/** /**
* Put the elements of the ndarray * Put the elements of the ndarray in to the specified indices
* in to the specified indices
* *
* @param indices the indices to put the ndarray in to * @param indices the indices to put the ndarray in to
* @param element the ndarray to put * @param element the ndarray to put
@ -1007,7 +990,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray put(int i, int j, INDArray element); INDArray put(int i, int j, INDArray element);
/** /**
* Inserts the element at the specified index * Inserts the element at the specified index
* *
@ -1018,7 +1000,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray put(int i, int j, Number element); INDArray put(int i, int j, Number element);
/** /**
* Inserts the element at the specified index * Inserts the element at the specified index
* *
@ -1028,7 +1009,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray put(int i, INDArray element); INDArray put(int i, INDArray element);
/** /**
* In place division of a column vector * In place division of a column vector
* *

View File

@ -101,19 +101,6 @@ public class IndexingTests extends BaseNd4jTest {
assertEquals(vals,x); assertEquals(vals,x);
} }
@Test @Ignore
public void testIndexGetDuplicate() {
List<List<Integer>> indices = new ArrayList<>();
indices.add(Arrays.asList(0,0));
INDArray linspace = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
INDArray get = linspace.get(indices);
INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
assertEquals(assertion,get);
}
@Test @Test
public void testGetScalar() { public void testGetScalar() {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);