[WIP] Mish (#55)

* Mish activation function and its derivative

Signed-off-by: raver119 <raver119@gmail.com>

* signature fix

Signed-off-by: raver119 <raver119@gmail.com>

* mish as activation for dl4j

Signed-off-by: raver119 <raver119@gmail.com>

* javadoc

Signed-off-by: raver119 <raver119@gmail.com>

* minor optimization

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-18 13:21:26 +03:00 committed by GitHub
parent c5cbdcd8f4
commit db7ca956c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 256 additions and 38 deletions

View File

@ -169,7 +169,9 @@
(53, GELU) ,\
(54, GELUDerivative), \
(55, PreciseGELU) ,\
(56, PreciseGELUDerivative)
(56, PreciseGELUDerivative), \
(57, Mish),\
(58, MishDerivative)
// these ops return one of FLOAT data types
#define TRANSFORM_FLOAT_OPS \

View File

@ -1718,6 +1718,32 @@ namespace simdOps {
}
};
template <typename X>
class Mish {
public:
no_op_exec_special_same
no_op_exec_special_same_cuda
op_def static X op(X d1, X *params) {
return d1 * nd4j::math::nd4j_tanh<X,X>(nd4j::math::nd4j_softplus<X,X>(d1));
}
};
template <typename X>
class MishDerivative {
public:
no_op_exec_special_same
no_op_exec_special_same_cuda
op_def static X op(X d1, X *params) {
auto ex = nd4j::math::nd4j_exp<X,X>(d1);
auto e2x = ex * ex;
auto e3x = ex * ex * ex;
return (ex * (4 * (d1 + 1) + 4 * e2x + e3x + ex *(4 * d1 + 6))) / nd4j::math::nd4j_pow<X, X, X>((2 * ex + e2x + 2), (X) 2.f);
}
};
template <typename X>
class GELU {
public:
@ -1954,7 +1980,7 @@ namespace simdOps {
no_op_exec_special_same_cuda
op_def static X op(X d1, X *params) {
return nd4j::math::softplus<X, X>(d1);
return nd4j::math::nd4j_softplus<X, X>(d1);
}
};

View File

@ -74,6 +74,9 @@ namespace nd4j {
template<typename T, typename Z>
math_def inline Z nd4j_copysign(T val1, T val2);
template <typename T, typename Z>
math_def inline Z nd4j_softplus(T val);
//#ifndef __CUDACC__
template<typename X, typename Y, typename Z>
math_def inline Z nd4j_dot(X *x, Y *y, int length);
@ -159,7 +162,7 @@ namespace nd4j {
math_def inline Z nd4j_sinh(T val);
template<typename T, typename Z>
math_def inline Z softplus(T val) {
math_def inline Z nd4j_softplus(T val) {
return nd4j_log<T, Z>((Z) 1.0f + nd4j_exp<T, Z>(val));
}

View File

@ -260,40 +260,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACos;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASin;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ATan;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Cos;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erf;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p;
import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sin;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SwishDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tan;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
@ -1464,6 +1431,9 @@ public class DifferentialFunctionFactory {
return new PowDerivative(sameDiff(), iX, false, pow).outputVariable();
}
public SDVariable mishDerivative(SDVariable iX) {
return new MishDerivative(sameDiff(), iX, false).outputVariable();
}
public SDVariable swish(SDVariable iX) {
return new Swish(sameDiff(), iX, false).outputVariable();

View File

@ -28,7 +28,7 @@ import org.nd4j.linalg.activations.impl.*;
public enum Activation {
CUBE, ELU, HARDSIGMOID, HARDTANH, IDENTITY, LEAKYRELU, RATIONALTANH, RELU, RELU6,
RRELU, SIGMOID, SOFTMAX, SOFTPLUS, SOFTSIGN, TANH, RECTIFIEDTANH, SELU, SWISH,
THRESHOLDEDRELU, GELU;
THRESHOLDEDRELU, GELU, MISH;
/**
* Creates an instance of the activation function
@ -77,6 +77,8 @@ public enum Activation {
return new ActivationThresholdedReLU();
case GELU:
return new ActivationGELU();
case MISH:
return new ActivationMish();
default:
throw new UnsupportedOperationException("Unknown or not supported activation function: " + this);
}

View File

@ -0,0 +1,59 @@
/* *****************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.val;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Mish;
import org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/**
* https://arxiv.org/ftp/arxiv/papers/1908/1908.08681.pdf
*/
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationMish extends BaseActivationFunction {
@Override
public INDArray getActivation(INDArray in, boolean training) {
Nd4j.getExecutioner().execAndReturn(new Mish(in));
return in;
}
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
val dLdZ = Nd4j.getExecutioner().exec(new MishDerivative(in, in));
dLdZ.muli(epsilon);
return new Pair<>(in, null);
}
@Override
public String toString() {
return "mish";
}
}

View File

@ -0,0 +1,77 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
import java.util.Arrays;
import java.util.List;
/**
* Mish activation function
*
* @author raver119@gmail.com
*/
public class Mish extends BaseTransformStrictOp {
public Mish(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public Mish() {
}
public Mish(INDArray x, INDArray z) {
super(x, z);
}
public Mish(INDArray ndArray) {
super(ndArray);
}
@Override
public int opNum() {
return 57;
}
@Override
public String opName() {
return "mish";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
return "Mish";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().mishDerivative(arg()).mul(i_v.get(0));
return Arrays.asList(ret);
}
}

View File

@ -0,0 +1,79 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
import java.util.List;
/**
* Mish derivative
*
* @author raver119@gmail.com
*/
public class MishDerivative extends BaseTransformStrictOp {
public MishDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public MishDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public MishDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public MishDerivative() {}
public MishDerivative(INDArray x, INDArray z) {
super(x, z);
}
public MishDerivative(INDArray x) {
super(x);
}
@Override
public int opNum() {
return 58;
}
@Override
public String opName() {
return "_mishderivative";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException();
}
}