Fixes, cleanup, enable now fixed tests, etc (#254)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
8a05ec2a97
commit
3e73e9b56e
|
@ -16,6 +16,10 @@
|
|||
|
||||
package org.nd4j.autodiff.validation;
|
||||
|
||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative;
|
||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -253,7 +257,7 @@ public class OpValidation {
|
|||
public static void checkDeserializedEquality(SameDiff original, ByteBuffer bbSerialized, TestCase tc) {
|
||||
SameDiff deserialized;
|
||||
try{
|
||||
deserialized = SameDiff.fromFlatBuffers(bbSerialized);
|
||||
deserialized = SameDiff.fromFlatBuffers(bbSerialized);
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException("IOException deserializing from FlatBuffers", e);
|
||||
}
|
||||
|
@ -900,6 +904,7 @@ public class OpValidation {
|
|||
TanhDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class,
|
||||
PowDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
|
||||
|
@ -911,6 +916,8 @@ public class OpValidation {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp.class,
|
||||
|
||||
|
||||
BiasAddGrad.class,
|
||||
ConcatBp.class,
|
||||
|
@ -976,7 +983,8 @@ public class OpValidation {
|
|||
BarnesHutSymmetrize.class,
|
||||
SpTreeCell.class,
|
||||
CbowRound.class,
|
||||
SkipGramRound.class
|
||||
SkipGramRound.class,
|
||||
HashCode.class
|
||||
);
|
||||
|
||||
return new HashSet<>(list);
|
||||
|
@ -1026,11 +1034,21 @@ public class OpValidation {
|
|||
IMax.class,
|
||||
IMin.class,
|
||||
LastIndex.class,
|
||||
ArgMax.class,
|
||||
ArgMin.class,
|
||||
|
||||
//Exclude ops that output integer types only:
|
||||
Shape.class,
|
||||
ShapeN.class,
|
||||
SizeAt.class,
|
||||
BroadcastDynamicShape.class,
|
||||
ReductionShape.class,
|
||||
ShiftBits.class,
|
||||
RShiftBits.class,
|
||||
BitsHammingDistance.class,
|
||||
CyclicShiftBits.class,
|
||||
CyclicRShiftBits.class,
|
||||
|
||||
|
||||
//Exclude Random ops
|
||||
RandomStandardNormal.class,
|
||||
|
@ -1209,7 +1227,7 @@ public class OpValidation {
|
|||
"to_int64",
|
||||
"to_uint32",
|
||||
"to_uint64"
|
||||
);
|
||||
);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
|
|
@ -37,14 +37,14 @@ public class TensorArrayConcat extends BaseTensorOp {
|
|||
}
|
||||
|
||||
public TensorArrayConcat(){}
|
||||
@Override
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op name found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "TensorArrayConcatV3";
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"TensorArrayConcat", "TensorArrayConcatV2", "TensorArrayConcatV3"};
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -37,14 +37,14 @@ public class TensorArrayGather extends BaseTensorOp {
|
|||
}
|
||||
|
||||
public TensorArrayGather(){}
|
||||
@Override
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op name found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "TensorArrayGatherV3";
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"TensorArrayGather", "TensorArrayGatherV2", "TensorArrayGatherV3"};
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -42,9 +42,10 @@ public class TensorArrayRead extends BaseTensorOp {
|
|||
}
|
||||
|
||||
public TensorArrayRead(){}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "TensorArrayReadV3";
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"TensorArrayRead", "TensorArrayReadV2", "TensorArrayReadV3"};
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -36,8 +36,8 @@ public class TensorArrayScatter extends BaseTensorOp {
|
|||
public TensorArrayScatter(){}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "TensorArrayScatterV3";
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"TensorArrayScatter", "TensorArrayScatterV2", "TensorArrayScatterV3"};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -31,8 +31,8 @@ import java.util.Map;
|
|||
|
||||
public class TensorArraySize extends BaseTensorOp {
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "TensorArraySizeV3";
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"TensorArraySize", "TensorArraySizeV2", "TensorArraySizeV3"};
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -36,8 +36,8 @@ public class TensorArraySplit extends BaseTensorOp {
|
|||
public TensorArraySplit(){}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "TensorArraySplitV3";
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"TensorArraySplit", "TensorArraySplitV2", "TensorArraySplitV3"};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -35,8 +35,8 @@ public class TensorArrayWrite extends BaseTensorOp {
|
|||
|
||||
public TensorArrayWrite(){}
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "TensorArrayWriteV3";
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"TensorArrayWrite", "TensorArrayWriteV2", "TensorArrayWriteV3"};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -49,9 +49,9 @@ public class TensorArrayWrite extends BaseTensorOp {
|
|||
return Op.Type.CUSTOM;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataType){
|
||||
//Dummy float variable
|
||||
return Collections.singletonList(DataType.FLOAT);
|
||||
}
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataType){
|
||||
//Dummy float variable
|
||||
return Collections.singletonList(DataType.FLOAT);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,28 +71,6 @@ public class ClipByNorm extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> grad) {
|
||||
//dOut/dIn is ??? if clipped, 1 otherwise
|
||||
/*
|
||||
int origRank = Shape.rankFromShape(arg().getShape());
|
||||
SDVariable l2norm = f().norm2(arg(), true, dimensions);
|
||||
SDVariable isClippedBC = f().gte(l2norm, clipValue);
|
||||
SDVariable notClippedBC = isClippedBC.rsub(1.0);
|
||||
|
||||
// SDVariable dnormdx = arg().div(broadcastableNorm);
|
||||
// SDVariable sqNorm = f().square(broadcastableNorm);
|
||||
// SDVariable dOutdInClipped = sqNorm.rdiv(-1).mul(dnormdx).mul(arg()) //-1/(norm2(x))^2 * x/norm2(x)
|
||||
// .add(broadcastableNorm.rdiv(1.0))
|
||||
// .mul(clipValue);
|
||||
|
||||
SDVariable dOutdInClipped = f().neg(f().square(arg()).div(f().cube(l2norm))) //-x^2/(norm2(x))^3
|
||||
.add(l2norm.rdiv(1.0)) //+ 1/norm(x)
|
||||
.mul(clipValue).mul(isClippedBC);
|
||||
|
||||
|
||||
SDVariable ret = notClippedBC.add(dOutdInClipped).mul(grad.get(0));
|
||||
return Arrays.asList(ret);
|
||||
*/
|
||||
|
||||
return Collections.singletonList(new ClipByNormBp(f().sameDiff(), arg(), grad.get(0), clipValue, dimensions).outputVariable());
|
||||
}
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ public class BitwiseAnd extends BaseDynamicTransformOp {
|
|||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "bitwise_and";
|
||||
return "BitwiseAnd";
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ public class BitwiseOr extends BaseDynamicTransformOp {
|
|||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "bitwise_or";
|
||||
return "BitwiseOr";
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ public class BitwiseXor extends BaseDynamicTransformOp {
|
|||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "bitwise_xor";
|
||||
return "BitwiseXor";
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -2472,4 +2472,28 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBroadcastInt1() {
|
||||
|
||||
INDArray out = Nd4j.create(DataType.INT, 1);
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
|
||||
.addInputs(Nd4j.createFromArray(1), Nd4j.createFromArray(4))
|
||||
.addOutputs(out)
|
||||
.build();
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
assertEquals(Nd4j.createFromArray(4), out);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBroadcastInt2(){
|
||||
INDArray out = Nd4j.create(DataType.INT, 2);
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
|
||||
.addInputs(Nd4j.createFromArray(2, 2), Nd4j.createFromArray(1))
|
||||
.addOutputs(out)
|
||||
.build();
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
|
||||
assertEquals(Nd4j.createFromArray(2, 2), out);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,30 +66,23 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
private static final String MODEL_FILENAME = "frozen_model.pb";
|
||||
|
||||
public static final String[] IGNORE_REGEXES = new String[]{
|
||||
|
||||
//Still failing: 2019/07/01 - https://github.com/deeplearning4j/deeplearning4j/issues/6322 and https://github.com/eclipse/deeplearning4j/issues/7955
|
||||
"broadcast_dynamic_shape/1_4",
|
||||
"broadcast_dynamic_shape/2,2_1",
|
||||
|
||||
//Failing 2019/07/01 - Libnd4j Concat sizing issue - https://github.com/eclipse/deeplearning4j/issues/7963
|
||||
"boolean_mask/.*",
|
||||
|
||||
//Failing 2019/07/01 - Issue 10, https://github.com/deeplearning4j/deeplearning4j/issues/6958
|
||||
//Still failing 2019/09/11
|
||||
"slogdet/.*",
|
||||
|
||||
//Failing 2019/07/01 - https://github.com/eclipse/deeplearning4j/issues/7965
|
||||
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
||||
"bincount/.*",
|
||||
|
||||
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
|
||||
"truncatemod/.*",
|
||||
|
||||
//Still failing as of 2019/07/02 - https://github.com/deeplearning4j/deeplearning4j/issues/6464 - not sure if related to: https://github.com/deeplearning4j/deeplearning4j/issues/6447
|
||||
//Still failing as of 2019/09/11 - https://github.com/deeplearning4j/deeplearning4j/issues/6464 - not sure if related to: https://github.com/deeplearning4j/deeplearning4j/issues/6447
|
||||
"cnn2d_nn/nhwc_b1_k12_s12_d12_SAME",
|
||||
|
||||
//2019/07/02 - No tensorflow op found for SparseTensorDenseAdd
|
||||
//2019/09/11 - No tensorflow op found for SparseTensorDenseAdd
|
||||
"confusion/.*",
|
||||
|
||||
//2019/07/02 - Couple of tests failing (InferenceSession issues)
|
||||
//2019/09/11 - Couple of tests failing (InferenceSession issues)
|
||||
"rnn/bstack/d_.*",
|
||||
|
||||
//2019/05/21 - Failing on AVX2/512 intermittently (Linux, OSX), passing elsewhere
|
||||
|
@ -103,13 +96,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
//2019/05/28 - JVM crash on ppc64le only - See issue 7657
|
||||
"g_11",
|
||||
|
||||
//2019/06/21 - Not yet implemented: https://github.com/eclipse/deeplearning4j/issues/7913
|
||||
"fake_quant/min_max_args_per_channel/.*",
|
||||
|
||||
//2019/06/22 - Known issue: https://github.com/eclipse/deeplearning4j/issues/7935
|
||||
"fake_quant/min_max_vars/.*",
|
||||
"fake_quant/min_max_args/.*",
|
||||
|
||||
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
||||
"multinomial/.*"
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue