diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 49e760961..ac017beef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -1562,8 +1562,8 @@ public class DifferentialFunctionFactory { } - public SDVariable eluBp(SDVariable in, SDVariable epsilon) { - return new EluBp(sameDiff(), in, epsilon).outputVariable(); + public SDVariable eluBp(SDVariable in, SDVariable epsilon, double alpha) { + return new EluBp(sameDiff(), in, epsilon, alpha).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java index b7ac3887c..b714b1f06 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java @@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; -import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.BooleanIndexing; -import org.nd4j.linalg.indexing.conditions.Conditions; +import org.nd4j.linalg.primitives.Pair; /** * f(x) = alpha * (exp(x) - 1.0); x < 0 @@ -55,15 +53,7 @@ public class ActivationELU extends BaseActivationFunction { */ @Override public INDArray getActivation(INDArray in, boolean training) { - // no support in ELU native to override alpha - if (this.alpha != 1.00) { - INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0]; - alphaMultiple.muli(alpha); - BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0)); - } else { - Nd4j.getExecutioner().execAndReturn(new ELU(in)); - } - return in; + return Nd4j.exec(new ELU(in, in, alpha))[0]; } /* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java index f4624a6ee..0e2a4c6b9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java @@ -33,8 +33,9 @@ public class EluBp extends DynamicCustomOp { public EluBp(){ } - public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){ + public EluBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){ super(sd, new SDVariable[]{input, gradient}); + addTArgument(alpha); } public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java index a144e868b..6923639fd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java @@ -23,13 +23,9 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; import java.util.Collections; import java.util.List; -import java.util.Map; /** * ELU: Exponential Linear Unit (alpha=1.0)
@@ -41,19 +37,31 @@ import java.util.Map; * @author Alex Black */ public class ELU extends DynamicCustomOp { + public static final double DEFAULT_ALPHA = 1.0; + + protected double alpha; + public ELU(SameDiff sameDiff, SDVariable i_v) { super(sameDiff, new SDVariable[]{i_v}); + this.alpha = DEFAULT_ALPHA; + addTArgument(alpha); } public ELU() { } public ELU(INDArray x, INDArray z) { + this(x, z, DEFAULT_ALPHA); + } + + public ELU(INDArray x, INDArray z, double alpha) { super(null, wrapOrNull(x), wrapOrNull(z)); + this.alpha = alpha; + addTArgument(alpha); } public ELU(INDArray x) { - this(x, null); + this(x, null, DEFAULT_ALPHA); } @Override @@ -75,7 +83,7 @@ public class ELU extends DynamicCustomOp { public List doDiff(List i_v) { //ELU: e^x-1 if x<0, x otherwise //dL/dIn = dL/Out * dOut/dIn - return Collections.singletonList(f().eluBp(arg(), i_v.get(0))); + return Collections.singletonList(f().eluBp(arg(), i_v.get(0), alpha)); } @Override