parent
ba269a26ab
commit
364a6e1a2a
|
@ -1562,8 +1562,8 @@ public class DifferentialFunctionFactory {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable eluBp(SDVariable in, SDVariable epsilon) {
|
public SDVariable eluBp(SDVariable in, SDVariable epsilon, double alpha) {
|
||||||
return new EluBp(sameDiff(), in, epsilon).outputVariable();
|
return new EluBp(sameDiff(), in, epsilon, alpha).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* f(x) = alpha * (exp(x) - 1.0); x < 0
|
* f(x) = alpha * (exp(x) - 1.0); x < 0
|
||||||
|
@ -55,15 +53,7 @@ public class ActivationELU extends BaseActivationFunction {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public INDArray getActivation(INDArray in, boolean training) {
|
public INDArray getActivation(INDArray in, boolean training) {
|
||||||
// no support in ELU native to override alpha
|
return Nd4j.exec(new ELU(in, in, alpha))[0];
|
||||||
if (this.alpha != 1.00) {
|
|
||||||
INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0];
|
|
||||||
alphaMultiple.muli(alpha);
|
|
||||||
BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
|
|
||||||
} else {
|
|
||||||
Nd4j.getExecutioner().execAndReturn(new ELU(in));
|
|
||||||
}
|
|
||||||
return in;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -33,8 +33,9 @@ public class EluBp extends DynamicCustomOp {
|
||||||
|
|
||||||
public EluBp(){ }
|
public EluBp(){ }
|
||||||
|
|
||||||
public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
public EluBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){
|
||||||
super(sd, new SDVariable[]{input, gradient});
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
addTArgument(alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) {
|
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) {
|
||||||
|
|
|
@ -23,13 +23,9 @@ import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.tensorflow.framework.AttrValue;
|
|
||||||
import org.tensorflow.framework.GraphDef;
|
|
||||||
import org.tensorflow.framework.NodeDef;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ELU: Exponential Linear Unit (alpha=1.0)<br>
|
* ELU: Exponential Linear Unit (alpha=1.0)<br>
|
||||||
|
@ -41,19 +37,31 @@ import java.util.Map;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
public class ELU extends DynamicCustomOp {
|
public class ELU extends DynamicCustomOp {
|
||||||
|
public static final double DEFAULT_ALPHA = 1.0;
|
||||||
|
|
||||||
|
protected double alpha;
|
||||||
|
|
||||||
public ELU(SameDiff sameDiff, SDVariable i_v) {
|
public ELU(SameDiff sameDiff, SDVariable i_v) {
|
||||||
super(sameDiff, new SDVariable[]{i_v});
|
super(sameDiff, new SDVariable[]{i_v});
|
||||||
|
this.alpha = DEFAULT_ALPHA;
|
||||||
|
addTArgument(alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU() {
|
public ELU() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU(INDArray x, INDArray z) {
|
public ELU(INDArray x, INDArray z) {
|
||||||
|
this(x, z, DEFAULT_ALPHA);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ELU(INDArray x, INDArray z, double alpha) {
|
||||||
super(null, wrapOrNull(x), wrapOrNull(z));
|
super(null, wrapOrNull(x), wrapOrNull(z));
|
||||||
|
this.alpha = alpha;
|
||||||
|
addTArgument(alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU(INDArray x) {
|
public ELU(INDArray x) {
|
||||||
this(x, null);
|
this(x, null, DEFAULT_ALPHA);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -75,7 +83,7 @@ public class ELU extends DynamicCustomOp {
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
//ELU: e^x-1 if x<0, x otherwise
|
//ELU: e^x-1 if x<0, x otherwise
|
||||||
//dL/dIn = dL/Out * dOut/dIn
|
//dL/dIn = dL/Out * dOut/dIn
|
||||||
return Collections.singletonList(f().eluBp(arg(), i_v.get(0)));
|
return Collections.singletonList(f().eluBp(arg(), i_v.get(0), alpha));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue