From aa4af2c36d5fbcc15e99250a1dc718b8c8e88a98 Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Mon, 29 Jul 2019 21:20:16 +0900 Subject: [PATCH] refactor duplicate code from pad methods. (#86) * refactor duplicate code from pad methods. Signed-off-by: Robert Altena * replace switch with if. Signed-off-by: Robert Altena --- .../java/org/nd4j/linalg/factory/Nd4j.java | 111 +++++++----------- 1 file changed, 45 insertions(+), 66 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 82dd64337..b82a9e19e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -225,85 +225,64 @@ public class Nd4j { * based on the specified mode */ public static INDArray pad(INDArray toPad, int[][] padWidth, List constantValues, PadMode padMode) { - switch (padMode) { - case CONSTANT: - if (padWidth.length < toPad.rank()) - throw new IllegalArgumentException("Please specify a pad width for each dimension"); + if (padMode == PadMode.CONSTANT) { + if (padWidth.length < toPad.rank()) + throw new IllegalArgumentException("Please specify a pad width for each dimension"); - List sizes = new ArrayList<>(); - for (int i = 0; i < toPad.rank(); i++) { - sizes.add(padWidth[i]); - } - - INDArray ret = toPad; - for (int i = 0; i < toPad.rank(); i++) { - int[] pad = sizes.get(i); - double[] constant = constantValues.get(i); - int padBefore = pad[0]; - int padAfter = pad[1]; - if (constant.length < 2) { - double val = constant[0]; - constant = new double[2]; - constant[0] = val; - constant[1] = val; - } - - double beforeVal = constant[0]; - double afterVal = constant[1]; - ret = Nd4j.prepend(ret, padBefore, beforeVal, i); - ret = Nd4j.append(ret, padAfter, afterVal, i); - - } - - return ret; - - default: - throw new UnsupportedOperationException(); + List sizes = new ArrayList<>(); + for (int i = 0; i < toPad.rank(); i++) { + sizes.add(padWidth[i]); + } + return padImpl(toPad, sizes, constantValues); } + throw new UnsupportedOperationException(); } /** * See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth. */ public static INDArray pad(INDArray toPad, int[] padWidth, List constantValues, PadMode padMode) { - switch (padMode) { - case CONSTANT: - if (padWidth.length < toPad.rank()) - throw new IllegalArgumentException("Please specify a pad width for each dimension"); + if (padMode == PadMode.CONSTANT) { + if (padWidth.length < toPad.rank()) + throw new IllegalArgumentException("Please specify a pad width for each dimension"); - toPad = Nd4j.stripOnes(toPad); + toPad = Nd4j.stripOnes(toPad); - List sizes = new ArrayList<>(); - for (int i = 0; i < toPad.rank(); i++) { - sizes.add(padWidth); - } + List sizes = new ArrayList<>(); + for (int i = 0; i < toPad.rank(); i++) { + sizes.add(padWidth); + } - INDArray ret = toPad; - //TODO: Remove duplicate code. - for (int i = 0; i < toPad.rank(); i++) { - int[] pad = sizes.get(i); - double[] constant = constantValues.get(i); - int padBefore = pad[0]; - int padAfter = pad[1]; - if (constant.length < 2) { - double val = constant[0]; - constant = new double[2]; - constant[0] = val; - constant[1] = val; - } - - double beforeVal = constant[0]; - double afterVal = constant[1]; - ret = Nd4j.prepend(ret, padBefore, beforeVal, i); - ret = Nd4j.append(ret, padAfter, afterVal, i); - - } - return ret; - - default: - throw new UnsupportedOperationException(); + return padImpl(toPad, sizes, constantValues); } + throw new UnsupportedOperationException(); + } + + // common code for pad(INDArray, int[], List, PadMode) and + // pad(INDArray, int[][], List, PadMode) + private static INDArray padImpl(INDArray toPad, List sizes, List constantValues){ + + INDArray ret = toPad; + for (int i = 0; i < toPad.rank(); i++) { + int[] pad = sizes.get(i); + double[] constant = constantValues.get(i); + int padBefore = pad[0]; + int padAfter = pad[1]; + if (constant.length < 2) { + double val = constant[0]; + constant = new double[2]; + constant[0] = val; + constant[1] = val; + } + + double beforeVal = constant[0]; + double afterVal = constant[1]; + ret = Nd4j.prepend(ret, padBefore, beforeVal, i); + ret = Nd4j.append(ret, padAfter, afterVal, i); + + } + return ret; } /**