diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 92fd58d7a..918850e34 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -169,7 +169,9 @@ (53, GELU) ,\ (54, GELUDerivative), \ (55, PreciseGELU) ,\ - (56, PreciseGELUDerivative) + (56, PreciseGELUDerivative), \ + (57, Mish),\ + (58, MishDerivative) // these ops return one of FLOAT data types #define TRANSFORM_FLOAT_OPS \ diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index ab4bfca90..1cdf08130 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -1718,6 +1718,32 @@ namespace simdOps { } }; + template + class Mish { + public: + no_op_exec_special_same + no_op_exec_special_same_cuda + + op_def static X op(X d1, X *params) { + return d1 * nd4j::math::nd4j_tanh(nd4j::math::nd4j_softplus(d1)); + } + }; + + template + class MishDerivative { + public: + no_op_exec_special_same + no_op_exec_special_same_cuda + + op_def static X op(X d1, X *params) { + auto ex = nd4j::math::nd4j_exp(d1); + auto e2x = ex * ex; + auto e3x = ex * ex * ex; + + return (ex * (4 * (d1 + 1) + 4 * e2x + e3x + ex *(4 * d1 + 6))) / nd4j::math::nd4j_pow((2 * ex + e2x + 2), (X) 2.f); + } + }; + template class GELU { public: @@ -1954,7 +1980,7 @@ namespace simdOps { no_op_exec_special_same_cuda op_def static X op(X d1, X *params) { - return nd4j::math::softplus(d1); + return nd4j::math::nd4j_softplus(d1); } }; diff --git a/libnd4j/include/templatemath.h b/libnd4j/include/templatemath.h index 23f6b342d..55cd4033a 100644 --- a/libnd4j/include/templatemath.h +++ b/libnd4j/include/templatemath.h @@ -74,6 +74,9 @@ namespace nd4j { template math_def inline Z nd4j_copysign(T val1, T val2); + template + math_def inline Z nd4j_softplus(T val); + //#ifndef __CUDACC__ template math_def inline Z nd4j_dot(X *x, Y *y, int length); @@ -159,7 +162,7 @@ namespace nd4j { math_def inline Z nd4j_sinh(T val); template - math_def inline Z softplus(T val) { + math_def inline Z nd4j_softplus(T val) { return nd4j_log((Z) 1.0f + nd4j_exp(val)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 52e725191..445be0a6a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -260,40 +260,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ACos; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ASin; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ATan; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Cos; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Erf; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1; -import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU; -import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid; -import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Log; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p; -import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid; -import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU; -import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Sin; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SwishDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Tan; -import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; +import org.nd4j.linalg.api.ops.impl.transforms.strict.*; import org.nd4j.linalg.api.ops.random.custom.DistributionUniform; import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; import org.nd4j.linalg.api.ops.random.custom.RandomExponential; @@ -1464,6 +1431,9 @@ public class DifferentialFunctionFactory { return new PowDerivative(sameDiff(), iX, false, pow).outputVariable(); } + public SDVariable mishDerivative(SDVariable iX) { + return new MishDerivative(sameDiff(), iX, false).outputVariable(); + } public SDVariable swish(SDVariable iX) { return new Swish(sameDiff(), iX, false).outputVariable(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java index 86ba0dfef..e166461e1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java @@ -28,7 +28,7 @@ import org.nd4j.linalg.activations.impl.*; public enum Activation { CUBE, ELU, HARDSIGMOID, HARDTANH, IDENTITY, LEAKYRELU, RATIONALTANH, RELU, RELU6, RRELU, SIGMOID, SOFTMAX, SOFTPLUS, SOFTSIGN, TANH, RECTIFIEDTANH, SELU, SWISH, - THRESHOLDEDRELU, GELU; + THRESHOLDEDRELU, GELU, MISH; /** * Creates an instance of the activation function @@ -77,6 +77,8 @@ public enum Activation { return new ActivationThresholdedReLU(); case GELU: return new ActivationGELU(); + case MISH: + return new ActivationMish(); default: throw new UnsupportedOperationException("Unknown or not supported activation function: " + this); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationMish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationMish.java new file mode 100644 index 000000000..b789a2925 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationMish.java @@ -0,0 +1,59 @@ +/* ***************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.activations.impl; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.val; +import org.nd4j.linalg.activations.BaseActivationFunction; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Mish; +import org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +/** + * https://arxiv.org/ftp/arxiv/papers/1908/1908.08681.pdf + */ +@EqualsAndHashCode(callSuper = false) +@Getter +public class ActivationMish extends BaseActivationFunction { + + @Override + public INDArray getActivation(INDArray in, boolean training) { + Nd4j.getExecutioner().execAndReturn(new Mish(in)); + return in; + } + + @Override + public Pair backprop(INDArray in, INDArray epsilon) { + assertShape(in, epsilon); + + val dLdZ = Nd4j.getExecutioner().exec(new MishDerivative(in, in)); + + dLdZ.muli(epsilon); + return new Pair<>(in, null); + } + + @Override + public String toString() { + return "mish"; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java new file mode 100644 index 000000000..416f74133 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java @@ -0,0 +1,77 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.strict; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.BaseTransformStrictOp; + +import java.util.Arrays; +import java.util.List; + +/** + * Mish activation function + * + * @author raver119@gmail.com + */ +public class Mish extends BaseTransformStrictOp { + public Mish(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); + } + + public Mish() { + } + + public Mish(INDArray x, INDArray z) { + super(x, z); + } + + public Mish(INDArray ndArray) { + super(ndArray); + } + + @Override + public int opNum() { + return 57; + } + + @Override + public String opName() { + return "mish"; + } + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + return "Mish"; + } + + + @Override + public List doDiff(List i_v) { + SDVariable ret = f().mishDerivative(arg()).mul(i_v.get(0)); + return Arrays.asList(ret); + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/MishDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/MishDerivative.java new file mode 100644 index 000000000..ddbbf1d1b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/MishDerivative.java @@ -0,0 +1,79 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.strict; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.BaseTransformStrictOp; + +import java.util.List; + +/** + * Mish derivative + * + * @author raver119@gmail.com + */ +public class MishDerivative extends BaseTransformStrictOp { + public MishDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { + super(sameDiff, i_v1, i_v2); + } + + public MishDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { + super(sameDiff, i_v1, i_v2, inPlace); + } + + public MishDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); + } + + public MishDerivative() {} + + public MishDerivative(INDArray x, INDArray z) { + super(x, z); + } + + public MishDerivative(INDArray x) { + super(x); + } + + @Override + public int opNum() { + return 58; + } + + @Override + public String opName() { + return "_mishderivative"; + } + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException(); + } +}