parent
fc334ff47a
commit
0175ace4c3
|
@ -7,7 +7,9 @@ import org.nd4j.autodiff.listeners.Operation;
|
|||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -80,6 +82,14 @@ public class ArraySavingListener extends BaseListener {
|
|||
boolean eq = arr1.equalsWithEps(arr2, eps);
|
||||
if(eq){
|
||||
System.out.println("Equals: " + varName.replaceAll("__", "/"));
|
||||
} else {
|
||||
if(arr1.dataType() == DataType.BOOL){
|
||||
INDArray xor = Nd4j.exec(new Xor(arr1, arr2));
|
||||
int count = xor.castTo(DataType.INT).sumNumber().intValue();
|
||||
System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - boolean, # differences = " + count);
|
||||
System.out.println("\t" + f.getAbsolutePath());
|
||||
System.out.println("\t" + f2.getAbsolutePath());
|
||||
xor.close();
|
||||
} else {
|
||||
INDArray sub = arr1.sub(arr2);
|
||||
INDArray diff = Nd4j.math.abs(sub);
|
||||
|
@ -88,7 +98,8 @@ public class ArraySavingListener extends BaseListener {
|
|||
System.out.println("\t" + f.getAbsolutePath());
|
||||
System.out.println("\t" + f2.getAbsolutePath());
|
||||
sub.close();
|
||||
diff.close();;
|
||||
diff.close();
|
||||
}
|
||||
}
|
||||
arr1.close();
|
||||
arr2.close();
|
||||
|
|
|
@ -57,8 +57,8 @@ public class DivOp extends BaseDynamicTransformOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"Div","RealDiv"};
|
||||
public String tensorflowName() {
|
||||
return "Div";
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue