parent
44a8d19ac6
commit
1f4ad08305
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -32,26 +32,6 @@ public abstract class BaseActivationFunction implements IActivation {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setParametersViewArray(INDArray viewArray, boolean initialize) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getParametersViewArray() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setGradientViewArray(INDArray viewArray) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getGradientViewArray() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void assertShape(INDArray in, INDArray epsilon){
|
protected void assertShape(INDArray in, INDArray epsilon){
|
||||||
if(!in.equalShapes(epsilon)){
|
if(!in.equalShapes(epsilon)){
|
||||||
throw new IllegalStateException("Shapes must be equal during backprop: in.shape{} = " + Arrays.toString(in.shape())
|
throw new IllegalStateException("Shapes must be equal during backprop: in.shape{} = " + Arrays.toString(in.shape())
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -37,8 +37,8 @@ public interface IActivation extends Serializable {
|
||||||
* Carry out activation function on the input array (usually known as 'preOut' or 'z')
|
* Carry out activation function on the input array (usually known as 'preOut' or 'z')
|
||||||
* Implementations must overwrite "in", transform in place and return "in"
|
* Implementations must overwrite "in", transform in place and return "in"
|
||||||
* Can support separate behaviour during test
|
* Can support separate behaviour during test
|
||||||
* @param in
|
* @param in input array.
|
||||||
* @param training
|
* @param training true when training.
|
||||||
* @return transformed activation
|
* @return transformed activation
|
||||||
*/
|
*/
|
||||||
INDArray getActivation(INDArray in, boolean training);
|
INDArray getActivation(INDArray in, boolean training);
|
||||||
|
@ -59,12 +59,4 @@ public interface IActivation extends Serializable {
|
||||||
|
|
||||||
int numParams(int inputSize);
|
int numParams(int inputSize);
|
||||||
|
|
||||||
void setParametersViewArray(INDArray viewArray, boolean initialize);
|
|
||||||
|
|
||||||
INDArray getParametersViewArray();
|
|
||||||
|
|
||||||
void setGradientViewArray(INDArray viewArray);
|
|
||||||
|
|
||||||
INDArray getGradientViewArray();
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -21,17 +21,15 @@ import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* f(x) = x^3
|
* f(x) = x^3
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationCube extends BaseActivationFunction {
|
public class ActivationCube extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -31,12 +31,12 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
*
|
*
|
||||||
* alpha defaults to 1, if not specified
|
* alpha defaults to 1, if not specified
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationELU extends BaseActivationFunction {
|
public class ActivationELU extends BaseActivationFunction {
|
||||||
public static final double DEFAULT_ALPHA = 1.0;
|
public static final double DEFAULT_ALPHA = 1.0;
|
||||||
|
|
||||||
private double alpha = DEFAULT_ALPHA;
|
private double alpha;
|
||||||
|
|
||||||
public ActivationELU() {
|
public ActivationELU() {
|
||||||
this(DEFAULT_ALPHA);
|
this(DEFAULT_ALPHA);
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -32,7 +32,7 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
*
|
*
|
||||||
* @see GELU
|
* @see GELU
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationGELU extends BaseActivationFunction {
|
public class ActivationGELU extends BaseActivationFunction {
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ public class ActivationGELU extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = null;
|
INDArray dLdz;
|
||||||
if (precise)
|
if (precise)
|
||||||
dLdz = Nd4j.getExecutioner().exec(new PreciseGELUDerivative(in, in));
|
dLdz = Nd4j.getExecutioner().exec(new PreciseGELUDerivative(in, in));
|
||||||
else
|
else
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* f(x) = min(1, max(0, 0.2*x + 0.5))
|
* f(x) = min(1, max(0, 0.2*x + 0.5))
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationHardSigmoid extends BaseActivationFunction {
|
public class ActivationHardSigmoid extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -18,21 +18,19 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ⎧ 1, if x > 1
|
* ⎧ 1, if x > 1
|
||||||
f(x) = ⎨ -1, if x < -1
|
f(x) = ⎨ -1, if x < -1
|
||||||
⎩ x, otherwise
|
⎩ x, otherwise
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationHardTanH extends BaseActivationFunction {
|
public class ActivationHardTanH extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
/**
|
/**
|
||||||
* f(x) = x
|
* f(x) = x
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationIdentity extends BaseActivationFunction {
|
public class ActivationIdentity extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -18,26 +18,24 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
|
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Leaky RELU
|
* Leaky RELU
|
||||||
* f(x) = max(0, x) + alpha * min(0, x)
|
* f(x) = max(0, x) + alpha * min(0, x)
|
||||||
* alpha defaults to 0.01
|
* alpha defaults to 0.01
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationLReLU extends BaseActivationFunction {
|
public class ActivationLReLU extends BaseActivationFunction {
|
||||||
public static final double DEFAULT_ALPHA = 0.01;
|
public static final double DEFAULT_ALPHA = 0.01;
|
||||||
|
|
||||||
private double alpha = DEFAULT_ALPHA;
|
private double alpha;
|
||||||
|
|
||||||
public ActivationLReLU() {
|
public ActivationLReLU() {
|
||||||
this(DEFAULT_ALPHA);
|
this(DEFAULT_ALPHA);
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -33,17 +33,13 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
*
|
*
|
||||||
* @author Max Pumperla
|
* @author Max Pumperla
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationPReLU extends BaseActivationFunction {
|
public class ActivationPReLU extends BaseActivationFunction {
|
||||||
|
|
||||||
private INDArray alpha;
|
private INDArray alpha;
|
||||||
private long[] sharedAxes = null;
|
private long[] sharedAxes = null;
|
||||||
|
|
||||||
public ActivationPReLU(INDArray alpha) {
|
|
||||||
this.alpha = alpha;
|
|
||||||
}
|
|
||||||
|
|
||||||
public ActivationPReLU(INDArray alpha, long[] sharedAxes) {
|
public ActivationPReLU(INDArray alpha, long[] sharedAxes) {
|
||||||
this.alpha = alpha;
|
this.alpha = alpha;
|
||||||
this.sharedAxes = sharedAxes;
|
this.sharedAxes = sharedAxes;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -37,7 +37,7 @@ import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||||
* <a href="http://arxiv.org/abs/1505.00853">
|
* <a href="http://arxiv.org/abs/1505.00853">
|
||||||
* Empirical Evaluation of Rectified Activations in Convolutional Network</a>
|
* Empirical Evaluation of Rectified Activations in Convolutional Network</a>
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@JsonIgnoreProperties({"alpha"})
|
@JsonIgnoreProperties({"alpha"})
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationRReLU extends BaseActivationFunction {
|
public class ActivationRReLU extends BaseActivationFunction {
|
||||||
|
@ -62,7 +62,7 @@ public class ActivationRReLU extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public INDArray getActivation(INDArray in, boolean training) {
|
public INDArray getActivation(INDArray in, boolean training) {
|
||||||
if (training) {
|
if (training) {
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try(MemoryWorkspace ignored = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape());
|
this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape());
|
||||||
}
|
}
|
||||||
INDArray inTimesAlpha = in.mul(alpha);
|
INDArray inTimesAlpha = in.mul(alpha);
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Rational tanh approximation
|
* Rational tanh approximation
|
||||||
|
@ -37,7 +35,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
*
|
*
|
||||||
* Underlying implementation is in native code
|
* Underlying implementation is in native code
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationRationalTanh extends BaseActivationFunction {
|
public class ActivationRationalTanh extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -28,7 +28,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
/**
|
/**
|
||||||
* f(x) = max(0, x)
|
* f(x) = max(0, x)
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationReLU extends BaseActivationFunction {
|
public class ActivationReLU extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -20,9 +20,7 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
|
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
@ -30,7 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
/**
|
/**
|
||||||
* f(x) = min(max(input, cutoff), 6)
|
* f(x) = min(max(input, cutoff), 6)
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationReLU6 extends BaseActivationFunction {
|
public class ActivationReLU6 extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Rectified tanh
|
* Rectified tanh
|
||||||
|
@ -34,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
*
|
*
|
||||||
* Underlying implementation is in native code
|
* Underlying implementation is in native code
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationRectifiedTanh extends BaseActivationFunction {
|
public class ActivationRectifiedTanh extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* https://arxiv.org/pdf/1706.02515.pdf
|
* https://arxiv.org/pdf/1706.02515.pdf
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationSELU extends BaseActivationFunction {
|
public class ActivationSELU extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -18,18 +18,17 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* f(x) = 1 / (1 + exp(-x))
|
* f(x) = 1 / (1 + exp(-x))
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationSigmoid extends BaseActivationFunction {
|
public class ActivationSigmoid extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* f(x) = log(1+e^x)
|
* f(x) = log(1+e^x)
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationSoftPlus extends BaseActivationFunction {
|
public class ActivationSoftPlus extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -18,19 +18,17 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* f_i(x) = x_i / (1+|x_i|)
|
* f_i(x) = x_i / (1+|x_i|)
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationSoftSign extends BaseActivationFunction {
|
public class ActivationSoftSign extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -20,8 +20,6 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -31,13 +29,13 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
* f_i(x) = exp(x_i - shift) / sum_j exp(x_j - shift)
|
* f_i(x) = exp(x_i - shift) / sum_j exp(x_j - shift)
|
||||||
* where shift = max_i(x_i)
|
* where shift = max_i(x_i)
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationSoftmax extends BaseActivationFunction {
|
public class ActivationSoftmax extends BaseActivationFunction {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray getActivation(INDArray in, boolean training) {
|
public INDArray getActivation(INDArray in, boolean training) {
|
||||||
Nd4j.getExecutioner().execAndReturn((CustomOp) new SoftMax(in, in));
|
Nd4j.getExecutioner().execAndReturn(new SoftMax(in, in));
|
||||||
return in;
|
return in;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -28,7 +28,7 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
/**
|
/**
|
||||||
* f(x) = x * sigmoid(x)
|
* f(x) = x * sigmoid(x)
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationSwish extends BaseActivationFunction {
|
public class ActivationSwish extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -28,7 +28,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
/**
|
/**
|
||||||
* f(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
* f(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationTanH extends BaseActivationFunction {
|
public class ActivationTanH extends BaseActivationFunction {
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -31,12 +31,12 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
*
|
*
|
||||||
* @author Max Pumperla
|
* @author Max Pumperla
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Getter
|
@Getter
|
||||||
public class ActivationThresholdedReLU extends BaseActivationFunction {
|
public class ActivationThresholdedReLU extends BaseActivationFunction {
|
||||||
|
|
||||||
public static final double DEFAULT_THETA = 1.0;
|
public static final double DEFAULT_THETA = 1.0;
|
||||||
private double theta = DEFAULT_THETA;
|
private double theta;
|
||||||
|
|
||||||
public ActivationThresholdedReLU() {
|
public ActivationThresholdedReLU() {
|
||||||
this(DEFAULT_THETA);
|
this(DEFAULT_THETA);
|
||||||
|
|
Loading…
Reference in New Issue