New Nd4j backprop ops for activations (#211)
* new (for java at least) backprop ops Signed-off-by: Ryan Nett <rnett@skymind.io> * update activation functions Signed-off-by: Ryan Nett <rnett@skymind.io> * add differential functions for SameDiff Signed-off-by: Ryan Nett <rnett@skymind.io> * deprecate old ops Signed-off-by: Ryan Nett <rnett@skymind.io> * update correct old ops Signed-off-by: Ryan Nett <rnett@skymind.io> * update ops backprop to use new ops Signed-off-by: Ryan Nett <rnett@skymind.io> * misc updates for deprecated functions (mostly Nd4j.rand w/ vararg shape) Signed-off-by: Ryan Nett <rnett@skymind.io> * remove old imports Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
6d04d30c94
commit
b3a134b608
|
@ -204,20 +204,30 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum;
|
|||
import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
|
||||
|
@ -1126,10 +1136,26 @@ public class DifferentialFunctionFactory {
|
|||
return new org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative(sameDiff(), iX, wrt).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable tanhRationalBp(SDVariable in, SDVariable epsilon) {
|
||||
return new RationalTanhBp(sameDiff(), in, epsilon).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable tanhRectifiedBp(SDVariable in, SDVariable epsilon) {
|
||||
return new RectifiedTanhBp(sameDiff(), in, epsilon).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Use {@link #tanhRationalBp(SDVariable, SDVariable)}
|
||||
*/
|
||||
@Deprecated
|
||||
public SDVariable tanhRationalDerivative(SDVariable in) {
|
||||
return new RationalTanhDerivative(sameDiff(), in, false).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Use {@link #tanhRectifiedBp(SDVariable, SDVariable)}
|
||||
*/
|
||||
@Deprecated
|
||||
public SDVariable tanhRectifiedDerivative(SDVariable in) {
|
||||
return new RectifiedTanhDerivative(sameDiff(), in, false).outputVariable();
|
||||
}
|
||||
|
@ -1280,6 +1306,14 @@ public class DifferentialFunctionFactory {
|
|||
return new Cube(sameDiff(), iX, false).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable cubeBp(SDVariable in, SDVariable epsilon) {
|
||||
return new CubeBp(sameDiff(), in, epsilon).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #cubeBp(SDVariable, SDVariable)}
|
||||
*/
|
||||
@Deprecated
|
||||
public SDVariable cubeDerivative(SDVariable iX) {
|
||||
return new CubeDerivative(sameDiff(), iX, false).outputVariable();
|
||||
}
|
||||
|
@ -1329,6 +1363,14 @@ public class DifferentialFunctionFactory {
|
|||
return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable thresholdRelu(SDVariable in, SDVariable epsilon, double cutoff){
|
||||
return new ThresholdRelu(sameDiff(), in, cutoff).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable thresholdReluBp(SDVariable in, SDVariable epsilon, double cutoff){
|
||||
return new ThresholdReluBp(sameDiff(), in, epsilon, cutoff).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable relu6(SDVariable iX, double cutoff) {
|
||||
return new Relu6(sameDiff(), iX, false, cutoff).outputVariable();
|
||||
}
|
||||
|
@ -1350,6 +1392,14 @@ public class DifferentialFunctionFactory {
|
|||
return new HardTanh(sameDiff(), iX, false).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable hardTanhBp(SDVariable in, SDVariable epsilon) {
|
||||
return new HardTanhBp(sameDiff(), in, epsilon).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #hardTanhBp(SDVariable, SDVariable)}
|
||||
*/
|
||||
@Deprecated
|
||||
public SDVariable hardTanhDerivative(SDVariable iX) {
|
||||
return new HardTanhDerivative(sameDiff(), iX, false).outputVariable();
|
||||
}
|
||||
|
@ -1358,6 +1408,9 @@ public class DifferentialFunctionFactory {
|
|||
return new HardSigmoid(sameDiff(), in, false).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable hardSigmoidBp(SDVariable in, SDVariable epsilon){
|
||||
return new HardSigmoidBp(sameDiff(), in, epsilon).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable sigmoid(SDVariable iX) {
|
||||
return new Sigmoid(sameDiff(), iX, false).outputVariable();
|
||||
|
@ -1486,10 +1539,16 @@ public class DifferentialFunctionFactory {
|
|||
|
||||
}
|
||||
|
||||
public SDVariable softsignBp(SDVariable in, SDVariable epsilon) {
|
||||
return new SoftSignBp(sameDiff(), in, epsilon).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #softsignBp(SDVariable, SDVariable)}
|
||||
*/
|
||||
@Deprecated
|
||||
public SDVariable softsignDerivative(SDVariable iX) {
|
||||
return new SoftSignDerivative(sameDiff(), iX, false).outputVariable();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -1504,10 +1563,16 @@ public class DifferentialFunctionFactory {
|
|||
|
||||
}
|
||||
|
||||
public SDVariable eluBp(SDVariable in, SDVariable epsilon) {
|
||||
return new EluBp(sameDiff(), in, epsilon).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #eluBp(SDVariable, SDVariable)}
|
||||
*/
|
||||
@Deprecated
|
||||
public SDVariable eluDerivative(SDVariable iX) {
|
||||
return new ELUDerivative(sameDiff(), iX, false).outputVariable();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -1516,6 +1581,14 @@ public class DifferentialFunctionFactory {
|
|||
|
||||
}
|
||||
|
||||
public SDVariable leakyReluBp(SDVariable in, SDVariable epsilon, double cutoff) {
|
||||
return new LeakyReLUBp(sameDiff(), in, epsilon, cutoff).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #leakyReluBp(SDVariable, SDVariable, double)}
|
||||
*/
|
||||
@Deprecated
|
||||
public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) {
|
||||
return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable();
|
||||
}
|
||||
|
@ -1832,7 +1905,15 @@ public class DifferentialFunctionFactory {
|
|||
return new SELU(sameDiff(), arg, false).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable seluBp(SDVariable in, SDVariable epsilon) {
|
||||
validateDifferentialFunctionsameDiff(in);
|
||||
return new SeluBp(sameDiff(), in, epsilon).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #seluBp(SDVariable, SDVariable)}
|
||||
*/
|
||||
@Deprecated
|
||||
public SDVariable seluDerivative(SDVariable arg) {
|
||||
validateDifferentialFunctionsameDiff(arg);
|
||||
return new SELUDerivative(sameDiff(), arg, false).outputVariable();
|
||||
|
|
|
@ -901,6 +901,17 @@ public class OpValidation {
|
|||
TanhDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class,
|
||||
PowDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
|
||||
|
||||
BiasAddGrad.class,
|
||||
ConcatBp.class,
|
||||
|
|
|
@ -229,6 +229,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class,
|
||||
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class,
|
||||
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
|
||||
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
|
||||
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,
|
||||
|
@ -433,6 +434,17 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError.class,
|
||||
|
|
|
@ -21,6 +21,8 @@ import lombok.Getter;
|
|||
import lombok.NonNull;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -42,9 +44,9 @@ public class ActivationCube extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(@NonNull INDArray in, @NonNull INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new CubeDerivative(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
Nd4j.getExecutioner().execAndReturn(new CubeBp(in, epsilon, in));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -85,9 +86,8 @@ public class ActivationELU extends BaseActivationFunction {
|
|||
}
|
||||
|
||||
else {
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in));
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -41,9 +43,9 @@ public class ActivationHardSigmoid extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
Nd4j.getExecutioner().execAndReturn(new HardSigmoidBp(in, epsilon, in));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -43,9 +45,10 @@ public class ActivationHardTanH extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new HardTanhDerivative(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
|
||||
Nd4j.getExecutioner().execAndReturn(new HardTanhBp(in, epsilon, in));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
|
@ -54,9 +56,10 @@ public class ActivationLReLU extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new LeakyReLUDerivative(in, alpha));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
|
||||
Nd4j.getExecutioner().execAndReturn(new LeakyReLUBp(in, epsilon, in, alpha));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -63,7 +63,7 @@ public class ActivationRReLU extends BaseActivationFunction {
|
|||
public INDArray getActivation(INDArray in, boolean training) {
|
||||
if (training) {
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
this.alpha = Nd4j.rand(in.shape(), l, u, Nd4j.getRandom());
|
||||
this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape());
|
||||
}
|
||||
INDArray inTimesAlpha = in.mul(alpha);
|
||||
BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0));
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -48,9 +50,10 @@ public class ActivationRationalTanh extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new RationalTanhDerivative(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
|
||||
Nd4j.getExecutioner().execAndReturn(new RationalTanhBp(in, epsilon, in));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -20,8 +20,10 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.Getter;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
|
@ -41,9 +43,10 @@ public class ActivationReLU6 extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new Step(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
|
||||
Nd4j.getExecutioner().execAndReturn(new Relu6Derivative(in, epsilon, in));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -45,9 +47,10 @@ public class ActivationRectifiedTanh extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new RectifiedTanhDerivative(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
|
||||
Nd4j.getExecutioner().execAndReturn(new RectifiedTanhBp(in, epsilon, in));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -41,9 +43,10 @@ public class ActivationSELU extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new SELUDerivative(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
|
||||
Nd4j.getExecutioner().execAndReturn(new SeluBp(in, epsilon, in));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,7 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -41,9 +42,10 @@ public class ActivationSigmoid extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new SigmoidDerivative(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
|
||||
Nd4j.getExecutioner().execAndReturn(new SigmoidDerivative(in, epsilon, in));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
|
@ -41,9 +43,10 @@ public class ActivationSoftPlus extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new Sigmoid(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
|
||||
Nd4j.getExecutioner().execAndReturn(new SoftPlusBp(in, epsilon, in));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -41,9 +43,10 @@ public class ActivationSoftSign extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new SoftSignDerivative(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
|
||||
Nd4j.getExecutioner().execAndReturn(new SoftSignBp(in, epsilon, in));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,7 +21,9 @@ import lombok.Getter;
|
|||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
|
@ -42,10 +44,10 @@ public class ActivationSoftmax extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray out = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, in.ulike()))[0];
|
||||
INDArray x = out.mul(epsilon).sum(1);
|
||||
INDArray dLdz = out.mul(epsilon.subColumnVector(x));
|
||||
return new Pair<>(dLdz, null);
|
||||
|
||||
Nd4j.getExecutioner().execAndReturn(new SoftmaxBp(in, epsilon, in, -1));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,11 +18,11 @@ package org.nd4j.linalg.activations.impl;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
|
@ -41,9 +41,10 @@ public class ActivationTanH extends BaseActivationFunction {
|
|||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
assertShape(in, epsilon);
|
||||
INDArray dLdz = Nd4j.getExecutioner().exec(new TanhDerivative(in));
|
||||
dLdz.muli(epsilon);
|
||||
return new Pair<>(dLdz, null);
|
||||
|
||||
Nd4j.getExecutioner().execAndReturn(new TanhDerivative(in, epsilon, in));
|
||||
|
||||
return new Pair<>(in, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,17 +16,13 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.scalar;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.graph.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -108,8 +104,7 @@ public class LeakyReLU extends BaseScalarOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
SDVariable ret = f().leakyReluDerivative(arg(), alpha).mul(i_v.get(0));
|
||||
return Arrays.asList(ret);
|
||||
return Collections.singletonList(f().leakyReluBp(arg(), i_v.get(0), alpha));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -75,15 +75,8 @@ public class RectifiedLinear extends BaseScalarOp {
|
|||
return "Relu";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
if(scalarValue.getDouble(0) == 0.0){
|
||||
return Collections.singletonList(f().reluDerivative(arg(), i_v.get(0)));
|
||||
} else {
|
||||
SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0];
|
||||
SDVariable ret = step.mul(i_v.get(0));
|
||||
return Collections.singletonList(ret);
|
||||
}
|
||||
return Collections.singletonList(f().thresholdReluBp(arg(), i_v.get(0), scalarValue.getDouble(0)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,10 +3,10 @@ package org.nd4j.linalg.api.ops.impl.scalar;
|
|||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.shade.guava.base.Preconditions;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
@ -30,7 +30,8 @@ public class RectifiedLinearDerivative extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
|
||||
|
||||
/**
|
||||
* Threshold ReLU op. The genral case of {@link RectifiedLinear}.
|
||||
*/
|
||||
public class ThresholdRelu extends DynamicCustomOp {
|
||||
|
||||
@Getter
|
||||
private double cutoff = 0.0;
|
||||
|
||||
public ThresholdRelu(){ }
|
||||
|
||||
public ThresholdRelu(SameDiff sd, SDVariable input, boolean inPlace, double cutoff){
|
||||
super(sd, new SDVariable[]{input}, inPlace);
|
||||
this.cutoff = cutoff;
|
||||
addTArgument(cutoff);
|
||||
}
|
||||
|
||||
public ThresholdRelu(SameDiff sd, SDVariable input, double cutoff){
|
||||
super(sd, new SDVariable[]{input});
|
||||
this.cutoff = cutoff;
|
||||
addTArgument(cutoff);
|
||||
}
|
||||
|
||||
public ThresholdRelu(@NonNull INDArray input, INDArray output, double cutoff){
|
||||
super(new INDArray[]{input}, wrapOrNull(output));
|
||||
this.cutoff = cutoff;
|
||||
addTArgument(cutoff);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "thresholdedrelu";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
return Collections.singletonList(f().thresholdReluBp(arg(), f1.get(0), cutoff));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
/**
|
||||
* Cube backpropagation op - dL/dIn from in and dL/dOut
|
||||
*/
|
||||
public class CubeBp extends DynamicCustomOp {
|
||||
|
||||
public CubeBp(){ }
|
||||
|
||||
public CubeBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||
super(sd, new SDVariable[]{input, gradient});
|
||||
}
|
||||
|
||||
public CubeBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "cube_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -27,7 +27,11 @@ import java.util.List;
|
|||
|
||||
/**
|
||||
* Cube derivative, e.g. 3x^2
|
||||
*
|
||||
* @deprecated Use {@link CubeBp}
|
||||
*
|
||||
*/
|
||||
@Deprecated
|
||||
public class CubeDerivative extends BaseTransformStrictOp {
|
||||
public CubeDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||
super(sameDiff, i_v, inPlace);
|
||||
|
|
|
@ -35,8 +35,11 @@ import java.util.List;
|
|||
* Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter (2015)<br>
|
||||
* <a href="http://arxiv.org/abs/1511.07289">http://arxiv.org/abs/1511.07289</a>
|
||||
*
|
||||
* @deprecated Use {@link EluBp}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Deprecated
|
||||
public class ELUDerivative extends BaseTransformStrictOp {
|
||||
public ELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||
super(sameDiff, i_v, inPlace);
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
/**
|
||||
* ELU backpropagation op - dL/dIn from in and dL/dOut
|
||||
*/
|
||||
public class EluBp extends DynamicCustomOp {
|
||||
|
||||
public EluBp(){ }
|
||||
|
||||
public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||
super(sd, new SDVariable[]{input, gradient});
|
||||
}
|
||||
|
||||
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "elu_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
/**
|
||||
* Hard Sigmoid backpropagation op - dL/dIn from in and dL/dOut
|
||||
*/
|
||||
public class HardSigmoidBp extends DynamicCustomOp {
|
||||
|
||||
public HardSigmoidBp(){ }
|
||||
|
||||
public HardSigmoidBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||
super(sd, new SDVariable[]{input, gradient});
|
||||
}
|
||||
|
||||
public HardSigmoidBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "hardsigmoid_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -29,8 +29,11 @@ import java.util.List;
|
|||
/**
|
||||
* HardSigmoid derivative
|
||||
*
|
||||
* @deprecated Use {@link HardSigmoidBp}
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
@Deprecated
|
||||
public class HardSigmoidDerivative extends BaseTransformStrictOp {
|
||||
public HardSigmoidDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||
super(sameDiff, i_v, inPlace);
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
/**
|
||||
* Hard Tanh backpropagation op - dL/dIn from in and dL/dOut
|
||||
*/
|
||||
public class HardTanhBp extends DynamicCustomOp {
|
||||
|
||||
public HardTanhBp(){ }
|
||||
|
||||
public HardTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||
super(sd, new SDVariable[]{input, gradient});
|
||||
}
|
||||
|
||||
public HardTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "hardtanh_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -31,8 +31,11 @@ import java.util.List;
|
|||
/**
|
||||
* Hard tanh elementwise derivative function
|
||||
*
|
||||
* @deprecated Use {@link HardTanhBp}
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@Deprecated
|
||||
public class HardTanhDerivative extends BaseTransformStrictOp {
|
||||
public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||
super(sameDiff, i_v, inPlace);
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
/**
|
||||
* LReLU backpropagation op - dL/dIn from in and dL/dOut
|
||||
*/
|
||||
public class LeakyReLUBp extends DynamicCustomOp {
|
||||
public static final double DEFAULT_ALPHA = 0.01;
|
||||
private double alpha = DEFAULT_ALPHA;
|
||||
|
||||
public LeakyReLUBp(){ }
|
||||
|
||||
public LeakyReLUBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){
|
||||
super(sd, new SDVariable[]{input, gradient});
|
||||
this.alpha = alpha;
|
||||
addTArgument(alpha);
|
||||
}
|
||||
|
||||
public LeakyReLUBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
this.alpha = alpha;
|
||||
addTArgument(alpha);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "lrelu_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
/**
|
||||
* Rational Tanh backpropagation op - dL/dIn from in and dL/dOut
|
||||
*/
|
||||
public class RationalTanhBp extends DynamicCustomOp {
|
||||
|
||||
public RationalTanhBp(){ }
|
||||
|
||||
public RationalTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||
super(sd, new SDVariable[]{input, gradient});
|
||||
}
|
||||
|
||||
public RationalTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "rationaltanh_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -31,9 +31,12 @@ import java.util.List;
|
|||
* Rational Tanh Derivative, as described at https://github.com/deeplearning4j/libnd4j/issues/351
|
||||
* Calculates dOut/dIn given input, not dL/dIn given dL/dOut and input
|
||||
*
|
||||
* @deprecated Use {@link RationalTanhBp}
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
* @author AlexDBlack
|
||||
*/
|
||||
@Deprecated
|
||||
public class RationalTanhDerivative extends BaseTransformStrictOp {
|
||||
public RationalTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) {
|
||||
super(sameDiff, in, inPlace);
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
/**
|
||||
* Rectified Tanh backpropagation op - dL/dIn from in and dL/dOut
|
||||
*/
|
||||
public class RectifiedTanhBp extends DynamicCustomOp {
|
||||
|
||||
public RectifiedTanhBp(){ }
|
||||
|
||||
public RectifiedTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||
super(sd, new SDVariable[]{input, gradient});
|
||||
}
|
||||
|
||||
public RectifiedTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "rectifiedtanh_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -30,9 +30,12 @@ import java.util.List;
|
|||
/**
|
||||
* Rectified Tanh Derivative
|
||||
*
|
||||
* @deprecated Use {@link RectifiedTanhBp}
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
* @author AlexDBlack
|
||||
*/
|
||||
@Deprecated
|
||||
public class RectifiedTanhDerivative extends BaseTransformStrictOp {
|
||||
public RectifiedTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) {
|
||||
super(sameDiff, in, inPlace);
|
||||
|
|
|
@ -16,15 +16,18 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
|
||||
|
||||
/**
|
||||
* Derivative of Rectified linear unit 6, i.e. min(max(input, cutoff), 6), where cutoff can be chosen.
|
||||
|
@ -33,7 +36,9 @@ import java.util.List;
|
|||
*/
|
||||
public class Relu6Derivative extends DynamicCustomOp {
|
||||
|
||||
private double cutoff = 0.0;
|
||||
private static final double DEFAULT_CUTOFF = 0.0;
|
||||
|
||||
private double cutoff = DEFAULT_CUTOFF;
|
||||
|
||||
public Relu6Derivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, double cutoff) {
|
||||
super("relu6_bp", sameDiff, new SDVariable[]{i_v1, i_v2});
|
||||
|
@ -45,6 +50,16 @@ public class Relu6Derivative extends DynamicCustomOp {
|
|||
this.extraArgs = new Object[]{cutoff};
|
||||
}
|
||||
|
||||
public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||
this(input, gradient, output, DEFAULT_CUTOFF);
|
||||
}
|
||||
|
||||
public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
this.cutoff = cutoff;
|
||||
this.extraArgs = new Object[]{cutoff};
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 0;
|
||||
|
|
|
@ -31,8 +31,11 @@ import java.util.List;
|
|||
*
|
||||
* https://arxiv.org/pdf/1706.02515.pdf
|
||||
*
|
||||
* @deprecated Use {@link SeluBp}
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
@Deprecated
|
||||
public class SELUDerivative extends BaseTransformStrictOp {
|
||||
|
||||
private static final double SELU_ALPHA = 1.6732632423543772848170429916717;
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
/**
|
||||
* SELU backpropagation op - dL/dIn from in and dL/dOut
|
||||
*/
|
||||
public class SeluBp extends DynamicCustomOp {
|
||||
|
||||
public SeluBp(){ }
|
||||
|
||||
public SeluBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||
super(sd, new SDVariable[]{input, gradient});
|
||||
}
|
||||
|
||||
public SeluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "selu_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
/**
|
||||
* SoftPlus backpropagation op - dL/dIn from in and dL/dOut
|
||||
*/
|
||||
public class SoftPlusBp extends DynamicCustomOp {
|
||||
|
||||
public SoftPlusBp(){ }
|
||||
|
||||
public SoftPlusBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||
super(sd, new SDVariable[]{input, gradient});
|
||||
}
|
||||
|
||||
public SoftPlusBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "softplus_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
/**
|
||||
* SoftSign backpropagation op - dL/dIn from in and dL/dOut
|
||||
*/
|
||||
public class SoftSignBp extends DynamicCustomOp {
|
||||
|
||||
public SoftSignBp(){ }
|
||||
|
||||
public SoftSignBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||
super(sd, new SDVariable[]{input, gradient});
|
||||
}
|
||||
|
||||
public SoftSignBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "softsign_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -29,7 +29,10 @@ import java.util.List;
|
|||
|
||||
/**
|
||||
* SoftSign derivative.
|
||||
*
|
||||
* @deprecated Use {@link SoftSignBp}
|
||||
*/
|
||||
@Deprecated
|
||||
public class SoftSignDerivative extends BaseTransformStrictOp {
|
||||
public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||
super(sameDiff, i_v, inPlace);
|
||||
|
|
|
@ -16,10 +16,12 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
|
@ -40,6 +42,12 @@ public class SoftmaxBp extends DynamicCustomOp {
|
|||
addIArgument(dimension);
|
||||
}
|
||||
|
||||
public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Integer dimension){
|
||||
super(new INDArray[]{input, grad}, wrapOrNull(output));
|
||||
if(dimension != null)
|
||||
addIArgument(dimension);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "softmax_bp";
|
||||
|
|
|
@ -35,15 +35,15 @@ public class TanhDerivative extends DynamicCustomOp {
|
|||
super(sameDiff, new SDVariable[]{i_v1, i_v2});
|
||||
}
|
||||
|
||||
public TanhDerivative(INDArray x, INDArray z) {
|
||||
super(null, x, z, null, null);
|
||||
public TanhDerivative(INDArray x, INDArray y, INDArray z) {
|
||||
super(null, new INDArray[]{x, y}, new INDArray[]{z});
|
||||
}
|
||||
|
||||
public TanhDerivative() {
|
||||
}
|
||||
|
||||
public TanhDerivative(INDArray x) {
|
||||
this(x, null);
|
||||
public TanhDerivative(INDArray x, INDArray y) {
|
||||
this(x, y, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu;
|
||||
|
||||
/**
|
||||
* Threshold ReLU Backprop op - dL/dIn from in and dL/dOut
|
||||
*
|
||||
* For {@link RectifiedLinear} as well as {@link ThresholdRelu}.
|
||||
*/
|
||||
public class ThresholdReluBp extends DynamicCustomOp {
|
||||
|
||||
@Getter
|
||||
private double cutoff = 0;
|
||||
|
||||
public ThresholdReluBp(){ }
|
||||
|
||||
public ThresholdReluBp(SameDiff sd, SDVariable input, SDVariable gradient, double cutoff){
|
||||
super(sd, new SDVariable[]{input, gradient});
|
||||
this.cutoff = cutoff;
|
||||
addTArgument(cutoff);
|
||||
}
|
||||
|
||||
public ThresholdReluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){
|
||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||
this.cutoff = cutoff;
|
||||
addTArgument(cutoff);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "thresholdedrelu_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.same;
|
||||
|
||||
import java.util.Collections;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
|
@ -70,7 +71,6 @@ public class Cube extends BaseTransformSameOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
SDVariable g = f().mul(f().cubeDerivative(arg()),f1.get(0));
|
||||
return Arrays.asList(g);
|
||||
return Collections.singletonList(f().cubeBp(arg(), f1.get(0)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||
|
||||
import java.util.Collections;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
|
@ -76,8 +77,7 @@ public class ELU extends BaseTransformStrictOp {
|
|||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
//ELU: e^x-1 if x<0, x otherwise
|
||||
//dL/dIn = dL/Out * dOut/dIn
|
||||
SDVariable ret = f().eluDerivative(arg()).mul(i_v.get(0));
|
||||
return Arrays.asList(ret);
|
||||
return Collections.singletonList(f().eluBp(arg(), i_v.get(0)));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -69,9 +69,7 @@ public class HardSigmoid extends BaseTransformStrictOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
SDVariable in = arg();
|
||||
SDVariable dOutdIn = new HardSigmoidDerivative(sameDiff, in, false).outputVariables()[0];
|
||||
return Collections.singletonList(dOutdIn.mul(f1.get(0)));
|
||||
return Collections.singletonList(f().hardSigmoidBp(arg(), f1.get(0)));
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||
|
||||
import java.util.Collections;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -70,7 +71,6 @@ public class HardTanh extends BaseTransformStrictOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
SDVariable ret = f().hardTanhDerivative(arg()).mul(i_v.get(0));
|
||||
return Arrays.asList(ret);
|
||||
return Collections.singletonList(f().hardTanhBp(arg(), i_v.get(0)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,13 +16,10 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
||||
|
||||
import java.util.Collections;
|
||||
|
@ -71,6 +68,6 @@ public class RationalTanh extends BaseTransformStrictOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
return Collections.singletonList(f().tanhRationalDerivative(arg()).mul(f1.get(0)));
|
||||
return Collections.singletonList(f().tanhRationalBp(arg(), f1.get(0)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,13 +17,10 @@
|
|||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -88,6 +85,6 @@ public class RectifiedTanh extends BaseTransformStrictOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
return Collections.singletonList(f().tanhRectifiedDerivative(arg()).mul(f1.get(0)));
|
||||
return Collections.singletonList(f().tanhRectifiedBp(arg(), f1.get(0)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||
|
||||
import java.util.Collections;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -76,8 +77,7 @@ public class SELU extends BaseTransformStrictOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
SDVariable ret = f().seluDerivative(arg()).mul(i_v.get(0));
|
||||
return Arrays.asList(ret);
|
||||
return Collections.singletonList(f().seluBp(arg(), i_v.get(0)));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -28,8 +28,11 @@ import java.util.List;
|
|||
/**
|
||||
* Sigmoid derivative
|
||||
*
|
||||
* @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative}
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@Deprecated
|
||||
public class SigmoidDerivative extends BaseTransformStrictOp {
|
||||
public SigmoidDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||
super(sameDiff, i_v1, i_v2);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||
|
||||
import java.util.Collections;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -73,8 +74,7 @@ public class SoftSign extends BaseTransformStrictOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
SDVariable ret = f().softsignDerivative(arg()).mul(i_v.get(0));
|
||||
return Arrays.asList(ret);
|
||||
return Collections.singletonList(f().softsignBp(arg(), i_v.get(0)));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -27,7 +27,10 @@ import java.util.List;
|
|||
|
||||
/**
|
||||
* Tanh derivative
|
||||
*
|
||||
* @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative}.
|
||||
*/
|
||||
@Deprecated
|
||||
public class TanhDerivative extends BaseTransformStrictOp {
|
||||
public TanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||
super(sameDiff, i_v, inPlace);
|
||||
|
|
|
@ -42,7 +42,7 @@ public class LecunUniformInitScheme extends BaseWeightInitScheme {
|
|||
@Override
|
||||
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
||||
double b = 3.0 / Math.sqrt(fanIn);
|
||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-b, b));
|
||||
return Nd4j.rand(Nd4j.getDistributions().createUniform(-b, b), shape);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ public class ReluUniformInitScheme extends BaseWeightInitScheme {
|
|||
@Override
|
||||
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
||||
double u = Math.sqrt(6.0 / fanIn);
|
||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
|
||||
return Nd4j.rand(Nd4j.getDistributions().createUniform(-u, u), shape); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ public class SigmoidUniformInitScheme extends BaseWeightInitScheme {
|
|||
@Override
|
||||
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
||||
double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut));
|
||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-r, r));
|
||||
return Nd4j.rand(Nd4j.getDistributions().createUniform(-r, r), shape);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ public class UniformInitScheme extends BaseWeightInitScheme {
|
|||
@Override
|
||||
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
||||
double a = 1.0 / Math.sqrt(fanIn);
|
||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-a, a));
|
||||
return Nd4j.rand(Nd4j.getDistributions().createUniform(-a, a), shape);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ public class VarScalingNormalUniformFanInInitScheme extends BaseWeightInitScheme
|
|||
@Override
|
||||
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
||||
double scalingFanIn = 3.0 / Math.sqrt(fanIn);
|
||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn));
|
||||
return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn), shape);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ public class VarScalingNormalUniformFanOutInitScheme extends BaseWeightInitSchem
|
|||
@Override
|
||||
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
||||
double scalingFanOut = 3.0 / Math.sqrt(fanOut);
|
||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut));
|
||||
return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut), shape);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ public class VarScalingUniformFanAvgInitScheme extends BaseWeightInitScheme {
|
|||
@Override
|
||||
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
||||
double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2);
|
||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg));
|
||||
return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg), shape);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ public class XavierUniformInitScheme extends BaseWeightInitScheme {
|
|||
//As per Glorot and Bengio 2010: Uniform distribution U(-s,s) with s = sqrt(6/(fanIn + fanOut))
|
||||
//Eq 16: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
|
||||
double s = Math.sqrt(6.0) / Math.sqrt(fanIn + fanOut);
|
||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-s, s));
|
||||
return Nd4j.rand(Nd4j.getDistributions().createUniform(-s, s), shape);
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue