Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-30 21:40:27 +09:00 committed by Alex Black
parent 62e96c9724
commit 54e320a255
6 changed files with 65 additions and 350 deletions

View File

@ -1711,16 +1711,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result;
}
/**
* Returns the element at the specified row/column
* This will throw an exception if the
*
* @param row the row of the element to return
* @param column the row of the element to return
* @return a scalar indarray of the element at this index
*/
@Override
public INDArray getScalar(long row, long column) {
return getScalar(new long[] {row, column});
@ -1885,27 +1875,20 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
/**
* Inserts the element at the specified index
*
* @param indices the indices to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int[] indices, INDArray element) {
Nd4j.getCompressor().autoDecompress(this);
if (!element.isScalar())
throw new IllegalArgumentException("Unable to insert anything but a scalar");
if (isRowVector() && indices[0] == 0 && indices.length == 2) {
int ix = 0; //Shape.offset(javaShapeInformation);
int ix = 0;
for (int i = 1; i < indices.length; i++)
ix += indices[i] * stride(i);
if (ix >= data.length())
throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices));
data.put(ix, element.getDouble(0));
} else {
int ix = 0; //Shape.offset(javaShapeInformation);
int ix = 0;
for (int i = 0; i < indices.length; i++)
if (size(i) != 1)
ix += indices[i] * stride(i);
@ -1913,10 +1896,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices));
data.put(ix, element.getDouble(0));
}
return this;
}
@Override
@ -1970,39 +1950,16 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return putWhereWithMask(mask,Nd4j.scalar(put));
}
/**
* Inserts the element at the specified index
*
* @param i the row insert into
* @param j the column to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int i, int j, INDArray element) {
return put(new int[] {i, j}, element);
}
/**
* Inserts the element at the specified index
*
* @param i the row insert into
* @param j the column to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int i, int j, Number element) {
return putScalar(new int[] {i, j}, element.doubleValue());
}
/**
* Assigns the given matrix (put) to the specified slice
*
* @param slice the slice to assign
* @param put the slice to put
* @return this for chainability
*/
@Override
public INDArray putSlice(int slice, INDArray put) {
Nd4j.getCompressor().autoDecompress(this);
@ -2102,10 +2059,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getStrides(shape, ordering);
}
/**
* Returns the square of the Euclidean distance.
*/
@Override
public double squaredDistance(INDArray other) {
validateNumericalArray("squaredDistance", false);
@ -2113,9 +2066,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return d2 * d2;
}
/**
* Returns the (euclidean) distance.
*/
@Override
public double distance2(INDArray other) {
validateNumericalArray("distance2", false);
@ -2123,9 +2073,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this, other)).getFinalResult().doubleValue();
}
/**
* Returns the (1-norm) distance.
*/
@Override
public double distance1(INDArray other) {
validateNumericalArray("distance1", false);
@ -2133,8 +2080,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this, other)).getFinalResult().doubleValue();
}
@Override
public INDArray get(INDArray indices) {
if(indices.rank() > 2) {
@ -2190,49 +2135,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
@Override
public INDArray get(List<List<Integer>> indices) {
INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()];
for(int i = 0; i < indArrayIndices.length; i++) {
indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i)));
}
boolean hasNext = true;
Generator<List<List<Long>>> iterate = SpecifiedIndex.iterate(indArrayIndices);
List<INDArray> resultList = new ArrayList<>();
while(hasNext) {
try {
List<List<Long>> next = iterate.next();
int[][] nextArr = new int[next.size()][];
for(int i = 0; i < next.size(); i++) {
nextArr[i] = Ints.toArray(next.get(i));
}
int[] curr = Ints.concat(nextArr);
INDArray currSlice = this;
for(int j = 0; j < curr.length; j++) {
currSlice = currSlice.slice(curr[j]);
}
//slice drops the first dimension, this adds a 1 to match normal numpy behavior
currSlice = currSlice.reshape(Longs.concat(new long[]{1},currSlice.shape()));
resultList.add(currSlice);
}
catch(NoSuchElementException e) {
hasNext = false;
}
}
return Nd4j.concat(0,resultList.toArray(new INDArray[resultList.size()]));
}
@Override
public INDArray put(INDArray indices, INDArray element) {
if(indices.rank() > 2) {
@ -2245,7 +2147,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next()));
}
}
else {
List<INDArray> arrList = new ArrayList<>();
@ -2258,8 +2159,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
Nd4j.getExecutioner().execAndReturn(new Assign(new INDArray[]{slice,element},new INDArray[]{slice}));
arrList.add(slice(row.getInt(j)));
}
}
}
else if(indices.isRowVector()) {
@ -2267,15 +2166,10 @@ public abstract class BaseNDArray implements INDArray, Iterable {
arrList.add(slice(indices.getInt(i)));
}
}
}
return this;
}
@Override
public INDArray put(INDArrayIndex[] indices, INDArray element) {
Nd4j.getCompressor().autoDecompress(this);
@ -2343,7 +2237,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return this;
}
/**
* Mainly here for people coming from numpy.
* This is equivalent to a call to permute
@ -2446,7 +2339,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return ret;
}
protected void init(int[] shape, int[] stride) {
//null character
if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') {
@ -2466,7 +2358,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
@Override
public INDArray getScalar(long i) {
if (i >= this.length())
@ -2936,13 +2827,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return dup().rsubiRowVector(rowVector);
}
/**
* Inserts the element at the specified index
*
* @param i the index insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int i, INDArray element) {
Preconditions.checkArgument(element.isScalar(), "Element must be a scalar: element has shape %ndShape", element);
@ -3706,12 +3590,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return normmax(false, dimension);
}
/**
* Reverse division
*
* @param other the matrix to divide from
* @return
*/
@Override
public INDArray rdiv(INDArray other) {
validateNumericalArray("rdiv", false);
@ -3722,37 +3600,17 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
}
/**
* Reverse divsion (in place)
*
* @param other
* @return
*/
@Override
public INDArray rdivi(INDArray other) {
return rdivi(other, this);
}
/**
* Reverse division
*
* @param other the matrix to subtract from
* @param result the result ndarray
* @return
*/
@Override
public INDArray rdiv(INDArray other, INDArray result) {
validateNumericalArray("rdiv", false);
return dup().rdivi(other, result);
}
/**
* Reverse division (in-place)
*
* @param other the other ndarray to subtract
* @param result the result ndarray
* @return the ndarray with the operation applied
*/
@Override
public INDArray rdivi(INDArray other, INDArray result) {
validateNumericalArray("rdivi", false);
@ -3761,23 +3619,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result;
}
/**
* Reverse subtraction
*
* @param other the matrix to subtract from
* @param result the result ndarray
* @return
*/
@Override
public INDArray rsub(INDArray other, INDArray result) {
validateNumericalArray("rsub", false);
return rsubi(other, result);
}
/**
* @param other
* @return
*/
@Override
public INDArray rsub(INDArray other) {
validateNumericalArray("rsub", false);
@ -3788,22 +3635,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
}
/**
* @param other
* @return
*/
@Override
public INDArray rsubi(INDArray other) {
return rsubi(other, this);
}
/**
* Reverse subtraction (in-place)
*
* @param other the other ndarray to subtract
* @param result the result ndarray
* @return the ndarray with the operation applied
*/
@Override
public INDArray rsubi(INDArray other, INDArray result) {
validateNumericalArray("rsubi", false);
@ -3812,12 +3648,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result;
}
/**
* Set the value of the ndarray to the specified value
*
* @param value the value to assign
* @return the ndarray with the values
*/
@Override
public INDArray assign(Number value) {
Preconditions.checkState(dataType() != DataType.BOOL || value.doubleValue() == 0.0 || value.doubleValue() == 1.0, "Only values 0 or 1 are allowed for scalar " +
@ -3826,7 +3656,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return this;
}
@Override
public INDArray assign(boolean value) {
return assign(value ? 1 : 0);
@ -3846,7 +3675,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
@Override
@Deprecated
@Deprecated //TODO: Investigate. Not deprecated in the base interface.
public long linearIndex(long i) {
long idx = i;
for (int j = 0; j < jvmShapeInfo.rank - 1; j++) {
@ -4073,15 +3902,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return addi(n, this);
}
/**
* Replicate and tile array to fill out to the given shape
* See:
* https://github.com/numpy/numpy/blob/master/numpy/matlib.py#L310-L358
* @param shape the new shape of this ndarray
* @return the shape to fill out to
*/
@Override
public INDArray repmat(int[] shape) {
Nd4j.getCompressor().autoDecompress(this);
@ -4109,16 +3929,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return out;
}
/**
* Insert a row in to this array
* Will throw an exception if this
* ndarray is not a matrix
*
* @param row the row insert into
* @param toPut the row to insert
* @return this
*/
@Override
public INDArray putRow(long row, INDArray toPut) {
if (isRowVector() && toPut.isVector()) {
@ -4127,15 +3937,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return put(new INDArrayIndex[] {NDArrayIndex.point(row), NDArrayIndex.all()}, toPut);
}
/**
* Insert a column in to this array
* Will throw an exception if this
* ndarray is not a matrix
*
* @param column the column to insert
* @param toPut the array to put
* @return this
*/
@Override
public INDArray putColumn(int column, INDArray toPut) {
Nd4j.getCompressor().autoDecompress(this);
@ -4752,11 +4553,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return reshape(length());
}
/**
* Flattens the array for linear indexing
*
* @return the flattened version of this array
*/
@Override
public void sliceVectors(List<INDArray> list) {
if (isVector())
@ -4804,12 +4600,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return col.reshape(col.length(), 1);
}
/**
* Get whole rows from the passed indices.
*
* @param rindices
*/
@Override
public INDArray getRows(int[] rindices) {
Nd4j.getCompressor().autoDecompress(this);
@ -4826,13 +4616,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
}
/**
* Returns a subset of this array based on the specified
* indexes
*
* @param indexes the indexes in to the array
* @return a view of the array with the specified indices
*/
@Override
public INDArray get(INDArrayIndex... indexes) {
Nd4j.getCompressor().autoDecompress(this);
@ -5020,13 +4803,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return out;
}
/**
* Get whole columns
* from the passed indices.
*
* @param cindices
*/
@Override
public INDArray getColumns(int... cindices) {
if (!isMatrix() && !isVector())

View File

@ -153,7 +153,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
arrList.set(j,put);
}
}
}
}
else if(indices.isRowVector()) {
@ -161,12 +160,8 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
arrList.add(slice(indices.getInt(i)));
}
}
return Nd4j.concat(0,arrList.toArray(new INDArray[arrList.size()]));
}
}
@Override
@ -259,21 +254,13 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
return null;
}
@Override
public INDArray get(List<List<Integer>> indices) {
return null;
}
@Override
public INDArray put(INDArray indices, INDArray element) {
INDArrayIndex[] realIndices = new INDArrayIndex[indices.rank()];
for(int i = 0; i < realIndices.length; i++) {
realIndices[i] = new SpecifiedIndex(indices.slice(i).dup().data().asInt());
}
return put(realIndices,element);
}
@Override

