diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 035ff0960..3343d9dde 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -133,23 +133,8 @@ import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance; import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance; import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; -import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; -import org.nd4j.linalg.api.ops.impl.scalar.LogX; +import org.nd4j.linalg.api.ops.impl.scalar.*; import org.nd4j.linalg.api.ops.impl.scalar.Pow; -import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative; -import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; -import org.nd4j.linalg.api.ops.impl.scalar.Relu6; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction; -import org.nd4j.linalg.api.ops.impl.scalar.Step; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual; @@ -211,6 +196,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; import org.nd4j.linalg.api.ops.impl.transforms.custom.*; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Pow; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin; @@ -1340,6 +1326,10 @@ public class DifferentialFunctionFactory { return new RectifiedLinear(sameDiff(), iX, false, cutoff).outputVariable(); } + public SDVariable reluDerivative(SDVariable input, SDVariable grad){ + return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable(); + } + public SDVariable relu6(SDVariable iX, double cutoff) { return new Relu6(sameDiff(), iX, false, cutoff).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index e2d3cf87b..da580b748 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -228,6 +228,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.scalar.Pow.class, org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class, + org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class, org.nd4j.linalg.api.ops.impl.scalar.Relu6.class, org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class, org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java index 36313593d..cf0ec7fa0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative; import org.nd4j.linalg.api.ops.impl.scalar.Step; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; @@ -41,8 +42,7 @@ public class ActivationReLU extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new Step(in)); - dLdz.muli(epsilon); + INDArray dLdz = Nd4j.exec(new RectifiedLinearDerivative(in, epsilon, in.ulike()))[0]; return new Pair<>(dLdz, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java index 15673a5a0..98fa587b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -77,8 +78,12 @@ public class RectifiedLinear extends BaseScalarOp { @Override public List doDiff(List i_v) { - SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0]; - SDVariable ret = step.mul(i_v.get(0)); - return Arrays.asList(ret); + if(scalarValue.getDouble(0) == 0.0){ + return Collections.singletonList(f().reluDerivative(arg(), i_v.get(0))); + } else { + SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0]; + SDVariable ret = step.mul(i_v.get(0)); + return Collections.singletonList(ret); + } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java new file mode 100644 index 000000000..3af7a4190 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java @@ -0,0 +1,43 @@ +package org.nd4j.linalg.api.ops.impl.scalar; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.shade.guava.base.Preconditions; + +import java.util.Collections; +import java.util.List; + +public class RectifiedLinearDerivative extends DynamicCustomOp { + + public RectifiedLinearDerivative(){ } + + public RectifiedLinearDerivative(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public RectifiedLinearDerivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "relu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +}