Add ArraySavingListener for debugging (#114)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									70e08c3a6c
								
							
						
					
					
						commit
						b66154a9d4
					
				| @ -0,0 +1,107 @@ | ||||
| package org.nd4j.autodiff.listeners.debugging; | ||||
| 
 | ||||
| import lombok.NonNull; | ||||
| import org.nd4j.autodiff.listeners.At; | ||||
| import org.nd4j.autodiff.listeners.BaseListener; | ||||
| import org.nd4j.autodiff.listeners.Operation; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
| import org.nd4j.autodiff.samediff.internal.SameDiffOp; | ||||
| import org.nd4j.base.Preconditions; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.dataset.api.MultiDataSet; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.util.HashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| 
 | ||||
| public class ArraySavingListener extends BaseListener { | ||||
| 
 | ||||
|     protected final File dir; | ||||
|     protected int count = 0; | ||||
| 
 | ||||
|     public ArraySavingListener(@NonNull File dir){ | ||||
| 
 | ||||
|         if(!dir.exists()){ | ||||
|             dir.mkdir(); | ||||
|         } | ||||
| 
 | ||||
|         if(dir.listFiles() != null && dir.listFiles().length > 0){ | ||||
|             throw new IllegalStateException("Directory is not empty: " + dir.getAbsolutePath()); | ||||
|         } | ||||
| 
 | ||||
|         this.dir = dir; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean isActive(Operation operation) { | ||||
|         return true; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { | ||||
|         List<String> outNames = op.getOutputsOfOp(); | ||||
|         for(int i=0; i<outputs.length; i++ ){ | ||||
|             String filename = (count++) + "_" + outNames.get(i).replaceAll("/", "__") + ".bin"; | ||||
|             File outFile = new File(dir, filename); | ||||
| 
 | ||||
|             INDArray arr = outputs[i]; | ||||
|             try { | ||||
|                 Nd4j.saveBinary(arr, outFile); | ||||
|                 System.out.println(outFile.getAbsolutePath()); | ||||
|             } catch (IOException e){ | ||||
|                 throw new RuntimeException(e); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public static void compare(File dir1, File dir2, double eps) throws Exception { | ||||
|         File[] files1 = dir1.listFiles(); | ||||
|         File[] files2 = dir2.listFiles(); | ||||
|         Preconditions.checkNotNull(files1, "No files in directory 1: %s", dir1); | ||||
|         Preconditions.checkNotNull(files2, "No files in directory 2: %s", dir2); | ||||
|         Preconditions.checkState(files1.length == files2.length, "Different number of files: %s vs %s", files1.length, files2.length); | ||||
| 
 | ||||
|         Map<String,File> m1 = toMap(files1); | ||||
|         Map<String,File> m2 = toMap(files2); | ||||
| 
 | ||||
|         for(File f : files1){ | ||||
|             String name = f.getName(); | ||||
|             String varName = name.substring(name.indexOf('_') + 1, name.length()-4); //Strip "x_" and ".bin" | ||||
|             File f2 = m2.get(varName); | ||||
| 
 | ||||
|             INDArray arr1 = Nd4j.readBinary(f); | ||||
|             INDArray arr2 = Nd4j.readBinary(f2); | ||||
| 
 | ||||
|             //TODO String arrays won't work here! | ||||
|             boolean eq = arr1.equalsWithEps(arr2, eps); | ||||
|             if(eq){ | ||||
|                 System.out.println("Equals: " + varName.replaceAll("__", "/")); | ||||
|             } else { | ||||
|                 INDArray sub = arr1.sub(arr2); | ||||
|                 INDArray diff = Nd4j.math.abs(sub); | ||||
|                 double maxDiff = diff.maxNumber().doubleValue(); | ||||
|                 System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff); | ||||
|                 System.out.println("\t" + f.getAbsolutePath()); | ||||
|                 System.out.println("\t" + f2.getAbsolutePath()); | ||||
|                 sub.close(); | ||||
|                 diff.close();; | ||||
|             } | ||||
|             arr1.close(); | ||||
|             arr2.close(); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     private static Map<String,File> toMap(File[] files){ | ||||
|         Map<String,File> ret = new HashMap<>(); | ||||
|         for(File f : files) { | ||||
|             String  name = f.getName(); | ||||
|             String varName = name.substring(name.indexOf('_') + 1, name.length() - 4); //Strip "x_" and ".bin" | ||||
|             ret.put(varName, f); | ||||
|         } | ||||
|         return ret; | ||||
|     } | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user