parent
44a8d19ac6
commit
1f4ad08305
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -32,26 +32,6 @@ public abstract class BaseActivationFunction implements IActivation {
|
|||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParametersViewArray(INDArray viewArray, boolean initialize) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getParametersViewArray() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setGradientViewArray(INDArray viewArray) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getGradientViewArray() {
|
||||
return null;
|
||||
}
|
||||
|
||||
protected void assertShape(INDArray in, INDArray epsilon){
|
||||
if(!in.equalShapes(epsilon)){
|
||||
throw new IllegalStateException("Shapes must be equal during backprop: in.shape{} = " + Arrays.toString(in.shape())
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -37,8 +37,8 @@ public interface IActivation extends Serializable {
|
|||
* Carry out activation function on the input array (usually known as 'preOut' or 'z')
|
||||
* Implementations must overwrite "in", transform in place and return "in"
|
||||
* Can support separate behaviour during test
|
||||
* @param in
|
||||
* @param training
|
||||
* @param in input array.
|
||||
* @param training true when training.
|
||||
* @return transformed activation
|
||||
*/
|
||||
INDArray getActivation(INDArray in, boolean training);
|
||||
|
@ -59,12 +59,4 @@ public interface IActivation extends Serializable {
|
|||
|
||||
int numParams(int inputSize);
|
||||
|
||||
void setParametersViewArray(INDArray viewArray, boolean initialize);
|
||||
|
||||
INDArray getParametersViewArray();
|
||||
|
||||
void setGradientViewArray(INDArray viewArray);
|
||||
|
||||
INDArray getGradientViewArray();
|
||||
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -21,17 +21,15 @@ import lombok.Getter;
|
|||
import lombok.NonNull;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* f(x) = x^3
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationCube extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -31,12 +31,12 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
*
|
||||
* alpha defaults to 1, if not specified
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationELU extends BaseActivationFunction {
|
||||
public static final double DEFAULT_ALPHA = 1.0;
|
||||
|
||||
private double alpha = DEFAULT_ALPHA;
|
||||
private double alpha;
|
||||
|
||||
public ActivationELU() {
|
||||
this(DEFAULT_ALPHA);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -32,7 +32,7 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
*
|
||||
* @see GELU
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationGELU extends BaseActivationFunction {
|
||||
|
||||
|
@ -58,7 +58,7 @@ public class ActivationGELU extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = null;
|
||||
INDArray dLdz;
|
||||
if (precise)
|
||||
dLdz = Nd4j.getExecutioner().exec(new PreciseGELUDerivative(in, in));
|
||||
else
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* f(x) = min(1, max(0, 0.2*x + 0.5))
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationHardSigmoid extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -18,21 +18,19 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* ⎧ 1, if x > 1
|
||||
f(x) = ⎨ -1, if x < -1
|
||||
⎩ x, otherwise
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationHardTanH extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
/**
|
||||
* f(x) = x
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationIdentity extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -18,26 +18,24 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* Leaky RELU
|
||||
* f(x) = max(0, x) + alpha * min(0, x)
|
||||
* alpha defaults to 0.01
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationLReLU extends BaseActivationFunction {
|
||||
public static final double DEFAULT_ALPHA = 0.01;
|
||||
|
||||
private double alpha = DEFAULT_ALPHA;
|
||||
private double alpha;
|
||||
|
||||
public ActivationLReLU() {
|
||||
this(DEFAULT_ALPHA);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -33,17 +33,13 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
*
|
||||
* @author Max Pumperla
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationPReLU extends BaseActivationFunction {
|
||||
|
||||
private INDArray alpha;
|
||||
private long[] sharedAxes = null;
|
||||
|
||||
public ActivationPReLU(INDArray alpha) {
|
||||
this.alpha = alpha;
|
||||
}
|
||||
|
||||
public ActivationPReLU(INDArray alpha, long[] sharedAxes) {
|
||||
this.alpha = alpha;
|
||||
this.sharedAxes = sharedAxes;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -37,7 +37,7 @@ import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
|||
* <a href="http://arxiv.org/abs/1505.00853">
|
||||
* Empirical Evaluation of Rectified Activations in Convolutional Network</a>
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@JsonIgnoreProperties({"alpha"})
|
||||
@Getter
|
||||
public class ActivationRReLU extends BaseActivationFunction {
|
||||
|
@ -62,7 +62,7 @@ public class ActivationRReLU extends BaseActivationFunction {
|
|||
@Override
|
||||
public INDArray getActivation(INDArray in, boolean training) {
|
||||
if (training) {
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
try(MemoryWorkspace ignored = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape());
|
||||
}
|
||||
INDArray inTimesAlpha = in.mul(alpha);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* Rational tanh approximation
|
||||
|
@ -37,7 +35,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
*
|
||||
* Underlying implementation is in native code
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationRationalTanh extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -28,7 +28,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
/**
|
||||
* f(x) = max(0, x)
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationReLU extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -20,9 +20,7 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.Getter;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
@ -30,7 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
/**
|
||||
* f(x) = min(max(input, cutoff), 6)
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationReLU6 extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* Rectified tanh
|
||||
|
@ -34,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
*
|
||||
* Underlying implementation is in native code
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationRectifiedTanh extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* https://arxiv.org/pdf/1706.02515.pdf
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationSELU extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -18,18 +18,17 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* f(x) = 1 / (1 + exp(-x))
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationSigmoid extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* f(x) = log(1+e^x)
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationSoftPlus extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* f_i(x) = x_i / (1+|x_i|)
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationSoftSign extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -20,8 +20,6 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.Getter;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -31,13 +29,13 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
* f_i(x) = exp(x_i - shift) / sum_j exp(x_j - shift)
|
||||
* where shift = max_i(x_i)
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationSoftmax extends BaseActivationFunction {
|
||||
|
||||
@Override
|
||||
public INDArray getActivation(INDArray in, boolean training) {
|
||||
Nd4j.getExecutioner().execAndReturn((CustomOp) new SoftMax(in, in));
|
||||
Nd4j.getExecutioner().execAndReturn(new SoftMax(in, in));
|
||||
return in;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -28,7 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
/**
|
||||
* f(x) = x * sigmoid(x)
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationSwish extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -28,7 +28,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
/**
|
||||
* f(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationTanH extends BaseActivationFunction {
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -31,12 +31,12 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
*
|
||||
* @author Max Pumperla
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Getter
|
||||
public class ActivationThresholdedReLU extends BaseActivationFunction {
|
||||
|
||||
public static final double DEFAULT_THETA = 1.0;
|
||||
private double theta = DEFAULT_THETA;
|
||||
private double theta;
|
||||
|
||||
public ActivationThresholdedReLU() {
|
||||
this(DEFAULT_THETA);
|
||||
|
|
Loading…
Reference in New Issue