Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-02 17:42:12 +10:00 committed by GitHub
parent acf559425a
commit 65c9f2a888
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 18 deletions

View File

@ -1559,7 +1559,7 @@ public class DifferentialFunctionFactory {
public SDVariable elu(SDVariable iX) { public SDVariable elu(SDVariable iX) {
return new ELU(sameDiff(), iX, false).outputVariable(); return new ELU(sameDiff(), iX).outputVariable();
} }

View File

@ -58,7 +58,7 @@ public class ActivationELU extends BaseActivationFunction {
public INDArray getActivation(INDArray in, boolean training) { public INDArray getActivation(INDArray in, boolean training) {
// no support in ELU native to override alpha // no support in ELU native to override alpha
if (this.alpha != 1.00) { if (this.alpha != 1.00) {
INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup())); INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0];
alphaMultiple.muli(alpha); alphaMultiple.muli(alpha);
BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0)); BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
} else { } else {

View File

@ -16,17 +16,20 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict; package org.nd4j.linalg.api.ops.impl.transforms.strict;
import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.BaseTransformOp; import org.tensorflow.framework.AttrValue;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Arrays; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* ELU: Exponential Linear Unit (alpha=1.0)<br> * ELU: Exponential Linear Unit (alpha=1.0)<br>
@ -37,25 +40,20 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
public class ELU extends BaseTransformStrictOp { public class ELU extends DynamicCustomOp {
public ELU(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { public ELU(SameDiff sameDiff, SDVariable i_v) {
super(sameDiff, i_v, inPlace); super(sameDiff, new SDVariable[]{i_v});
} }
public ELU() { public ELU() {
} }
public ELU(INDArray x, INDArray z) { public ELU(INDArray x, INDArray z) {
super(x, z); super(null, wrapOrNull(x), wrapOrNull(z));
} }
public ELU(INDArray x) { public ELU(INDArray x) {
super(x); this(x, null);
}
@Override
public int opNum() {
return 35;
} }
@Override @Override
@ -73,6 +71,11 @@ public class ELU extends BaseTransformStrictOp {
return "Elu"; return "Elu";
} }
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
}
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
//ELU: e^x-1 if x<0, x otherwise //ELU: e^x-1 if x<0, x otherwise
@ -80,4 +83,11 @@ public class ELU extends BaseTransformStrictOp {
return Collections.singletonList(f().eluBp(arg(), i_v.get(0))); return Collections.singletonList(f().eluBp(arg(), i_v.get(0)));
} }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 datatype for ELU, got %s", dataTypes);
Preconditions.checkState(dataTypes.get(0).isFPType(), "Expected floating point input type for ELU, got %s", dataTypes);
return dataTypes;
}
} }

View File

@ -438,7 +438,7 @@ public class Transforms {
public static INDArray elu(INDArray in, boolean copy) { public static INDArray elu(INDArray in, boolean copy) {
return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in))); return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0];
} }
public static INDArray eluDerivative(INDArray arr) { public static INDArray eluDerivative(INDArray arr) {