parent
acf559425a
commit
65c9f2a888
|
@ -1559,7 +1559,7 @@ public class DifferentialFunctionFactory {
|
||||||
|
|
||||||
|
|
||||||
public SDVariable elu(SDVariable iX) {
|
public SDVariable elu(SDVariable iX) {
|
||||||
return new ELU(sameDiff(), iX, false).outputVariable();
|
return new ELU(sameDiff(), iX).outputVariable();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ public class ActivationELU extends BaseActivationFunction {
|
||||||
public INDArray getActivation(INDArray in, boolean training) {
|
public INDArray getActivation(INDArray in, boolean training) {
|
||||||
// no support in ELU native to override alpha
|
// no support in ELU native to override alpha
|
||||||
if (this.alpha != 1.00) {
|
if (this.alpha != 1.00) {
|
||||||
INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()));
|
INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0];
|
||||||
alphaMultiple.muli(alpha);
|
alphaMultiple.muli(alpha);
|
||||||
BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
|
BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -16,17 +16,20 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ELU: Exponential Linear Unit (alpha=1.0)<br>
|
* ELU: Exponential Linear Unit (alpha=1.0)<br>
|
||||||
|
@ -37,25 +40,20 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
public class ELU extends BaseTransformStrictOp {
|
public class ELU extends DynamicCustomOp {
|
||||||
public ELU(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
public ELU(SameDiff sameDiff, SDVariable i_v) {
|
||||||
super(sameDiff, i_v, inPlace);
|
super(sameDiff, new SDVariable[]{i_v});
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU() {
|
public ELU() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU(INDArray x, INDArray z) {
|
public ELU(INDArray x, INDArray z) {
|
||||||
super(x, z);
|
super(null, wrapOrNull(x), wrapOrNull(z));
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU(INDArray x) {
|
public ELU(INDArray x) {
|
||||||
super(x);
|
this(x, null);
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 35;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -73,6 +71,11 @@ public class ELU extends BaseTransformStrictOp {
|
||||||
return "Elu";
|
return "Elu";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
|
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
//ELU: e^x-1 if x<0, x otherwise
|
//ELU: e^x-1 if x<0, x otherwise
|
||||||
|
@ -80,4 +83,11 @@ public class ELU extends BaseTransformStrictOp {
|
||||||
return Collections.singletonList(f().eluBp(arg(), i_v.get(0)));
|
return Collections.singletonList(f().eluBp(arg(), i_v.get(0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 datatype for ELU, got %s", dataTypes);
|
||||||
|
Preconditions.checkState(dataTypes.get(0).isFPType(), "Expected floating point input type for ELU, got %s", dataTypes);
|
||||||
|
|
||||||
|
return dataTypes;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -438,7 +438,7 @@ public class Transforms {
|
||||||
|
|
||||||
|
|
||||||
public static INDArray elu(INDArray in, boolean copy) {
|
public static INDArray elu(INDArray in, boolean copy) {
|
||||||
return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)));
|
return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray eluDerivative(INDArray arr) {
|
public static INDArray eluDerivative(INDArray arr) {
|
||||||
|
|
Loading…
Reference in New Issue