Add java op class for relu derivative, and use in Activation ReLU (#207)
* Add java op class for relu derivative, and use in ACtivation ReLU Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
70a9ae5068
commit
a7dca9fc87
|
@ -133,23 +133,8 @@ import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
|
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
|
import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
|
import org.nd4j.linalg.api.ops.impl.scalar.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.LogX;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.Pow;
|
import org.nd4j.linalg.api.ops.impl.scalar.Pow;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
|
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
|
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual;
|
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual;
|
||||||
|
@ -211,6 +196,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Pow;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin;
|
||||||
|
@ -1340,6 +1326,10 @@ public class DifferentialFunctionFactory {
|
||||||
return new RectifiedLinear(sameDiff(), iX, false, cutoff).outputVariable();
|
return new RectifiedLinear(sameDiff(), iX, false, cutoff).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable reluDerivative(SDVariable input, SDVariable grad){
|
||||||
|
return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
public SDVariable relu6(SDVariable iX, double cutoff) {
|
public SDVariable relu6(SDVariable iX, double cutoff) {
|
||||||
return new Relu6(sameDiff(), iX, false, cutoff).outputVariable();
|
return new Relu6(sameDiff(), iX, false, cutoff).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
|
@ -228,6 +228,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.Pow.class,
|
org.nd4j.linalg.api.ops.impl.scalar.Pow.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class,
|
org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class,
|
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
|
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
|
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,
|
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
|
@ -41,8 +42,7 @@ public class ActivationReLU extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new Step(in));
|
INDArray dLdz = Nd4j.exec(new RectifiedLinearDerivative(in, epsilon, in.ulike()))[0];
|
||||||
dLdz.muli(epsilon);
|
|
||||||
return new Pair<>(dLdz, null);
|
return new Pair<>(dLdz, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -77,8 +78,12 @@ public class RectifiedLinear extends BaseScalarOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
|
if(scalarValue.getDouble(0) == 0.0){
|
||||||
|
return Collections.singletonList(f().reluDerivative(arg(), i_v.get(0)));
|
||||||
|
} else {
|
||||||
SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0];
|
SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0];
|
||||||
SDVariable ret = step.mul(i_v.get(0));
|
SDVariable ret = step.mul(i_v.get(0));
|
||||||
return Arrays.asList(ret);
|
return Collections.singletonList(ret);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.scalar;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.shade.guava.base.Preconditions;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class RectifiedLinearDerivative extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public RectifiedLinearDerivative(){ }
|
||||||
|
|
||||||
|
public RectifiedLinearDerivative(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
}
|
||||||
|
|
||||||
|
public RectifiedLinearDerivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "relu_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue