INDArray refactoring (#170)
* javadoc Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove javaTensorAlongDimension Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * javadoc
This commit is contained in:
		
							parent
							
								
									d31197db5f
								
							
						
					
					
						commit
						59a6e4e3ae
					
				| @ -1007,19 +1007,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|         return toTad; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the vector along a particular dimension | ||||
|      * | ||||
|      * @param index     the index of the vector to getScalar | ||||
|      * @param dimension the dimension to getScalar the vector from | ||||
|      * @return the vector along a particular dimension | ||||
|      */ | ||||
|     @Override | ||||
|     @Deprecated | ||||
|     public INDArray javaTensorAlongDimension(int index, int... dimension) { | ||||
|         return doTad(index, dimension); | ||||
|     } | ||||
| 
 | ||||
|     private void setShapeInformation(Pair<DataBuffer, long[]> shapeInfo) { | ||||
|         this.shapeInformation = shapeInfo.getFirst(); | ||||
|         this.jvmShapeInfo = new JvmShapeInfo(shapeInfo.getSecond()); | ||||
| @ -1110,14 +1097,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|         return ret2.permutei(finalPermuteDims); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the number of possible vectors for a given dimension | ||||
|      * | ||||
|      * @param dimension the dimension to calculate the number of vectors for | ||||
|      * @return the number of possible vectors along a dimension | ||||
|      */ | ||||
|     @Override | ||||
|     public long vectorsAlongDimension(int dimension) { | ||||
|         if (dimension == 0 && isVector() || isRowVectorOrScalar()) | ||||
| @ -1150,17 +1129,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|         return length / size(dimension); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the vector along a particular dimension | ||||
|      * | ||||
|      * @param index     the index of the vector to get | ||||
|      * @param dimension the dimension to get the vector from | ||||
|      * @return the vector along a particular dimension | ||||
|      */ | ||||
|     @Override | ||||
|     public INDArray vectorAlongDimension(int index, int dimension) { | ||||
|         if (dimension < 0) | ||||
|         if (dimension < 0) { | ||||
|             dimension = jvmShapeInfo.getRank() + dimension; | ||||
|         } | ||||
| 
 | ||||
|         //return the whole thing | ||||
|         if (dimension == jvmShapeInfo.getRank() - 1 && size(dimension) == 1 && rank() > 2 | ||||
| @ -1168,12 +1141,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         INDArray ret = tensorAlongDimension(index, dimension); | ||||
|         //if (isMatrix() && ret.isVector() && dimension == 1 && !ret.isRowVector()) | ||||
|         //    return ret.reshape(ArrayUtil.reverseCopy(ret.shape())); | ||||
|         //else if (isMatrix() && ret.isVector() && dimension == 0 && !ret.isColumnVector()) | ||||
|         //    return ret.reshape(ArrayUtil.reverseCopy(ret.shape())); | ||||
|         return ret; | ||||
|         return tensorAlongDimension(index, dimension); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
| @ -1196,13 +1164,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|         setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride),  0, ordering(), this.dataType(), false)); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Cumulative sum along a dimension | ||||
|      * | ||||
|      * @param dimension the dimension to perform cumulative sum along | ||||
|      * @return the cumulative sum along the specified dimension | ||||
|      */ | ||||
|     @Override | ||||
|     public INDArray cumsumi(int dimension) { | ||||
|         validateNumericalArray("cumsumi", true); | ||||
| @ -1351,25 +1312,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|         return logEntropy(Integer.MAX_VALUE).getDouble(0); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Cumulative sum along a dimension (in place) | ||||
|      * | ||||
|      * @param dimension the dimension to perform cumulative sum along | ||||
|      * @return the cumulative sum along the specified dimension | ||||
|      */ | ||||
|     @Override | ||||
|     public INDArray cumsum(int dimension) { | ||||
|         validateNumericalArray("cumsum", true); | ||||
|         return dup().cumsumi(dimension); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Assign all of the elements in the given | ||||
|      * ndarray to this ndarray | ||||
|      * | ||||
|      * @param arr the elements to assign | ||||
|      * @return this | ||||
|      */ | ||||
|     @Override | ||||
|     public INDArray assign(final INDArray arr) { | ||||
|         Preconditions.checkState((this.isScalar() && arr.isScalar()) || (this.isVector() && arr.isVector()) || Shape.shapeEqualWithSqueeze(this.shape(), arr.shape()), | ||||
| @ -1378,7 +1326,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
| 
 | ||||
|         Preconditions.checkArgument(this.length() == arr.length(), "Length of both arrays must be equal"); | ||||
| 
 | ||||
|         //Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set(this, arr, this, length())); | ||||
|         Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.any.Assign(arr, this)); | ||||
|         return this; | ||||
|     } | ||||
| @ -1413,7 +1360,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|     @Override | ||||
|     public INDArray putScalar(long i, float value) { | ||||
|         return putScalar(i, (double) value); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
| @ -1540,7 +1486,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|         return this; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray putScalar(int[] indexes, float value) { | ||||
|         return putScalar(indexes, (double) value); | ||||
| @ -1556,27 +1501,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|         return putScalar(indexes, (double) value); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Returns an ndarray with 1 if the element is epsilon equals | ||||
|      * | ||||
|      * @param other the number to compare | ||||
|      * @return a copied ndarray with the given | ||||
|      * binary conditions | ||||
|      */ | ||||
|     @Override | ||||
|     public INDArray eps(Number other) { | ||||
|         validateNumericalArray("eps", true); | ||||
|         return Nd4j.getExecutioner().exec(new ScalarEps(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * epsilon equals than comparison: | ||||
|      * If the given number is less than the | ||||
|      * comparison number the item is 0 otherwise 1 | ||||
|      * | ||||
|      * @param other the number to compare | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public INDArray eps(INDArray other) { | ||||
|         validateNumericalArray("eps", true); | ||||
| @ -1613,7 +1543,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|         return Nd4j.getExecutioner().exec(new ScalarGreaterThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray lt(INDArray other) { | ||||
|         validateNumericalArray("less than (lt)", false); | ||||
| @ -1675,9 +1604,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|         return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isNan())); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Negate each element. | ||||
|      */ | ||||
|     @Override | ||||
|     public INDArray neg() { | ||||
|         validateNumericalArray("negative (neg)", true); | ||||
| @ -1686,9 +1612,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|         return Nd4j.getExecutioner().exec(new Negative(this, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()))); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Negate each element (in-place). | ||||
|      */ | ||||
|     @Override | ||||
|     public INDArray negi() { | ||||
|         validateNumericalArray("negative (negi)", true); | ||||
| @ -3909,28 +3832,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { | ||||
|         return assign(value ? 1 : 0); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Assign all elements from given ndarray that are matching given condition, | ||||
|      * ndarray to this ndarray | ||||
|      * | ||||
|      * @param arr       the elements to assign | ||||
|      * @param condition | ||||
|      * @return this | ||||
|      */ | ||||
|     @Override | ||||
|     public INDArray assignIf(INDArray arr, Condition condition) { | ||||
|         BooleanIndexing.assignIf(this, arr, condition); | ||||
|         return this; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Replaces all elements in this ndarray that are matching give condition, with corresponding elements from given array | ||||
|      * | ||||
|      * @param arr | ||||
|      * @param condition | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public INDArray replaceWhere(INDArray arr, Condition condition) { | ||||
|         Nd4j.getCompressor().autoDecompress(this); | ||||
|  | ||||
| @ -411,11 +411,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray javaTensorAlongDimension(int index, int... dimension) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray cumsumi(int dimension) { | ||||
|         return null; | ||||
| @ -476,7 +471,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray isInfinite() { | ||||
|         throw new UnsupportedOperationException(); | ||||
| @ -551,6 +545,7 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { | ||||
|     public INDArray lte(Number other) { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray lt(INDArray other) { | ||||
|         return null; | ||||
|  | ||||
| @ -106,27 +106,31 @@ public interface INDArray extends Serializable, AutoCloseable { | ||||
| 
 | ||||
|     /** | ||||
|      * Element wise stride | ||||
|      * @return the element wise stride | ||||
|      */ | ||||
|     int elementWiseStride(); | ||||
| 
 | ||||
|     /** | ||||
|      * Get a scalar | ||||
|      * at the given linear offset | ||||
|      * Get a double at the given linear offset unsafe, without checks. | ||||
|      * @param offset the offset to get at | ||||
|      * @return this | ||||
|      * @return double value at offset | ||||
|      */ | ||||
|     double getDoubleUnsafe(long offset); | ||||
|     double getDoubleUnsafe(long offset); //TODO: consider deleting. | ||||
| 
 | ||||
|     /** | ||||
|      * Get string value at given index. | ||||
|      * @param index index to retreive | ||||
|      * @return string value at index. | ||||
|      */ | ||||
|     String getString(long index); | ||||
| 
 | ||||
|     /** | ||||
|      * Insert a scalar | ||||
|      * at the given linear offset | ||||
|      * Insert a scalar at the given linear offset | ||||
|      * @param offset the offset to insert at | ||||
|      * @param value the value to insert | ||||
|      * @return this | ||||
|      */ | ||||
|     INDArray putScalarUnsafe(long offset, double value); | ||||
|     INDArray putScalarUnsafe(long offset, double value); //TODO: consider deleting. | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the number of possible vectors for a given dimension | ||||
| @ -162,17 +166,6 @@ public interface INDArray extends Serializable, AutoCloseable { | ||||
|      */ | ||||
|     INDArray tensorAlongDimension(long index, int... dimension); | ||||
| 
 | ||||
|     /** | ||||
|      * Get the vector along a particular dimension | ||||
|      * | ||||
|      * @param index     the index of the vector to getScalar | ||||
|      * @param dimension the dimension to getScalar the vector from | ||||
|      * @return the vector along a particular dimension | ||||
|      */ | ||||
|     @Deprecated | ||||
|     INDArray javaTensorAlongDimension(int index, int... dimension); | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the cumulative sum along a dimension. In-place method. | ||||
|      * | ||||
| @ -190,8 +183,7 @@ public interface INDArray extends Serializable, AutoCloseable { | ||||
|     INDArray cumsum(int dimension); | ||||
| 
 | ||||
|     /** | ||||
|      * Assign all of the elements in the given | ||||
|      * ndarray to this ndarray | ||||
|      * Assign all of the elements in the given ndarray to this ndarray | ||||
|      * | ||||
|      * @param arr the elements to assign | ||||
|      * @return this | ||||
| @ -254,9 +246,19 @@ public interface INDArray extends Serializable, AutoCloseable { | ||||
|      */ | ||||
|     INDArray putScalar(int[] i, double value); | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * See {@link #putScalar(int[], double)} | ||||
|      */ | ||||
|     INDArray putScalar(long[] i, double value); | ||||
| 
 | ||||
|     /** | ||||
|      * See {@link #putScalar(int[], double)} | ||||
|      */ | ||||
|     INDArray putScalar(long[] i, float value); | ||||
| 
 | ||||
|     /** | ||||
|      * See {@link #putScalar(int[], double)} | ||||
|      */ | ||||
|     INDArray putScalar(long[] i, int value); | ||||
| 
 | ||||
|     /** | ||||
| @ -300,7 +302,6 @@ public interface INDArray extends Serializable, AutoCloseable { | ||||
|      */ | ||||
|     INDArray lt(Number other); | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Put the specified float value at the specified indices in this array | ||||
|      * | ||||
| @ -327,8 +328,6 @@ public interface INDArray extends Serializable, AutoCloseable { | ||||
|      */ | ||||
|     INDArray eps(Number other); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the binary ndarray for "Equals" comparison. | ||||
|      * | ||||
| @ -367,10 +366,8 @@ public interface INDArray extends Serializable, AutoCloseable { | ||||
|      * @param other the ndarray to compare. | ||||
|      * @return the binary ndarray for "Less" comparison. | ||||
|      */ | ||||
| 
 | ||||
|     INDArray lt(INDArray other); | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the binary ndarray for "Epsilon equals" comparison. | ||||
|      * | ||||
| @ -403,7 +400,6 @@ public interface INDArray extends Serializable, AutoCloseable { | ||||
|      */ | ||||
|     INDArray eq(INDArray other); | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the binary ndarray for "Greater Than" comparison. | ||||
|      * | ||||
| @ -424,8 +420,6 @@ public interface INDArray extends Serializable, AutoCloseable { | ||||
|      */ | ||||
|     INDArray isNaN(); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the ndarray negative (cloned) | ||||
|      * | ||||
|  | ||||
| @ -181,37 +181,37 @@ public class NDArrayCreationUtil { | ||||
|         INDArray[] out = new INDArray[12]; | ||||
| 
 | ||||
|         INDArray temp01 = Nd4j.linspace(1, cols * rows * 4, cols * rows * 4, dataType).reshape(cols, rows, 4); | ||||
|         out[0] = temp01.javaTensorAlongDimension(0, 0, 1).reshape(rows, cols); | ||||
|         out[0] = temp01.tensorAlongDimension(0, 0, 1).reshape(rows, cols); | ||||
|         long[] temp01Shape = new long[] {cols, rows, 4}; | ||||
|         int len = ArrayUtil.prod(temp01Shape); | ||||
|         temp01 = Nd4j.linspace(1, len, len, dataType).reshape(temp01Shape); | ||||
|         out[1] = temp01.javaTensorAlongDimension(2, 0, 1).reshape(rows, cols); | ||||
|         out[1] = temp01.tensorAlongDimension(2, 0, 1).reshape(rows, cols); | ||||
| 
 | ||||
|         Nd4j.getRandom().setSeed(seed); | ||||
|         INDArray temp02 = Nd4j.linspace(1, len, len, dataType).reshape(new long[] {cols, 4, rows}); | ||||
|         out[2] = temp02.javaTensorAlongDimension(0, 0, 2).reshape(rows, cols); | ||||
|         out[2] = temp02.tensorAlongDimension(0, 0, 2).reshape(rows, cols); | ||||
|         temp02 = Nd4j.linspace(1, len, len, dataType).reshape(cols, 4, rows); | ||||
|         out[3] = temp02.javaTensorAlongDimension(2, 0, 2).reshape(rows, cols); | ||||
|         out[3] = temp02.tensorAlongDimension(2, 0, 2).reshape(rows, cols); | ||||
| 
 | ||||
|         INDArray temp10 = Nd4j.linspace(1, len, len, dataType).reshape(rows, cols, 4); | ||||
|         out[4] = temp10.javaTensorAlongDimension(0, 1, 0).reshape(rows, cols); | ||||
|         out[4] = temp10.tensorAlongDimension(0, 1, 0).reshape(rows, cols); | ||||
|         temp10 = Nd4j.linspace(1, len, len, dataType).reshape(rows, cols, 4); | ||||
|         out[5] = temp10.javaTensorAlongDimension(2, 1, 0).reshape(rows, cols); | ||||
|         out[5] = temp10.tensorAlongDimension(2, 1, 0).reshape(rows, cols); | ||||
| 
 | ||||
|         INDArray temp12 = Nd4j.linspace(1, len, len, dataType).reshape(4, cols, rows); | ||||
|         out[6] = temp12.javaTensorAlongDimension(0, 1, 2).reshape(rows, cols); | ||||
|         out[6] = temp12.tensorAlongDimension(0, 1, 2).reshape(rows, cols); | ||||
|         temp12 = Nd4j.linspace(1, len, len, dataType).reshape(4, cols, rows); | ||||
|         out[7] = temp12.javaTensorAlongDimension(2, 1, 2).reshape(rows, cols); | ||||
|         out[7] = temp12.tensorAlongDimension(2, 1, 2).reshape(rows, cols); | ||||
| 
 | ||||
|         INDArray temp20 = Nd4j.linspace(1, len, len, dataType).reshape(rows, 4, cols); | ||||
|         out[8] = temp20.javaTensorAlongDimension(0, 2, 0).reshape(rows, cols); | ||||
|         out[8] = temp20.tensorAlongDimension(0, 2, 0).reshape(rows, cols); | ||||
|         temp20 = Nd4j.linspace(1, len, len, dataType).reshape(rows, 4, cols); | ||||
|         out[9] = temp20.javaTensorAlongDimension(2, 2, 0).reshape(rows, cols); | ||||
|         out[9] = temp20.tensorAlongDimension(2, 2, 0).reshape(rows, cols); | ||||
| 
 | ||||
|         INDArray temp21 = Nd4j.linspace(1, len, len, dataType).reshape(4, rows, cols); | ||||
|         out[10] = temp21.javaTensorAlongDimension(0, 2, 1).reshape(rows, cols); | ||||
|         out[10] = temp21.tensorAlongDimension(0, 2, 1).reshape(rows, cols); | ||||
|         temp21 = Nd4j.linspace(1, len, len, dataType).reshape(4, rows, cols); | ||||
|         out[11] = temp21.javaTensorAlongDimension(2, 2, 1).reshape(rows, cols); | ||||
|         out[11] = temp21.tensorAlongDimension(2, 2, 1).reshape(rows, cols); | ||||
| 
 | ||||
|         String baseMsg = "getTensorAlongDimensionMatricesWithShape(" + rows + "," + cols + "," + seed + ")"; | ||||
|         List<Pair<INDArray, String>> list = new ArrayList<>(12); | ||||
| @ -361,9 +361,9 @@ public class NDArrayCreationUtil { | ||||
|         val shape4d1 = new long[]{shape[0], shape[1], shape[2], 3}; | ||||
|         int lenshape4d1 = ArrayUtil.prod(shape4d1); | ||||
|         INDArray orig1a = Nd4j.linspace(1, lenshape4d1, lenshape4d1, dataType).reshape(shape4d1); | ||||
|         INDArray tad1a = orig1a.javaTensorAlongDimension(0, 0, 1, 2); | ||||
|         INDArray tad1a = orig1a.tensorAlongDimension(0, 0, 1, 2); | ||||
|         INDArray orig1b = Nd4j.linspace(1, lenshape4d1, lenshape4d1, dataType).reshape(shape4d1); | ||||
|         INDArray tad1b = orig1b.javaTensorAlongDimension(1, 0, 1, 2); | ||||
|         INDArray tad1b = orig1b.tensorAlongDimension(1, 0, 1, 2); | ||||
| 
 | ||||
|         list.add(new Pair<>(tad1a, baseMsg + ".get(0)")); | ||||
|         list.add(new Pair<>(tad1b, baseMsg + ".get(1)")); | ||||
| @ -371,19 +371,19 @@ public class NDArrayCreationUtil { | ||||
|         long[] shape4d2 = {3, shape[0], shape[1], shape[2]}; | ||||
|         int lenshape4d2 = ArrayUtil.prod(shape4d2); | ||||
|         INDArray orig2 = Nd4j.linspace(1, lenshape4d2, lenshape4d2, dataType).reshape(shape4d2); | ||||
|         INDArray tad2 = orig2.javaTensorAlongDimension(1, 1, 2, 3); | ||||
|         INDArray tad2 = orig2.tensorAlongDimension(1, 1, 2, 3); | ||||
|         list.add(new Pair<>(tad2, baseMsg + ".get(2)")); | ||||
| 
 | ||||
|         long[] shape4d3 = {shape[0], shape[1], 3, shape[2]}; | ||||
|         int lenshape4d3 = ArrayUtil.prod(shape4d3); | ||||
|         INDArray orig3 = Nd4j.linspace(1, lenshape4d3, lenshape4d3, dataType).reshape(shape4d3); | ||||
|         INDArray tad3 = orig3.javaTensorAlongDimension(1, 1, 3, 0); | ||||
|         INDArray tad3 = orig3.tensorAlongDimension(1, 1, 3, 0); | ||||
|         list.add(new Pair<>(tad3, baseMsg + ".get(3)")); | ||||
| 
 | ||||
|         long[] shape4d4 = {shape[0], 3, shape[1], shape[2]}; | ||||
|         int lenshape4d4 = ArrayUtil.prod(shape4d4); | ||||
|         INDArray orig4 = Nd4j.linspace(1, lenshape4d4, lenshape4d4, dataType).reshape(shape4d4); | ||||
|         INDArray tad4 = orig4.javaTensorAlongDimension(1, 2, 0, 3); | ||||
|         INDArray tad4 = orig4.tensorAlongDimension(1, 2, 0, 3); | ||||
|         list.add(new Pair<>(tad4, baseMsg + ".get(4)")); | ||||
| 
 | ||||
|         return list; | ||||
| @ -513,9 +513,9 @@ public class NDArrayCreationUtil { | ||||
|         int[] shape4d1 = {3, shape[0], shape[1], shape[2], shape[3]}; | ||||
|         int len = ArrayUtil.prod(shape4d1); | ||||
|         INDArray orig1a = Nd4j.linspace(1, len, len, dataType).reshape(ArrayUtil.toLongArray(shape4d1)); | ||||
|         INDArray tad1a = orig1a.javaTensorAlongDimension(0, 1, 2, 3, 4); | ||||
|         INDArray tad1a = orig1a.tensorAlongDimension(0, 1, 2, 3, 4); | ||||
|         INDArray orig1b = Nd4j.linspace(1, len, len, dataType).reshape(ArrayUtil.toLongArray(shape4d1)); | ||||
|         INDArray tad1b = orig1b.javaTensorAlongDimension(2, 1, 2, 3, 4); | ||||
|         INDArray tad1b = orig1b.tensorAlongDimension(2, 1, 2, 3, 4); | ||||
| 
 | ||||
|         list.add(new Pair<>(tad1a, baseMsg + ".get(0)")); | ||||
|         list.add(new Pair<>(tad1b, baseMsg + ".get(1)")); | ||||
| @ -523,19 +523,19 @@ public class NDArrayCreationUtil { | ||||
|         int[] shape4d2 = {3, shape[0], shape[1], shape[2], shape[3]}; | ||||
|         int len2 = ArrayUtil.prod(shape4d2); | ||||
|         INDArray orig2 = Nd4j.linspace(1, len2, len2, dataType).reshape(ArrayUtil.toLongArray(shape4d2)); | ||||
|         INDArray tad2 = orig2.javaTensorAlongDimension(1, 3, 4, 2, 1); | ||||
|         INDArray tad2 = orig2.tensorAlongDimension(1, 3, 4, 2, 1); | ||||
|         list.add(new Pair<>(tad2, baseMsg + ".get(2)")); | ||||
| 
 | ||||
|         int[] shape4d3 = {shape[0], shape[1], 3, shape[2], shape[3]}; | ||||
|         int len3 = ArrayUtil.prod(shape4d3); | ||||
|         INDArray orig3 = Nd4j.linspace(1, len3, len3, dataType).reshape(ArrayUtil.toLongArray(shape4d3)); | ||||
|         INDArray tad3 = orig3.javaTensorAlongDimension(1, 4, 1, 3, 0); | ||||
|         INDArray tad3 = orig3.tensorAlongDimension(1, 4, 1, 3, 0); | ||||
|         list.add(new Pair<>(tad3, baseMsg + ".get(3)")); | ||||
| 
 | ||||
|         int[] shape4d4 = {shape[0], shape[1], shape[2], shape[3], 3}; | ||||
|         int len4 = ArrayUtil.prod(shape4d4); | ||||
|         INDArray orig4 = Nd4j.linspace(1, len4, len4, dataType).reshape(ArrayUtil.toLongArray(shape4d4)); | ||||
|         INDArray tad4 = orig4.javaTensorAlongDimension(1, 2, 0, 3, 1); | ||||
|         INDArray tad4 = orig4.tensorAlongDimension(1, 2, 0, 3, 1); | ||||
|         list.add(new Pair<>(tad4, baseMsg + ".get(4)")); | ||||
| 
 | ||||
|         return list; | ||||
| @ -655,26 +655,26 @@ public class NDArrayCreationUtil { | ||||
|         Nd4j.getRandom().setSeed(seed); | ||||
|         int[] shape4d1 = {3, shape[0], shape[1], shape[2], shape[3], shape[4]}; | ||||
|         INDArray orig1a = Nd4j.rand(dataType, shape4d1); | ||||
|         INDArray tad1a = orig1a.javaTensorAlongDimension(0, 1, 2, 3, 4, 5); | ||||
|         INDArray tad1a = orig1a.tensorAlongDimension(0, 1, 2, 3, 4, 5); | ||||
|         INDArray orig1b = Nd4j.rand(dataType, shape4d1); | ||||
|         INDArray tad1b = orig1b.javaTensorAlongDimension(2, 1, 2, 3, 4, 5); | ||||
|         INDArray tad1b = orig1b.tensorAlongDimension(2, 1, 2, 3, 4, 5); | ||||
| 
 | ||||
|         list.add(new Pair<>(tad1a, baseMsg + ".get(0)")); | ||||
|         list.add(new Pair<>(tad1b, baseMsg + ".get(1)")); | ||||
| 
 | ||||
|         int[] shape4d2 = {3, shape[0], shape[1], shape[2], shape[3], shape[4]}; | ||||
|         INDArray orig2 = Nd4j.rand(dataType, shape4d2); | ||||
|         INDArray tad2 = orig2.javaTensorAlongDimension(1, 3, 5, 4, 2, 1); | ||||
|         INDArray tad2 = orig2.tensorAlongDimension(1, 3, 5, 4, 2, 1); | ||||
|         list.add(new Pair<>(tad2, baseMsg + ".get(2)")); | ||||
| 
 | ||||
|         int[] shape4d3 = {shape[0], shape[1], shape[2], shape[3], shape[4], 2}; | ||||
|         INDArray orig3 = Nd4j.rand(dataType, shape4d3); | ||||
|         INDArray tad3 = orig3.javaTensorAlongDimension(1, 4, 1, 3, 2, 0); | ||||
|         INDArray tad3 = orig3.tensorAlongDimension(1, 4, 1, 3, 2, 0); | ||||
|         list.add(new Pair<>(tad3, baseMsg + ".get(3)")); | ||||
| 
 | ||||
|         int[] shape4d4 = {shape[0], shape[1], shape[2], shape[3], 3, shape[4]}; | ||||
|         INDArray orig4 = Nd4j.rand(dataType, shape4d4); | ||||
|         INDArray tad4 = orig4.javaTensorAlongDimension(1, 5, 2, 0, 3, 1); | ||||
|         INDArray tad4 = orig4.tensorAlongDimension(1, 5, 2, 0, 3, 1); | ||||
|         list.add(new Pair<>(tad4, baseMsg + ".get(4)")); | ||||
| 
 | ||||
|         return list; | ||||
|  | ||||
| @ -374,7 +374,7 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet { | ||||
|         long nTensors = labels.tensorsAlongDimension(1); | ||||
|         for (int i = 0; i < nTensors; i++) { | ||||
|             INDArray row = labels.tensorAlongDimension(i, 1); | ||||
|             INDArray javaRow = labels.javaTensorAlongDimension(i, 1); | ||||
|             INDArray javaRow = labels.tensorAlongDimension(i, 1); | ||||
|             int maxIdx = Nd4j.getBlasWrapper().iamax(row); | ||||
|             int maxIdxJava = Nd4j.getBlasWrapper().iamax(javaRow); | ||||
|             if (maxIdx < 0) | ||||
|  | ||||
| @ -1334,7 +1334,7 @@ public class Nd4jTestsC extends BaseNd4jTest { | ||||
|             INDArray zC = Nd4j.create(shape, 'c'); | ||||
|             zC.setData(Nd4j.linspace(1, 24, 24, DataType.DOUBLE).data()); | ||||
|             for (int tad = 0; tad < zC.tensorsAlongDimension(dim); tad++) { | ||||
|                 INDArray javaTad = zC.javaTensorAlongDimension(tad, dim); | ||||
|                 INDArray javaTad = zC.tensorAlongDimension(tad, dim); | ||||
|                 System.out.println("Tad " + tad + " is " + zC.tensorAlongDimension(tad, dim)); | ||||
|             } | ||||
| 
 | ||||
|  | ||||
| @ -57,7 +57,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|         for (int i = 0; i < n; i++) { | ||||
|             StopWatch javaTiming = new StopWatch(); | ||||
|             javaTiming.start(); | ||||
|             row.javaTensorAlongDimension(0, 0); | ||||
|             row.tensorAlongDimension(0, 0); | ||||
|             javaTiming.stop(); | ||||
|             StopWatch cTiming = new StopWatch(); | ||||
|             cTiming.start(); | ||||
| @ -98,7 +98,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|             assertEquals(cols, arr.tensorsAlongDimension(0)); | ||||
|             for (int i = 0; i < cols; i++) { | ||||
|                 INDArray tad = arr.tensorAlongDimension(i, 0); | ||||
|                 INDArray javaTad = arr.javaTensorAlongDimension(i, 0); | ||||
|                 INDArray javaTad = arr.tensorAlongDimension(i, 0); | ||||
|                 assertEquals(javaTad, tad); | ||||
|                 assertArrayEquals(new int[] {rows}, tad.shape()); | ||||
|                 //assertEquals(testValues.javaTensorAlongDimension(i, 0), tad); | ||||
| @ -120,7 +120,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|         list = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, new int[]{rows, cols, dim2}, DataType.DOUBLE); | ||||
|         for (Pair<INDArray, String> p : list) { | ||||
|             INDArray arr = p.getFirst().assign(testValues); | ||||
|             INDArray javaTad = arr.javaTensorAlongDimension(0, 0); | ||||
|             INDArray javaTad = arr.tensorAlongDimension(0, 0); | ||||
|             INDArray tadTest = arr.tensorAlongDimension(0, 0); | ||||
|             assertEquals(javaTad, tadTest); | ||||
|             //Along dimension 0: expect row vector with length 'rows' | ||||
| @ -165,7 +165,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|             //Along dimension 0,1: expect matrix with shape [rows,cols] | ||||
|             assertEquals(dim2, arr.tensorsAlongDimension(0, 1)); | ||||
|             for (int i = 0; i < dim2; i++) { | ||||
|                 INDArray javaTad = arr.javaTensorAlongDimension(i, 0, 1); | ||||
|                 INDArray javaTad = arr.tensorAlongDimension(i, 0, 1); | ||||
|                 INDArray tad = arr.tensorAlongDimension(i, 0, 1); | ||||
|                 int javaEleStride = javaTad.elementWiseStride(); | ||||
|                 int testTad = tad.elementWiseStride(); | ||||
| @ -178,11 +178,11 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|             //Along dimension 0,2: expect matrix with shape [rows,dim2] | ||||
|             assertEquals(cols, arr.tensorsAlongDimension(0, 2)); | ||||
|             for (int i = 0; i < cols; i++) { | ||||
|                 INDArray javaTad = arr.javaTensorAlongDimension(i, 0, 2); | ||||
|                 INDArray javaTad = arr.tensorAlongDimension(i, 0, 2); | ||||
|                 INDArray tad = arr.tensorAlongDimension(i, 0, 2); | ||||
|                 assertEquals(javaTad, tad); | ||||
|                 assertArrayEquals(new long[] {rows, dim2}, tad.shape()); | ||||
|                 assertEquals(testValues.javaTensorAlongDimension(i, 0, 2), tad); | ||||
|                 assertEquals(testValues.tensorAlongDimension(i, 0, 2), tad); | ||||
|             } | ||||
| 
 | ||||
|             //Along dimension 1,2: expect matrix with shape [cols,dim2] | ||||
| @ -190,7 +190,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|             for (int i = 0; i < rows; i++) { | ||||
|                 INDArray tad = arr.tensorAlongDimension(i, 1, 2); | ||||
|                 assertArrayEquals(new long[] {cols, dim2}, tad.shape()); | ||||
|                 assertEquals(testValues.javaTensorAlongDimension(i, 1, 2), tad); | ||||
|                 assertEquals(testValues.tensorAlongDimension(i, 1, 2), tad); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
| @ -207,7 +207,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|             for (int i = 0; i < dim2 * dim3; i++) { | ||||
|                 INDArray tad = arr.tensorAlongDimension(i, 0, 1); | ||||
|                 assertArrayEquals(new long[] {rows, cols}, tad.shape()); | ||||
|                 assertEquals(testValues.javaTensorAlongDimension(i, 0, 1), tad); | ||||
|                 assertEquals(testValues.tensorAlongDimension(i, 0, 1), tad); | ||||
|             } | ||||
| 
 | ||||
|             //Along dimension 0,2: expect matrix with shape [rows,dim2] | ||||
| @ -215,7 +215,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|             for (int i = 0; i < cols * dim3; i++) { | ||||
|                 INDArray tad = arr.tensorAlongDimension(i, 0, 2); | ||||
|                 assertArrayEquals(new long[] {rows, dim2}, tad.shape()); | ||||
|                 assertEquals(testValues.javaTensorAlongDimension(i, 0, 2), tad); | ||||
|                 assertEquals(testValues.tensorAlongDimension(i, 0, 2), tad); | ||||
|             } | ||||
| 
 | ||||
|             //Along dimension 0,3: expect matrix with shape [rows,dim3] | ||||
| @ -223,7 +223,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|             for (int i = 0; i < cols * dim2; i++) { | ||||
|                 INDArray tad = arr.tensorAlongDimension(i, 0, 3); | ||||
|                 assertArrayEquals(new long[] {rows, dim3}, tad.shape()); | ||||
|                 assertEquals(testValues.javaTensorAlongDimension(i, 0, 3), tad); | ||||
|                 assertEquals(testValues.tensorAlongDimension(i, 0, 3), tad); | ||||
|             } | ||||
| 
 | ||||
| 
 | ||||
| @ -232,7 +232,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|             for (int i = 0; i < rows * dim3; i++) { | ||||
|                 INDArray tad = arr.tensorAlongDimension(i, 1, 2); | ||||
|                 assertArrayEquals(new long[] {cols, dim2}, tad.shape()); | ||||
|                 assertEquals(testValues.javaTensorAlongDimension(i, 1, 2), tad); | ||||
|                 assertEquals(testValues.tensorAlongDimension(i, 1, 2), tad); | ||||
|             } | ||||
| 
 | ||||
|             //Along dimension 1,3: expect matrix with shape [cols,dim3] | ||||
| @ -240,7 +240,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|             for (int i = 0; i < rows * dim2; i++) { | ||||
|                 INDArray tad = arr.tensorAlongDimension(i, 1, 3); | ||||
|                 assertArrayEquals(new long[] {cols, dim3}, tad.shape()); | ||||
|                 assertEquals(testValues.javaTensorAlongDimension(i, 1, 3), tad); | ||||
|                 assertEquals(testValues.tensorAlongDimension(i, 1, 3), tad); | ||||
|             } | ||||
| 
 | ||||
|             //Along dimension 2,3: expect matrix with shape [dim2,dim3] | ||||
| @ -248,7 +248,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { | ||||
|             for (int i = 0; i < rows * cols; i++) { | ||||
|                 INDArray tad = arr.tensorAlongDimension(i, 2, 3); | ||||
|                 assertArrayEquals(new long[] {dim2, dim3}, tad.shape()); | ||||
|                 assertEquals(testValues.javaTensorAlongDimension(i, 2, 3), tad); | ||||
|                 assertEquals(testValues.tensorAlongDimension(i, 2, 3), tad); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @ -614,7 +614,7 @@ public class OpExecutionerTests extends BaseNd4jTest { | ||||
|             assertEquals(arr5s.getDouble(i), 16, 1e-1); | ||||
|         System.out.println("6d"); | ||||
|         INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4); | ||||
|         INDArray arr6Tad = arr6.javaTensorAlongDimension(0, 2, 3); | ||||
|         INDArray arr6Tad = arr6.tensorAlongDimension(0, 2, 3); | ||||
|         INDArray arr6s = arr6.sum(2, 3); | ||||
|         for (int i = 0; i < arr6s.length(); i++) | ||||
|             assertEquals(arr6s.getDouble(i), 16, 1e-1); | ||||
|  | ||||
| @ -79,7 +79,7 @@ public class TADTests extends BaseNd4jTest { | ||||
| 
 | ||||
|                     int[] shape = new int[] {e, x}; | ||||
|                     Arrays.sort(shape); | ||||
|                     INDArray assertion = array.javaTensorAlongDimension(0, shape); | ||||
|                     INDArray assertion = array.tensorAlongDimension(0, shape); | ||||
|                     INDArray test = array.tensorAlongDimension(0, shape); | ||||
| 
 | ||||
|                     assertEquals(assertion, test); | ||||
| @ -101,7 +101,7 @@ public class TADTests extends BaseNd4jTest { | ||||
|                 Arrays.sort(shape); | ||||
|                 log.info("About to do shape: " + Arrays.toString(shape) + " for array of shape " | ||||
|                                 + array.shapeInfoToString()); | ||||
|                 INDArray assertion = array.javaTensorAlongDimension(0, shape); | ||||
|                 INDArray assertion = array.tensorAlongDimension(0, shape); | ||||
|                 INDArray test = array.tensorAlongDimension(0, shape); | ||||
|                 assertEquals(assertion, test); | ||||
|                 //assertEquals(assertion.shapeInfoDataBuffer(), test.shapeInfoDataBuffer()); | ||||
| @ -121,8 +121,8 @@ public class TADTests extends BaseNd4jTest { | ||||
|     public void testMysteriousCrash() { | ||||
|         INDArray arrayF = Nd4j.create(new int[] {1, 1, 4, 4}, 'f'); | ||||
|         INDArray arrayC = Nd4j.create(new int[] {1, 1, 4, 4}, 'c'); | ||||
|         INDArray javaCTad = arrayC.javaTensorAlongDimension(0, 2, 3); | ||||
|         INDArray javaFTad = arrayF.javaTensorAlongDimension(0, 2, 3); | ||||
|         INDArray javaCTad = arrayC.tensorAlongDimension(0, 2, 3); | ||||
|         INDArray javaFTad = arrayF.tensorAlongDimension(0, 2, 3); | ||||
|         Pair<DataBuffer, DataBuffer> tadBuffersF = | ||||
|                         Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(arrayF, 2, 3); | ||||
|         Pair<DataBuffer, DataBuffer> tadBuffersC = | ||||
|  | ||||
| @ -185,7 +185,7 @@ public class ConcatTestsC extends BaseNd4jTest { | ||||
|         //ConcatV2, dim 1 | ||||
|         second = Nd4j.linspace(24, 32, 8, Nd4j.dataType()).reshape('c', 2, 1, 4); | ||||
|         for (int i = 0; i < second.tensorsAlongDimension(1); i++) { | ||||
|             INDArray secondTad = second.javaTensorAlongDimension(i, 1); | ||||
|             INDArray secondTad = second.tensorAlongDimension(i, 1); | ||||
|             System.out.println(second.tensorAlongDimension(i, 1)); | ||||
|         } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user