View File

@ -568,13 +568,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
return this;
}
/**
* Returns a subset of this array based on the specified
* indexes
*
* @param indexes the indexes in to the array
* @return a view of the array with the specified indices
*/
@Override
public INDArray get(INDArrayIndex... indexes) {

View File

@ -124,14 +124,6 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
return this;
}
/**
* Returns a subset of this array based on the specified
* indexes
*
* @param indexes the indexes in to the array
* @return a view of the array with the specified indices
*/
@Override
public INDArray get(INDArrayIndex... indexes) {
//check for row/column vector and point index being 0

View File

@ -458,7 +458,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray rsub(Number n);
/**
* Reverse subtraction in place - i.e., (n - thisArrayValues)
*
@ -467,7 +466,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray rsubi(Number n);
/**
* Division by a number
*
@ -484,7 +482,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray divi(Number n);
/**
* Scalar multiplication (copy)
*
@ -501,7 +498,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray muli(Number n);
/**
* Scalar subtraction (copied)
*
@ -510,7 +506,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray sub(Number n);
/**
* In place scalar subtraction
*
@ -535,7 +530,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray addi(Number n);
/**
* Reverse division (number / ndarray)
*
@ -545,7 +539,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray rdiv(Number n, INDArray result);
/**
* Reverse in place division
*
@ -573,11 +566,12 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray rsubi(Number n, INDArray result);
/**
* @param n
* @param result
* @return
* Division if ndarray by number
*
* @param n the number to divide by
* @param result the result ndarray
* @return the result ndarray
*/
INDArray div(Number n, INDArray result);
@ -586,24 +580,35 @@ public interface INDArray extends Serializable, AutoCloseable {
*
* @param n the number to divide by
* @param result the result ndarray
* @return
* @return the result ndarray
*/
INDArray divi(Number n, INDArray result);
/**
* Multiplication of ndarray.
*
* @param n the number to multiply by
* @param result the result ndarray
* @return the result ndarray
*/
INDArray mul(Number n, INDArray result);
/**
* In place multiplication of this ndarray
*
* @param n the number to divide by
* @param result the result ndarray
* @return
* @return the result ndarray
*/
INDArray muli(Number n, INDArray result);
/**
* Subtraction of this ndarray
*
* @param n the number to subtract by
* @param result the result ndarray
* @return the result ndarray
*/
INDArray sub(Number n, INDArray result);
/**
@ -615,6 +620,12 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray subi(Number n, INDArray result);
/**
* Addition of this ndarray.
* @param n the number to add
* @param result the result ndarray
* @return the result ndarray
*/
INDArray add(Number n, INDArray result);
/**
@ -626,25 +637,24 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray addi(Number n, INDArray result);
/**
* Returns a subset of this array based on the specified
* indexes
* Returns a subset of this array based on the specified indexes
*
* @param indexes the indexes in to the array
* @return a view of the array with the specified indices
*/
INDArray get(INDArrayIndex... indexes);
//TODO: revisit after #8166 is resolved.
/**
* Return a mask on whether each element
* matches the given condition
* Return a mask on whether each element matches the given condition
* @param comp
* @param condition
* @return
*/
INDArray match(INDArray comp,Condition condition);
//TODO: revisit after #8166 is resolved.
/**
* Returns a mask
* @param comp
@ -673,54 +683,51 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray getWhere(Number comp,Condition condition);
//TODO: unused / untested method. (only used to forward calls from putWhere(Number,INDArray ,Condition).
/**
* Assign the element according
* to the comparison array
* Assign the element according to the comparison array
* @param comp the comparison array
* @param put the elements to put
* @param condition the condition for masking on
* @return
* @return a copy of this array with the conditional assignments.
*/
INDArray putWhere(INDArray comp,INDArray put,Condition condition);
//TODO: unused / untested method.
/**
* Assign the element according
* to the comparison array
* Assign the element according to the comparison array
* @param comp the comparison array
* @param put the elements to put
* @param condition the condition for masking on
* @return
* @return a copy of this array with the conditional assignments.
*/
INDArray putWhere(Number comp,INDArray put,Condition condition);
//TODO: unused / untested method. (only used to forward calls from other putWhereWithMask implementations.
/**
* Use a pre computed mask
* for assigning arrays
* Use a pre computed mask for assigning arrays
* @param mask the mask to use
* @param put the array to put
* @return the resulting array
* @return a copy of this array with the conditional assignments.
*/
INDArray putWhereWithMask(INDArray mask,INDArray put);
//TODO: unused / untested method.
/**
* Use a pre computed mask
* for assigning arrays
* Use a pre computed mask for assigning arrays
* @param mask the mask to use
* @param put the array to put
* @return the resulting array
* @return a copy of this array with the conditional assignments.
*/
INDArray putWhereWithMask(INDArray mask,Number put);
//TODO: unused / untested method.
/**
* Assign the element according
* to the comparison array
* Assign the element according to the comparison array
* @param comp the comparison array
* @param put the elements to put
* @param condition the condition for masking on
* @return
* @return a copy of this array with the conditional assignments.
*/
INDArray putWhere(Number comp,Number put,Condition condition);
@ -731,14 +738,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray get(INDArray indices);
/**
* Get the elements from this ndarray based on the specified indices
* @param indices an ndaray of the indices to get the elements for
* @return the elements to get the array for
*/
@Deprecated
INDArray get(List<List<Integer>> indices);
/**
* Get an INDArray comprised of the specified columns only. Copy operation.
*
@ -771,20 +770,20 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray rdivi(INDArray other);
//TODO: unused / untested method.
/**
* Reverse division
*
* @param other the matrix to subtract from
* @param other the matrix to divide from
* @param result the result ndarray
* @return
* @return the result ndarray
*/
INDArray rdiv(INDArray other, INDArray result);
/**
* Reverse division (in-place)
*
* @param other the other ndarray to subtract
* @param other the matrix to divide from
* @param result the result ndarray
* @return the ndarray with the operation applied
*/
@ -795,11 +794,10 @@ public interface INDArray extends Serializable, AutoCloseable {
*
* @param other the matrix to subtract from
* @param result the result ndarray
* @return
* @return the result ndarray
*/
INDArray rsub(INDArray other, INDArray result);
/**
* Element-wise reverse subtraction (copy op). i.e., other - this
*
@ -842,22 +840,19 @@ public interface INDArray extends Serializable, AutoCloseable {
INDArray assign(boolean value);
/**
* Get the linear index of the data in to
* the array
* Get the linear index of the data in to the array
*
* @param i the index to getScalar
* @return the linear index in to the data
*/
long linearIndex(long i);
//TODO: unused / untested method. only used recursively.
/**
*
* @param list
* Flattens the array for linear indexing in list.
*/
void sliceVectors(List<INDArray> list);
/**
* Assigns the given matrix (put) to the specified slice
*
@ -875,16 +870,15 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray cond(Condition condition);
/**
* Replicate and tile array to fill out to the given shape
*
* See:
* https://github.com/numpy/numpy/blob/master/numpy/matlib.py#L310-L358
* @param shape the new shape of this ndarray
* @return the shape to fill out to
*/
INDArray repmat(int... shape);
/**
* Repeat elements along a specified dimension.
*
@ -894,11 +888,9 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray repeat(int dimension, long... repeats);
/**
* Insert a row in to this array
* Will throw an exception if this
* ndarray is not a matrix
* Will throw an exception if this ndarray is not a matrix
*
* @param row the row insert into
* @param toPut the row to insert
@ -908,8 +900,7 @@ public interface INDArray extends Serializable, AutoCloseable {
/**
* Insert a column in to this array
* Will throw an exception if this
* ndarray is not a matrix
* Will throw an exception if this ndarray is not a matrix
*
* @param column the column to insert
* @param toPut the array to put
@ -919,7 +910,6 @@ public interface INDArray extends Serializable, AutoCloseable {
/**
* Returns the element at the specified row/column
* This will throw an exception if the
*
* @param row the row of the element to return
* @param column the row of the element to return
@ -950,26 +940,20 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
double distance1(INDArray other);
/**
* Put element in to the indices denoted by
* the indices ndarray.
* This is equivalent to:
* In numpy this is equivalent to:
* a[indices] = element
*
* in numpy.
*
* @param indices the indices to put
* @param element the element array to put
* @return this array
*/
INDArray put(INDArray indices,INDArray element);
/**
* Put the elements of the ndarray
* in to the specified indices
* Put the elements of the ndarray in to the specified indices
*
* @param indices the indices to put the ndarray in to
* @param element the ndarray to put
@ -978,8 +962,7 @@ public interface INDArray extends Serializable, AutoCloseable {
INDArray put(INDArrayIndex[] indices, INDArray element);
/**
* Put the elements of the ndarray
* in to the specified indices
* Put the elements of the ndarray in to the specified indices
*
* @param indices the indices to put the ndarray in to
* @param element the ndarray to put
@ -1007,7 +990,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray put(int i, int j, INDArray element);
/**
* Inserts the element at the specified index
*
@ -1018,7 +1000,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray put(int i, int j, Number element);
/**
* Inserts the element at the specified index
*
@ -1028,7 +1009,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray put(int i, INDArray element);
/**
* In place division of a column vector
*

View File

@ -100,20 +100,7 @@ public class IndexingTests extends BaseNd4jTest {
INDArray vals = Nd4j.valueArrayOf(new long[] {2,2,2,2},5, DataType.DOUBLE);
assertEquals(vals,x);
}
@Test @Ignore
public void testIndexGetDuplicate() {
List<List<Integer>> indices = new ArrayList<>();
indices.add(Arrays.asList(0,0));
INDArray linspace = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
INDArray get = linspace.get(indices);
INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
assertEquals(assertion,get);
}
@Test
public void testGetScalar() {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);