refactor duplicate code from pad methods. (#86)

* refactor duplicate code from pad methods.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* replace switch with if.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-07-29 21:20:16 +09:00 committed by AlexDBlack
parent d4e7997134
commit aa4af2c36d
1 changed files with 45 additions and 66 deletions

View File

@ -225,85 +225,64 @@ public class Nd4j {
* based on the specified mode * based on the specified mode
*/ */
public static INDArray pad(INDArray toPad, int[][] padWidth, List<double[]> constantValues, PadMode padMode) { public static INDArray pad(INDArray toPad, int[][] padWidth, List<double[]> constantValues, PadMode padMode) {
switch (padMode) { if (padMode == PadMode.CONSTANT) {
case CONSTANT: if (padWidth.length < toPad.rank())
if (padWidth.length < toPad.rank()) throw new IllegalArgumentException("Please specify a pad width for each dimension");
throw new IllegalArgumentException("Please specify a pad width for each dimension");
List<int[]> sizes = new ArrayList<>(); List<int[]> sizes = new ArrayList<>();
for (int i = 0; i < toPad.rank(); i++) { for (int i = 0; i < toPad.rank(); i++) {
sizes.add(padWidth[i]); sizes.add(padWidth[i]);
} }
INDArray ret = toPad;
for (int i = 0; i < toPad.rank(); i++) {
int[] pad = sizes.get(i);
double[] constant = constantValues.get(i);
int padBefore = pad[0];
int padAfter = pad[1];
if (constant.length < 2) {
double val = constant[0];
constant = new double[2];
constant[0] = val;
constant[1] = val;
}
double beforeVal = constant[0];
double afterVal = constant[1];
ret = Nd4j.prepend(ret, padBefore, beforeVal, i);
ret = Nd4j.append(ret, padAfter, afterVal, i);
}
return ret;
default:
throw new UnsupportedOperationException();
return padImpl(toPad, sizes, constantValues);
} }
throw new UnsupportedOperationException();
} }
/** /**
* See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth. * See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth.
*/ */
public static INDArray pad(INDArray toPad, int[] padWidth, List<double[]> constantValues, PadMode padMode) { public static INDArray pad(INDArray toPad, int[] padWidth, List<double[]> constantValues, PadMode padMode) {
switch (padMode) { if (padMode == PadMode.CONSTANT) {
case CONSTANT: if (padWidth.length < toPad.rank())
if (padWidth.length < toPad.rank()) throw new IllegalArgumentException("Please specify a pad width for each dimension");
throw new IllegalArgumentException("Please specify a pad width for each dimension");
toPad = Nd4j.stripOnes(toPad); toPad = Nd4j.stripOnes(toPad);
List<int[]> sizes = new ArrayList<>(); List<int[]> sizes = new ArrayList<>();
for (int i = 0; i < toPad.rank(); i++) { for (int i = 0; i < toPad.rank(); i++) {
sizes.add(padWidth); sizes.add(padWidth);
} }
INDArray ret = toPad; return padImpl(toPad, sizes, constantValues);
//TODO: Remove duplicate code.
for (int i = 0; i < toPad.rank(); i++) {
int[] pad = sizes.get(i);
double[] constant = constantValues.get(i);
int padBefore = pad[0];
int padAfter = pad[1];
if (constant.length < 2) {
double val = constant[0];
constant = new double[2];
constant[0] = val;
constant[1] = val;
}
double beforeVal = constant[0];
double afterVal = constant[1];
ret = Nd4j.prepend(ret, padBefore, beforeVal, i);
ret = Nd4j.append(ret, padAfter, afterVal, i);
}
return ret;
default:
throw new UnsupportedOperationException();
} }
throw new UnsupportedOperationException();
}
// common code for pad(INDArray, int[], List<double[]>, PadMode) and
// pad(INDArray, int[][], List<double[]>, PadMode)
private static INDArray padImpl(INDArray toPad, List<int[]> sizes, List<double[]> constantValues){
INDArray ret = toPad;
for (int i = 0; i < toPad.rank(); i++) {
int[] pad = sizes.get(i);
double[] constant = constantValues.get(i);
int padBefore = pad[0];
int padAfter = pad[1];
if (constant.length < 2) {
double val = constant[0];
constant = new double[2];
constant[0] = val;
constant[1] = val;
}
double beforeVal = constant[0];
double afterVal = constant[1];
ret = Nd4j.prepend(ret, padBefore, beforeVal, i);
ret = Nd4j.append(ret, padAfter, afterVal, i);
}
return ret;
} }
/** /**