diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 431f24496..7c5c960e5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -343,9 +343,8 @@ public class Nd4j { /** - * Append the given - * array with the specified value size - * along a particular axis + * Append the given array with the specified value size along a particular axis. + * The prepend method has the same signature and prepends the given array. * @param arr the array to append to * @param padAmount the pad amount of the array to be returned * @param val the value to append @@ -353,6 +352,22 @@ public class Nd4j { * @return the newly created array */ public static INDArray append(INDArray arr, int padAmount, double val, int axis) { + return appendImpl(arr, padAmount, val, axis, true); + } + + /** + * @see #append(INDArray, int, double, int) + */ + public static INDArray prepend(INDArray arr, int padAmount, double val, int axis) { + return appendImpl(arr, padAmount, val, axis, false); + } + + /** + * Append / Prepend shared implementation. + * @param appendFlag flag to determine Append / Prepend. + * @see #append(INDArray, int, double, int) + */ + private static INDArray appendImpl(INDArray arr, int padAmount, double val, int axis, boolean appendFlag){ if (padAmount == 0) return arr; long[] paShape = ArrayUtil.copy(arr.shape()); @@ -360,31 +375,9 @@ public class Nd4j { axis = axis + arr.shape().length; paShape[axis] = padAmount; INDArray concatArray = Nd4j.valueArrayOf(paShape, val, arr.dataType()); - return Nd4j.concat(axis, arr, concatArray); + return appendFlag ? Nd4j.concat(axis, arr, concatArray) : Nd4j.concat(axis, concatArray, arr); } - - /** - * Append the given - * array with the specified value size - * along a particular axis - * @param arr the array to append to - * @param padAmount the pad amount of the array to be returned - * @param val the value to append - * @param axis the axis to append to - * @return the newly created array - */ - public static INDArray prepend(INDArray arr, int padAmount, double val, int axis) { - if (padAmount == 0) - return arr; - - long[] paShape = ArrayUtil.copy(arr.shape()); - if (axis < 0) - axis = axis + arr.shape().length; - paShape[axis] = padAmount; - INDArray concatArr = Nd4j.valueArrayOf(paShape, val, arr.dataType()); - return Nd4j.concat(axis, concatArr, arr); - } - + /** * Expand the array dimensions. * This is equivalent to