Small tweaks (#119)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-12-09 23:08:00 +11:00 committed by GitHub
parent fc334ff47a
commit 0175ace4c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 10 deletions

View File

@ -7,7 +7,9 @@ import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -80,6 +82,14 @@ public class ArraySavingListener extends BaseListener {
boolean eq = arr1.equalsWithEps(arr2, eps); boolean eq = arr1.equalsWithEps(arr2, eps);
if(eq){ if(eq){
System.out.println("Equals: " + varName.replaceAll("__", "/")); System.out.println("Equals: " + varName.replaceAll("__", "/"));
} else {
if(arr1.dataType() == DataType.BOOL){
INDArray xor = Nd4j.exec(new Xor(arr1, arr2));
int count = xor.castTo(DataType.INT).sumNumber().intValue();
System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - boolean, # differences = " + count);
System.out.println("\t" + f.getAbsolutePath());
System.out.println("\t" + f2.getAbsolutePath());
xor.close();
} else { } else {
INDArray sub = arr1.sub(arr2); INDArray sub = arr1.sub(arr2);
INDArray diff = Nd4j.math.abs(sub); INDArray diff = Nd4j.math.abs(sub);
@ -88,7 +98,8 @@ public class ArraySavingListener extends BaseListener {
System.out.println("\t" + f.getAbsolutePath()); System.out.println("\t" + f.getAbsolutePath());
System.out.println("\t" + f2.getAbsolutePath()); System.out.println("\t" + f2.getAbsolutePath());
sub.close(); sub.close();
diff.close();; diff.close();
}
} }
arr1.close(); arr1.close();
arr2.close(); arr2.close();

View File

@ -57,8 +57,8 @@ public class DivOp extends BaseDynamicTransformOp {
} }
@Override @Override
public String[] tensorflowNames() { public String tensorflowName() {
return new String[]{"Div","RealDiv"}; return "Div";
} }