refactor duplicate code from pad methods. (#86)
* refactor duplicate code from pad methods. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * replace switch with if. Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
d4e7997134
commit
aa4af2c36d
|
@ -225,85 +225,64 @@ public class Nd4j {
|
|||
* based on the specified mode
|
||||
*/
|
||||
public static INDArray pad(INDArray toPad, int[][] padWidth, List<double[]> constantValues, PadMode padMode) {
|
||||
switch (padMode) {
|
||||
case CONSTANT:
|
||||
if (padWidth.length < toPad.rank())
|
||||
throw new IllegalArgumentException("Please specify a pad width for each dimension");
|
||||
if (padMode == PadMode.CONSTANT) {
|
||||
if (padWidth.length < toPad.rank())
|
||||
throw new IllegalArgumentException("Please specify a pad width for each dimension");
|
||||
|
||||
List<int[]> sizes = new ArrayList<>();
|
||||
for (int i = 0; i < toPad.rank(); i++) {
|
||||
sizes.add(padWidth[i]);
|
||||
}
|
||||
|
||||
INDArray ret = toPad;
|
||||
for (int i = 0; i < toPad.rank(); i++) {
|
||||
int[] pad = sizes.get(i);
|
||||
double[] constant = constantValues.get(i);
|
||||
int padBefore = pad[0];
|
||||
int padAfter = pad[1];
|
||||
if (constant.length < 2) {
|
||||
double val = constant[0];
|
||||
constant = new double[2];
|
||||
constant[0] = val;
|
||||
constant[1] = val;
|
||||
}
|
||||
|
||||
double beforeVal = constant[0];
|
||||
double afterVal = constant[1];
|
||||
ret = Nd4j.prepend(ret, padBefore, beforeVal, i);
|
||||
ret = Nd4j.append(ret, padAfter, afterVal, i);
|
||||
|
||||
}
|
||||
|
||||
return ret;
|
||||
|
||||
default:
|
||||
throw new UnsupportedOperationException();
|
||||
List<int[]> sizes = new ArrayList<>();
|
||||
for (int i = 0; i < toPad.rank(); i++) {
|
||||
sizes.add(padWidth[i]);
|
||||
}
|
||||
|
||||
return padImpl(toPad, sizes, constantValues);
|
||||
}
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth.
|
||||
*/
|
||||
public static INDArray pad(INDArray toPad, int[] padWidth, List<double[]> constantValues, PadMode padMode) {
|
||||
switch (padMode) {
|
||||
case CONSTANT:
|
||||
if (padWidth.length < toPad.rank())
|
||||
throw new IllegalArgumentException("Please specify a pad width for each dimension");
|
||||
if (padMode == PadMode.CONSTANT) {
|
||||
if (padWidth.length < toPad.rank())
|
||||
throw new IllegalArgumentException("Please specify a pad width for each dimension");
|
||||
|
||||
toPad = Nd4j.stripOnes(toPad);
|
||||
toPad = Nd4j.stripOnes(toPad);
|
||||
|
||||
List<int[]> sizes = new ArrayList<>();
|
||||
for (int i = 0; i < toPad.rank(); i++) {
|
||||
sizes.add(padWidth);
|
||||
}
|
||||
List<int[]> sizes = new ArrayList<>();
|
||||
for (int i = 0; i < toPad.rank(); i++) {
|
||||
sizes.add(padWidth);
|
||||
}
|
||||
|
||||
INDArray ret = toPad;
|
||||
//TODO: Remove duplicate code.
|
||||
for (int i = 0; i < toPad.rank(); i++) {
|
||||
int[] pad = sizes.get(i);
|
||||
double[] constant = constantValues.get(i);
|
||||
int padBefore = pad[0];
|
||||
int padAfter = pad[1];
|
||||
if (constant.length < 2) {
|
||||
double val = constant[0];
|
||||
constant = new double[2];
|
||||
constant[0] = val;
|
||||
constant[1] = val;
|
||||
}
|
||||
|
||||
double beforeVal = constant[0];
|
||||
double afterVal = constant[1];
|
||||
ret = Nd4j.prepend(ret, padBefore, beforeVal, i);
|
||||
ret = Nd4j.append(ret, padAfter, afterVal, i);
|
||||
|
||||
}
|
||||
return ret;
|
||||
|
||||
default:
|
||||
throw new UnsupportedOperationException();
|
||||
return padImpl(toPad, sizes, constantValues);
|
||||
}
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
// common code for pad(INDArray, int[], List<double[]>, PadMode) and
|
||||
// pad(INDArray, int[][], List<double[]>, PadMode)
|
||||
private static INDArray padImpl(INDArray toPad, List<int[]> sizes, List<double[]> constantValues){
|
||||
|
||||
INDArray ret = toPad;
|
||||
for (int i = 0; i < toPad.rank(); i++) {
|
||||
int[] pad = sizes.get(i);
|
||||
double[] constant = constantValues.get(i);
|
||||
int padBefore = pad[0];
|
||||
int padAfter = pad[1];
|
||||
if (constant.length < 2) {
|
||||
double val = constant[0];
|
||||
constant = new double[2];
|
||||
constant[0] = val;
|
||||
constant[1] = val;
|
||||
}
|
||||
|
||||
double beforeVal = constant[0];
|
||||
double afterVal = constant[1];
|
||||
ret = Nd4j.prepend(ret, padBefore, beforeVal, i);
|
||||
ret = Nd4j.append(ret, padAfter, afterVal, i);
|
||||
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
Loading…
Reference in New Issue