refactoring activations. (#8261)

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-10-03 19:35:27 +09:00 committed by Alex Black
parent 44a8d19ac6
commit 1f4ad08305
24 changed files with 71 additions and 126 deletions

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -32,26 +32,6 @@ public abstract class BaseActivationFunction implements IActivation {
return 0;
}
@Override
public void setParametersViewArray(INDArray viewArray, boolean initialize) {
//No op
}
@Override
public INDArray getParametersViewArray() {
return null;
}
@Override
public void setGradientViewArray(INDArray viewArray) {
//No op
}
@Override
public INDArray getGradientViewArray() {
return null;
}
protected void assertShape(INDArray in, INDArray epsilon){
if(!in.equalShapes(epsilon)){
throw new IllegalStateException("Shapes must be equal during backprop: in.shape{} = " + Arrays.toString(in.shape())

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -37,8 +37,8 @@ public interface IActivation extends Serializable {
* Carry out activation function on the input array (usually known as 'preOut' or 'z')
* Implementations must overwrite "in", transform in place and return "in"
* Can support separate behaviour during test
* @param in
* @param training
* @param in input array.
* @param training true when training.
* @return transformed activation
*/
INDArray getActivation(INDArray in, boolean training);
@ -59,12 +59,4 @@ public interface IActivation extends Serializable {
int numParams(int inputSize);
void setParametersViewArray(INDArray viewArray, boolean initialize);
INDArray getParametersViewArray();
void setGradientViewArray(INDArray viewArray);
INDArray getGradientViewArray();
}

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -21,17 +21,15 @@ import lombok.Getter;
import lombok.NonNull;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/**
* f(x) = x^3
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationCube extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -31,12 +31,12 @@ import org.nd4j.linalg.primitives.Pair;
*
* alpha defaults to 1, if not specified
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationELU extends BaseActivationFunction {
public static final double DEFAULT_ALPHA = 1.0;
private double alpha = DEFAULT_ALPHA;
private double alpha;
public ActivationELU() {
this(DEFAULT_ALPHA);

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -32,7 +32,7 @@ import org.nd4j.linalg.primitives.Pair;
*
* @see GELU
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationGELU extends BaseActivationFunction {
@ -58,7 +58,7 @@ public class ActivationGELU extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = null;
INDArray dLdz;
if (precise)
dLdz = Nd4j.getExecutioner().exec(new PreciseGELUDerivative(in, in));
else

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/**
* f(x) = min(1, max(0, 0.2*x + 0.5))
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationHardSigmoid extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -18,21 +18,19 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/**
* 1, if x > 1
f(x) = -1, if x < -1
x, otherwise
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationHardTanH extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
/**
* f(x) = x
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationIdentity extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -18,26 +18,24 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/**
* Leaky RELU
* f(x) = max(0, x) + alpha * min(0, x)
* alpha defaults to 0.01
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationLReLU extends BaseActivationFunction {
public static final double DEFAULT_ALPHA = 0.01;
private double alpha = DEFAULT_ALPHA;
private double alpha;
public ActivationLReLU() {
this(DEFAULT_ALPHA);

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -33,17 +33,13 @@ import org.nd4j.linalg.primitives.Pair;
*
* @author Max Pumperla
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationPReLU extends BaseActivationFunction {
private INDArray alpha;
private long[] sharedAxes = null;
public ActivationPReLU(INDArray alpha) {
this.alpha = alpha;
}
public ActivationPReLU(INDArray alpha, long[] sharedAxes) {
this.alpha = alpha;
this.sharedAxes = sharedAxes;

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -37,7 +37,7 @@ import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
* <a href="http://arxiv.org/abs/1505.00853">
* Empirical Evaluation of Rectified Activations in Convolutional Network</a>
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@JsonIgnoreProperties({"alpha"})
@Getter
public class ActivationRReLU extends BaseActivationFunction {
@ -62,7 +62,7 @@ public class ActivationRReLU extends BaseActivationFunction {
@Override
public INDArray getActivation(INDArray in, boolean training) {
if (training) {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
try(MemoryWorkspace ignored = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape());
}
INDArray inTimesAlpha = in.mul(alpha);

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/**
* Rational tanh approximation
@ -37,7 +35,7 @@ import org.nd4j.linalg.factory.Nd4j;
*
* Underlying implementation is in native code
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationRationalTanh extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -28,7 +28,7 @@ import org.nd4j.linalg.factory.Nd4j;
/**
* f(x) = max(0, x)
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationReLU extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -20,9 +20,7 @@ import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
import org.nd4j.linalg.api.ops.impl.scalar.Step;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
@ -30,7 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
/**
* f(x) = min(max(input, cutoff), 6)
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationReLU6 extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/**
* Rectified tanh
@ -34,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
*
* Underlying implementation is in native code
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationRectifiedTanh extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/**
* https://arxiv.org/pdf/1706.02515.pdf
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationSELU extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -18,18 +18,17 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/**
* f(x) = 1 / (1 + exp(-x))
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationSigmoid extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/**
* f(x) = log(1+e^x)
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationSoftPlus extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/**
* f_i(x) = x_i / (1+|x_i|)
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationSoftSign extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -20,8 +20,6 @@ import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.factory.Nd4j;
@ -31,13 +29,13 @@ import org.nd4j.linalg.primitives.Pair;
* f_i(x) = exp(x_i - shift) / sum_j exp(x_j - shift)
* where shift = max_i(x_i)
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationSoftmax extends BaseActivationFunction {
@Override
public INDArray getActivation(INDArray in, boolean training) {
Nd4j.getExecutioner().execAndReturn((CustomOp) new SoftMax(in, in));
Nd4j.getExecutioner().execAndReturn(new SoftMax(in, in));
return in;
}

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -28,7 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
/**
* f(x) = x * sigmoid(x)
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationSwish extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -28,7 +28,7 @@ import org.nd4j.linalg.factory.Nd4j;
/**
* f(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationTanH extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -31,12 +31,12 @@ import org.nd4j.linalg.primitives.Pair;
*
* @author Max Pumperla
*/
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
@Getter
public class ActivationThresholdedReLU extends BaseActivationFunction {
public static final double DEFAULT_THETA = 1.0;
private double theta = DEFAULT_THETA;
private double theta;
public ActivationThresholdedReLU() {
this(DEFAULT_THETA);