From 65c9f2a888beb04f5ca888204ccbbdfdeff6e5fc Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 2 Sep 2019 17:42:12 +1000 Subject: [PATCH] ELU fix (#217) Signed-off-by: AlexDBlack --- .../DifferentialFunctionFactory.java | 2 +- .../activations/impl/ActivationELU.java | 2 +- .../api/ops/impl/transforms/strict/ELU.java | 40 ++++++++++++------- .../linalg/ops/transforms/Transforms.java | 2 +- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 0d58f024d..32d2d1474 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -1559,7 +1559,7 @@ public class DifferentialFunctionFactory { public SDVariable elu(SDVariable iX) { - return new ELU(sameDiff(), iX, false).outputVariable(); + return new ELU(sameDiff(), iX).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java index 665c84096..56fd84676 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java @@ -58,7 +58,7 @@ public class ActivationELU extends BaseActivationFunction { public INDArray getActivation(INDArray in, boolean training) { // no support in ELU native to override alpha if (this.alpha != 1.00) { - INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup())); + INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0]; alphaMultiple.muli(alpha); BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0)); } else { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java index 74d258fb1..7c85bfb1d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java @@ -16,17 +16,20 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; -import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; -import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Map; /** * ELU: Exponential Linear Unit (alpha=1.0)
@@ -37,25 +40,20 @@ import java.util.List; * * @author Alex Black */ -public class ELU extends BaseTransformStrictOp { - public ELU(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); +public class ELU extends DynamicCustomOp { + public ELU(SameDiff sameDiff, SDVariable i_v) { + super(sameDiff, new SDVariable[]{i_v}); } public ELU() { } public ELU(INDArray x, INDArray z) { - super(x, z); + super(null, wrapOrNull(x), wrapOrNull(z)); } public ELU(INDArray x) { - super(x); - } - - @Override - public int opNum() { - return 35; + this(x, null); } @Override @@ -73,6 +71,11 @@ public class ELU extends BaseTransformStrictOp { return "Elu"; } + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); + } + @Override public List doDiff(List i_v) { //ELU: e^x-1 if x<0, x otherwise @@ -80,4 +83,11 @@ public class ELU extends BaseTransformStrictOp { return Collections.singletonList(f().eluBp(arg(), i_v.get(0))); } + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 datatype for ELU, got %s", dataTypes); + Preconditions.checkState(dataTypes.get(0).isFPType(), "Expected floating point input type for ELU, got %s", dataTypes); + + return dataTypes; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java index 1f4004cb2..af95c73f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java @@ -438,7 +438,7 @@ public class Transforms { public static INDArray elu(INDArray in, boolean copy) { - return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in))); + return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0]; } public static INDArray eluDerivative(INDArray arr) {