From b3a134b608337a7eae11cf3ded5737e8e6508813 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 1 Sep 2019 23:15:23 -0700 Subject: [PATCH] New Nd4j backprop ops for activations (#211) * new (for java at least) backprop ops Signed-off-by: Ryan Nett * update activation functions Signed-off-by: Ryan Nett * add differential functions for SameDiff Signed-off-by: Ryan Nett * deprecate old ops Signed-off-by: Ryan Nett * update correct old ops Signed-off-by: Ryan Nett * update ops backprop to use new ops Signed-off-by: Ryan Nett * misc updates for deprecated functions (mostly Nd4j.rand w/ vararg shape) Signed-off-by: Ryan Nett * remove old imports Signed-off-by: Ryan Nett --- .../DifferentialFunctionFactory.java | 85 ++++++++++++++++++- .../autodiff/validation/OpValidation.java | 11 +++ .../converters/ImportClassMapping.java | 12 +++ .../activations/impl/ActivationCube.java | 8 +- .../activations/impl/ActivationELU.java | 6 +- .../impl/ActivationHardSigmoid.java | 8 +- .../activations/impl/ActivationHardTanH.java | 9 +- .../activations/impl/ActivationLReLU.java | 9 +- .../activations/impl/ActivationRReLU.java | 2 +- .../impl/ActivationRationalTanh.java | 9 +- .../activations/impl/ActivationReLU6.java | 9 +- .../impl/ActivationRectifiedTanh.java | 9 +- .../activations/impl/ActivationSELU.java | 9 +- .../activations/impl/ActivationSigmoid.java | 10 ++- .../activations/impl/ActivationSoftPlus.java | 9 +- .../activations/impl/ActivationSoftSign.java | 9 +- .../activations/impl/ActivationSoftmax.java | 10 ++- .../activations/impl/ActivationTanH.java | 9 +- .../linalg/api/ops/impl/scalar/LeakyReLU.java | 17 ++-- .../api/ops/impl/scalar/RectifiedLinear.java | 9 +- .../scalar/RectifiedLinearDerivative.java | 5 +- .../impl/transforms/custom/ThresholdRelu.java | 77 +++++++++++++++++ .../ops/impl/transforms/gradient/CubeBp.java | 62 ++++++++++++++ .../transforms/gradient/CubeDerivative.java | 4 + .../transforms/gradient/ELUDerivative.java | 3 + .../ops/impl/transforms/gradient/EluBp.java | 62 ++++++++++++++ .../transforms/gradient/HardSigmoidBp.java | 62 ++++++++++++++ .../gradient/HardSigmoidDerivative.java | 3 + .../impl/transforms/gradient/HardTanhBp.java | 62 ++++++++++++++ .../gradient/HardTanhDerivative.java | 3 + .../impl/transforms/gradient/LeakyReLUBp.java | 68 +++++++++++++++ .../transforms/gradient/RationalTanhBp.java | 62 ++++++++++++++ .../gradient/RationalTanhDerivative.java | 3 + .../transforms/gradient/RectifiedTanhBp.java | 62 ++++++++++++++ .../gradient/RectifiedTanhDerivative.java | 3 + .../transforms/gradient/Relu6Derivative.java | 17 +++- .../transforms/gradient/SELUDerivative.java | 3 + .../ops/impl/transforms/gradient/SeluBp.java | 62 ++++++++++++++ .../impl/transforms/gradient/SoftPlusBp.java | 62 ++++++++++++++ .../impl/transforms/gradient/SoftSignBp.java | 63 ++++++++++++++ .../gradient/SoftSignDerivative.java | 3 + .../impl/transforms/gradient/SoftmaxBp.java | 8 ++ .../transforms/gradient/TanhDerivative.java | 8 +- .../transforms/gradient/ThresholdReluBp.java | 74 ++++++++++++++++ .../api/ops/impl/transforms/same/Cube.java | 4 +- .../api/ops/impl/transforms/strict/ELU.java | 4 +- .../impl/transforms/strict/HardSigmoid.java | 4 +- .../ops/impl/transforms/strict/HardTanh.java | 4 +- .../impl/transforms/strict/RationalTanh.java | 5 +- .../impl/transforms/strict/RectifiedTanh.java | 5 +- .../api/ops/impl/transforms/strict/SELU.java | 4 +- .../transforms/strict/SigmoidDerivative.java | 3 + .../ops/impl/transforms/strict/SoftSign.java | 4 +- .../transforms/strict/TanhDerivative.java | 3 + .../impl/LecunUniformInitScheme.java | 2 +- .../impl/ReluUniformInitScheme.java | 2 +- .../impl/SigmoidUniformInitScheme.java | 2 +- .../weightinit/impl/UniformInitScheme.java | 2 +- ...arScalingNormalUniformFanInInitScheme.java | 2 +- ...rScalingNormalUniformFanOutInitScheme.java | 2 +- .../VarScalingUniformFanAvgInitScheme.java | 2 +- .../impl/XavierUniformInitScheme.java | 2 +- 62 files changed, 1053 insertions(+), 103 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SeluBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftPlusBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ThresholdReluBp.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 7ffeca762..0d58f024d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -204,20 +204,30 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum; import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast; import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; @@ -1126,10 +1136,26 @@ public class DifferentialFunctionFactory { return new org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative(sameDiff(), iX, wrt).outputVariable(); } + public SDVariable tanhRationalBp(SDVariable in, SDVariable epsilon) { + return new RationalTanhBp(sameDiff(), in, epsilon).outputVariable(); + } + + public SDVariable tanhRectifiedBp(SDVariable in, SDVariable epsilon) { + return new RectifiedTanhBp(sameDiff(), in, epsilon).outputVariable(); + } + + /** + * Use {@link #tanhRationalBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable tanhRationalDerivative(SDVariable in) { return new RationalTanhDerivative(sameDiff(), in, false).outputVariable(); } + /** + * Use {@link #tanhRectifiedBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable tanhRectifiedDerivative(SDVariable in) { return new RectifiedTanhDerivative(sameDiff(), in, false).outputVariable(); } @@ -1280,6 +1306,14 @@ public class DifferentialFunctionFactory { return new Cube(sameDiff(), iX, false).outputVariable(); } + public SDVariable cubeBp(SDVariable in, SDVariable epsilon) { + return new CubeBp(sameDiff(), in, epsilon).outputVariable(); + } + + /** + * @deprecated Use {@link #cubeBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable cubeDerivative(SDVariable iX) { return new CubeDerivative(sameDiff(), iX, false).outputVariable(); } @@ -1329,6 +1363,14 @@ public class DifferentialFunctionFactory { return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable(); } + public SDVariable thresholdRelu(SDVariable in, SDVariable epsilon, double cutoff){ + return new ThresholdRelu(sameDiff(), in, cutoff).outputVariable(); + } + + public SDVariable thresholdReluBp(SDVariable in, SDVariable epsilon, double cutoff){ + return new ThresholdReluBp(sameDiff(), in, epsilon, cutoff).outputVariable(); + } + public SDVariable relu6(SDVariable iX, double cutoff) { return new Relu6(sameDiff(), iX, false, cutoff).outputVariable(); } @@ -1350,6 +1392,14 @@ public class DifferentialFunctionFactory { return new HardTanh(sameDiff(), iX, false).outputVariable(); } + public SDVariable hardTanhBp(SDVariable in, SDVariable epsilon) { + return new HardTanhBp(sameDiff(), in, epsilon).outputVariable(); + } + + /** + * @deprecated Use {@link #hardTanhBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable hardTanhDerivative(SDVariable iX) { return new HardTanhDerivative(sameDiff(), iX, false).outputVariable(); } @@ -1358,6 +1408,9 @@ public class DifferentialFunctionFactory { return new HardSigmoid(sameDiff(), in, false).outputVariable(); } + public SDVariable hardSigmoidBp(SDVariable in, SDVariable epsilon){ + return new HardSigmoidBp(sameDiff(), in, epsilon).outputVariable(); + } public SDVariable sigmoid(SDVariable iX) { return new Sigmoid(sameDiff(), iX, false).outputVariable(); @@ -1486,10 +1539,16 @@ public class DifferentialFunctionFactory { } + public SDVariable softsignBp(SDVariable in, SDVariable epsilon) { + return new SoftSignBp(sameDiff(), in, epsilon).outputVariable(); + } + /** + * @deprecated Use {@link #softsignBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable softsignDerivative(SDVariable iX) { return new SoftSignDerivative(sameDiff(), iX, false).outputVariable(); - } @@ -1504,10 +1563,16 @@ public class DifferentialFunctionFactory { } + public SDVariable eluBp(SDVariable in, SDVariable epsilon) { + return new EluBp(sameDiff(), in, epsilon).outputVariable(); + } + /** + * @deprecated Use {@link #eluBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable eluDerivative(SDVariable iX) { return new ELUDerivative(sameDiff(), iX, false).outputVariable(); - } @@ -1516,6 +1581,14 @@ public class DifferentialFunctionFactory { } + public SDVariable leakyReluBp(SDVariable in, SDVariable epsilon, double cutoff) { + return new LeakyReLUBp(sameDiff(), in, epsilon, cutoff).outputVariable(); + } + + /** + * @deprecated Use {@link #leakyReluBp(SDVariable, SDVariable, double)} + */ + @Deprecated public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) { return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable(); } @@ -1832,7 +1905,15 @@ public class DifferentialFunctionFactory { return new SELU(sameDiff(), arg, false).outputVariable(); } + public SDVariable seluBp(SDVariable in, SDVariable epsilon) { + validateDifferentialFunctionsameDiff(in); + return new SeluBp(sameDiff(), in, epsilon).outputVariable(); + } + /** + * @deprecated Use {@link #seluBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable seluDerivative(SDVariable arg) { validateDifferentialFunctionsameDiff(arg); return new SELUDerivative(sameDiff(), arg, false).outputVariable(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 5bc175952..42485331d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -901,6 +901,17 @@ public class OpValidation { TanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class, PowDerivative.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class, BiasAddGrad.class, ConcatBp.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index da580b748..95b800e6a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -229,6 +229,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class, org.nd4j.linalg.api.ops.impl.scalar.Relu6.class, org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class, org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class, @@ -433,6 +434,17 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java index 79fbcea5a..766f4a4db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java @@ -21,6 +21,8 @@ import lombok.Getter; import lombok.NonNull; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; import org.nd4j.linalg.api.ops.impl.transforms.same.Cube; import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative; import org.nd4j.linalg.factory.Nd4j; @@ -42,9 +44,9 @@ public class ActivationCube extends BaseActivationFunction { @Override public Pair backprop(@NonNull INDArray in, @NonNull INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new CubeDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + Nd4j.getExecutioner().execAndReturn(new CubeBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java index 48e118d06..665c84096 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -85,9 +86,8 @@ public class ActivationELU extends BaseActivationFunction { } else { - INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in)); + return new Pair<>(in, null); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java index b2b73be0e..4076b40e3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,9 +43,9 @@ public class ActivationHardSigmoid extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + Nd4j.getExecutioner().execAndReturn(new HardSigmoidBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java index 7cf80ffc3..f8c405d38 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -43,9 +45,10 @@ public class ActivationHardTanH extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new HardTanhDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new HardTanhBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java index f59a7ddb0..864f16901 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; @@ -54,9 +56,10 @@ public class ActivationLReLU extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new LeakyReLUDerivative(in, alpha)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new LeakyReLUBp(in, epsilon, in, alpha)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java index f2e01508f..8d5bb3ddd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java @@ -63,7 +63,7 @@ public class ActivationRReLU extends BaseActivationFunction { public INDArray getActivation(INDArray in, boolean training) { if (training) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { - this.alpha = Nd4j.rand(in.shape(), l, u, Nd4j.getRandom()); + this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape()); } INDArray inTimesAlpha = in.mul(alpha); BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java index 84c8878bc..0e6cc2a51 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -48,9 +50,10 @@ public class ActivationRationalTanh extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new RationalTanhDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new RationalTanhBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java index f7bc24966..611f2c9ee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java @@ -20,8 +20,10 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.scalar.Relu6; import org.nd4j.linalg.api.ops.impl.scalar.Step; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; @@ -41,9 +43,10 @@ public class ActivationReLU6 extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new Step(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new Relu6Derivative(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java index 58ff1bc7b..ccd4cafe2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -45,9 +47,10 @@ public class ActivationRectifiedTanh extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new RectifiedTanhDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new RectifiedTanhBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java index 773a2578a..3eed5ac9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,9 +43,10 @@ public class ActivationSELU extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new SELUDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new SeluBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java index aa7e7a1c6..72500e677 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java @@ -18,7 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,9 +42,10 @@ public class ActivationSigmoid extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new SigmoidDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new SigmoidDerivative(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java index fa5fe3ef8..0eb7781f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; @@ -41,9 +43,10 @@ public class ActivationSoftPlus extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new Sigmoid(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new SoftPlusBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java index ff3b298e4..3857ff084 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,9 +43,10 @@ public class ActivationSoftSign extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new SoftSignDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new SoftSignBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java index 2fd8b439a..095c2548a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java @@ -21,7 +21,9 @@ import lombok.Getter; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; @@ -42,10 +44,10 @@ public class ActivationSoftmax extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray out = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, in.ulike()))[0]; - INDArray x = out.mul(epsilon).sum(1); - INDArray dLdz = out.mul(epsilon.subColumnVector(x)); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new SoftmaxBp(in, epsilon, in, -1)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java index 038d6032d..a30b8e303 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java @@ -18,11 +18,11 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative; import org.nd4j.linalg.factory.Nd4j; /** @@ -41,9 +41,10 @@ public class ActivationTanH extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new TanhDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new TanhDerivative(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java index 9f600b29b..fe70de288 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java @@ -16,17 +16,13 @@ package org.nd4j.linalg.api.ops.impl.scalar; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.graph.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseScalarOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; - -import java.util.Arrays; -import java.util.LinkedHashMap; +import java.util.Collections; import java.util.List; import java.util.Map; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.BaseScalarOp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -108,8 +104,7 @@ public class LeakyReLU extends BaseScalarOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().leakyReluDerivative(arg(), alpha).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(f().leakyReluBp(arg(), i_v.get(0), alpha)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java index 98fa587b5..ca8cee2f1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java @@ -75,15 +75,8 @@ public class RectifiedLinear extends BaseScalarOp { return "Relu"; } - @Override public List doDiff(List i_v) { - if(scalarValue.getDouble(0) == 0.0){ - return Collections.singletonList(f().reluDerivative(arg(), i_v.get(0))); - } else { - SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0]; - SDVariable ret = step.mul(i_v.get(0)); - return Collections.singletonList(ret); - } + return Collections.singletonList(f().thresholdReluBp(arg(), i_v.get(0), scalarValue.getDouble(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java index 3af7a4190..7e4d0fa09 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java @@ -3,10 +3,10 @@ package org.nd4j.linalg.api.ops.impl.scalar; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.shade.guava.base.Preconditions; import java.util.Collections; import java.util.List; @@ -30,7 +30,8 @@ public class RectifiedLinearDerivative extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List dataTypes) { - Preconditions.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); return Collections.singletonList(dataTypes.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java new file mode 100644 index 000000000..82e2ae6e3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import java.util.Collections; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; + +/** + * Threshold ReLU op. The genral case of {@link RectifiedLinear}. + */ +public class ThresholdRelu extends DynamicCustomOp { + + @Getter + private double cutoff = 0.0; + + public ThresholdRelu(){ } + + public ThresholdRelu(SameDiff sd, SDVariable input, boolean inPlace, double cutoff){ + super(sd, new SDVariable[]{input}, inPlace); + this.cutoff = cutoff; + addTArgument(cutoff); + } + + public ThresholdRelu(SameDiff sd, SDVariable input, double cutoff){ + super(sd, new SDVariable[]{input}); + this.cutoff = cutoff; + addTArgument(cutoff); + } + + public ThresholdRelu(@NonNull INDArray input, INDArray output, double cutoff){ + super(new INDArray[]{input}, wrapOrNull(output)); + this.cutoff = cutoff; + addTArgument(cutoff); + } + + @Override + public String opName(){ + return "thresholdedrelu"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + return Collections.singletonList(f().thresholdReluBp(arg(), f1.get(0), cutoff)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeBp.java new file mode 100644 index 000000000..16f67c910 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * Cube backpropagation op - dL/dIn from in and dL/dOut + */ +public class CubeBp extends DynamicCustomOp { + + public CubeBp(){ } + + public CubeBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public CubeBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "cube_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java index 9c6a000c1..af9985e89 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java @@ -27,7 +27,11 @@ import java.util.List; /** * Cube derivative, e.g. 3x^2 + * + * @deprecated Use {@link CubeBp} + * */ +@Deprecated public class CubeDerivative extends BaseTransformStrictOp { public CubeDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ELUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ELUDerivative.java index 45357fb0b..016890f58 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ELUDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ELUDerivative.java @@ -35,8 +35,11 @@ import java.util.List; * Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter (2015)
* http://arxiv.org/abs/1511.07289 * + * @deprecated Use {@link EluBp} + * * @author Alex Black */ +@Deprecated public class ELUDerivative extends BaseTransformStrictOp { public ELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java new file mode 100644 index 000000000..e886716c1 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * ELU backpropagation op - dL/dIn from in and dL/dOut + */ +public class EluBp extends DynamicCustomOp { + + public EluBp(){ } + + public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "elu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidBp.java new file mode 100644 index 000000000..7bf905c5d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * Hard Sigmoid backpropagation op - dL/dIn from in and dL/dOut + */ +public class HardSigmoidBp extends DynamicCustomOp { + + public HardSigmoidBp(){ } + + public HardSigmoidBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public HardSigmoidBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "hardsigmoid_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java index 01420d98a..c7328a92b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java @@ -29,8 +29,11 @@ import java.util.List; /** * HardSigmoid derivative * + * @deprecated Use {@link HardSigmoidBp} + * * @author raver119@gmail.com */ +@Deprecated public class HardSigmoidDerivative extends BaseTransformStrictOp { public HardSigmoidDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhBp.java new file mode 100644 index 000000000..10102eb2b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * Hard Tanh backpropagation op - dL/dIn from in and dL/dOut + */ +public class HardTanhBp extends DynamicCustomOp { + + public HardTanhBp(){ } + + public HardTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public HardTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "hardtanh_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java index eb0a28e09..e5322c02f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java @@ -31,8 +31,11 @@ import java.util.List; /** * Hard tanh elementwise derivative function * + * @deprecated Use {@link HardTanhBp} + * * @author Adam Gibson */ +@Deprecated public class HardTanhDerivative extends BaseTransformStrictOp { public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUBp.java new file mode 100644 index 000000000..60ef40423 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUBp.java @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * LReLU backpropagation op - dL/dIn from in and dL/dOut + */ +public class LeakyReLUBp extends DynamicCustomOp { + public static final double DEFAULT_ALPHA = 0.01; + private double alpha = DEFAULT_ALPHA; + + public LeakyReLUBp(){ } + + public LeakyReLUBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){ + super(sd, new SDVariable[]{input, gradient}); + this.alpha = alpha; + addTArgument(alpha); + } + + public LeakyReLUBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + this.alpha = alpha; + addTArgument(alpha); + } + + @Override + public String opName(){ + return "lrelu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhBp.java new file mode 100644 index 000000000..f70d79ae1 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * Rational Tanh backpropagation op - dL/dIn from in and dL/dOut + */ +public class RationalTanhBp extends DynamicCustomOp { + + public RationalTanhBp(){ } + + public RationalTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public RationalTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "rationaltanh_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhDerivative.java index 4f0f5915e..18443e9bb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhDerivative.java @@ -31,9 +31,12 @@ import java.util.List; * Rational Tanh Derivative, as described at https://github.com/deeplearning4j/libnd4j/issues/351 * Calculates dOut/dIn given input, not dL/dIn given dL/dOut and input * + * @deprecated Use {@link RationalTanhBp} + * * @author raver119@gmail.com * @author AlexDBlack */ +@Deprecated public class RationalTanhDerivative extends BaseTransformStrictOp { public RationalTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) { super(sameDiff, in, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhBp.java new file mode 100644 index 000000000..c10d6071a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * Rectified Tanh backpropagation op - dL/dIn from in and dL/dOut + */ +public class RectifiedTanhBp extends DynamicCustomOp { + + public RectifiedTanhBp(){ } + + public RectifiedTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public RectifiedTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "rectifiedtanh_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhDerivative.java index 37acf94ac..8c896fb10 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhDerivative.java @@ -30,9 +30,12 @@ import java.util.List; /** * Rectified Tanh Derivative * + * @deprecated Use {@link RectifiedTanhBp} + * * @author raver119@gmail.com * @author AlexDBlack */ +@Deprecated public class RectifiedTanhDerivative extends BaseTransformStrictOp { public RectifiedTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) { super(sameDiff, in, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/Relu6Derivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/Relu6Derivative.java index c915658b8..3477b4e71 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/Relu6Derivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/Relu6Derivative.java @@ -16,15 +16,18 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; import java.util.List; +import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; /** * Derivative of Rectified linear unit 6, i.e. min(max(input, cutoff), 6), where cutoff can be chosen. @@ -33,7 +36,9 @@ import java.util.List; */ public class Relu6Derivative extends DynamicCustomOp { - private double cutoff = 0.0; + private static final double DEFAULT_CUTOFF = 0.0; + + private double cutoff = DEFAULT_CUTOFF; public Relu6Derivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, double cutoff) { super("relu6_bp", sameDiff, new SDVariable[]{i_v1, i_v2}); @@ -45,6 +50,16 @@ public class Relu6Derivative extends DynamicCustomOp { this.extraArgs = new Object[]{cutoff}; } + public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + this(input, gradient, output, DEFAULT_CUTOFF); + } + + public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + this.cutoff = cutoff; + this.extraArgs = new Object[]{cutoff}; + } + @Override public int opNum() { return 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java index 58877f041..b00b29b75 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java @@ -31,8 +31,11 @@ import java.util.List; * * https://arxiv.org/pdf/1706.02515.pdf * + * @deprecated Use {@link SeluBp} + * * @author raver119@gmail.com */ +@Deprecated public class SELUDerivative extends BaseTransformStrictOp { private static final double SELU_ALPHA = 1.6732632423543772848170429916717; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SeluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SeluBp.java new file mode 100644 index 000000000..a13171e10 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SeluBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * SELU backpropagation op - dL/dIn from in and dL/dOut + */ +public class SeluBp extends DynamicCustomOp { + + public SeluBp(){ } + + public SeluBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public SeluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "selu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftPlusBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftPlusBp.java new file mode 100644 index 000000000..be8c1b702 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftPlusBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * SoftPlus backpropagation op - dL/dIn from in and dL/dOut + */ +public class SoftPlusBp extends DynamicCustomOp { + + public SoftPlusBp(){ } + + public SoftPlusBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public SoftPlusBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "softplus_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignBp.java new file mode 100644 index 000000000..c636361e6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignBp.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * SoftSign backpropagation op - dL/dIn from in and dL/dOut + */ +public class SoftSignBp extends DynamicCustomOp { + + public SoftSignBp(){ } + + public SoftSignBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public SoftSignBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "softsign_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java index 3741cfe90..4ae26e585 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java @@ -29,7 +29,10 @@ import java.util.List; /** * SoftSign derivative. + * + * @deprecated Use {@link SoftSignBp} */ +@Deprecated public class SoftSignDerivative extends BaseTransformStrictOp { public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java index 1a018d2e0..dbbdb8dde 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; @@ -40,6 +42,12 @@ public class SoftmaxBp extends DynamicCustomOp { addIArgument(dimension); } + public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Integer dimension){ + super(new INDArray[]{input, grad}, wrapOrNull(output)); + if(dimension != null) + addIArgument(dimension); + } + @Override public String opName() { return "softmax_bp"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java index be0f7d85f..4d8209b8a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java @@ -35,15 +35,15 @@ public class TanhDerivative extends DynamicCustomOp { super(sameDiff, new SDVariable[]{i_v1, i_v2}); } - public TanhDerivative(INDArray x, INDArray z) { - super(null, x, z, null, null); + public TanhDerivative(INDArray x, INDArray y, INDArray z) { + super(null, new INDArray[]{x, y}, new INDArray[]{z}); } public TanhDerivative() { } - public TanhDerivative(INDArray x) { - this(x, null); + public TanhDerivative(INDArray x, INDArray y) { + this(x, y, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ThresholdReluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ThresholdReluBp.java new file mode 100644 index 000000000..8d04a7118 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ThresholdReluBp.java @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; +import org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu; + +/** + * Threshold ReLU Backprop op - dL/dIn from in and dL/dOut + * + * For {@link RectifiedLinear} as well as {@link ThresholdRelu}. + */ +public class ThresholdReluBp extends DynamicCustomOp { + + @Getter + private double cutoff = 0; + + public ThresholdReluBp(){ } + + public ThresholdReluBp(SameDiff sd, SDVariable input, SDVariable gradient, double cutoff){ + super(sd, new SDVariable[]{input, gradient}); + this.cutoff = cutoff; + addTArgument(cutoff); + } + + public ThresholdReluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + this.cutoff = cutoff; + addTArgument(cutoff); + } + + @Override + public String opName(){ + return "thresholdedrelu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java index 9e1e05693..d58ad8f3f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -70,7 +71,6 @@ public class Cube extends BaseTransformSameOp { @Override public List doDiff(List f1) { - SDVariable g = f().mul(f().cubeDerivative(arg()),f1.get(0)); - return Arrays.asList(g); + return Collections.singletonList(f().cubeBp(arg(), f1.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java index 8266add39..74d258fb1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -76,8 +77,7 @@ public class ELU extends BaseTransformStrictOp { public List doDiff(List i_v) { //ELU: e^x-1 if x<0, x otherwise //dL/dIn = dL/Out * dOut/dIn - SDVariable ret = f().eluDerivative(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(f().eluBp(arg(), i_v.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java index a1703d221..ddca48d4c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java @@ -69,9 +69,7 @@ public class HardSigmoid extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - SDVariable in = arg(); - SDVariable dOutdIn = new HardSigmoidDerivative(sameDiff, in, false).outputVariables()[0]; - return Collections.singletonList(dOutdIn.mul(f1.get(0))); + return Collections.singletonList(f().hardSigmoidBp(arg(), f1.get(0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java index a2452443b..4237e72de 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -70,7 +71,6 @@ public class HardTanh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().hardTanhDerivative(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(f().hardTanhBp(arg(), i_v.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java index 2de0e90a7..a05e34637 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java @@ -16,13 +16,10 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; -import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Collections; @@ -71,6 +68,6 @@ public class RationalTanh extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().tanhRationalDerivative(arg()).mul(f1.get(0))); + return Collections.singletonList(f().tanhRationalBp(arg(), f1.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java index da439cec7..d5fbf1294 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java @@ -17,13 +17,10 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; import onnx.Onnx; -import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -88,6 +85,6 @@ public class RectifiedTanh extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().tanhRectifiedDerivative(arg()).mul(f1.get(0))); + return Collections.singletonList(f().tanhRectifiedBp(arg(), f1.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java index 159b0b170..f72676f86 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -76,8 +77,7 @@ public class SELU extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().seluDerivative(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(f().seluBp(arg(), i_v.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SigmoidDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SigmoidDerivative.java index 7213b50b0..08a97bae7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SigmoidDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SigmoidDerivative.java @@ -28,8 +28,11 @@ import java.util.List; /** * Sigmoid derivative * + * @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative} + * * @author Adam Gibson */ +@Deprecated public class SigmoidDerivative extends BaseTransformStrictOp { public SigmoidDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { super(sameDiff, i_v1, i_v2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java index 0b2782860..c7c90b201 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -73,8 +74,7 @@ public class SoftSign extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().softsignDerivative(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(f().softsignBp(arg(), i_v.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/TanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/TanhDerivative.java index fc9a9581f..fad63e73b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/TanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/TanhDerivative.java @@ -27,7 +27,10 @@ import java.util.List; /** * Tanh derivative + * + * @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative}. */ +@Deprecated public class TanhDerivative extends BaseTransformStrictOp { public TanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java index 67ff1a114..73b471535 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java @@ -42,7 +42,7 @@ public class LecunUniformInitScheme extends BaseWeightInitScheme { @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double b = 3.0 / Math.sqrt(fanIn); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-b, b)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-b, b), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java index 9561953e0..07eeadeb7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java @@ -43,7 +43,7 @@ public class ReluUniformInitScheme extends BaseWeightInitScheme { @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double u = Math.sqrt(6.0 / fanIn); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn) + return Nd4j.rand(Nd4j.getDistributions().createUniform(-u, u), shape); //U(-sqrt(6/fanIn), sqrt(6/fanIn) } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java index 58809a095..4c7420b38 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java @@ -46,7 +46,7 @@ public class SigmoidUniformInitScheme extends BaseWeightInitScheme { @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut)); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-r, r)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-r, r), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java index c8744d69e..ca44b3a84 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java @@ -43,7 +43,7 @@ public class UniformInitScheme extends BaseWeightInitScheme { @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double a = 1.0 / Math.sqrt(fanIn); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-a, a)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-a, a), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java index 1ed0e4efb..3ad193140 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java @@ -43,7 +43,7 @@ public class VarScalingNormalUniformFanInInitScheme extends BaseWeightInitScheme @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double scalingFanIn = 3.0 / Math.sqrt(fanIn); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java index 2405dea88..bafa2170d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java @@ -42,7 +42,7 @@ public class VarScalingNormalUniformFanOutInitScheme extends BaseWeightInitSchem @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double scalingFanOut = 3.0 / Math.sqrt(fanOut); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java index 8f3dc2a7d..2e3d85093 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java @@ -46,7 +46,7 @@ public class VarScalingUniformFanAvgInitScheme extends BaseWeightInitScheme { @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java index ddc0d7428..d1f156b01 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java @@ -46,7 +46,7 @@ public class XavierUniformInitScheme extends BaseWeightInitScheme { //As per Glorot and Bengio 2010: Uniform distribution U(-s,s) with s = sqrt(6/(fanIn + fanOut)) //Eq 16: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf double s = Math.sqrt(6.0) / Math.sqrt(fanIn + fanOut); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-s, s)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-s, s), shape); }