refactor duplicate code from pad methods. (#86)
* refactor duplicate code from pad methods. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * replace switch with if. Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
d4e7997134
commit
aa4af2c36d
|
@ -225,85 +225,64 @@ public class Nd4j {
|
||||||
* based on the specified mode
|
* based on the specified mode
|
||||||
*/
|
*/
|
||||||
public static INDArray pad(INDArray toPad, int[][] padWidth, List<double[]> constantValues, PadMode padMode) {
|
public static INDArray pad(INDArray toPad, int[][] padWidth, List<double[]> constantValues, PadMode padMode) {
|
||||||
switch (padMode) {
|
if (padMode == PadMode.CONSTANT) {
|
||||||
case CONSTANT:
|
if (padWidth.length < toPad.rank())
|
||||||
if (padWidth.length < toPad.rank())
|
throw new IllegalArgumentException("Please specify a pad width for each dimension");
|
||||||
throw new IllegalArgumentException("Please specify a pad width for each dimension");
|
|
||||||
|
|
||||||
List<int[]> sizes = new ArrayList<>();
|
List<int[]> sizes = new ArrayList<>();
|
||||||
for (int i = 0; i < toPad.rank(); i++) {
|
for (int i = 0; i < toPad.rank(); i++) {
|
||||||
sizes.add(padWidth[i]);
|
sizes.add(padWidth[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray ret = toPad;
|
|
||||||
for (int i = 0; i < toPad.rank(); i++) {
|
|
||||||
int[] pad = sizes.get(i);
|
|
||||||
double[] constant = constantValues.get(i);
|
|
||||||
int padBefore = pad[0];
|
|
||||||
int padAfter = pad[1];
|
|
||||||
if (constant.length < 2) {
|
|
||||||
double val = constant[0];
|
|
||||||
constant = new double[2];
|
|
||||||
constant[0] = val;
|
|
||||||
constant[1] = val;
|
|
||||||
}
|
|
||||||
|
|
||||||
double beforeVal = constant[0];
|
|
||||||
double afterVal = constant[1];
|
|
||||||
ret = Nd4j.prepend(ret, padBefore, beforeVal, i);
|
|
||||||
ret = Nd4j.append(ret, padAfter, afterVal, i);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
|
|
||||||
default:
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
|
|
||||||
|
return padImpl(toPad, sizes, constantValues);
|
||||||
}
|
}
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth.
|
* See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth.
|
||||||
*/
|
*/
|
||||||
public static INDArray pad(INDArray toPad, int[] padWidth, List<double[]> constantValues, PadMode padMode) {
|
public static INDArray pad(INDArray toPad, int[] padWidth, List<double[]> constantValues, PadMode padMode) {
|
||||||
switch (padMode) {
|
if (padMode == PadMode.CONSTANT) {
|
||||||
case CONSTANT:
|
if (padWidth.length < toPad.rank())
|
||||||
if (padWidth.length < toPad.rank())
|
throw new IllegalArgumentException("Please specify a pad width for each dimension");
|
||||||
throw new IllegalArgumentException("Please specify a pad width for each dimension");
|
|
||||||
|
|
||||||
toPad = Nd4j.stripOnes(toPad);
|
toPad = Nd4j.stripOnes(toPad);
|
||||||
|
|
||||||
List<int[]> sizes = new ArrayList<>();
|
List<int[]> sizes = new ArrayList<>();
|
||||||
for (int i = 0; i < toPad.rank(); i++) {
|
for (int i = 0; i < toPad.rank(); i++) {
|
||||||
sizes.add(padWidth);
|
sizes.add(padWidth);
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray ret = toPad;
|
return padImpl(toPad, sizes, constantValues);
|
||||||
//TODO: Remove duplicate code.
|
|
||||||
for (int i = 0; i < toPad.rank(); i++) {
|
|
||||||
int[] pad = sizes.get(i);
|
|
||||||
double[] constant = constantValues.get(i);
|
|
||||||
int padBefore = pad[0];
|
|
||||||
int padAfter = pad[1];
|
|
||||||
if (constant.length < 2) {
|
|
||||||
double val = constant[0];
|
|
||||||
constant = new double[2];
|
|
||||||
constant[0] = val;
|
|
||||||
constant[1] = val;
|
|
||||||
}
|
|
||||||
|
|
||||||
double beforeVal = constant[0];
|
|
||||||
double afterVal = constant[1];
|
|
||||||
ret = Nd4j.prepend(ret, padBefore, beforeVal, i);
|
|
||||||
ret = Nd4j.append(ret, padAfter, afterVal, i);
|
|
||||||
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
|
|
||||||
default:
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
}
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
// common code for pad(INDArray, int[], List<double[]>, PadMode) and
|
||||||
|
// pad(INDArray, int[][], List<double[]>, PadMode)
|
||||||
|
private static INDArray padImpl(INDArray toPad, List<int[]> sizes, List<double[]> constantValues){
|
||||||
|
|
||||||
|
INDArray ret = toPad;
|
||||||
|
for (int i = 0; i < toPad.rank(); i++) {
|
||||||
|
int[] pad = sizes.get(i);
|
||||||
|
double[] constant = constantValues.get(i);
|
||||||
|
int padBefore = pad[0];
|
||||||
|
int padAfter = pad[1];
|
||||||
|
if (constant.length < 2) {
|
||||||
|
double val = constant[0];
|
||||||
|
constant = new double[2];
|
||||||
|
constant[0] = val;
|
||||||
|
constant[1] = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
double beforeVal = constant[0];
|
||||||
|
double afterVal = constant[1];
|
||||||
|
ret = Nd4j.prepend(ret, padBefore, beforeVal, i);
|
||||||
|
ret = Nd4j.append(ret, padAfter, afterVal, i);
|
||||||
|
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Loading…
Reference in New Issue