refactor duplicate code from pad methods. (#86)
* refactor duplicate code from pad methods. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * replace switch with if. Signed-off-by: Robert Altena <Rob@Ra-ai.com>
This commit is contained in:
		
							parent
							
								
									d4e7997134
								
							
						
					
					
						commit
						aa4af2c36d
					
				| @ -225,85 +225,64 @@ public class Nd4j { | ||||
|      * based on the specified mode | ||||
|      */ | ||||
|     public static INDArray pad(INDArray toPad, int[][] padWidth, List<double[]> constantValues, PadMode padMode) { | ||||
|         switch (padMode) { | ||||
|             case CONSTANT: | ||||
|                 if (padWidth.length < toPad.rank()) | ||||
|                     throw new IllegalArgumentException("Please specify a pad width for each dimension"); | ||||
|         if (padMode == PadMode.CONSTANT) { | ||||
|             if (padWidth.length < toPad.rank()) | ||||
|                 throw new IllegalArgumentException("Please specify a pad width for each dimension"); | ||||
| 
 | ||||
|                 List<int[]> sizes = new ArrayList<>(); | ||||
|                 for (int i = 0; i < toPad.rank(); i++) { | ||||
|                     sizes.add(padWidth[i]); | ||||
|                 } | ||||
| 
 | ||||
|                 INDArray ret = toPad; | ||||
|                 for (int i = 0; i < toPad.rank(); i++) { | ||||
|                     int[] pad = sizes.get(i); | ||||
|                     double[] constant = constantValues.get(i); | ||||
|                     int padBefore = pad[0]; | ||||
|                     int padAfter = pad[1]; | ||||
|                     if (constant.length < 2) { | ||||
|                         double val = constant[0]; | ||||
|                         constant = new double[2]; | ||||
|                         constant[0] = val; | ||||
|                         constant[1] = val; | ||||
|                     } | ||||
| 
 | ||||
|                     double beforeVal = constant[0]; | ||||
|                     double afterVal = constant[1]; | ||||
|                     ret = Nd4j.prepend(ret, padBefore, beforeVal, i); | ||||
|                     ret = Nd4j.append(ret, padAfter, afterVal, i); | ||||
| 
 | ||||
|                 } | ||||
| 
 | ||||
|                 return ret; | ||||
| 
 | ||||
|             default: | ||||
|                 throw new UnsupportedOperationException(); | ||||
|             List<int[]> sizes = new ArrayList<>(); | ||||
|             for (int i = 0; i < toPad.rank(); i++) { | ||||
|                 sizes.add(padWidth[i]); | ||||
|             } | ||||
| 
 | ||||
|             return padImpl(toPad, sizes, constantValues); | ||||
|         } | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth. | ||||
|      */ | ||||
|     public static INDArray pad(INDArray toPad, int[] padWidth, List<double[]> constantValues, PadMode padMode) { | ||||
|         switch (padMode) { | ||||
|             case CONSTANT: | ||||
|                 if (padWidth.length < toPad.rank()) | ||||
|                     throw new IllegalArgumentException("Please specify a pad width for each dimension"); | ||||
|         if (padMode == PadMode.CONSTANT) { | ||||
|             if (padWidth.length < toPad.rank()) | ||||
|                 throw new IllegalArgumentException("Please specify a pad width for each dimension"); | ||||
| 
 | ||||
|                 toPad = Nd4j.stripOnes(toPad); | ||||
|             toPad = Nd4j.stripOnes(toPad); | ||||
| 
 | ||||
|                 List<int[]> sizes = new ArrayList<>(); | ||||
|                 for (int i = 0; i < toPad.rank(); i++) { | ||||
|                     sizes.add(padWidth); | ||||
|                 } | ||||
|             List<int[]> sizes = new ArrayList<>(); | ||||
|             for (int i = 0; i < toPad.rank(); i++) { | ||||
|                 sizes.add(padWidth); | ||||
|             } | ||||
| 
 | ||||
|                 INDArray ret = toPad; | ||||
|                 //TODO: Remove duplicate code. | ||||
|                 for (int i = 0; i < toPad.rank(); i++) { | ||||
|                     int[] pad = sizes.get(i); | ||||
|                     double[] constant = constantValues.get(i); | ||||
|                     int padBefore = pad[0]; | ||||
|                     int padAfter = pad[1]; | ||||
|                     if (constant.length < 2) { | ||||
|                         double val = constant[0]; | ||||
|                         constant = new double[2]; | ||||
|                         constant[0] = val; | ||||
|                         constant[1] = val; | ||||
|                     } | ||||
| 
 | ||||
|                     double beforeVal = constant[0]; | ||||
|                     double afterVal = constant[1]; | ||||
|                     ret = Nd4j.prepend(ret, padBefore, beforeVal, i); | ||||
|                     ret = Nd4j.append(ret, padAfter, afterVal, i); | ||||
| 
 | ||||
|                 } | ||||
|                 return ret; | ||||
| 
 | ||||
|             default: | ||||
|                 throw new UnsupportedOperationException(); | ||||
|             return padImpl(toPad, sizes, constantValues); | ||||
|         } | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     // common code for pad(INDArray, int[],   List<double[]>, PadMode) and | ||||
|     //                 pad(INDArray, int[][], List<double[]>, PadMode) | ||||
|     private static INDArray padImpl(INDArray toPad, List<int[]> sizes, List<double[]> constantValues){ | ||||
| 
 | ||||
|         INDArray ret = toPad; | ||||
|         for (int i = 0; i < toPad.rank(); i++) { | ||||
|             int[] pad = sizes.get(i); | ||||
|             double[] constant = constantValues.get(i); | ||||
|             int padBefore = pad[0]; | ||||
|             int padAfter = pad[1]; | ||||
|             if (constant.length < 2) { | ||||
|                 double val = constant[0]; | ||||
|                 constant = new double[2]; | ||||
|                 constant[0] = val; | ||||
|                 constant[1] = val; | ||||
|             } | ||||
| 
 | ||||
|             double beforeVal = constant[0]; | ||||
|             double afterVal = constant[1]; | ||||
|             ret = Nd4j.prepend(ret, padBefore, beforeVal, i); | ||||
|             ret = Nd4j.append(ret, padAfter, afterVal, i); | ||||
| 
 | ||||
|         } | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user