diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java new file mode 100644 index 000000000..9137fc831 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java @@ -0,0 +1,107 @@ +package org.nd4j.autodiff.listeners.debugging; + +import lombok.NonNull; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; + +import java.io.File; +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class ArraySavingListener extends BaseListener { + + protected final File dir; + protected int count = 0; + + public ArraySavingListener(@NonNull File dir){ + + if(!dir.exists()){ + dir.mkdir(); + } + + if(dir.listFiles() != null && dir.listFiles().length > 0){ + throw new IllegalStateException("Directory is not empty: " + dir.getAbsolutePath()); + } + + this.dir = dir; + } + + @Override + public boolean isActive(Operation operation) { + return true; + } + + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + List outNames = op.getOutputsOfOp(); + for(int i=0; i m1 = toMap(files1); + Map m2 = toMap(files2); + + for(File f : files1){ + String name = f.getName(); + String varName = name.substring(name.indexOf('_') + 1, name.length()-4); //Strip "x_" and ".bin" + File f2 = m2.get(varName); + + INDArray arr1 = Nd4j.readBinary(f); + INDArray arr2 = Nd4j.readBinary(f2); + + //TODO String arrays won't work here! + boolean eq = arr1.equalsWithEps(arr2, eps); + if(eq){ + System.out.println("Equals: " + varName.replaceAll("__", "/")); + } else { + INDArray sub = arr1.sub(arr2); + INDArray diff = Nd4j.math.abs(sub); + double maxDiff = diff.maxNumber().doubleValue(); + System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff); + System.out.println("\t" + f.getAbsolutePath()); + System.out.println("\t" + f2.getAbsolutePath()); + sub.close(); + diff.close();; + } + arr1.close(); + arr2.close(); + } + } + + private static Map toMap(File[] files){ + Map ret = new HashMap<>(); + for(File f : files) { + String name = f.getName(); + String varName = name.substring(name.indexOf('_') + 1, name.length() - 4); //Strip "x_" and ".bin" + ret.put(varName, f); + } + return ret; + } +}