INDArray refactoring (#170)
* javadoc Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove javaTensorAlongDimension Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * javadocmaster
parent
d31197db5f
commit
59a6e4e3ae
|
@ -1007,19 +1007,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return toTad;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the vector along a particular dimension
|
||||
*
|
||||
* @param index the index of the vector to getScalar
|
||||
* @param dimension the dimension to getScalar the vector from
|
||||
* @return the vector along a particular dimension
|
||||
*/
|
||||
@Override
|
||||
@Deprecated
|
||||
public INDArray javaTensorAlongDimension(int index, int... dimension) {
|
||||
return doTad(index, dimension);
|
||||
}
|
||||
|
||||
private void setShapeInformation(Pair<DataBuffer, long[]> shapeInfo) {
|
||||
this.shapeInformation = shapeInfo.getFirst();
|
||||
this.jvmShapeInfo = new JvmShapeInfo(shapeInfo.getSecond());
|
||||
|
@ -1110,14 +1097,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return ret2.permutei(finalPermuteDims);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Returns the number of possible vectors for a given dimension
|
||||
*
|
||||
* @param dimension the dimension to calculate the number of vectors for
|
||||
* @return the number of possible vectors along a dimension
|
||||
*/
|
||||
@Override
|
||||
public long vectorsAlongDimension(int dimension) {
|
||||
if (dimension == 0 && isVector() || isRowVectorOrScalar())
|
||||
|
@ -1150,17 +1129,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return length / size(dimension);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the vector along a particular dimension
|
||||
*
|
||||
* @param index the index of the vector to get
|
||||
* @param dimension the dimension to get the vector from
|
||||
* @return the vector along a particular dimension
|
||||
*/
|
||||
@Override
|
||||
public INDArray vectorAlongDimension(int index, int dimension) {
|
||||
if (dimension < 0)
|
||||
if (dimension < 0) {
|
||||
dimension = jvmShapeInfo.getRank() + dimension;
|
||||
}
|
||||
|
||||
//return the whole thing
|
||||
if (dimension == jvmShapeInfo.getRank() - 1 && size(dimension) == 1 && rank() > 2
|
||||
|
@ -1168,12 +1141,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return this;
|
||||
}
|
||||
|
||||
INDArray ret = tensorAlongDimension(index, dimension);
|
||||
//if (isMatrix() && ret.isVector() && dimension == 1 && !ret.isRowVector())
|
||||
// return ret.reshape(ArrayUtil.reverseCopy(ret.shape()));
|
||||
//else if (isMatrix() && ret.isVector() && dimension == 0 && !ret.isColumnVector())
|
||||
// return ret.reshape(ArrayUtil.reverseCopy(ret.shape()));
|
||||
return ret;
|
||||
return tensorAlongDimension(index, dimension);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1196,13 +1164,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false));
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Cumulative sum along a dimension
|
||||
*
|
||||
* @param dimension the dimension to perform cumulative sum along
|
||||
* @return the cumulative sum along the specified dimension
|
||||
*/
|
||||
@Override
|
||||
public INDArray cumsumi(int dimension) {
|
||||
validateNumericalArray("cumsumi", true);
|
||||
|
@ -1351,25 +1312,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return logEntropy(Integer.MAX_VALUE).getDouble(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Cumulative sum along a dimension (in place)
|
||||
*
|
||||
* @param dimension the dimension to perform cumulative sum along
|
||||
* @return the cumulative sum along the specified dimension
|
||||
*/
|
||||
@Override
|
||||
public INDArray cumsum(int dimension) {
|
||||
validateNumericalArray("cumsum", true);
|
||||
return dup().cumsumi(dimension);
|
||||
}
|
||||
|
||||
/**
|
||||
* Assign all of the elements in the given
|
||||
* ndarray to this ndarray
|
||||
*
|
||||
* @param arr the elements to assign
|
||||
* @return this
|
||||
*/
|
||||
@Override
|
||||
public INDArray assign(final INDArray arr) {
|
||||
Preconditions.checkState((this.isScalar() && arr.isScalar()) || (this.isVector() && arr.isVector()) || Shape.shapeEqualWithSqueeze(this.shape(), arr.shape()),
|
||||
|
@ -1378,7 +1326,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
|
||||
Preconditions.checkArgument(this.length() == arr.length(), "Length of both arrays must be equal");
|
||||
|
||||
//Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set(this, arr, this, length()));
|
||||
Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.any.Assign(arr, this));
|
||||
return this;
|
||||
}
|
||||
|
@ -1413,7 +1360,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
@Override
|
||||
public INDArray putScalar(long i, float value) {
|
||||
return putScalar(i, (double) value);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1540,7 +1486,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return this;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public INDArray putScalar(int[] indexes, float value) {
|
||||
return putScalar(indexes, (double) value);
|
||||
|
@ -1556,27 +1501,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return putScalar(indexes, (double) value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an ndarray with 1 if the element is epsilon equals
|
||||
*
|
||||
* @param other the number to compare
|
||||
* @return a copied ndarray with the given
|
||||
* binary conditions
|
||||
*/
|
||||
@Override
|
||||
public INDArray eps(Number other) {
|
||||
validateNumericalArray("eps", true);
|
||||
return Nd4j.getExecutioner().exec(new ScalarEps(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other));
|
||||
}
|
||||
|
||||
/**
|
||||
* epsilon equals than comparison:
|
||||
* If the given number is less than the
|
||||
* comparison number the item is 0 otherwise 1
|
||||
*
|
||||
* @param other the number to compare
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public INDArray eps(INDArray other) {
|
||||
validateNumericalArray("eps", true);
|
||||
|
@ -1613,7 +1543,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return Nd4j.getExecutioner().exec(new ScalarGreaterThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other));
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public INDArray lt(INDArray other) {
|
||||
validateNumericalArray("less than (lt)", false);
|
||||
|
@ -1675,9 +1604,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isNan()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Negate each element.
|
||||
*/
|
||||
@Override
|
||||
public INDArray neg() {
|
||||
validateNumericalArray("negative (neg)", true);
|
||||
|
@ -1686,9 +1612,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return Nd4j.getExecutioner().exec(new Negative(this, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering())));
|
||||
}
|
||||
|
||||
/**
|
||||
* Negate each element (in-place).
|
||||
*/
|
||||
@Override
|
||||
public INDArray negi() {
|
||||
validateNumericalArray("negative (negi)", true);
|
||||
|
@ -3909,28 +3832,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return assign(value ? 1 : 0);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Assign all elements from given ndarray that are matching given condition,
|
||||
* ndarray to this ndarray
|
||||
*
|
||||
* @param arr the elements to assign
|
||||
* @param condition
|
||||
* @return this
|
||||
*/
|
||||
@Override
|
||||
public INDArray assignIf(INDArray arr, Condition condition) {
|
||||
BooleanIndexing.assignIf(this, arr, condition);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces all elements in this ndarray that are matching give condition, with corresponding elements from given array
|
||||
*
|
||||
* @param arr
|
||||
* @param condition
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public INDArray replaceWhere(INDArray arr, Condition condition) {
|
||||
Nd4j.getCompressor().autoDecompress(this);
|
||||
|
|
|
@ -411,11 +411,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray javaTensorAlongDimension(int index, int... dimension) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray cumsumi(int dimension) {
|
||||
return null;
|
||||
|
@ -476,7 +471,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
|||
return null;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public INDArray isInfinite() {
|
||||
throw new UnsupportedOperationException();
|
||||
|
@ -551,6 +545,7 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
|||
public INDArray lte(Number other) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray lt(INDArray other) {
|
||||
return null;
|
||||
|
|
|
@ -106,27 +106,31 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
|
||||
/**
|
||||
* Element wise stride
|
||||
* @return the element wise stride
|
||||
*/
|
||||
int elementWiseStride();
|
||||
|
||||
/**
|
||||
* Get a scalar
|
||||
* at the given linear offset
|
||||
* Get a double at the given linear offset unsafe, without checks.
|
||||
* @param offset the offset to get at
|
||||
* @return this
|
||||
* @return double value at offset
|
||||
*/
|
||||
double getDoubleUnsafe(long offset);
|
||||
double getDoubleUnsafe(long offset); //TODO: consider deleting.
|
||||
|
||||
/**
|
||||
* Get string value at given index.
|
||||
* @param index index to retreive
|
||||
* @return string value at index.
|
||||
*/
|
||||
String getString(long index);
|
||||
|
||||
/**
|
||||
* Insert a scalar
|
||||
* at the given linear offset
|
||||
* Insert a scalar at the given linear offset
|
||||
* @param offset the offset to insert at
|
||||
* @param value the value to insert
|
||||
* @return this
|
||||
*/
|
||||
INDArray putScalarUnsafe(long offset, double value);
|
||||
INDArray putScalarUnsafe(long offset, double value); //TODO: consider deleting.
|
||||
|
||||
/**
|
||||
* Returns the number of possible vectors for a given dimension
|
||||
|
@ -162,17 +166,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
*/
|
||||
INDArray tensorAlongDimension(long index, int... dimension);
|
||||
|
||||
/**
|
||||
* Get the vector along a particular dimension
|
||||
*
|
||||
* @param index the index of the vector to getScalar
|
||||
* @param dimension the dimension to getScalar the vector from
|
||||
* @return the vector along a particular dimension
|
||||
*/
|
||||
@Deprecated
|
||||
INDArray javaTensorAlongDimension(int index, int... dimension);
|
||||
|
||||
|
||||
/**
|
||||
* Returns the cumulative sum along a dimension. In-place method.
|
||||
*
|
||||
|
@ -190,8 +183,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
INDArray cumsum(int dimension);
|
||||
|
||||
/**
|
||||
* Assign all of the elements in the given
|
||||
* ndarray to this ndarray
|
||||
* Assign all of the elements in the given ndarray to this ndarray
|
||||
*
|
||||
* @param arr the elements to assign
|
||||
* @return this
|
||||
|
@ -254,9 +246,19 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
*/
|
||||
INDArray putScalar(int[] i, double value);
|
||||
|
||||
|
||||
/**
|
||||
* See {@link #putScalar(int[], double)}
|
||||
*/
|
||||
INDArray putScalar(long[] i, double value);
|
||||
|
||||
/**
|
||||
* See {@link #putScalar(int[], double)}
|
||||
*/
|
||||
INDArray putScalar(long[] i, float value);
|
||||
|
||||
/**
|
||||
* See {@link #putScalar(int[], double)}
|
||||
*/
|
||||
INDArray putScalar(long[] i, int value);
|
||||
|
||||
/**
|
||||
|
@ -300,7 +302,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
*/
|
||||
INDArray lt(Number other);
|
||||
|
||||
|
||||
/**
|
||||
* Put the specified float value at the specified indices in this array
|
||||
*
|
||||
|
@ -327,8 +328,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
*/
|
||||
INDArray eps(Number other);
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Returns the binary ndarray for "Equals" comparison.
|
||||
*
|
||||
|
@ -367,10 +366,8 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
* @param other the ndarray to compare.
|
||||
* @return the binary ndarray for "Less" comparison.
|
||||
*/
|
||||
|
||||
INDArray lt(INDArray other);
|
||||
|
||||
|
||||
/**
|
||||
* Returns the binary ndarray for "Epsilon equals" comparison.
|
||||
*
|
||||
|
@ -403,7 +400,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
*/
|
||||
INDArray eq(INDArray other);
|
||||
|
||||
|
||||
/**
|
||||
* Returns the binary ndarray for "Greater Than" comparison.
|
||||
*
|
||||
|
@ -424,8 +420,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
*/
|
||||
INDArray isNaN();
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Returns the ndarray negative (cloned)
|
||||
*
|
||||
|
|
|
@ -181,37 +181,37 @@ public class NDArrayCreationUtil {
|
|||
INDArray[] out = new INDArray[12];
|
||||
|
||||
INDArray temp01 = Nd4j.linspace(1, cols * rows * 4, cols * rows * 4, dataType).reshape(cols, rows, 4);
|
||||
out[0] = temp01.javaTensorAlongDimension(0, 0, 1).reshape(rows, cols);
|
||||
out[0] = temp01.tensorAlongDimension(0, 0, 1).reshape(rows, cols);
|
||||
long[] temp01Shape = new long[] {cols, rows, 4};
|
||||
int len = ArrayUtil.prod(temp01Shape);
|
||||
temp01 = Nd4j.linspace(1, len, len, dataType).reshape(temp01Shape);
|
||||
out[1] = temp01.javaTensorAlongDimension(2, 0, 1).reshape(rows, cols);
|
||||
out[1] = temp01.tensorAlongDimension(2, 0, 1).reshape(rows, cols);
|
||||
|
||||
Nd4j.getRandom().setSeed(seed);
|
||||
INDArray temp02 = Nd4j.linspace(1, len, len, dataType).reshape(new long[] {cols, 4, rows});
|
||||
out[2] = temp02.javaTensorAlongDimension(0, 0, 2).reshape(rows, cols);
|
||||
out[2] = temp02.tensorAlongDimension(0, 0, 2).reshape(rows, cols);
|
||||
temp02 = Nd4j.linspace(1, len, len, dataType).reshape(cols, 4, rows);
|
||||
out[3] = temp02.javaTensorAlongDimension(2, 0, 2).reshape(rows, cols);
|
||||
out[3] = temp02.tensorAlongDimension(2, 0, 2).reshape(rows, cols);
|
||||
|
||||
INDArray temp10 = Nd4j.linspace(1, len, len, dataType).reshape(rows, cols, 4);
|
||||
out[4] = temp10.javaTensorAlongDimension(0, 1, 0).reshape(rows, cols);
|
||||
out[4] = temp10.tensorAlongDimension(0, 1, 0).reshape(rows, cols);
|
||||
temp10 = Nd4j.linspace(1, len, len, dataType).reshape(rows, cols, 4);
|
||||
out[5] = temp10.javaTensorAlongDimension(2, 1, 0).reshape(rows, cols);
|
||||
out[5] = temp10.tensorAlongDimension(2, 1, 0).reshape(rows, cols);
|
||||
|
||||
INDArray temp12 = Nd4j.linspace(1, len, len, dataType).reshape(4, cols, rows);
|
||||
out[6] = temp12.javaTensorAlongDimension(0, 1, 2).reshape(rows, cols);
|
||||
out[6] = temp12.tensorAlongDimension(0, 1, 2).reshape(rows, cols);
|
||||
temp12 = Nd4j.linspace(1, len, len, dataType).reshape(4, cols, rows);
|
||||
out[7] = temp12.javaTensorAlongDimension(2, 1, 2).reshape(rows, cols);
|
||||
out[7] = temp12.tensorAlongDimension(2, 1, 2).reshape(rows, cols);
|
||||
|
||||
INDArray temp20 = Nd4j.linspace(1, len, len, dataType).reshape(rows, 4, cols);
|
||||
out[8] = temp20.javaTensorAlongDimension(0, 2, 0).reshape(rows, cols);
|
||||
out[8] = temp20.tensorAlongDimension(0, 2, 0).reshape(rows, cols);
|
||||
temp20 = Nd4j.linspace(1, len, len, dataType).reshape(rows, 4, cols);
|
||||
out[9] = temp20.javaTensorAlongDimension(2, 2, 0).reshape(rows, cols);
|
||||
out[9] = temp20.tensorAlongDimension(2, 2, 0).reshape(rows, cols);
|
||||
|
||||
INDArray temp21 = Nd4j.linspace(1, len, len, dataType).reshape(4, rows, cols);
|
||||
out[10] = temp21.javaTensorAlongDimension(0, 2, 1).reshape(rows, cols);
|
||||
out[10] = temp21.tensorAlongDimension(0, 2, 1).reshape(rows, cols);
|
||||
temp21 = Nd4j.linspace(1, len, len, dataType).reshape(4, rows, cols);
|
||||
out[11] = temp21.javaTensorAlongDimension(2, 2, 1).reshape(rows, cols);
|
||||
out[11] = temp21.tensorAlongDimension(2, 2, 1).reshape(rows, cols);
|
||||
|
||||
String baseMsg = "getTensorAlongDimensionMatricesWithShape(" + rows + "," + cols + "," + seed + ")";
|
||||
List<Pair<INDArray, String>> list = new ArrayList<>(12);
|
||||
|
@ -361,9 +361,9 @@ public class NDArrayCreationUtil {
|
|||
val shape4d1 = new long[]{shape[0], shape[1], shape[2], 3};
|
||||
int lenshape4d1 = ArrayUtil.prod(shape4d1);
|
||||
INDArray orig1a = Nd4j.linspace(1, lenshape4d1, lenshape4d1, dataType).reshape(shape4d1);
|
||||
INDArray tad1a = orig1a.javaTensorAlongDimension(0, 0, 1, 2);
|
||||
INDArray tad1a = orig1a.tensorAlongDimension(0, 0, 1, 2);
|
||||
INDArray orig1b = Nd4j.linspace(1, lenshape4d1, lenshape4d1, dataType).reshape(shape4d1);
|
||||
INDArray tad1b = orig1b.javaTensorAlongDimension(1, 0, 1, 2);
|
||||
INDArray tad1b = orig1b.tensorAlongDimension(1, 0, 1, 2);
|
||||
|
||||
list.add(new Pair<>(tad1a, baseMsg + ".get(0)"));
|
||||
list.add(new Pair<>(tad1b, baseMsg + ".get(1)"));
|
||||
|
@ -371,19 +371,19 @@ public class NDArrayCreationUtil {
|
|||
long[] shape4d2 = {3, shape[0], shape[1], shape[2]};
|
||||
int lenshape4d2 = ArrayUtil.prod(shape4d2);
|
||||
INDArray orig2 = Nd4j.linspace(1, lenshape4d2, lenshape4d2, dataType).reshape(shape4d2);
|
||||
INDArray tad2 = orig2.javaTensorAlongDimension(1, 1, 2, 3);
|
||||
INDArray tad2 = orig2.tensorAlongDimension(1, 1, 2, 3);
|
||||
list.add(new Pair<>(tad2, baseMsg + ".get(2)"));
|
||||
|
||||
long[] shape4d3 = {shape[0], shape[1], 3, shape[2]};
|
||||
int lenshape4d3 = ArrayUtil.prod(shape4d3);
|
||||
INDArray orig3 = Nd4j.linspace(1, lenshape4d3, lenshape4d3, dataType).reshape(shape4d3);
|
||||
INDArray tad3 = orig3.javaTensorAlongDimension(1, 1, 3, 0);
|
||||
INDArray tad3 = orig3.tensorAlongDimension(1, 1, 3, 0);
|
||||
list.add(new Pair<>(tad3, baseMsg + ".get(3)"));
|
||||
|
||||
long[] shape4d4 = {shape[0], 3, shape[1], shape[2]};
|
||||
int lenshape4d4 = ArrayUtil.prod(shape4d4);
|
||||
INDArray orig4 = Nd4j.linspace(1, lenshape4d4, lenshape4d4, dataType).reshape(shape4d4);
|
||||
INDArray tad4 = orig4.javaTensorAlongDimension(1, 2, 0, 3);
|
||||
INDArray tad4 = orig4.tensorAlongDimension(1, 2, 0, 3);
|
||||
list.add(new Pair<>(tad4, baseMsg + ".get(4)"));
|
||||
|
||||
return list;
|
||||
|
@ -513,9 +513,9 @@ public class NDArrayCreationUtil {
|
|||
int[] shape4d1 = {3, shape[0], shape[1], shape[2], shape[3]};
|
||||
int len = ArrayUtil.prod(shape4d1);
|
||||
INDArray orig1a = Nd4j.linspace(1, len, len, dataType).reshape(ArrayUtil.toLongArray(shape4d1));
|
||||
INDArray tad1a = orig1a.javaTensorAlongDimension(0, 1, 2, 3, 4);
|
||||
INDArray tad1a = orig1a.tensorAlongDimension(0, 1, 2, 3, 4);
|
||||
INDArray orig1b = Nd4j.linspace(1, len, len, dataType).reshape(ArrayUtil.toLongArray(shape4d1));
|
||||
INDArray tad1b = orig1b.javaTensorAlongDimension(2, 1, 2, 3, 4);
|
||||
INDArray tad1b = orig1b.tensorAlongDimension(2, 1, 2, 3, 4);
|
||||
|
||||
list.add(new Pair<>(tad1a, baseMsg + ".get(0)"));
|
||||
list.add(new Pair<>(tad1b, baseMsg + ".get(1)"));
|
||||
|
@ -523,19 +523,19 @@ public class NDArrayCreationUtil {
|
|||
int[] shape4d2 = {3, shape[0], shape[1], shape[2], shape[3]};
|
||||
int len2 = ArrayUtil.prod(shape4d2);
|
||||
INDArray orig2 = Nd4j.linspace(1, len2, len2, dataType).reshape(ArrayUtil.toLongArray(shape4d2));
|
||||
INDArray tad2 = orig2.javaTensorAlongDimension(1, 3, 4, 2, 1);
|
||||
INDArray tad2 = orig2.tensorAlongDimension(1, 3, 4, 2, 1);
|
||||
list.add(new Pair<>(tad2, baseMsg + ".get(2)"));
|
||||
|
||||
int[] shape4d3 = {shape[0], shape[1], 3, shape[2], shape[3]};
|
||||
int len3 = ArrayUtil.prod(shape4d3);
|
||||
INDArray orig3 = Nd4j.linspace(1, len3, len3, dataType).reshape(ArrayUtil.toLongArray(shape4d3));
|
||||
INDArray tad3 = orig3.javaTensorAlongDimension(1, 4, 1, 3, 0);
|
||||
INDArray tad3 = orig3.tensorAlongDimension(1, 4, 1, 3, 0);
|
||||
list.add(new Pair<>(tad3, baseMsg + ".get(3)"));
|
||||
|
||||
int[] shape4d4 = {shape[0], shape[1], shape[2], shape[3], 3};
|
||||
int len4 = ArrayUtil.prod(shape4d4);
|
||||
INDArray orig4 = Nd4j.linspace(1, len4, len4, dataType).reshape(ArrayUtil.toLongArray(shape4d4));
|
||||
INDArray tad4 = orig4.javaTensorAlongDimension(1, 2, 0, 3, 1);
|
||||
INDArray tad4 = orig4.tensorAlongDimension(1, 2, 0, 3, 1);
|
||||
list.add(new Pair<>(tad4, baseMsg + ".get(4)"));
|
||||
|
||||
return list;
|
||||
|
@ -655,26 +655,26 @@ public class NDArrayCreationUtil {
|
|||
Nd4j.getRandom().setSeed(seed);
|
||||
int[] shape4d1 = {3, shape[0], shape[1], shape[2], shape[3], shape[4]};
|
||||
INDArray orig1a = Nd4j.rand(dataType, shape4d1);
|
||||
INDArray tad1a = orig1a.javaTensorAlongDimension(0, 1, 2, 3, 4, 5);
|
||||
INDArray tad1a = orig1a.tensorAlongDimension(0, 1, 2, 3, 4, 5);
|
||||
INDArray orig1b = Nd4j.rand(dataType, shape4d1);
|
||||
INDArray tad1b = orig1b.javaTensorAlongDimension(2, 1, 2, 3, 4, 5);
|
||||
INDArray tad1b = orig1b.tensorAlongDimension(2, 1, 2, 3, 4, 5);
|
||||
|
||||
list.add(new Pair<>(tad1a, baseMsg + ".get(0)"));
|
||||
list.add(new Pair<>(tad1b, baseMsg + ".get(1)"));
|
||||
|
||||
int[] shape4d2 = {3, shape[0], shape[1], shape[2], shape[3], shape[4]};
|
||||
INDArray orig2 = Nd4j.rand(dataType, shape4d2);
|
||||
INDArray tad2 = orig2.javaTensorAlongDimension(1, 3, 5, 4, 2, 1);
|
||||
INDArray tad2 = orig2.tensorAlongDimension(1, 3, 5, 4, 2, 1);
|
||||
list.add(new Pair<>(tad2, baseMsg + ".get(2)"));
|
||||
|
||||
int[] shape4d3 = {shape[0], shape[1], shape[2], shape[3], shape[4], 2};
|
||||
INDArray orig3 = Nd4j.rand(dataType, shape4d3);
|
||||
INDArray tad3 = orig3.javaTensorAlongDimension(1, 4, 1, 3, 2, 0);
|
||||
INDArray tad3 = orig3.tensorAlongDimension(1, 4, 1, 3, 2, 0);
|
||||
list.add(new Pair<>(tad3, baseMsg + ".get(3)"));
|
||||
|
||||
int[] shape4d4 = {shape[0], shape[1], shape[2], shape[3], 3, shape[4]};
|
||||
INDArray orig4 = Nd4j.rand(dataType, shape4d4);
|
||||
INDArray tad4 = orig4.javaTensorAlongDimension(1, 5, 2, 0, 3, 1);
|
||||
INDArray tad4 = orig4.tensorAlongDimension(1, 5, 2, 0, 3, 1);
|
||||
list.add(new Pair<>(tad4, baseMsg + ".get(4)"));
|
||||
|
||||
return list;
|
||||
|
|
|
@ -374,7 +374,7 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet {
|
|||
long nTensors = labels.tensorsAlongDimension(1);
|
||||
for (int i = 0; i < nTensors; i++) {
|
||||
INDArray row = labels.tensorAlongDimension(i, 1);
|
||||
INDArray javaRow = labels.javaTensorAlongDimension(i, 1);
|
||||
INDArray javaRow = labels.tensorAlongDimension(i, 1);
|
||||
int maxIdx = Nd4j.getBlasWrapper().iamax(row);
|
||||
int maxIdxJava = Nd4j.getBlasWrapper().iamax(javaRow);
|
||||
if (maxIdx < 0)
|
||||
|
|
|
@ -1334,7 +1334,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
INDArray zC = Nd4j.create(shape, 'c');
|
||||
zC.setData(Nd4j.linspace(1, 24, 24, DataType.DOUBLE).data());
|
||||
for (int tad = 0; tad < zC.tensorsAlongDimension(dim); tad++) {
|
||||
INDArray javaTad = zC.javaTensorAlongDimension(tad, dim);
|
||||
INDArray javaTad = zC.tensorAlongDimension(tad, dim);
|
||||
System.out.println("Tad " + tad + " is " + zC.tensorAlongDimension(tad, dim));
|
||||
}
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
for (int i = 0; i < n; i++) {
|
||||
StopWatch javaTiming = new StopWatch();
|
||||
javaTiming.start();
|
||||
row.javaTensorAlongDimension(0, 0);
|
||||
row.tensorAlongDimension(0, 0);
|
||||
javaTiming.stop();
|
||||
StopWatch cTiming = new StopWatch();
|
||||
cTiming.start();
|
||||
|
@ -98,7 +98,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
assertEquals(cols, arr.tensorsAlongDimension(0));
|
||||
for (int i = 0; i < cols; i++) {
|
||||
INDArray tad = arr.tensorAlongDimension(i, 0);
|
||||
INDArray javaTad = arr.javaTensorAlongDimension(i, 0);
|
||||
INDArray javaTad = arr.tensorAlongDimension(i, 0);
|
||||
assertEquals(javaTad, tad);
|
||||
assertArrayEquals(new int[] {rows}, tad.shape());
|
||||
//assertEquals(testValues.javaTensorAlongDimension(i, 0), tad);
|
||||
|
@ -120,7 +120,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
list = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, new int[]{rows, cols, dim2}, DataType.DOUBLE);
|
||||
for (Pair<INDArray, String> p : list) {
|
||||
INDArray arr = p.getFirst().assign(testValues);
|
||||
INDArray javaTad = arr.javaTensorAlongDimension(0, 0);
|
||||
INDArray javaTad = arr.tensorAlongDimension(0, 0);
|
||||
INDArray tadTest = arr.tensorAlongDimension(0, 0);
|
||||
assertEquals(javaTad, tadTest);
|
||||
//Along dimension 0: expect row vector with length 'rows'
|
||||
|
@ -165,7 +165,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
//Along dimension 0,1: expect matrix with shape [rows,cols]
|
||||
assertEquals(dim2, arr.tensorsAlongDimension(0, 1));
|
||||
for (int i = 0; i < dim2; i++) {
|
||||
INDArray javaTad = arr.javaTensorAlongDimension(i, 0, 1);
|
||||
INDArray javaTad = arr.tensorAlongDimension(i, 0, 1);
|
||||
INDArray tad = arr.tensorAlongDimension(i, 0, 1);
|
||||
int javaEleStride = javaTad.elementWiseStride();
|
||||
int testTad = tad.elementWiseStride();
|
||||
|
@ -178,11 +178,11 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
//Along dimension 0,2: expect matrix with shape [rows,dim2]
|
||||
assertEquals(cols, arr.tensorsAlongDimension(0, 2));
|
||||
for (int i = 0; i < cols; i++) {
|
||||
INDArray javaTad = arr.javaTensorAlongDimension(i, 0, 2);
|
||||
INDArray javaTad = arr.tensorAlongDimension(i, 0, 2);
|
||||
INDArray tad = arr.tensorAlongDimension(i, 0, 2);
|
||||
assertEquals(javaTad, tad);
|
||||
assertArrayEquals(new long[] {rows, dim2}, tad.shape());
|
||||
assertEquals(testValues.javaTensorAlongDimension(i, 0, 2), tad);
|
||||
assertEquals(testValues.tensorAlongDimension(i, 0, 2), tad);
|
||||
}
|
||||
|
||||
//Along dimension 1,2: expect matrix with shape [cols,dim2]
|
||||
|
@ -190,7 +190,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
for (int i = 0; i < rows; i++) {
|
||||
INDArray tad = arr.tensorAlongDimension(i, 1, 2);
|
||||
assertArrayEquals(new long[] {cols, dim2}, tad.shape());
|
||||
assertEquals(testValues.javaTensorAlongDimension(i, 1, 2), tad);
|
||||
assertEquals(testValues.tensorAlongDimension(i, 1, 2), tad);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -207,7 +207,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
for (int i = 0; i < dim2 * dim3; i++) {
|
||||
INDArray tad = arr.tensorAlongDimension(i, 0, 1);
|
||||
assertArrayEquals(new long[] {rows, cols}, tad.shape());
|
||||
assertEquals(testValues.javaTensorAlongDimension(i, 0, 1), tad);
|
||||
assertEquals(testValues.tensorAlongDimension(i, 0, 1), tad);
|
||||
}
|
||||
|
||||
//Along dimension 0,2: expect matrix with shape [rows,dim2]
|
||||
|
@ -215,7 +215,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
for (int i = 0; i < cols * dim3; i++) {
|
||||
INDArray tad = arr.tensorAlongDimension(i, 0, 2);
|
||||
assertArrayEquals(new long[] {rows, dim2}, tad.shape());
|
||||
assertEquals(testValues.javaTensorAlongDimension(i, 0, 2), tad);
|
||||
assertEquals(testValues.tensorAlongDimension(i, 0, 2), tad);
|
||||
}
|
||||
|
||||
//Along dimension 0,3: expect matrix with shape [rows,dim3]
|
||||
|
@ -223,7 +223,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
for (int i = 0; i < cols * dim2; i++) {
|
||||
INDArray tad = arr.tensorAlongDimension(i, 0, 3);
|
||||
assertArrayEquals(new long[] {rows, dim3}, tad.shape());
|
||||
assertEquals(testValues.javaTensorAlongDimension(i, 0, 3), tad);
|
||||
assertEquals(testValues.tensorAlongDimension(i, 0, 3), tad);
|
||||
}
|
||||
|
||||
|
||||
|
@ -232,7 +232,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
for (int i = 0; i < rows * dim3; i++) {
|
||||
INDArray tad = arr.tensorAlongDimension(i, 1, 2);
|
||||
assertArrayEquals(new long[] {cols, dim2}, tad.shape());
|
||||
assertEquals(testValues.javaTensorAlongDimension(i, 1, 2), tad);
|
||||
assertEquals(testValues.tensorAlongDimension(i, 1, 2), tad);
|
||||
}
|
||||
|
||||
//Along dimension 1,3: expect matrix with shape [cols,dim3]
|
||||
|
@ -240,7 +240,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
for (int i = 0; i < rows * dim2; i++) {
|
||||
INDArray tad = arr.tensorAlongDimension(i, 1, 3);
|
||||
assertArrayEquals(new long[] {cols, dim3}, tad.shape());
|
||||
assertEquals(testValues.javaTensorAlongDimension(i, 1, 3), tad);
|
||||
assertEquals(testValues.tensorAlongDimension(i, 1, 3), tad);
|
||||
}
|
||||
|
||||
//Along dimension 2,3: expect matrix with shape [dim2,dim3]
|
||||
|
@ -248,7 +248,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
|
|||
for (int i = 0; i < rows * cols; i++) {
|
||||
INDArray tad = arr.tensorAlongDimension(i, 2, 3);
|
||||
assertArrayEquals(new long[] {dim2, dim3}, tad.shape());
|
||||
assertEquals(testValues.javaTensorAlongDimension(i, 2, 3), tad);
|
||||
assertEquals(testValues.tensorAlongDimension(i, 2, 3), tad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -614,7 +614,7 @@ public class OpExecutionerTests extends BaseNd4jTest {
|
|||
assertEquals(arr5s.getDouble(i), 16, 1e-1);
|
||||
System.out.println("6d");
|
||||
INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4);
|
||||
INDArray arr6Tad = arr6.javaTensorAlongDimension(0, 2, 3);
|
||||
INDArray arr6Tad = arr6.tensorAlongDimension(0, 2, 3);
|
||||
INDArray arr6s = arr6.sum(2, 3);
|
||||
for (int i = 0; i < arr6s.length(); i++)
|
||||
assertEquals(arr6s.getDouble(i), 16, 1e-1);
|
||||
|
|
|
@ -79,7 +79,7 @@ public class TADTests extends BaseNd4jTest {
|
|||
|
||||
int[] shape = new int[] {e, x};
|
||||
Arrays.sort(shape);
|
||||
INDArray assertion = array.javaTensorAlongDimension(0, shape);
|
||||
INDArray assertion = array.tensorAlongDimension(0, shape);
|
||||
INDArray test = array.tensorAlongDimension(0, shape);
|
||||
|
||||
assertEquals(assertion, test);
|
||||
|
@ -101,7 +101,7 @@ public class TADTests extends BaseNd4jTest {
|
|||
Arrays.sort(shape);
|
||||
log.info("About to do shape: " + Arrays.toString(shape) + " for array of shape "
|
||||
+ array.shapeInfoToString());
|
||||
INDArray assertion = array.javaTensorAlongDimension(0, shape);
|
||||
INDArray assertion = array.tensorAlongDimension(0, shape);
|
||||
INDArray test = array.tensorAlongDimension(0, shape);
|
||||
assertEquals(assertion, test);
|
||||
//assertEquals(assertion.shapeInfoDataBuffer(), test.shapeInfoDataBuffer());
|
||||
|
@ -121,8 +121,8 @@ public class TADTests extends BaseNd4jTest {
|
|||
public void testMysteriousCrash() {
|
||||
INDArray arrayF = Nd4j.create(new int[] {1, 1, 4, 4}, 'f');
|
||||
INDArray arrayC = Nd4j.create(new int[] {1, 1, 4, 4}, 'c');
|
||||
INDArray javaCTad = arrayC.javaTensorAlongDimension(0, 2, 3);
|
||||
INDArray javaFTad = arrayF.javaTensorAlongDimension(0, 2, 3);
|
||||
INDArray javaCTad = arrayC.tensorAlongDimension(0, 2, 3);
|
||||
INDArray javaFTad = arrayF.tensorAlongDimension(0, 2, 3);
|
||||
Pair<DataBuffer, DataBuffer> tadBuffersF =
|
||||
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(arrayF, 2, 3);
|
||||
Pair<DataBuffer, DataBuffer> tadBuffersC =
|
||||
|
|
|
@ -185,7 +185,7 @@ public class ConcatTestsC extends BaseNd4jTest {
|
|||
//ConcatV2, dim 1
|
||||
second = Nd4j.linspace(24, 32, 8, Nd4j.dataType()).reshape('c', 2, 1, 4);
|
||||
for (int i = 0; i < second.tensorsAlongDimension(1); i++) {
|
||||
INDArray secondTad = second.javaTensorAlongDimension(i, 1);
|
||||
INDArray secondTad = second.tensorAlongDimension(i, 1);
|
||||
System.out.println(second.tensorAlongDimension(i, 1));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue