Remove duplicate code for append() and prepend() Fix #7983 (#7987)

* fix #7983

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* small fix.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* fix javadoc.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-07-09 12:58:50 +09:00 committed by Alex Black
parent 595656d01e
commit 0c48e55f91
1 changed files with 20 additions and 27 deletions

View File

@ -343,9 +343,8 @@ public class Nd4j {
/**
* Append the given
* array with the specified value size
* along a particular axis
* Append the given array with the specified value size along a particular axis.
* The prepend method has the same signature and prepends the given array.
* @param arr the array to append to
* @param padAmount the pad amount of the array to be returned
* @param val the value to append
@ -353,6 +352,22 @@ public class Nd4j {
* @return the newly created array
*/
public static INDArray append(INDArray arr, int padAmount, double val, int axis) {
return appendImpl(arr, padAmount, val, axis, true);
}
/**
* @see #append(INDArray, int, double, int)
*/
public static INDArray prepend(INDArray arr, int padAmount, double val, int axis) {
return appendImpl(arr, padAmount, val, axis, false);
}
/**
* Append / Prepend shared implementation.
* @param appendFlag flag to determine Append / Prepend.
* @see #append(INDArray, int, double, int)
*/
private static INDArray appendImpl(INDArray arr, int padAmount, double val, int axis, boolean appendFlag){
if (padAmount == 0)
return arr;
long[] paShape = ArrayUtil.copy(arr.shape());
@ -360,29 +375,7 @@ public class Nd4j {
axis = axis + arr.shape().length;
paShape[axis] = padAmount;
INDArray concatArray = Nd4j.valueArrayOf(paShape, val, arr.dataType());
return Nd4j.concat(axis, arr, concatArray);
}
/**
* Append the given
* array with the specified value size
* along a particular axis
* @param arr the array to append to
* @param padAmount the pad amount of the array to be returned
* @param val the value to append
* @param axis the axis to append to
* @return the newly created array
*/
public static INDArray prepend(INDArray arr, int padAmount, double val, int axis) {
if (padAmount == 0)
return arr;
long[] paShape = ArrayUtil.copy(arr.shape());
if (axis < 0)
axis = axis + arr.shape().length;
paShape[axis] = padAmount;
INDArray concatArr = Nd4j.valueArrayOf(paShape, val, arr.dataType());
return Nd4j.concat(axis, concatArr, arr);
return appendFlag ? Nd4j.concat(axis, arr, concatArray) : Nd4j.concat(axis, concatArray, arr);
}
/**