refactoring activations. (#8261)

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-10-03 19:35:27 +09:00 committed by Alex Black
parent 44a8d19ac6
commit 1f4ad08305
24 changed files with 71 additions and 126 deletions

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -32,26 +32,6 @@ public abstract class BaseActivationFunction implements IActivation {
return 0; return 0;
} }
@Override
public void setParametersViewArray(INDArray viewArray, boolean initialize) {
//No op
}
@Override
public INDArray getParametersViewArray() {
return null;
}
@Override
public void setGradientViewArray(INDArray viewArray) {
//No op
}
@Override
public INDArray getGradientViewArray() {
return null;
}
protected void assertShape(INDArray in, INDArray epsilon){ protected void assertShape(INDArray in, INDArray epsilon){
if(!in.equalShapes(epsilon)){ if(!in.equalShapes(epsilon)){
throw new IllegalStateException("Shapes must be equal during backprop: in.shape{} = " + Arrays.toString(in.shape()) throw new IllegalStateException("Shapes must be equal during backprop: in.shape{} = " + Arrays.toString(in.shape())

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -37,8 +37,8 @@ public interface IActivation extends Serializable {
* Carry out activation function on the input array (usually known as 'preOut' or 'z') * Carry out activation function on the input array (usually known as 'preOut' or 'z')
* Implementations must overwrite "in", transform in place and return "in" * Implementations must overwrite "in", transform in place and return "in"
* Can support separate behaviour during test * Can support separate behaviour during test
* @param in * @param in input array.
* @param training * @param training true when training.
* @return transformed activation * @return transformed activation
*/ */
INDArray getActivation(INDArray in, boolean training); INDArray getActivation(INDArray in, boolean training);
@ -59,12 +59,4 @@ public interface IActivation extends Serializable {
int numParams(int inputSize); int numParams(int inputSize);
void setParametersViewArray(INDArray viewArray, boolean initialize);
INDArray getParametersViewArray();
void setGradientViewArray(INDArray viewArray);
INDArray getGradientViewArray();
} }

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -21,17 +21,15 @@ import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube; import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
/** /**
* f(x) = x^3 * f(x) = x^3
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationCube extends BaseActivationFunction { public class ActivationCube extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -31,12 +31,12 @@ import org.nd4j.linalg.primitives.Pair;
* *
* alpha defaults to 1, if not specified * alpha defaults to 1, if not specified
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationELU extends BaseActivationFunction { public class ActivationELU extends BaseActivationFunction {
public static final double DEFAULT_ALPHA = 1.0; public static final double DEFAULT_ALPHA = 1.0;
private double alpha = DEFAULT_ALPHA; private double alpha;
public ActivationELU() { public ActivationELU() {
this(DEFAULT_ALPHA); this(DEFAULT_ALPHA);

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -32,7 +32,7 @@ import org.nd4j.linalg.primitives.Pair;
* *
* @see GELU * @see GELU
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationGELU extends BaseActivationFunction { public class ActivationGELU extends BaseActivationFunction {
@ -58,7 +58,7 @@ public class ActivationGELU extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = null; INDArray dLdz;
if (precise) if (precise)
dLdz = Nd4j.getExecutioner().exec(new PreciseGELUDerivative(in, in)); dLdz = Nd4j.getExecutioner().exec(new PreciseGELUDerivative(in, in));
else else

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid; import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/** /**
* f(x) = min(1, max(0, 0.2*x + 0.5)) * f(x) = min(1, max(0, 0.2*x + 0.5))
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationHardSigmoid extends BaseActivationFunction { public class ActivationHardSigmoid extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -18,21 +18,19 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh; import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/** /**
* 1, if x > 1 * 1, if x > 1
f(x) = -1, if x < -1 f(x) = -1, if x < -1
x, otherwise x, otherwise
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationHardTanH extends BaseActivationFunction { public class ActivationHardTanH extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
/** /**
* f(x) = x * f(x) = x
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationIdentity extends BaseActivationFunction { public class ActivationIdentity extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -18,26 +18,24 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/** /**
* Leaky RELU * Leaky RELU
* f(x) = max(0, x) + alpha * min(0, x) * f(x) = max(0, x) + alpha * min(0, x)
* alpha defaults to 0.01 * alpha defaults to 0.01
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationLReLU extends BaseActivationFunction { public class ActivationLReLU extends BaseActivationFunction {
public static final double DEFAULT_ALPHA = 0.01; public static final double DEFAULT_ALPHA = 0.01;
private double alpha = DEFAULT_ALPHA; private double alpha;
public ActivationLReLU() { public ActivationLReLU() {
this(DEFAULT_ALPHA); this(DEFAULT_ALPHA);

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -33,17 +33,13 @@ import org.nd4j.linalg.primitives.Pair;
* *
* @author Max Pumperla * @author Max Pumperla
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationPReLU extends BaseActivationFunction { public class ActivationPReLU extends BaseActivationFunction {
private INDArray alpha; private INDArray alpha;
private long[] sharedAxes = null; private long[] sharedAxes = null;
public ActivationPReLU(INDArray alpha) {
this.alpha = alpha;
}
public ActivationPReLU(INDArray alpha, long[] sharedAxes) { public ActivationPReLU(INDArray alpha, long[] sharedAxes) {
this.alpha = alpha; this.alpha = alpha;
this.sharedAxes = sharedAxes; this.sharedAxes = sharedAxes;

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -37,7 +37,7 @@ import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
* <a href="http://arxiv.org/abs/1505.00853"> * <a href="http://arxiv.org/abs/1505.00853">
* Empirical Evaluation of Rectified Activations in Convolutional Network</a> * Empirical Evaluation of Rectified Activations in Convolutional Network</a>
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@JsonIgnoreProperties({"alpha"}) @JsonIgnoreProperties({"alpha"})
@Getter @Getter
public class ActivationRReLU extends BaseActivationFunction { public class ActivationRReLU extends BaseActivationFunction {
@ -62,7 +62,7 @@ public class ActivationRReLU extends BaseActivationFunction {
@Override @Override
public INDArray getActivation(INDArray in, boolean training) { public INDArray getActivation(INDArray in, boolean training) {
if (training) { if (training) {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ignored = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape()); this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape());
} }
INDArray inTimesAlpha = in.mul(alpha); INDArray inTimesAlpha = in.mul(alpha);

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh; import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/** /**
* Rational tanh approximation * Rational tanh approximation
@ -37,7 +35,7 @@ import org.nd4j.linalg.factory.Nd4j;
* *
* Underlying implementation is in native code * Underlying implementation is in native code
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationRationalTanh extends BaseActivationFunction { public class ActivationRationalTanh extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -28,7 +28,7 @@ import org.nd4j.linalg.factory.Nd4j;
/** /**
* f(x) = max(0, x) * f(x) = max(0, x)
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationReLU extends BaseActivationFunction { public class ActivationReLU extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -20,9 +20,7 @@ import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.scalar.Relu6; import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
import org.nd4j.linalg.api.ops.impl.scalar.Step;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -30,7 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
/** /**
* f(x) = min(max(input, cutoff), 6) * f(x) = min(max(input, cutoff), 6)
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationReLU6 extends BaseActivationFunction { public class ActivationReLU6 extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh; import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/** /**
* Rectified tanh * Rectified tanh
@ -34,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
* *
* Underlying implementation is in native code * Underlying implementation is in native code
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationRectifiedTanh extends BaseActivationFunction { public class ActivationRectifiedTanh extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU; import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/** /**
* https://arxiv.org/pdf/1706.02515.pdf * https://arxiv.org/pdf/1706.02515.pdf
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationSELU extends BaseActivationFunction { public class ActivationSELU extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -18,18 +18,17 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid; import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/** /**
* f(x) = 1 / (1 + exp(-x)) * f(x) = 1 / (1 + exp(-x))
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationSigmoid extends BaseActivationFunction { public class ActivationSigmoid extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus; import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/** /**
* f(x) = log(1+e^x) * f(x) = log(1+e^x)
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationSoftPlus extends BaseActivationFunction { public class ActivationSoftPlus extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign; import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
/** /**
* f_i(x) = x_i / (1+|x_i|) * f_i(x) = x_i / (1+|x_i|)
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationSoftSign extends BaseActivationFunction { public class ActivationSoftSign extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -20,8 +20,6 @@ import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -31,13 +29,13 @@ import org.nd4j.linalg.primitives.Pair;
* f_i(x) = exp(x_i - shift) / sum_j exp(x_j - shift) * f_i(x) = exp(x_i - shift) / sum_j exp(x_j - shift)
* where shift = max_i(x_i) * where shift = max_i(x_i)
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationSoftmax extends BaseActivationFunction { public class ActivationSoftmax extends BaseActivationFunction {
@Override @Override
public INDArray getActivation(INDArray in, boolean training) { public INDArray getActivation(INDArray in, boolean training) {
Nd4j.getExecutioner().execAndReturn((CustomOp) new SoftMax(in, in)); Nd4j.getExecutioner().execAndReturn(new SoftMax(in, in));
return in; return in;
} }

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -28,7 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
/** /**
* f(x) = x * sigmoid(x) * f(x) = x * sigmoid(x)
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationSwish extends BaseActivationFunction { public class ActivationSwish extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -28,7 +28,7 @@ import org.nd4j.linalg.factory.Nd4j;
/** /**
* f(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) * f(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationTanH extends BaseActivationFunction { public class ActivationTanH extends BaseActivationFunction {

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -31,12 +31,12 @@ import org.nd4j.linalg.primitives.Pair;
* *
* @author Max Pumperla * @author Max Pumperla
*/ */
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
@Getter @Getter
public class ActivationThresholdedReLU extends BaseActivationFunction { public class ActivationThresholdedReLU extends BaseActivationFunction {
public static final double DEFAULT_THETA = 1.0; public static final double DEFAULT_THETA = 1.0;
private double theta = DEFAULT_THETA; private double theta;
public ActivationThresholdedReLU() { public ActivationThresholdedReLU() {
this(DEFAULT_THETA); this(DEFAULT_THETA);