diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
index 49e760961..ac017beef 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
@@ -1562,8 +1562,8 @@ public class DifferentialFunctionFactory {
}
- public SDVariable eluBp(SDVariable in, SDVariable epsilon) {
- return new EluBp(sameDiff(), in, epsilon).outputVariable();
+ public SDVariable eluBp(SDVariable in, SDVariable epsilon, double alpha) {
+ return new EluBp(sameDiff(), in, epsilon, alpha).outputVariable();
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java
index b7ac3887c..b714b1f06 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java
@@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
-import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
-import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.indexing.BooleanIndexing;
-import org.nd4j.linalg.indexing.conditions.Conditions;
+import org.nd4j.linalg.primitives.Pair;
/**
* f(x) = alpha * (exp(x) - 1.0); x < 0
@@ -55,15 +53,7 @@ public class ActivationELU extends BaseActivationFunction {
*/
@Override
public INDArray getActivation(INDArray in, boolean training) {
- // no support in ELU native to override alpha
- if (this.alpha != 1.00) {
- INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0];
- alphaMultiple.muli(alpha);
- BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
- } else {
- Nd4j.getExecutioner().execAndReturn(new ELU(in));
- }
- return in;
+ return Nd4j.exec(new ELU(in, in, alpha))[0];
}
/*
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java
index f4624a6ee..0e2a4c6b9 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java
@@ -33,8 +33,9 @@ public class EluBp extends DynamicCustomOp {
public EluBp(){ }
- public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){
+ public EluBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){
super(sd, new SDVariable[]{input, gradient});
+ addTArgument(alpha);
}
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) {
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java
index a144e868b..6923639fd 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java
@@ -23,13 +23,9 @@ import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.GraphDef;
-import org.tensorflow.framework.NodeDef;
import java.util.Collections;
import java.util.List;
-import java.util.Map;
/**
* ELU: Exponential Linear Unit (alpha=1.0)
@@ -41,19 +37,31 @@ import java.util.Map;
* @author Alex Black
*/
public class ELU extends DynamicCustomOp {
+ public static final double DEFAULT_ALPHA = 1.0;
+
+ protected double alpha;
+
public ELU(SameDiff sameDiff, SDVariable i_v) {
super(sameDiff, new SDVariable[]{i_v});
+ this.alpha = DEFAULT_ALPHA;
+ addTArgument(alpha);
}
public ELU() {
}
public ELU(INDArray x, INDArray z) {
+ this(x, z, DEFAULT_ALPHA);
+ }
+
+ public ELU(INDArray x, INDArray z, double alpha) {
super(null, wrapOrNull(x), wrapOrNull(z));
+ this.alpha = alpha;
+ addTArgument(alpha);
}
public ELU(INDArray x) {
- this(x, null);
+ this(x, null, DEFAULT_ALPHA);
}
@Override
@@ -75,7 +83,7 @@ public class ELU extends DynamicCustomOp {
public List doDiff(List i_v) {
//ELU: e^x-1 if x<0, x otherwise
//dL/dIn = dL/Out * dOut/dIn
- return Collections.singletonList(f().eluBp(arg(), i_v.get(0)));
+ return Collections.singletonList(f().eluBp(arg(), i_v.get(0), alpha));
}
@Override