parent
fc334ff47a
commit
0175ace4c3
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j
autodiff/listeners/debugging
linalg/api/ops/impl/transforms/pairwise/arithmetic
|
@ -7,7 +7,9 @@ import org.nd4j.autodiff.listeners.Operation;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -81,14 +83,23 @@ public class ArraySavingListener extends BaseListener {
|
||||||
if(eq){
|
if(eq){
|
||||||
System.out.println("Equals: " + varName.replaceAll("__", "/"));
|
System.out.println("Equals: " + varName.replaceAll("__", "/"));
|
||||||
} else {
|
} else {
|
||||||
INDArray sub = arr1.sub(arr2);
|
if(arr1.dataType() == DataType.BOOL){
|
||||||
INDArray diff = Nd4j.math.abs(sub);
|
INDArray xor = Nd4j.exec(new Xor(arr1, arr2));
|
||||||
double maxDiff = diff.maxNumber().doubleValue();
|
int count = xor.castTo(DataType.INT).sumNumber().intValue();
|
||||||
System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff);
|
System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - boolean, # differences = " + count);
|
||||||
System.out.println("\t" + f.getAbsolutePath());
|
System.out.println("\t" + f.getAbsolutePath());
|
||||||
System.out.println("\t" + f2.getAbsolutePath());
|
System.out.println("\t" + f2.getAbsolutePath());
|
||||||
sub.close();
|
xor.close();
|
||||||
diff.close();;
|
} else {
|
||||||
|
INDArray sub = arr1.sub(arr2);
|
||||||
|
INDArray diff = Nd4j.math.abs(sub);
|
||||||
|
double maxDiff = diff.maxNumber().doubleValue();
|
||||||
|
System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff);
|
||||||
|
System.out.println("\t" + f.getAbsolutePath());
|
||||||
|
System.out.println("\t" + f2.getAbsolutePath());
|
||||||
|
sub.close();
|
||||||
|
diff.close();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
arr1.close();
|
arr1.close();
|
||||||
arr2.close();
|
arr2.close();
|
||||||
|
|
|
@ -57,8 +57,8 @@ public class DivOp extends BaseDynamicTransformOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] tensorflowNames() {
|
public String tensorflowName() {
|
||||||
return new String[]{"Div","RealDiv"};
|
return "Div";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue