TF optional properties resolution fix (#48)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-07-06 14:33:14 +10:00 committed by AlexDBlack
parent 4b23dc01fd
commit 88ea9a49eb
2 changed files with 2 additions and 6 deletions

View File

@ -550,6 +550,8 @@ public abstract class DifferentialFunction {
continue; continue;
val var = sameDiff.getVarNameForFieldAndFunction(this,property); val var = sameDiff.getVarNameForFieldAndFunction(this,property);
if(var == null)
continue; //Rarely (like Conv2D) properties will be optional. For example kH/kW args will be inferred from weight shape
val fieldType = fields.get(property); val fieldType = fields.get(property);
val varArr = sameDiff.getArrForVarName(var); val varArr = sameDiff.getArrForVarName(var);
//already defined //already defined

View File

@ -30,12 +30,10 @@ import org.junit.runners.Parameterized;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.function.BiFunction; import org.nd4j.linalg.function.BiFunction;
import org.nd4j.resources.Downloader; import org.nd4j.resources.Downloader;
import org.nd4j.util.ArchiveUtils; import org.nd4j.util.ArchiveUtils;
@ -62,10 +60,6 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we
//2019/06/28 - Output incorrect, can't debug b/c https://github.com/eclipse/deeplearning4j/issues/7957 //2019/06/28 - Output incorrect, can't debug b/c https://github.com/eclipse/deeplearning4j/issues/7957
"cifar10_gan_85", "cifar10_gan_85",
//2019/07/03 - java.lang.NullPointerException: varName is marked @NonNull but is null
// https://github.com/eclipse/deeplearning4j/issues/7975
"alexnet",
//2019/07/03 - Out of Memory error //2019/07/03 - Out of Memory error
"compression_residual_gru", "compression_residual_gru",