Fixes, cleanup, enable now fixed tests, etc (#254)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-11 23:37:24 +10:00 committed by GitHub
parent 8a05ec2a97
commit 3e73e9b56e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 75 additions and 68 deletions

View File

@ -16,6 +16,10 @@
package org.nd4j.autodiff.validation; package org.nd4j.autodiff.validation;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative;
import org.nd4j.shade.guava.collect.ImmutableSet; import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath; import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -900,6 +904,7 @@ public class OpValidation {
TanhDerivative.class, TanhDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class,
PowDerivative.class, PowDerivative.class,
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
@ -911,6 +916,8 @@ public class OpValidation {
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp.class,
BiasAddGrad.class, BiasAddGrad.class,
ConcatBp.class, ConcatBp.class,
@ -976,7 +983,8 @@ public class OpValidation {
BarnesHutSymmetrize.class, BarnesHutSymmetrize.class,
SpTreeCell.class, SpTreeCell.class,
CbowRound.class, CbowRound.class,
SkipGramRound.class SkipGramRound.class,
HashCode.class
); );
return new HashSet<>(list); return new HashSet<>(list);
@ -1026,11 +1034,21 @@ public class OpValidation {
IMax.class, IMax.class,
IMin.class, IMin.class,
LastIndex.class, LastIndex.class,
ArgMax.class,
ArgMin.class,
//Exclude ops that output integer types only: //Exclude ops that output integer types only:
Shape.class, Shape.class,
ShapeN.class, ShapeN.class,
SizeAt.class, SizeAt.class,
BroadcastDynamicShape.class,
ReductionShape.class,
ShiftBits.class,
RShiftBits.class,
BitsHammingDistance.class,
CyclicShiftBits.class,
CyclicRShiftBits.class,
//Exclude Random ops //Exclude Random ops
RandomStandardNormal.class, RandomStandardNormal.class,

View File

@ -43,8 +43,8 @@ public class TensorArrayConcat extends BaseTensorOp {
} }
@Override @Override
public String tensorflowName() { public String[] tensorflowNames() {
return "TensorArrayConcatV3"; return new String[]{"TensorArrayConcat", "TensorArrayConcatV2", "TensorArrayConcatV3"};
} }

View File

@ -43,8 +43,8 @@ public class TensorArrayGather extends BaseTensorOp {
} }
@Override @Override
public String tensorflowName() { public String[] tensorflowNames() {
return "TensorArrayGatherV3"; return new String[]{"TensorArrayGather", "TensorArrayGatherV2", "TensorArrayGatherV3"};
} }

View File

@ -42,9 +42,10 @@ public class TensorArrayRead extends BaseTensorOp {
} }
public TensorArrayRead(){} public TensorArrayRead(){}
@Override @Override
public String tensorflowName() { public String[] tensorflowNames() {
return "TensorArrayReadV3"; return new String[]{"TensorArrayRead", "TensorArrayReadV2", "TensorArrayReadV3"};
} }

View File

@ -36,8 +36,8 @@ public class TensorArrayScatter extends BaseTensorOp {
public TensorArrayScatter(){} public TensorArrayScatter(){}
@Override @Override
public String tensorflowName() { public String[] tensorflowNames() {
return "TensorArrayScatterV3"; return new String[]{"TensorArrayScatter", "TensorArrayScatterV2", "TensorArrayScatterV3"};
} }
@Override @Override

View File

@ -31,8 +31,8 @@ import java.util.Map;
public class TensorArraySize extends BaseTensorOp { public class TensorArraySize extends BaseTensorOp {
@Override @Override
public String tensorflowName() { public String[] tensorflowNames() {
return "TensorArraySizeV3"; return new String[]{"TensorArraySize", "TensorArraySizeV2", "TensorArraySizeV3"};
} }

View File

@ -36,8 +36,8 @@ public class TensorArraySplit extends BaseTensorOp {
public TensorArraySplit(){} public TensorArraySplit(){}
@Override @Override
public String tensorflowName() { public String[] tensorflowNames() {
return "TensorArraySplitV3"; return new String[]{"TensorArraySplit", "TensorArraySplitV2", "TensorArraySplitV3"};
} }
@Override @Override

View File

@ -35,8 +35,8 @@ public class TensorArrayWrite extends BaseTensorOp {
public TensorArrayWrite(){} public TensorArrayWrite(){}
@Override @Override
public String tensorflowName() { public String[] tensorflowNames() {
return "TensorArrayWriteV3"; return new String[]{"TensorArrayWrite", "TensorArrayWriteV2", "TensorArrayWriteV3"};
} }
@Override @Override

View File

@ -71,28 +71,6 @@ public class ClipByNorm extends DynamicCustomOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> grad) { public List<SDVariable> doDiff(List<SDVariable> grad) {
//dOut/dIn is ??? if clipped, 1 otherwise
/*
int origRank = Shape.rankFromShape(arg().getShape());
SDVariable l2norm = f().norm2(arg(), true, dimensions);
SDVariable isClippedBC = f().gte(l2norm, clipValue);
SDVariable notClippedBC = isClippedBC.rsub(1.0);
// SDVariable dnormdx = arg().div(broadcastableNorm);
// SDVariable sqNorm = f().square(broadcastableNorm);
// SDVariable dOutdInClipped = sqNorm.rdiv(-1).mul(dnormdx).mul(arg()) //-1/(norm2(x))^2 * x/norm2(x)
// .add(broadcastableNorm.rdiv(1.0))
// .mul(clipValue);
SDVariable dOutdInClipped = f().neg(f().square(arg()).div(f().cube(l2norm))) //-x^2/(norm2(x))^3
.add(l2norm.rdiv(1.0)) //+ 1/norm(x)
.mul(clipValue).mul(isClippedBC);
SDVariable ret = notClippedBC.add(dOutdInClipped).mul(grad.get(0));
return Arrays.asList(ret);
*/
return Collections.singletonList(new ClipByNormBp(f().sameDiff(), arg(), grad.get(0), clipValue, dimensions).outputVariable()); return Collections.singletonList(new ClipByNormBp(f().sameDiff(), arg(), grad.get(0), clipValue, dimensions).outputVariable());
} }

View File

@ -50,7 +50,7 @@ public class BitwiseAnd extends BaseDynamicTransformOp {
@Override @Override
public String opName() { public String opName() {
return "bitwise_and"; return "BitwiseAnd";
} }

View File

@ -61,7 +61,7 @@ public class BitwiseOr extends BaseDynamicTransformOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
return "bitwise_or"; return "BitwiseOr";
} }

View File

@ -61,7 +61,7 @@ public class BitwiseXor extends BaseDynamicTransformOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
return "bitwise_xor"; return "BitwiseXor";
} }

View File

@ -2472,4 +2472,28 @@ public class ShapeOpValidation extends BaseOpValidation {
.build(); .build();
} }
@Test
public void testBroadcastInt1() {
INDArray out = Nd4j.create(DataType.INT, 1);
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
.addInputs(Nd4j.createFromArray(1), Nd4j.createFromArray(4))
.addOutputs(out)
.build();
Nd4j.getExecutioner().exec(op);
assertEquals(Nd4j.createFromArray(4), out);
}
@Test
public void testBroadcastInt2(){
INDArray out = Nd4j.create(DataType.INT, 2);
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
.addInputs(Nd4j.createFromArray(2, 2), Nd4j.createFromArray(1))
.addOutputs(out)
.build();
Nd4j.getExecutioner().exec(op);
assertEquals(Nd4j.createFromArray(2, 2), out);
}
} }

View File

@ -66,30 +66,23 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
private static final String MODEL_FILENAME = "frozen_model.pb"; private static final String MODEL_FILENAME = "frozen_model.pb";
public static final String[] IGNORE_REGEXES = new String[]{ public static final String[] IGNORE_REGEXES = new String[]{
//Still failing: 2019/07/01 - https://github.com/deeplearning4j/deeplearning4j/issues/6322 and https://github.com/eclipse/deeplearning4j/issues/7955
"broadcast_dynamic_shape/1_4",
"broadcast_dynamic_shape/2,2_1",
//Failing 2019/07/01 - Libnd4j Concat sizing issue - https://github.com/eclipse/deeplearning4j/issues/7963
"boolean_mask/.*",
//Failing 2019/07/01 - Issue 10, https://github.com/deeplearning4j/deeplearning4j/issues/6958 //Failing 2019/07/01 - Issue 10, https://github.com/deeplearning4j/deeplearning4j/issues/6958
//Still failing 2019/09/11
"slogdet/.*", "slogdet/.*",
//Failing 2019/07/01 - https://github.com/eclipse/deeplearning4j/issues/7965 //Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
"bincount/.*", "bincount/.*",
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too //TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
"truncatemod/.*", "truncatemod/.*",
//Still failing as of 2019/07/02 - https://github.com/deeplearning4j/deeplearning4j/issues/6464 - not sure if related to: https://github.com/deeplearning4j/deeplearning4j/issues/6447 //Still failing as of 2019/09/11 - https://github.com/deeplearning4j/deeplearning4j/issues/6464 - not sure if related to: https://github.com/deeplearning4j/deeplearning4j/issues/6447
"cnn2d_nn/nhwc_b1_k12_s12_d12_SAME", "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME",
//2019/07/02 - No tensorflow op found for SparseTensorDenseAdd //2019/09/11 - No tensorflow op found for SparseTensorDenseAdd
"confusion/.*", "confusion/.*",
//2019/07/02 - Couple of tests failing (InferenceSession issues) //2019/09/11 - Couple of tests failing (InferenceSession issues)
"rnn/bstack/d_.*", "rnn/bstack/d_.*",
//2019/05/21 - Failing on AVX2/512 intermittently (Linux, OSX), passing elsewhere //2019/05/21 - Failing on AVX2/512 intermittently (Linux, OSX), passing elsewhere
@ -103,13 +96,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
//2019/05/28 - JVM crash on ppc64le only - See issue 7657 //2019/05/28 - JVM crash on ppc64le only - See issue 7657
"g_11", "g_11",
//2019/06/21 - Not yet implemented: https://github.com/eclipse/deeplearning4j/issues/7913
"fake_quant/min_max_args_per_channel/.*",
//2019/06/22 - Known issue: https://github.com/eclipse/deeplearning4j/issues/7935
"fake_quant/min_max_vars/.*",
"fake_quant/min_max_args/.*",
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913 //2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
"multinomial/.*" "multinomial/.*"
}; };