Add java op class for relu derivative, and use in Activation ReLU (#207)

* Add java op class for relu derivative, and use in ACtivation ReLU

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-30 23:36:00 +10:00 committed by GitHub
parent 70a9ae5068
commit a7dca9fc87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 21 deletions

View File

@ -133,23 +133,8 @@ import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance; import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance; import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; import org.nd4j.linalg.api.ops.impl.scalar.*;
import org.nd4j.linalg.api.ops.impl.scalar.LogX;
import org.nd4j.linalg.api.ops.impl.scalar.Pow; import org.nd4j.linalg.api.ops.impl.scalar.Pow;
import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction;
import org.nd4j.linalg.api.ops.impl.scalar.Step;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual;
@ -211,6 +196,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.custom.*; import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Pow;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin;
@ -1340,6 +1326,10 @@ public class DifferentialFunctionFactory {
return new RectifiedLinear(sameDiff(), iX, false, cutoff).outputVariable(); return new RectifiedLinear(sameDiff(), iX, false, cutoff).outputVariable();
} }
public SDVariable reluDerivative(SDVariable input, SDVariable grad){
return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable();
}
public SDVariable relu6(SDVariable iX, double cutoff) { public SDVariable relu6(SDVariable iX, double cutoff) {
return new Relu6(sameDiff(), iX, false, cutoff).outputVariable(); return new Relu6(sameDiff(), iX, false, cutoff).outputVariable();
} }

View File

@ -228,6 +228,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.scalar.Pow.class, org.nd4j.linalg.api.ops.impl.scalar.Pow.class,
org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class, org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class,
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class,
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class, org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class, org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class, org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,

View File

@ -18,6 +18,7 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative;
import org.nd4j.linalg.api.ops.impl.scalar.Step; import org.nd4j.linalg.api.ops.impl.scalar.Step;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
@ -41,8 +42,7 @@ public class ActivationReLU extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new Step(in)); INDArray dLdz = Nd4j.exec(new RectifiedLinearDerivative(in, epsilon, in.ulike()))[0];
dLdz.muli(epsilon);
return new Pair<>(dLdz, null); return new Pair<>(dLdz, null);
} }

View File

@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseScalarOp; import org.nd4j.linalg.api.ops.BaseScalarOp;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
/** /**
@ -77,8 +78,12 @@ public class RectifiedLinear extends BaseScalarOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0]; if(scalarValue.getDouble(0) == 0.0){
SDVariable ret = step.mul(i_v.get(0)); return Collections.singletonList(f().reluDerivative(arg(), i_v.get(0)));
return Arrays.asList(ret); } else {
SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0];
SDVariable ret = step.mul(i_v.get(0));
return Collections.singletonList(ret);
}
} }
} }

View File

@ -0,0 +1,43 @@
package org.nd4j.linalg.api.ops.impl.scalar;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.shade.guava.base.Preconditions;
import java.util.Collections;
import java.util.List;
public class RectifiedLinearDerivative extends DynamicCustomOp {
public RectifiedLinearDerivative(){ }
public RectifiedLinearDerivative(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public RectifiedLinearDerivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "relu_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}