From 2d991f544512eecffeddbd8bebfe14cc95c4e8ef Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 10 Jul 2019 18:36:47 -0700 Subject: [PATCH] LeakyReLU fix (#55) * LeakyReLU: Use serScalar to set alpha correctly in TF import LogX: remove incorrect TF mapping Pow: remove TF import method (no mapping) BaseOp: remove duplicate extraArgs Signed-off-by: Ryan Nett * un-ignore cifar-10 gan, as it is now passing Signed-off-by: Ryan Nett --- .../main/java/org/nd4j/linalg/api/ops/BaseOp.java | 2 +- .../nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java | 13 +++++++++++++ .../org/nd4j/linalg/api/ops/impl/scalar/LogX.java | 2 +- .../org/nd4j/linalg/api/ops/impl/scalar/Pow.java | 13 ------------- .../nd4j/imports/TFGraphs/TFGraphTestZooModels.java | 3 --- 5 files changed, 15 insertions(+), 18 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java index 22e2d8314..f499ea162 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java @@ -48,7 +48,7 @@ import java.util.Map; public abstract class BaseOp extends DifferentialFunction implements Op { protected INDArray x, y, z; - protected Object[] extraArgs; + @Getter @Setter protected String xVertexId,yVertexId,zVertexId; // cached instance, for dataType checks diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java index 9f4117b2f..9f600b29b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.graph.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import org.nd4j.linalg.api.ops.BaseTransformOp; @@ -26,6 +27,10 @@ import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import org.nd4j.linalg.factory.Nd4j; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; /** * Leaky Rectified linear unit. Default alpha=0.01, cutoff=0
@@ -106,4 +111,12 @@ public class LeakyReLU extends BaseScalarOp { SDVariable ret = f().leakyReluDerivative(arg(), alpha).mul(i_v.get(0)); return Arrays.asList(ret); } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, + GraphDef graph) { + alpha = attributesForNode.get("alpha").getF(); + extraArgs = new Object[]{alpha}; + this.setScalar(Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.FLOAT, alpha)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LogX.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LogX.java index d6154a8f5..8758b0136 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LogX.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LogX.java @@ -81,6 +81,6 @@ public class LogX extends BaseScalarOp { @Override public String tensorflowName() { - return "LogX"; + throw new NoOpNameFoundException("No TensorFlow op found for " + getClass().getSimpleName()); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java index aa3289d5c..8aafce3d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java @@ -78,19 +78,6 @@ public class Pow extends BaseScalarOp { return "pow"; } - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - val weightsName = nodeDef.getInput(1); - val tmp = initWith.getArrForVarName(weightsName); - - // if second argument is scalar - we should provide array of same shape - if (tmp != null) { - if (tmp.isScalar()) { - this.pow = tmp.getDouble(0); - } - } - } - @Override public String onnxName() { return "Pow"; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java index 4695893be..92de44e3b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java @@ -59,9 +59,6 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we //2019/07/10 - Libnd4j assign error - https://github.com/eclipse/deeplearning4j/issues/8002 "xlnet_cased_L-24_H-1024_A-16", - //2019/06/28 - Output incorrect, can't debug b/c https://github.com/eclipse/deeplearning4j/issues/7957 - "cifar10_gan_85", - //2019/07/03 - Out of Memory error "compression_residual_gru",