INDArray refactoring (#170)

* javadoc

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove javaTensorAlongDimension

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* javadoc
master
Robert Altena 2019-08-28 12:03:23 +09:00 committed by GitHub
parent d31197db5f
commit 59a6e4e3ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 75 additions and 179 deletions

View File

@ -1007,19 +1007,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return toTad;
}
/**
* Get the vector along a particular dimension
*
* @param index the index of the vector to getScalar
* @param dimension the dimension to getScalar the vector from
* @return the vector along a particular dimension
*/
@Override
@Deprecated
public INDArray javaTensorAlongDimension(int index, int... dimension) {
return doTad(index, dimension);
}
private void setShapeInformation(Pair<DataBuffer, long[]> shapeInfo) {
this.shapeInformation = shapeInfo.getFirst();
this.jvmShapeInfo = new JvmShapeInfo(shapeInfo.getSecond());
@ -1110,14 +1097,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return ret2.permutei(finalPermuteDims);
}
/**
* Returns the number of possible vectors for a given dimension
*
* @param dimension the dimension to calculate the number of vectors for
* @return the number of possible vectors along a dimension
*/
@Override
public long vectorsAlongDimension(int dimension) {
if (dimension == 0 && isVector() || isRowVectorOrScalar())
@ -1150,17 +1129,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return length / size(dimension);
}
/**
* Get the vector along a particular dimension
*
* @param index the index of the vector to get
* @param dimension the dimension to get the vector from
* @return the vector along a particular dimension
*/
@Override
public INDArray vectorAlongDimension(int index, int dimension) {
if (dimension < 0)
if (dimension < 0) {
dimension = jvmShapeInfo.getRank() + dimension;
}
//return the whole thing
if (dimension == jvmShapeInfo.getRank() - 1 && size(dimension) == 1 && rank() > 2
@ -1168,12 +1141,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return this;
}
INDArray ret = tensorAlongDimension(index, dimension);
//if (isMatrix() && ret.isVector() && dimension == 1 && !ret.isRowVector())
// return ret.reshape(ArrayUtil.reverseCopy(ret.shape()));
//else if (isMatrix() && ret.isVector() && dimension == 0 && !ret.isColumnVector())
// return ret.reshape(ArrayUtil.reverseCopy(ret.shape()));
return ret;
return tensorAlongDimension(index, dimension);
}
@Override
@ -1196,13 +1164,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false));
}
/**
* Cumulative sum along a dimension
*
* @param dimension the dimension to perform cumulative sum along
* @return the cumulative sum along the specified dimension
*/
@Override
public INDArray cumsumi(int dimension) {
validateNumericalArray("cumsumi", true);
@ -1351,25 +1312,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return logEntropy(Integer.MAX_VALUE).getDouble(0);
}
/**
* Cumulative sum along a dimension (in place)
*
* @param dimension the dimension to perform cumulative sum along
* @return the cumulative sum along the specified dimension
*/
@Override
public INDArray cumsum(int dimension) {
validateNumericalArray("cumsum", true);
return dup().cumsumi(dimension);
}
/**
* Assign all of the elements in the given
* ndarray to this ndarray
*
* @param arr the elements to assign
* @return this
*/
@Override
public INDArray assign(final INDArray arr) {
Preconditions.checkState((this.isScalar() && arr.isScalar()) || (this.isVector() && arr.isVector()) || Shape.shapeEqualWithSqueeze(this.shape(), arr.shape()),
@ -1378,7 +1326,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
Preconditions.checkArgument(this.length() == arr.length(), "Length of both arrays must be equal");
//Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set(this, arr, this, length()));
Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.any.Assign(arr, this));
return this;
}
@ -1413,7 +1360,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override
public INDArray putScalar(long i, float value) {
return putScalar(i, (double) value);
}
@Override
@ -1540,7 +1486,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return this;
}
@Override
public INDArray putScalar(int[] indexes, float value) {
return putScalar(indexes, (double) value);
@ -1556,27 +1501,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return putScalar(indexes, (double) value);
}
/**
* Returns an ndarray with 1 if the element is epsilon equals
*
* @param other the number to compare
* @return a copied ndarray with the given
* binary conditions
*/
@Override
public INDArray eps(Number other) {
validateNumericalArray("eps", true);
return Nd4j.getExecutioner().exec(new ScalarEps(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other));
}
/**
* epsilon equals than comparison:
* If the given number is less than the
* comparison number the item is 0 otherwise 1
*
* @param other the number to compare
* @return
*/
@Override
public INDArray eps(INDArray other) {
validateNumericalArray("eps", true);
@ -1613,7 +1543,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new ScalarGreaterThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other));
}
@Override
public INDArray lt(INDArray other) {
validateNumericalArray("less than (lt)", false);
@ -1675,9 +1604,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isNan()));
}
/**
* Negate each element.
*/
@Override
public INDArray neg() {
validateNumericalArray("negative (neg)", true);
@ -1686,9 +1612,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new Negative(this, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering())));
}
/**
* Negate each element (in-place).
*/
@Override
public INDArray negi() {
validateNumericalArray("negative (negi)", true);
@ -3909,28 +3832,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return assign(value ? 1 : 0);
}
/**
* Assign all elements from given ndarray that are matching given condition,
* ndarray to this ndarray
*
* @param arr the elements to assign
* @param condition
* @return this
*/
@Override
public INDArray assignIf(INDArray arr, Condition condition) {
BooleanIndexing.assignIf(this, arr, condition);
return this;
}
/**
* Replaces all elements in this ndarray that are matching give condition, with corresponding elements from given array
*
* @param arr
* @param condition
* @return
*/
@Override
public INDArray replaceWhere(INDArray arr, Condition condition) {
Nd4j.getCompressor().autoDecompress(this);

View File

@ -411,11 +411,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
return null;
}
@Override
public INDArray javaTensorAlongDimension(int index, int... dimension) {
return null;
}
@Override
public INDArray cumsumi(int dimension) {
return null;
@ -476,7 +471,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
return null;
}
@Override
public INDArray isInfinite() {
throw new UnsupportedOperationException();
@ -551,6 +545,7 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
public INDArray lte(Number other) {
return null;
}
@Override
public INDArray lt(INDArray other) {
return null;

View File

@ -106,27 +106,31 @@ public interface INDArray extends Serializable, AutoCloseable {
/**
* Element wise stride
* @return the element wise stride
*/
int elementWiseStride();
/**
* Get a scalar
* at the given linear offset
* Get a double at the given linear offset unsafe, without checks.
* @param offset the offset to get at
* @return this
* @return double value at offset
*/
double getDoubleUnsafe(long offset);
double getDoubleUnsafe(long offset); //TODO: consider deleting.
/**
* Get string value at given index.
* @param index index to retreive
* @return string value at index.
*/
String getString(long index);
/**
* Insert a scalar
* at the given linear offset
* Insert a scalar at the given linear offset
* @param offset the offset to insert at
* @param value the value to insert
* @return this
*/
INDArray putScalarUnsafe(long offset, double value);
INDArray putScalarUnsafe(long offset, double value); //TODO: consider deleting.
/**
* Returns the number of possible vectors for a given dimension
@ -162,17 +166,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray tensorAlongDimension(long index, int... dimension);
/**
* Get the vector along a particular dimension
*
* @param index the index of the vector to getScalar
* @param dimension the dimension to getScalar the vector from
* @return the vector along a particular dimension
*/
@Deprecated
INDArray javaTensorAlongDimension(int index, int... dimension);
/**
* Returns the cumulative sum along a dimension. In-place method.
*
@ -190,8 +183,7 @@ public interface INDArray extends Serializable, AutoCloseable {
INDArray cumsum(int dimension);
/**
* Assign all of the elements in the given
* ndarray to this ndarray
* Assign all of the elements in the given ndarray to this ndarray
*
* @param arr the elements to assign
* @return this
@ -254,9 +246,19 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray putScalar(int[] i, double value);
/**
* See {@link #putScalar(int[], double)}
*/
INDArray putScalar(long[] i, double value);
/**
* See {@link #putScalar(int[], double)}
*/
INDArray putScalar(long[] i, float value);
/**
* See {@link #putScalar(int[], double)}
*/
INDArray putScalar(long[] i, int value);
/**
@ -300,7 +302,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray lt(Number other);
/**
* Put the specified float value at the specified indices in this array
*
@ -327,8 +328,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray eps(Number other);
/**
* Returns the binary ndarray for "Equals" comparison.
*
@ -367,10 +366,8 @@ public interface INDArray extends Serializable, AutoCloseable {
* @param other the ndarray to compare.
* @return the binary ndarray for "Less" comparison.
*/
INDArray lt(INDArray other);
/**
* Returns the binary ndarray for "Epsilon equals" comparison.
*
@ -403,7 +400,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray eq(INDArray other);
/**
* Returns the binary ndarray for "Greater Than" comparison.
*
@ -424,8 +420,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
INDArray isNaN();
/**
* Returns the ndarray negative (cloned)
*

View File

@ -181,37 +181,37 @@ public class NDArrayCreationUtil {
INDArray[] out = new INDArray[12];
INDArray temp01 = Nd4j.linspace(1, cols * rows * 4, cols * rows * 4, dataType).reshape(cols, rows, 4);
out[0] = temp01.javaTensorAlongDimension(0, 0, 1).reshape(rows, cols);
out[0] = temp01.tensorAlongDimension(0, 0, 1).reshape(rows, cols);
long[] temp01Shape = new long[] {cols, rows, 4};
int len = ArrayUtil.prod(temp01Shape);
temp01 = Nd4j.linspace(1, len, len, dataType).reshape(temp01Shape);
out[1] = temp01.javaTensorAlongDimension(2, 0, 1).reshape(rows, cols);
out[1] = temp01.tensorAlongDimension(2, 0, 1).reshape(rows, cols);
Nd4j.getRandom().setSeed(seed);
INDArray temp02 = Nd4j.linspace(1, len, len, dataType).reshape(new long[] {cols, 4, rows});
out[2] = temp02.javaTensorAlongDimension(0, 0, 2).reshape(rows, cols);
out[2] = temp02.tensorAlongDimension(0, 0, 2).reshape(rows, cols);
temp02 = Nd4j.linspace(1, len, len, dataType).reshape(cols, 4, rows);
out[3] = temp02.javaTensorAlongDimension(2, 0, 2).reshape(rows, cols);
out[3] = temp02.tensorAlongDimension(2, 0, 2).reshape(rows, cols);
INDArray temp10 = Nd4j.linspace(1, len, len, dataType).reshape(rows, cols, 4);
out[4] = temp10.javaTensorAlongDimension(0, 1, 0).reshape(rows, cols);
out[4] = temp10.tensorAlongDimension(0, 1, 0).reshape(rows, cols);
temp10 = Nd4j.linspace(1, len, len, dataType).reshape(rows, cols, 4);
out[5] = temp10.javaTensorAlongDimension(2, 1, 0).reshape(rows, cols);
out[5] = temp10.tensorAlongDimension(2, 1, 0).reshape(rows, cols);
INDArray temp12 = Nd4j.linspace(1, len, len, dataType).reshape(4, cols, rows);
out[6] = temp12.javaTensorAlongDimension(0, 1, 2).reshape(rows, cols);
out[6] = temp12.tensorAlongDimension(0, 1, 2).reshape(rows, cols);
temp12 = Nd4j.linspace(1, len, len, dataType).reshape(4, cols, rows);
out[7] = temp12.javaTensorAlongDimension(2, 1, 2).reshape(rows, cols);
out[7] = temp12.tensorAlongDimension(2, 1, 2).reshape(rows, cols);
INDArray temp20 = Nd4j.linspace(1, len, len, dataType).reshape(rows, 4, cols);
out[8] = temp20.javaTensorAlongDimension(0, 2, 0).reshape(rows, cols);
out[8] = temp20.tensorAlongDimension(0, 2, 0).reshape(rows, cols);
temp20 = Nd4j.linspace(1, len, len, dataType).reshape(rows, 4, cols);
out[9] = temp20.javaTensorAlongDimension(2, 2, 0).reshape(rows, cols);
out[9] = temp20.tensorAlongDimension(2, 2, 0).reshape(rows, cols);
INDArray temp21 = Nd4j.linspace(1, len, len, dataType).reshape(4, rows, cols);
out[10] = temp21.javaTensorAlongDimension(0, 2, 1).reshape(rows, cols);
out[10] = temp21.tensorAlongDimension(0, 2, 1).reshape(rows, cols);
temp21 = Nd4j.linspace(1, len, len, dataType).reshape(4, rows, cols);
out[11] = temp21.javaTensorAlongDimension(2, 2, 1).reshape(rows, cols);
out[11] = temp21.tensorAlongDimension(2, 2, 1).reshape(rows, cols);
String baseMsg = "getTensorAlongDimensionMatricesWithShape(" + rows + "," + cols + "," + seed + ")";
List<Pair<INDArray, String>> list = new ArrayList<>(12);
@ -361,9 +361,9 @@ public class NDArrayCreationUtil {
val shape4d1 = new long[]{shape[0], shape[1], shape[2], 3};
int lenshape4d1 = ArrayUtil.prod(shape4d1);
INDArray orig1a = Nd4j.linspace(1, lenshape4d1, lenshape4d1, dataType).reshape(shape4d1);
INDArray tad1a = orig1a.javaTensorAlongDimension(0, 0, 1, 2);
INDArray tad1a = orig1a.tensorAlongDimension(0, 0, 1, 2);
INDArray orig1b = Nd4j.linspace(1, lenshape4d1, lenshape4d1, dataType).reshape(shape4d1);
INDArray tad1b = orig1b.javaTensorAlongDimension(1, 0, 1, 2);
INDArray tad1b = orig1b.tensorAlongDimension(1, 0, 1, 2);
list.add(new Pair<>(tad1a, baseMsg + ".get(0)"));
list.add(new Pair<>(tad1b, baseMsg + ".get(1)"));
@ -371,19 +371,19 @@ public class NDArrayCreationUtil {
long[] shape4d2 = {3, shape[0], shape[1], shape[2]};
int lenshape4d2 = ArrayUtil.prod(shape4d2);
INDArray orig2 = Nd4j.linspace(1, lenshape4d2, lenshape4d2, dataType).reshape(shape4d2);
INDArray tad2 = orig2.javaTensorAlongDimension(1, 1, 2, 3);
INDArray tad2 = orig2.tensorAlongDimension(1, 1, 2, 3);
list.add(new Pair<>(tad2, baseMsg + ".get(2)"));
long[] shape4d3 = {shape[0], shape[1], 3, shape[2]};
int lenshape4d3 = ArrayUtil.prod(shape4d3);
INDArray orig3 = Nd4j.linspace(1, lenshape4d3, lenshape4d3, dataType).reshape(shape4d3);
INDArray tad3 = orig3.javaTensorAlongDimension(1, 1, 3, 0);
INDArray tad3 = orig3.tensorAlongDimension(1, 1, 3, 0);
list.add(new Pair<>(tad3, baseMsg + ".get(3)"));
long[] shape4d4 = {shape[0], 3, shape[1], shape[2]};
int lenshape4d4 = ArrayUtil.prod(shape4d4);
INDArray orig4 = Nd4j.linspace(1, lenshape4d4, lenshape4d4, dataType).reshape(shape4d4);
INDArray tad4 = orig4.javaTensorAlongDimension(1, 2, 0, 3);
INDArray tad4 = orig4.tensorAlongDimension(1, 2, 0, 3);
list.add(new Pair<>(tad4, baseMsg + ".get(4)"));
return list;
@ -513,9 +513,9 @@ public class NDArrayCreationUtil {
int[] shape4d1 = {3, shape[0], shape[1], shape[2], shape[3]};
int len = ArrayUtil.prod(shape4d1);
INDArray orig1a = Nd4j.linspace(1, len, len, dataType).reshape(ArrayUtil.toLongArray(shape4d1));
INDArray tad1a = orig1a.javaTensorAlongDimension(0, 1, 2, 3, 4);
INDArray tad1a = orig1a.tensorAlongDimension(0, 1, 2, 3, 4);
INDArray orig1b = Nd4j.linspace(1, len, len, dataType).reshape(ArrayUtil.toLongArray(shape4d1));
INDArray tad1b = orig1b.javaTensorAlongDimension(2, 1, 2, 3, 4);
INDArray tad1b = orig1b.tensorAlongDimension(2, 1, 2, 3, 4);
list.add(new Pair<>(tad1a, baseMsg + ".get(0)"));
list.add(new Pair<>(tad1b, baseMsg + ".get(1)"));
@ -523,19 +523,19 @@ public class NDArrayCreationUtil {
int[] shape4d2 = {3, shape[0], shape[1], shape[2], shape[3]};
int len2 = ArrayUtil.prod(shape4d2);
INDArray orig2 = Nd4j.linspace(1, len2, len2, dataType).reshape(ArrayUtil.toLongArray(shape4d2));
INDArray tad2 = orig2.javaTensorAlongDimension(1, 3, 4, 2, 1);
INDArray tad2 = orig2.tensorAlongDimension(1, 3, 4, 2, 1);
list.add(new Pair<>(tad2, baseMsg + ".get(2)"));
int[] shape4d3 = {shape[0], shape[1], 3, shape[2], shape[3]};
int len3 = ArrayUtil.prod(shape4d3);
INDArray orig3 = Nd4j.linspace(1, len3, len3, dataType).reshape(ArrayUtil.toLongArray(shape4d3));
INDArray tad3 = orig3.javaTensorAlongDimension(1, 4, 1, 3, 0);
INDArray tad3 = orig3.tensorAlongDimension(1, 4, 1, 3, 0);
list.add(new Pair<>(tad3, baseMsg + ".get(3)"));
int[] shape4d4 = {shape[0], shape[1], shape[2], shape[3], 3};
int len4 = ArrayUtil.prod(shape4d4);
INDArray orig4 = Nd4j.linspace(1, len4, len4, dataType).reshape(ArrayUtil.toLongArray(shape4d4));
INDArray tad4 = orig4.javaTensorAlongDimension(1, 2, 0, 3, 1);
INDArray tad4 = orig4.tensorAlongDimension(1, 2, 0, 3, 1);
list.add(new Pair<>(tad4, baseMsg + ".get(4)"));
return list;
@ -655,26 +655,26 @@ public class NDArrayCreationUtil {
Nd4j.getRandom().setSeed(seed);
int[] shape4d1 = {3, shape[0], shape[1], shape[2], shape[3], shape[4]};
INDArray orig1a = Nd4j.rand(dataType, shape4d1);
INDArray tad1a = orig1a.javaTensorAlongDimension(0, 1, 2, 3, 4, 5);
INDArray tad1a = orig1a.tensorAlongDimension(0, 1, 2, 3, 4, 5);
INDArray orig1b = Nd4j.rand(dataType, shape4d1);
INDArray tad1b = orig1b.javaTensorAlongDimension(2, 1, 2, 3, 4, 5);
INDArray tad1b = orig1b.tensorAlongDimension(2, 1, 2, 3, 4, 5);
list.add(new Pair<>(tad1a, baseMsg + ".get(0)"));
list.add(new Pair<>(tad1b, baseMsg + ".get(1)"));
int[] shape4d2 = {3, shape[0], shape[1], shape[2], shape[3], shape[4]};
INDArray orig2 = Nd4j.rand(dataType, shape4d2);
INDArray tad2 = orig2.javaTensorAlongDimension(1, 3, 5, 4, 2, 1);
INDArray tad2 = orig2.tensorAlongDimension(1, 3, 5, 4, 2, 1);
list.add(new Pair<>(tad2, baseMsg + ".get(2)"));
int[] shape4d3 = {shape[0], shape[1], shape[2], shape[3], shape[4], 2};
INDArray orig3 = Nd4j.rand(dataType, shape4d3);
INDArray tad3 = orig3.javaTensorAlongDimension(1, 4, 1, 3, 2, 0);
INDArray tad3 = orig3.tensorAlongDimension(1, 4, 1, 3, 2, 0);
list.add(new Pair<>(tad3, baseMsg + ".get(3)"));
int[] shape4d4 = {shape[0], shape[1], shape[2], shape[3], 3, shape[4]};
INDArray orig4 = Nd4j.rand(dataType, shape4d4);
INDArray tad4 = orig4.javaTensorAlongDimension(1, 5, 2, 0, 3, 1);
INDArray tad4 = orig4.tensorAlongDimension(1, 5, 2, 0, 3, 1);
list.add(new Pair<>(tad4, baseMsg + ".get(4)"));
return list;

View File

@ -374,7 +374,7 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet {
long nTensors = labels.tensorsAlongDimension(1);
for (int i = 0; i < nTensors; i++) {
INDArray row = labels.tensorAlongDimension(i, 1);
INDArray javaRow = labels.javaTensorAlongDimension(i, 1);
INDArray javaRow = labels.tensorAlongDimension(i, 1);
int maxIdx = Nd4j.getBlasWrapper().iamax(row);
int maxIdxJava = Nd4j.getBlasWrapper().iamax(javaRow);
if (maxIdx < 0)

View File

@ -1334,7 +1334,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray zC = Nd4j.create(shape, 'c');
zC.setData(Nd4j.linspace(1, 24, 24, DataType.DOUBLE).data());
for (int tad = 0; tad < zC.tensorsAlongDimension(dim); tad++) {
INDArray javaTad = zC.javaTensorAlongDimension(tad, dim);
INDArray javaTad = zC.tensorAlongDimension(tad, dim);
System.out.println("Tad " + tad + " is " + zC.tensorAlongDimension(tad, dim));
}

View File

@ -57,7 +57,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
for (int i = 0; i < n; i++) {
StopWatch javaTiming = new StopWatch();
javaTiming.start();
row.javaTensorAlongDimension(0, 0);
row.tensorAlongDimension(0, 0);
javaTiming.stop();
StopWatch cTiming = new StopWatch();
cTiming.start();
@ -98,7 +98,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
assertEquals(cols, arr.tensorsAlongDimension(0));
for (int i = 0; i < cols; i++) {
INDArray tad = arr.tensorAlongDimension(i, 0);
INDArray javaTad = arr.javaTensorAlongDimension(i, 0);
INDArray javaTad = arr.tensorAlongDimension(i, 0);
assertEquals(javaTad, tad);
assertArrayEquals(new int[] {rows}, tad.shape());
//assertEquals(testValues.javaTensorAlongDimension(i, 0), tad);
@ -120,7 +120,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
list = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, new int[]{rows, cols, dim2}, DataType.DOUBLE);
for (Pair<INDArray, String> p : list) {
INDArray arr = p.getFirst().assign(testValues);
INDArray javaTad = arr.javaTensorAlongDimension(0, 0);
INDArray javaTad = arr.tensorAlongDimension(0, 0);
INDArray tadTest = arr.tensorAlongDimension(0, 0);
assertEquals(javaTad, tadTest);
//Along dimension 0: expect row vector with length 'rows'
@ -165,7 +165,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
//Along dimension 0,1: expect matrix with shape [rows,cols]
assertEquals(dim2, arr.tensorsAlongDimension(0, 1));
for (int i = 0; i < dim2; i++) {
INDArray javaTad = arr.javaTensorAlongDimension(i, 0, 1);
INDArray javaTad = arr.tensorAlongDimension(i, 0, 1);
INDArray tad = arr.tensorAlongDimension(i, 0, 1);
int javaEleStride = javaTad.elementWiseStride();
int testTad = tad.elementWiseStride();
@ -178,11 +178,11 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
//Along dimension 0,2: expect matrix with shape [rows,dim2]
assertEquals(cols, arr.tensorsAlongDimension(0, 2));
for (int i = 0; i < cols; i++) {
INDArray javaTad = arr.javaTensorAlongDimension(i, 0, 2);
INDArray javaTad = arr.tensorAlongDimension(i, 0, 2);
INDArray tad = arr.tensorAlongDimension(i, 0, 2);
assertEquals(javaTad, tad);
assertArrayEquals(new long[] {rows, dim2}, tad.shape());
assertEquals(testValues.javaTensorAlongDimension(i, 0, 2), tad);
assertEquals(testValues.tensorAlongDimension(i, 0, 2), tad);
}
//Along dimension 1,2: expect matrix with shape [cols,dim2]
@ -190,7 +190,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
for (int i = 0; i < rows; i++) {
INDArray tad = arr.tensorAlongDimension(i, 1, 2);
assertArrayEquals(new long[] {cols, dim2}, tad.shape());
assertEquals(testValues.javaTensorAlongDimension(i, 1, 2), tad);
assertEquals(testValues.tensorAlongDimension(i, 1, 2), tad);
}
}
@ -207,7 +207,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
for (int i = 0; i < dim2 * dim3; i++) {
INDArray tad = arr.tensorAlongDimension(i, 0, 1);
assertArrayEquals(new long[] {rows, cols}, tad.shape());
assertEquals(testValues.javaTensorAlongDimension(i, 0, 1), tad);
assertEquals(testValues.tensorAlongDimension(i, 0, 1), tad);
}
//Along dimension 0,2: expect matrix with shape [rows,dim2]
@ -215,7 +215,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
for (int i = 0; i < cols * dim3; i++) {
INDArray tad = arr.tensorAlongDimension(i, 0, 2);
assertArrayEquals(new long[] {rows, dim2}, tad.shape());
assertEquals(testValues.javaTensorAlongDimension(i, 0, 2), tad);
assertEquals(testValues.tensorAlongDimension(i, 0, 2), tad);
}
//Along dimension 0,3: expect matrix with shape [rows,dim3]
@ -223,7 +223,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
for (int i = 0; i < cols * dim2; i++) {
INDArray tad = arr.tensorAlongDimension(i, 0, 3);
assertArrayEquals(new long[] {rows, dim3}, tad.shape());
assertEquals(testValues.javaTensorAlongDimension(i, 0, 3), tad);
assertEquals(testValues.tensorAlongDimension(i, 0, 3), tad);
}
@ -232,7 +232,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
for (int i = 0; i < rows * dim3; i++) {
INDArray tad = arr.tensorAlongDimension(i, 1, 2);
assertArrayEquals(new long[] {cols, dim2}, tad.shape());
assertEquals(testValues.javaTensorAlongDimension(i, 1, 2), tad);
assertEquals(testValues.tensorAlongDimension(i, 1, 2), tad);
}
//Along dimension 1,3: expect matrix with shape [cols,dim3]
@ -240,7 +240,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
for (int i = 0; i < rows * dim2; i++) {
INDArray tad = arr.tensorAlongDimension(i, 1, 3);
assertArrayEquals(new long[] {cols, dim3}, tad.shape());
assertEquals(testValues.javaTensorAlongDimension(i, 1, 3), tad);
assertEquals(testValues.tensorAlongDimension(i, 1, 3), tad);
}
//Along dimension 2,3: expect matrix with shape [dim2,dim3]
@ -248,7 +248,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest {
for (int i = 0; i < rows * cols; i++) {
INDArray tad = arr.tensorAlongDimension(i, 2, 3);
assertArrayEquals(new long[] {dim2, dim3}, tad.shape());
assertEquals(testValues.javaTensorAlongDimension(i, 2, 3), tad);
assertEquals(testValues.tensorAlongDimension(i, 2, 3), tad);
}
}
}

View File

@ -614,7 +614,7 @@ public class OpExecutionerTests extends BaseNd4jTest {
assertEquals(arr5s.getDouble(i), 16, 1e-1);
System.out.println("6d");
INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4);
INDArray arr6Tad = arr6.javaTensorAlongDimension(0, 2, 3);
INDArray arr6Tad = arr6.tensorAlongDimension(0, 2, 3);
INDArray arr6s = arr6.sum(2, 3);
for (int i = 0; i < arr6s.length(); i++)
assertEquals(arr6s.getDouble(i), 16, 1e-1);

View File

@ -79,7 +79,7 @@ public class TADTests extends BaseNd4jTest {
int[] shape = new int[] {e, x};
Arrays.sort(shape);
INDArray assertion = array.javaTensorAlongDimension(0, shape);
INDArray assertion = array.tensorAlongDimension(0, shape);
INDArray test = array.tensorAlongDimension(0, shape);
assertEquals(assertion, test);
@ -101,7 +101,7 @@ public class TADTests extends BaseNd4jTest {
Arrays.sort(shape);
log.info("About to do shape: " + Arrays.toString(shape) + " for array of shape "
+ array.shapeInfoToString());
INDArray assertion = array.javaTensorAlongDimension(0, shape);
INDArray assertion = array.tensorAlongDimension(0, shape);
INDArray test = array.tensorAlongDimension(0, shape);
assertEquals(assertion, test);
//assertEquals(assertion.shapeInfoDataBuffer(), test.shapeInfoDataBuffer());
@ -121,8 +121,8 @@ public class TADTests extends BaseNd4jTest {
public void testMysteriousCrash() {
INDArray arrayF = Nd4j.create(new int[] {1, 1, 4, 4}, 'f');
INDArray arrayC = Nd4j.create(new int[] {1, 1, 4, 4}, 'c');
INDArray javaCTad = arrayC.javaTensorAlongDimension(0, 2, 3);
INDArray javaFTad = arrayF.javaTensorAlongDimension(0, 2, 3);
INDArray javaCTad = arrayC.tensorAlongDimension(0, 2, 3);
INDArray javaFTad = arrayF.tensorAlongDimension(0, 2, 3);
Pair<DataBuffer, DataBuffer> tadBuffersF =
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(arrayF, 2, 3);
Pair<DataBuffer, DataBuffer> tadBuffersC =

View File

@ -185,7 +185,7 @@ public class ConcatTestsC extends BaseNd4jTest {
//ConcatV2, dim 1
second = Nd4j.linspace(24, 32, 8, Nd4j.dataType()).reshape('c', 2, 1, 4);
for (int i = 0; i < second.tensorsAlongDimension(1); i++) {
INDArray secondTad = second.javaTensorAlongDimension(i, 1);
INDArray secondTad = second.tensorAlongDimension(i, 1);
System.out.println(second.tensorAlongDimension(i, 1));
}