Fixes, cleanup, enable now fixed tests, etc (#254)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
8a05ec2a97
commit
3e73e9b56e
|
@ -16,6 +16,10 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.validation;
|
package org.nd4j.autodiff.validation;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative;
|
||||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -253,7 +257,7 @@ public class OpValidation {
|
||||||
public static void checkDeserializedEquality(SameDiff original, ByteBuffer bbSerialized, TestCase tc) {
|
public static void checkDeserializedEquality(SameDiff original, ByteBuffer bbSerialized, TestCase tc) {
|
||||||
SameDiff deserialized;
|
SameDiff deserialized;
|
||||||
try{
|
try{
|
||||||
deserialized = SameDiff.fromFlatBuffers(bbSerialized);
|
deserialized = SameDiff.fromFlatBuffers(bbSerialized);
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
throw new RuntimeException("IOException deserializing from FlatBuffers", e);
|
throw new RuntimeException("IOException deserializing from FlatBuffers", e);
|
||||||
}
|
}
|
||||||
|
@ -900,6 +904,7 @@ public class OpValidation {
|
||||||
TanhDerivative.class,
|
TanhDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class,
|
||||||
PowDerivative.class,
|
PowDerivative.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
|
||||||
|
@ -911,6 +916,8 @@ public class OpValidation {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp.class,
|
||||||
|
|
||||||
|
|
||||||
BiasAddGrad.class,
|
BiasAddGrad.class,
|
||||||
ConcatBp.class,
|
ConcatBp.class,
|
||||||
|
@ -976,7 +983,8 @@ public class OpValidation {
|
||||||
BarnesHutSymmetrize.class,
|
BarnesHutSymmetrize.class,
|
||||||
SpTreeCell.class,
|
SpTreeCell.class,
|
||||||
CbowRound.class,
|
CbowRound.class,
|
||||||
SkipGramRound.class
|
SkipGramRound.class,
|
||||||
|
HashCode.class
|
||||||
);
|
);
|
||||||
|
|
||||||
return new HashSet<>(list);
|
return new HashSet<>(list);
|
||||||
|
@ -1026,11 +1034,21 @@ public class OpValidation {
|
||||||
IMax.class,
|
IMax.class,
|
||||||
IMin.class,
|
IMin.class,
|
||||||
LastIndex.class,
|
LastIndex.class,
|
||||||
|
ArgMax.class,
|
||||||
|
ArgMin.class,
|
||||||
|
|
||||||
//Exclude ops that output integer types only:
|
//Exclude ops that output integer types only:
|
||||||
Shape.class,
|
Shape.class,
|
||||||
ShapeN.class,
|
ShapeN.class,
|
||||||
SizeAt.class,
|
SizeAt.class,
|
||||||
|
BroadcastDynamicShape.class,
|
||||||
|
ReductionShape.class,
|
||||||
|
ShiftBits.class,
|
||||||
|
RShiftBits.class,
|
||||||
|
BitsHammingDistance.class,
|
||||||
|
CyclicShiftBits.class,
|
||||||
|
CyclicRShiftBits.class,
|
||||||
|
|
||||||
|
|
||||||
//Exclude Random ops
|
//Exclude Random ops
|
||||||
RandomStandardNormal.class,
|
RandomStandardNormal.class,
|
||||||
|
@ -1209,7 +1227,7 @@ public class OpValidation {
|
||||||
"to_int64",
|
"to_int64",
|
||||||
"to_uint32",
|
"to_uint32",
|
||||||
"to_uint64"
|
"to_uint64"
|
||||||
);
|
);
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,14 +37,14 @@ public class TensorArrayConcat extends BaseTensorOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public TensorArrayConcat(){}
|
public TensorArrayConcat(){}
|
||||||
@Override
|
@Override
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
throw new NoOpNameFoundException("No onnx op name found for " + opName());
|
throw new NoOpNameFoundException("No onnx op name found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String[] tensorflowNames() {
|
||||||
return "TensorArrayConcatV3";
|
return new String[]{"TensorArrayConcat", "TensorArrayConcatV2", "TensorArrayConcatV3"};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,14 +37,14 @@ public class TensorArrayGather extends BaseTensorOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public TensorArrayGather(){}
|
public TensorArrayGather(){}
|
||||||
@Override
|
@Override
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
throw new NoOpNameFoundException("No onnx op name found for " + opName());
|
throw new NoOpNameFoundException("No onnx op name found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String[] tensorflowNames() {
|
||||||
return "TensorArrayGatherV3";
|
return new String[]{"TensorArrayGather", "TensorArrayGatherV2", "TensorArrayGatherV3"};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -42,9 +42,10 @@ public class TensorArrayRead extends BaseTensorOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public TensorArrayRead(){}
|
public TensorArrayRead(){}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String[] tensorflowNames() {
|
||||||
return "TensorArrayReadV3";
|
return new String[]{"TensorArrayRead", "TensorArrayReadV2", "TensorArrayReadV3"};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -36,8 +36,8 @@ public class TensorArrayScatter extends BaseTensorOp {
|
||||||
public TensorArrayScatter(){}
|
public TensorArrayScatter(){}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String[] tensorflowNames() {
|
||||||
return "TensorArrayScatterV3";
|
return new String[]{"TensorArrayScatter", "TensorArrayScatterV2", "TensorArrayScatterV3"};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -31,8 +31,8 @@ import java.util.Map;
|
||||||
|
|
||||||
public class TensorArraySize extends BaseTensorOp {
|
public class TensorArraySize extends BaseTensorOp {
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String[] tensorflowNames() {
|
||||||
return "TensorArraySizeV3";
|
return new String[]{"TensorArraySize", "TensorArraySizeV2", "TensorArraySizeV3"};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -36,8 +36,8 @@ public class TensorArraySplit extends BaseTensorOp {
|
||||||
public TensorArraySplit(){}
|
public TensorArraySplit(){}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String[] tensorflowNames() {
|
||||||
return "TensorArraySplitV3";
|
return new String[]{"TensorArraySplit", "TensorArraySplitV2", "TensorArraySplitV3"};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -35,8 +35,8 @@ public class TensorArrayWrite extends BaseTensorOp {
|
||||||
|
|
||||||
public TensorArrayWrite(){}
|
public TensorArrayWrite(){}
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String[] tensorflowNames() {
|
||||||
return "TensorArrayWriteV3";
|
return new String[]{"TensorArrayWrite", "TensorArrayWriteV2", "TensorArrayWriteV3"};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -49,9 +49,9 @@ public class TensorArrayWrite extends BaseTensorOp {
|
||||||
return Op.Type.CUSTOM;
|
return Op.Type.CUSTOM;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataType){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataType){
|
||||||
//Dummy float variable
|
//Dummy float variable
|
||||||
return Collections.singletonList(DataType.FLOAT);
|
return Collections.singletonList(DataType.FLOAT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,28 +71,6 @@ public class ClipByNorm extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> grad) {
|
public List<SDVariable> doDiff(List<SDVariable> grad) {
|
||||||
//dOut/dIn is ??? if clipped, 1 otherwise
|
|
||||||
/*
|
|
||||||
int origRank = Shape.rankFromShape(arg().getShape());
|
|
||||||
SDVariable l2norm = f().norm2(arg(), true, dimensions);
|
|
||||||
SDVariable isClippedBC = f().gte(l2norm, clipValue);
|
|
||||||
SDVariable notClippedBC = isClippedBC.rsub(1.0);
|
|
||||||
|
|
||||||
// SDVariable dnormdx = arg().div(broadcastableNorm);
|
|
||||||
// SDVariable sqNorm = f().square(broadcastableNorm);
|
|
||||||
// SDVariable dOutdInClipped = sqNorm.rdiv(-1).mul(dnormdx).mul(arg()) //-1/(norm2(x))^2 * x/norm2(x)
|
|
||||||
// .add(broadcastableNorm.rdiv(1.0))
|
|
||||||
// .mul(clipValue);
|
|
||||||
|
|
||||||
SDVariable dOutdInClipped = f().neg(f().square(arg()).div(f().cube(l2norm))) //-x^2/(norm2(x))^3
|
|
||||||
.add(l2norm.rdiv(1.0)) //+ 1/norm(x)
|
|
||||||
.mul(clipValue).mul(isClippedBC);
|
|
||||||
|
|
||||||
|
|
||||||
SDVariable ret = notClippedBC.add(dOutdInClipped).mul(grad.get(0));
|
|
||||||
return Arrays.asList(ret);
|
|
||||||
*/
|
|
||||||
|
|
||||||
return Collections.singletonList(new ClipByNormBp(f().sameDiff(), arg(), grad.get(0), clipValue, dimensions).outputVariable());
|
return Collections.singletonList(new ClipByNormBp(f().sameDiff(), arg(), grad.get(0), clipValue, dimensions).outputVariable());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ public class BitwiseAnd extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "bitwise_and";
|
return "BitwiseAnd";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class BitwiseOr extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "bitwise_or";
|
return "BitwiseOr";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class BitwiseXor extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "bitwise_xor";
|
return "BitwiseXor";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2472,4 +2472,28 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBroadcastInt1() {
|
||||||
|
|
||||||
|
INDArray out = Nd4j.create(DataType.INT, 1);
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
|
||||||
|
.addInputs(Nd4j.createFromArray(1), Nd4j.createFromArray(4))
|
||||||
|
.addOutputs(out)
|
||||||
|
.build();
|
||||||
|
Nd4j.getExecutioner().exec(op);
|
||||||
|
assertEquals(Nd4j.createFromArray(4), out);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBroadcastInt2(){
|
||||||
|
INDArray out = Nd4j.create(DataType.INT, 2);
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
|
||||||
|
.addInputs(Nd4j.createFromArray(2, 2), Nd4j.createFromArray(1))
|
||||||
|
.addOutputs(out)
|
||||||
|
.build();
|
||||||
|
Nd4j.getExecutioner().exec(op);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.createFromArray(2, 2), out);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,30 +66,23 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
private static final String MODEL_FILENAME = "frozen_model.pb";
|
private static final String MODEL_FILENAME = "frozen_model.pb";
|
||||||
|
|
||||||
public static final String[] IGNORE_REGEXES = new String[]{
|
public static final String[] IGNORE_REGEXES = new String[]{
|
||||||
|
|
||||||
//Still failing: 2019/07/01 - https://github.com/deeplearning4j/deeplearning4j/issues/6322 and https://github.com/eclipse/deeplearning4j/issues/7955
|
|
||||||
"broadcast_dynamic_shape/1_4",
|
|
||||||
"broadcast_dynamic_shape/2,2_1",
|
|
||||||
|
|
||||||
//Failing 2019/07/01 - Libnd4j Concat sizing issue - https://github.com/eclipse/deeplearning4j/issues/7963
|
|
||||||
"boolean_mask/.*",
|
|
||||||
|
|
||||||
//Failing 2019/07/01 - Issue 10, https://github.com/deeplearning4j/deeplearning4j/issues/6958
|
//Failing 2019/07/01 - Issue 10, https://github.com/deeplearning4j/deeplearning4j/issues/6958
|
||||||
|
//Still failing 2019/09/11
|
||||||
"slogdet/.*",
|
"slogdet/.*",
|
||||||
|
|
||||||
//Failing 2019/07/01 - https://github.com/eclipse/deeplearning4j/issues/7965
|
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
||||||
"bincount/.*",
|
"bincount/.*",
|
||||||
|
|
||||||
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
|
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
|
||||||
"truncatemod/.*",
|
"truncatemod/.*",
|
||||||
|
|
||||||
//Still failing as of 2019/07/02 - https://github.com/deeplearning4j/deeplearning4j/issues/6464 - not sure if related to: https://github.com/deeplearning4j/deeplearning4j/issues/6447
|
//Still failing as of 2019/09/11 - https://github.com/deeplearning4j/deeplearning4j/issues/6464 - not sure if related to: https://github.com/deeplearning4j/deeplearning4j/issues/6447
|
||||||
"cnn2d_nn/nhwc_b1_k12_s12_d12_SAME",
|
"cnn2d_nn/nhwc_b1_k12_s12_d12_SAME",
|
||||||
|
|
||||||
//2019/07/02 - No tensorflow op found for SparseTensorDenseAdd
|
//2019/09/11 - No tensorflow op found for SparseTensorDenseAdd
|
||||||
"confusion/.*",
|
"confusion/.*",
|
||||||
|
|
||||||
//2019/07/02 - Couple of tests failing (InferenceSession issues)
|
//2019/09/11 - Couple of tests failing (InferenceSession issues)
|
||||||
"rnn/bstack/d_.*",
|
"rnn/bstack/d_.*",
|
||||||
|
|
||||||
//2019/05/21 - Failing on AVX2/512 intermittently (Linux, OSX), passing elsewhere
|
//2019/05/21 - Failing on AVX2/512 intermittently (Linux, OSX), passing elsewhere
|
||||||
|
@ -103,13 +96,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
//2019/05/28 - JVM crash on ppc64le only - See issue 7657
|
//2019/05/28 - JVM crash on ppc64le only - See issue 7657
|
||||||
"g_11",
|
"g_11",
|
||||||
|
|
||||||
//2019/06/21 - Not yet implemented: https://github.com/eclipse/deeplearning4j/issues/7913
|
|
||||||
"fake_quant/min_max_args_per_channel/.*",
|
|
||||||
|
|
||||||
//2019/06/22 - Known issue: https://github.com/eclipse/deeplearning4j/issues/7935
|
|
||||||
"fake_quant/min_max_vars/.*",
|
|
||||||
"fake_quant/min_max_args/.*",
|
|
||||||
|
|
||||||
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
||||||
"multinomial/.*"
|
"multinomial/.*"
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue