[WIP] Mish (#55)
* Mish activation function and its derivative Signed-off-by: raver119 <raver119@gmail.com> * signature fix Signed-off-by: raver119 <raver119@gmail.com> * mish as activation for dl4j Signed-off-by: raver119 <raver119@gmail.com> * javadoc Signed-off-by: raver119 <raver119@gmail.com> * minor optimization Signed-off-by: raver119 <raver119@gmail.com>master
parent
c5cbdcd8f4
commit
db7ca956c5
|
@ -169,7 +169,9 @@
|
|||
(53, GELU) ,\
|
||||
(54, GELUDerivative), \
|
||||
(55, PreciseGELU) ,\
|
||||
(56, PreciseGELUDerivative)
|
||||
(56, PreciseGELUDerivative), \
|
||||
(57, Mish),\
|
||||
(58, MishDerivative)
|
||||
|
||||
// these ops return one of FLOAT data types
|
||||
#define TRANSFORM_FLOAT_OPS \
|
||||
|
|
|
@ -1718,6 +1718,32 @@ namespace simdOps {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class Mish {
|
||||
public:
|
||||
no_op_exec_special_same
|
||||
no_op_exec_special_same_cuda
|
||||
|
||||
op_def static X op(X d1, X *params) {
|
||||
return d1 * nd4j::math::nd4j_tanh<X,X>(nd4j::math::nd4j_softplus<X,X>(d1));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class MishDerivative {
|
||||
public:
|
||||
no_op_exec_special_same
|
||||
no_op_exec_special_same_cuda
|
||||
|
||||
op_def static X op(X d1, X *params) {
|
||||
auto ex = nd4j::math::nd4j_exp<X,X>(d1);
|
||||
auto e2x = ex * ex;
|
||||
auto e3x = ex * ex * ex;
|
||||
|
||||
return (ex * (4 * (d1 + 1) + 4 * e2x + e3x + ex *(4 * d1 + 6))) / nd4j::math::nd4j_pow<X, X, X>((2 * ex + e2x + 2), (X) 2.f);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class GELU {
|
||||
public:
|
||||
|
@ -1954,7 +1980,7 @@ namespace simdOps {
|
|||
no_op_exec_special_same_cuda
|
||||
|
||||
op_def static X op(X d1, X *params) {
|
||||
return nd4j::math::softplus<X, X>(d1);
|
||||
return nd4j::math::nd4j_softplus<X, X>(d1);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -74,6 +74,9 @@ namespace nd4j {
|
|||
template<typename T, typename Z>
|
||||
math_def inline Z nd4j_copysign(T val1, T val2);
|
||||
|
||||
template <typename T, typename Z>
|
||||
math_def inline Z nd4j_softplus(T val);
|
||||
|
||||
//#ifndef __CUDACC__
|
||||
template<typename X, typename Y, typename Z>
|
||||
math_def inline Z nd4j_dot(X *x, Y *y, int length);
|
||||
|
@ -159,7 +162,7 @@ namespace nd4j {
|
|||
math_def inline Z nd4j_sinh(T val);
|
||||
|
||||
template<typename T, typename Z>
|
||||
math_def inline Z softplus(T val) {
|
||||
math_def inline Z nd4j_softplus(T val) {
|
||||
return nd4j_log<T, Z>((Z) 1.0f + nd4j_exp<T, Z>(val));
|
||||
}
|
||||
|
||||
|
|
|
@ -260,40 +260,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp;
|
|||
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACos;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASin;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ATan;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Cos;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erf;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sin;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SwishDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tan;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
|
||||
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
|
||||
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
|
||||
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
|
||||
|
@ -1464,6 +1431,9 @@ public class DifferentialFunctionFactory {
|
|||
return new PowDerivative(sameDiff(), iX, false, pow).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable mishDerivative(SDVariable iX) {
|
||||
return new MishDerivative(sameDiff(), iX, false).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable swish(SDVariable iX) {
|
||||
return new Swish(sameDiff(), iX, false).outputVariable();
|
||||
|
|
|
@ -28,7 +28,7 @@ import org.nd4j.linalg.activations.impl.*;
|
|||
public enum Activation {
|
||||
CUBE, ELU, HARDSIGMOID, HARDTANH, IDENTITY, LEAKYRELU, RATIONALTANH, RELU, RELU6,
|
||||
RRELU, SIGMOID, SOFTMAX, SOFTPLUS, SOFTSIGN, TANH, RECTIFIEDTANH, SELU, SWISH,
|
||||
THRESHOLDEDRELU, GELU;
|
||||
THRESHOLDEDRELU, GELU, MISH;
|
||||
|
||||
/**
|
||||
* Creates an instance of the activation function
|
||||
|
@ -77,6 +77,8 @@ public enum Activation {
|
|||
return new ActivationThresholdedReLU();
|
||||
case GELU:
|
||||
return new ActivationGELU();
|
||||
case MISH:
|
||||
return new ActivationMish();
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unknown or not supported activation function: " + this);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/* *****************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.activations.impl;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.val;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Mish;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* https://arxiv.org/ftp/arxiv/papers/1908/1908.08681.pdf
|
||||
*/
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationMish extends BaseActivationFunction {
|
||||
|
||||
@Override
|
||||
public INDArray getActivation(INDArray in, boolean training) {
|
||||
Nd4j.getExecutioner().execAndReturn(new Mish(in));
|
||||
return in;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
|
||||
val dLdZ = Nd4j.getExecutioner().exec(new MishDerivative(in, in));
|
||||
|
||||
dLdZ.muli(epsilon);
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "mish";
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Mish activation function
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class Mish extends BaseTransformStrictOp {
|
||||
public Mish(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||
super(sameDiff, i_v, inPlace);
|
||||
}
|
||||
|
||||
public Mish() {
|
||||
}
|
||||
|
||||
public Mish(INDArray x, INDArray z) {
|
||||
super(x, z);
|
||||
}
|
||||
|
||||
public Mish(INDArray ndArray) {
|
||||
super(ndArray);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 57;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "mish";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "Mish";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
SDVariable ret = f().mishDerivative(arg()).mul(i_v.get(0));
|
||||
return Arrays.asList(ret);
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Mish derivative
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class MishDerivative extends BaseTransformStrictOp {
|
||||
public MishDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||
super(sameDiff, i_v1, i_v2);
|
||||
}
|
||||
|
||||
public MishDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
||||
super(sameDiff, i_v1, i_v2, inPlace);
|
||||
}
|
||||
|
||||
public MishDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||
super(sameDiff, i_v, inPlace);
|
||||
}
|
||||
|
||||
public MishDerivative() {}
|
||||
|
||||
public MishDerivative(INDArray x, INDArray z) {
|
||||
super(x, z);
|
||||
}
|
||||
|
||||
public MishDerivative(INDArray x) {
|
||||
super(x);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 58;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "_mishderivative";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue