Small tweaks (#119)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-12-09 23:08:00 +11:00 committed by GitHub
parent fc334ff47a
commit 0175ace4c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 10 deletions

View File

@ -7,7 +7,9 @@ import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
@ -80,6 +82,14 @@ public class ArraySavingListener extends BaseListener {
boolean eq = arr1.equalsWithEps(arr2, eps);
if(eq){
System.out.println("Equals: " + varName.replaceAll("__", "/"));
} else {
if(arr1.dataType() == DataType.BOOL){
INDArray xor = Nd4j.exec(new Xor(arr1, arr2));
int count = xor.castTo(DataType.INT).sumNumber().intValue();
System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - boolean, # differences = " + count);
System.out.println("\t" + f.getAbsolutePath());
System.out.println("\t" + f2.getAbsolutePath());
xor.close();
} else {
INDArray sub = arr1.sub(arr2);
INDArray diff = Nd4j.math.abs(sub);
@ -88,7 +98,8 @@ public class ArraySavingListener extends BaseListener {
System.out.println("\t" + f.getAbsolutePath());
System.out.println("\t" + f2.getAbsolutePath());
sub.close();
diff.close();;
diff.close();
}
}
arr1.close();
arr2.close();

View File

@ -57,8 +57,8 @@ public class DivOp extends BaseDynamicTransformOp {
}
@Override
public String[] tensorflowNames() {
return new String[]{"Div","RealDiv"};
public String tensorflowName() {
return "Div";
}