New Nd4j backprop ops for activations (#211)

* new (for java at least) backprop ops

Signed-off-by: Ryan Nett <rnett@skymind.io>

* update activation functions

Signed-off-by: Ryan Nett <rnett@skymind.io>

* add differential functions for SameDiff

Signed-off-by: Ryan Nett <rnett@skymind.io>

* deprecate old ops

Signed-off-by: Ryan Nett <rnett@skymind.io>

* update correct old ops

Signed-off-by: Ryan Nett <rnett@skymind.io>

* update ops backprop to use new ops

Signed-off-by: Ryan Nett <rnett@skymind.io>

* misc updates for deprecated functions (mostly Nd4j.rand w/ vararg shape)

Signed-off-by: Ryan Nett <rnett@skymind.io>

* remove old imports

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-09-01 23:15:23 -07:00 committed by Alex Black
parent 6d04d30c94
commit b3a134b608
62 changed files with 1053 additions and 103 deletions

View File

@ -204,20 +204,30 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum;
import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast;
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
@ -1126,10 +1136,26 @@ public class DifferentialFunctionFactory {
return new org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative(sameDiff(), iX, wrt).outputVariable();
}
public SDVariable tanhRationalBp(SDVariable in, SDVariable epsilon) {
return new RationalTanhBp(sameDiff(), in, epsilon).outputVariable();
}
public SDVariable tanhRectifiedBp(SDVariable in, SDVariable epsilon) {
return new RectifiedTanhBp(sameDiff(), in, epsilon).outputVariable();
}
/**
* Use {@link #tanhRationalBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable tanhRationalDerivative(SDVariable in) {
return new RationalTanhDerivative(sameDiff(), in, false).outputVariable();
}
/**
* Use {@link #tanhRectifiedBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable tanhRectifiedDerivative(SDVariable in) {
return new RectifiedTanhDerivative(sameDiff(), in, false).outputVariable();
}
@ -1280,6 +1306,14 @@ public class DifferentialFunctionFactory {
return new Cube(sameDiff(), iX, false).outputVariable();
}
public SDVariable cubeBp(SDVariable in, SDVariable epsilon) {
return new CubeBp(sameDiff(), in, epsilon).outputVariable();
}
/**
* @deprecated Use {@link #cubeBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable cubeDerivative(SDVariable iX) {
return new CubeDerivative(sameDiff(), iX, false).outputVariable();
}
@ -1329,6 +1363,14 @@ public class DifferentialFunctionFactory {
return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable();
}
public SDVariable thresholdRelu(SDVariable in, SDVariable epsilon, double cutoff){
return new ThresholdRelu(sameDiff(), in, cutoff).outputVariable();
}
public SDVariable thresholdReluBp(SDVariable in, SDVariable epsilon, double cutoff){
return new ThresholdReluBp(sameDiff(), in, epsilon, cutoff).outputVariable();
}
public SDVariable relu6(SDVariable iX, double cutoff) {
return new Relu6(sameDiff(), iX, false, cutoff).outputVariable();
}
@ -1350,6 +1392,14 @@ public class DifferentialFunctionFactory {
return new HardTanh(sameDiff(), iX, false).outputVariable();
}
public SDVariable hardTanhBp(SDVariable in, SDVariable epsilon) {
return new HardTanhBp(sameDiff(), in, epsilon).outputVariable();
}
/**
* @deprecated Use {@link #hardTanhBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable hardTanhDerivative(SDVariable iX) {
return new HardTanhDerivative(sameDiff(), iX, false).outputVariable();
}
@ -1358,6 +1408,9 @@ public class DifferentialFunctionFactory {
return new HardSigmoid(sameDiff(), in, false).outputVariable();
}
public SDVariable hardSigmoidBp(SDVariable in, SDVariable epsilon){
return new HardSigmoidBp(sameDiff(), in, epsilon).outputVariable();
}
public SDVariable sigmoid(SDVariable iX) {
return new Sigmoid(sameDiff(), iX, false).outputVariable();
@ -1486,10 +1539,16 @@ public class DifferentialFunctionFactory {
}
public SDVariable softsignBp(SDVariable in, SDVariable epsilon) {
return new SoftSignBp(sameDiff(), in, epsilon).outputVariable();
}
/**
* @deprecated Use {@link #softsignBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable softsignDerivative(SDVariable iX) {
return new SoftSignDerivative(sameDiff(), iX, false).outputVariable();
}
@ -1504,10 +1563,16 @@ public class DifferentialFunctionFactory {
}
public SDVariable eluBp(SDVariable in, SDVariable epsilon) {
return new EluBp(sameDiff(), in, epsilon).outputVariable();
}
/**
* @deprecated Use {@link #eluBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable eluDerivative(SDVariable iX) {
return new ELUDerivative(sameDiff(), iX, false).outputVariable();
}
@ -1516,6 +1581,14 @@ public class DifferentialFunctionFactory {
}
public SDVariable leakyReluBp(SDVariable in, SDVariable epsilon, double cutoff) {
return new LeakyReLUBp(sameDiff(), in, epsilon, cutoff).outputVariable();
}
/**
* @deprecated Use {@link #leakyReluBp(SDVariable, SDVariable, double)}
*/
@Deprecated
public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) {
return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable();
}
@ -1832,7 +1905,15 @@ public class DifferentialFunctionFactory {
return new SELU(sameDiff(), arg, false).outputVariable();
}
public SDVariable seluBp(SDVariable in, SDVariable epsilon) {
validateDifferentialFunctionsameDiff(in);
return new SeluBp(sameDiff(), in, epsilon).outputVariable();
}
/**
* @deprecated Use {@link #seluBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable seluDerivative(SDVariable arg) {
validateDifferentialFunctionsameDiff(arg);
return new SELUDerivative(sameDiff(), arg, false).outputVariable();

View File

@ -901,6 +901,17 @@ public class OpValidation {
TanhDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class,
PowDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
BiasAddGrad.class,
ConcatBp.class,

View File

@ -229,6 +229,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class,
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class,
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class,
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,
@ -433,6 +434,17 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError.class,

View File

@ -21,6 +21,8 @@ import lombok.Getter;
import lombok.NonNull;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.factory.Nd4j;
@ -42,9 +44,9 @@ public class ActivationCube extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(@NonNull INDArray in, @NonNull INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new CubeDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new CubeBp(in, epsilon, in));
return new Pair<>(in, null);
}
@Override

View File

@ -18,6 +18,7 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -85,9 +86,8 @@ public class ActivationELU extends BaseActivationFunction {
}
else {
INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in));
return new Pair<>(in, null);
}
}

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,9 +43,9 @@ public class ActivationHardSigmoid extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new HardSigmoidBp(in, epsilon, in));
return new Pair<>(in, null);
}
@Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -43,9 +45,10 @@ public class ActivationHardTanH extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new HardTanhDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new HardTanhBp(in, epsilon, in));
return new Pair<>(in, null);
}
@Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
@ -54,9 +56,10 @@ public class ActivationLReLU extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new LeakyReLUDerivative(in, alpha));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new LeakyReLUBp(in, epsilon, in, alpha));
return new Pair<>(in, null);
}
@Override

View File

@ -63,7 +63,7 @@ public class ActivationRReLU extends BaseActivationFunction {
public INDArray getActivation(INDArray in, boolean training) {
if (training) {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
this.alpha = Nd4j.rand(in.shape(), l, u, Nd4j.getRandom());
this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape());
}
INDArray inTimesAlpha = in.mul(alpha);
BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0));

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -48,9 +50,10 @@ public class ActivationRationalTanh extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new RationalTanhDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new RationalTanhBp(in, epsilon, in));
return new Pair<>(in, null);
}
@Override

View File

@ -20,8 +20,10 @@ import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
import org.nd4j.linalg.api.ops.impl.scalar.Step;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
@ -41,9 +43,10 @@ public class ActivationReLU6 extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new Step(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new Relu6Derivative(in, epsilon, in));
return new Pair<>(in, null);
}
@Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -45,9 +47,10 @@ public class ActivationRectifiedTanh extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new RectifiedTanhDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new RectifiedTanhBp(in, epsilon, in));
return new Pair<>(in, null);
}
@Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,9 +43,10 @@ public class ActivationSELU extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new SELUDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new SeluBp(in, epsilon, in));
return new Pair<>(in, null);
}
@Override

View File

@ -18,7 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,9 +42,10 @@ public class ActivationSigmoid extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new SigmoidDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new SigmoidDerivative(in, epsilon, in));
return new Pair<>(in, null);
}
@Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
@ -41,9 +43,10 @@ public class ActivationSoftPlus extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new Sigmoid(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new SoftPlusBp(in, epsilon, in));
return new Pair<>(in, null);
}
@Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,9 +43,10 @@ public class ActivationSoftSign extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new SoftSignDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new SoftSignBp(in, epsilon, in));
return new Pair<>(in, null);
}
@Override

View File

@ -21,7 +21,9 @@ import lombok.Getter;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
@ -42,10 +44,10 @@ public class ActivationSoftmax extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray out = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, in.ulike()))[0];
INDArray x = out.mul(epsilon).sum(1);
INDArray dLdz = out.mul(epsilon.subColumnVector(x));
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new SoftmaxBp(in, epsilon, in, -1));
return new Pair<>(in, null);
}
@Override

View File

@ -18,11 +18,11 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
import org.nd4j.linalg.factory.Nd4j;
/**
@ -41,9 +41,10 @@ public class ActivationTanH extends BaseActivationFunction {
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new TanhDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
Nd4j.getExecutioner().execAndReturn(new TanhDerivative(in, epsilon, in));
return new Pair<>(in, null);
}
@Override

View File

@ -16,17 +16,13 @@
package org.nd4j.linalg.api.ops.impl.scalar;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.graph.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseScalarOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseScalarOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
@ -108,8 +104,7 @@ public class LeakyReLU extends BaseScalarOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().leakyReluDerivative(arg(), alpha).mul(i_v.get(0));
return Arrays.asList(ret);
return Collections.singletonList(f().leakyReluBp(arg(), i_v.get(0), alpha));
}
@Override

View File

@ -75,15 +75,8 @@ public class RectifiedLinear extends BaseScalarOp {
return "Relu";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
if(scalarValue.getDouble(0) == 0.0){
return Collections.singletonList(f().reluDerivative(arg(), i_v.get(0)));
} else {
SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0];
SDVariable ret = step.mul(i_v.get(0));
return Collections.singletonList(ret);
}
return Collections.singletonList(f().thresholdReluBp(arg(), i_v.get(0), scalarValue.getDouble(0)));
}
}

View File

@ -3,10 +3,10 @@ package org.nd4j.linalg.api.ops.impl.scalar;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.shade.guava.base.Preconditions;
import java.util.Collections;
import java.util.List;
@ -30,7 +30,8 @@ public class RectifiedLinearDerivative extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));

View File

@ -0,0 +1,77 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import java.util.Collections;
import java.util.List;
import lombok.Getter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
/**
* Threshold ReLU op. The genral case of {@link RectifiedLinear}.
*/
public class ThresholdRelu extends DynamicCustomOp {
@Getter
private double cutoff = 0.0;
public ThresholdRelu(){ }
public ThresholdRelu(SameDiff sd, SDVariable input, boolean inPlace, double cutoff){
super(sd, new SDVariable[]{input}, inPlace);
this.cutoff = cutoff;
addTArgument(cutoff);
}
public ThresholdRelu(SameDiff sd, SDVariable input, double cutoff){
super(sd, new SDVariable[]{input});
this.cutoff = cutoff;
addTArgument(cutoff);
}
public ThresholdRelu(@NonNull INDArray input, INDArray output, double cutoff){
super(new INDArray[]{input}, wrapOrNull(output));
this.cutoff = cutoff;
addTArgument(cutoff);
}
@Override
public String opName(){
return "thresholdedrelu";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return Collections.singletonList(f().thresholdReluBp(arg(), f1.get(0), cutoff));
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* Cube backpropagation op - dL/dIn from in and dL/dOut
*/
public class CubeBp extends DynamicCustomOp {
public CubeBp(){ }
public CubeBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public CubeBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "cube_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -27,7 +27,11 @@ import java.util.List;
/**
* Cube derivative, e.g. 3x^2
*
* @deprecated Use {@link CubeBp}
*
*/
@Deprecated
public class CubeDerivative extends BaseTransformStrictOp {
public CubeDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);

View File

@ -35,8 +35,11 @@ import java.util.List;
* Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter (2015)<br>
* <a href="http://arxiv.org/abs/1511.07289">http://arxiv.org/abs/1511.07289</a>
*
* @deprecated Use {@link EluBp}
*
* @author Alex Black
*/
@Deprecated
public class ELUDerivative extends BaseTransformStrictOp {
public ELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* ELU backpropagation op - dL/dIn from in and dL/dOut
*/
public class EluBp extends DynamicCustomOp {
public EluBp(){ }
public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "elu_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* Hard Sigmoid backpropagation op - dL/dIn from in and dL/dOut
*/
public class HardSigmoidBp extends DynamicCustomOp {
public HardSigmoidBp(){ }
public HardSigmoidBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public HardSigmoidBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "hardsigmoid_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -29,8 +29,11 @@ import java.util.List;
/**
* HardSigmoid derivative
*
* @deprecated Use {@link HardSigmoidBp}
*
* @author raver119@gmail.com
*/
@Deprecated
public class HardSigmoidDerivative extends BaseTransformStrictOp {
public HardSigmoidDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* Hard Tanh backpropagation op - dL/dIn from in and dL/dOut
*/
public class HardTanhBp extends DynamicCustomOp {
public HardTanhBp(){ }
public HardTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public HardTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "hardtanh_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -31,8 +31,11 @@ import java.util.List;
/**
* Hard tanh elementwise derivative function
*
* @deprecated Use {@link HardTanhBp}
*
* @author Adam Gibson
*/
@Deprecated
public class HardTanhDerivative extends BaseTransformStrictOp {
public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);

View File

@ -0,0 +1,68 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* LReLU backpropagation op - dL/dIn from in and dL/dOut
*/
public class LeakyReLUBp extends DynamicCustomOp {
public static final double DEFAULT_ALPHA = 0.01;
private double alpha = DEFAULT_ALPHA;
public LeakyReLUBp(){ }
public LeakyReLUBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){
super(sd, new SDVariable[]{input, gradient});
this.alpha = alpha;
addTArgument(alpha);
}
public LeakyReLUBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
this.alpha = alpha;
addTArgument(alpha);
}
@Override
public String opName(){
return "lrelu_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* Rational Tanh backpropagation op - dL/dIn from in and dL/dOut
*/
public class RationalTanhBp extends DynamicCustomOp {
public RationalTanhBp(){ }
public RationalTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public RationalTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "rationaltanh_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -31,9 +31,12 @@ import java.util.List;
* Rational Tanh Derivative, as described at https://github.com/deeplearning4j/libnd4j/issues/351
* Calculates dOut/dIn given input, not dL/dIn given dL/dOut and input
*
* @deprecated Use {@link RationalTanhBp}
*
* @author raver119@gmail.com
* @author AlexDBlack
*/
@Deprecated
public class RationalTanhDerivative extends BaseTransformStrictOp {
public RationalTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) {
super(sameDiff, in, inPlace);

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* Rectified Tanh backpropagation op - dL/dIn from in and dL/dOut
*/
public class RectifiedTanhBp extends DynamicCustomOp {
public RectifiedTanhBp(){ }
public RectifiedTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public RectifiedTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "rectifiedtanh_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -30,9 +30,12 @@ import java.util.List;
/**
* Rectified Tanh Derivative
*
* @deprecated Use {@link RectifiedTanhBp}
*
* @author raver119@gmail.com
* @author AlexDBlack
*/
@Deprecated
public class RectifiedTanhDerivative extends BaseTransformStrictOp {
public RectifiedTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) {
super(sameDiff, in, inPlace);

View File

@ -16,15 +16,18 @@
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
/**
* Derivative of Rectified linear unit 6, i.e. min(max(input, cutoff), 6), where cutoff can be chosen.
@ -33,7 +36,9 @@ import java.util.List;
*/
public class Relu6Derivative extends DynamicCustomOp {
private double cutoff = 0.0;
private static final double DEFAULT_CUTOFF = 0.0;
private double cutoff = DEFAULT_CUTOFF;
public Relu6Derivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, double cutoff) {
super("relu6_bp", sameDiff, new SDVariable[]{i_v1, i_v2});
@ -45,6 +50,16 @@ public class Relu6Derivative extends DynamicCustomOp {
this.extraArgs = new Object[]{cutoff};
}
public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
this(input, gradient, output, DEFAULT_CUTOFF);
}
public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
this.cutoff = cutoff;
this.extraArgs = new Object[]{cutoff};
}
@Override
public int opNum() {
return 0;

View File

@ -31,8 +31,11 @@ import java.util.List;
*
* https://arxiv.org/pdf/1706.02515.pdf
*
* @deprecated Use {@link SeluBp}
*
* @author raver119@gmail.com
*/
@Deprecated
public class SELUDerivative extends BaseTransformStrictOp {
private static final double SELU_ALPHA = 1.6732632423543772848170429916717;

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* SELU backpropagation op - dL/dIn from in and dL/dOut
*/
public class SeluBp extends DynamicCustomOp {
public SeluBp(){ }
public SeluBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public SeluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "selu_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* SoftPlus backpropagation op - dL/dIn from in and dL/dOut
*/
public class SoftPlusBp extends DynamicCustomOp {
public SoftPlusBp(){ }
public SoftPlusBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public SoftPlusBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "softplus_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -0,0 +1,63 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* SoftSign backpropagation op - dL/dIn from in and dL/dOut
*/
public class SoftSignBp extends DynamicCustomOp {
public SoftSignBp(){ }
public SoftSignBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public SoftSignBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "softsign_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -29,7 +29,10 @@ import java.util.List;
/**
* SoftSign derivative.
*
* @deprecated Use {@link SoftSignBp}
*/
@Deprecated
public class SoftSignDerivative extends BaseTransformStrictOp {
public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);

View File

@ -16,10 +16,12 @@
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
@ -40,6 +42,12 @@ public class SoftmaxBp extends DynamicCustomOp {
addIArgument(dimension);
}
public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Integer dimension){
super(new INDArray[]{input, grad}, wrapOrNull(output));
if(dimension != null)
addIArgument(dimension);
}
@Override
public String opName() {
return "softmax_bp";

View File

@ -35,15 +35,15 @@ public class TanhDerivative extends DynamicCustomOp {
super(sameDiff, new SDVariable[]{i_v1, i_v2});
}
public TanhDerivative(INDArray x, INDArray z) {
super(null, x, z, null, null);
public TanhDerivative(INDArray x, INDArray y, INDArray z) {
super(null, new INDArray[]{x, y}, new INDArray[]{z});
}
public TanhDerivative() {
}
public TanhDerivative(INDArray x) {
this(x, null);
public TanhDerivative(INDArray x, INDArray y) {
this(x, y, null);
}
@Override

View File

@ -0,0 +1,74 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.Getter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu;
/**
* Threshold ReLU Backprop op - dL/dIn from in and dL/dOut
*
* For {@link RectifiedLinear} as well as {@link ThresholdRelu}.
*/
public class ThresholdReluBp extends DynamicCustomOp {
@Getter
private double cutoff = 0;
public ThresholdReluBp(){ }
public ThresholdReluBp(SameDiff sd, SDVariable input, SDVariable gradient, double cutoff){
super(sd, new SDVariable[]{input, gradient});
this.cutoff = cutoff;
addTArgument(cutoff);
}
public ThresholdReluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
this.cutoff = cutoff;
addTArgument(cutoff);
}
@Override
public String opName(){
return "thresholdedrelu_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.same;
import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
@ -70,7 +71,6 @@ public class Cube extends BaseTransformSameOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
SDVariable g = f().mul(f().cubeDerivative(arg()),f1.get(0));
return Arrays.asList(g);
return Collections.singletonList(f().cubeBp(arg(), f1.get(0)));
}
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
@ -76,8 +77,7 @@ public class ELU extends BaseTransformStrictOp {
public List<SDVariable> doDiff(List<SDVariable> i_v) {
//ELU: e^x-1 if x<0, x otherwise
//dL/dIn = dL/Out * dOut/dIn
SDVariable ret = f().eluDerivative(arg()).mul(i_v.get(0));
return Arrays.asList(ret);
return Collections.singletonList(f().eluBp(arg(), i_v.get(0)));
}
}

View File

@ -69,9 +69,7 @@ public class HardSigmoid extends BaseTransformStrictOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
SDVariable in = arg();
SDVariable dOutdIn = new HardSigmoidDerivative(sameDiff, in, false).outputVariables()[0];
return Collections.singletonList(dOutdIn.mul(f1.get(0)));
return Collections.singletonList(f().hardSigmoidBp(arg(), f1.get(0)));
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -70,7 +71,6 @@ public class HardTanh extends BaseTransformStrictOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().hardTanhDerivative(arg()).mul(i_v.get(0));
return Arrays.asList(ret);
return Collections.singletonList(f().hardTanhBp(arg(), i_v.get(0)));
}
}

View File

@ -16,13 +16,10 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
import java.util.Collections;
@ -71,6 +68,6 @@ public class RationalTanh extends BaseTransformStrictOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return Collections.singletonList(f().tanhRationalDerivative(arg()).mul(f1.get(0)));
return Collections.singletonList(f().tanhRationalBp(arg(), f1.get(0)));
}
}

View File

@ -17,13 +17,10 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
@ -88,6 +85,6 @@ public class RectifiedTanh extends BaseTransformStrictOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return Collections.singletonList(f().tanhRectifiedDerivative(arg()).mul(f1.get(0)));
return Collections.singletonList(f().tanhRectifiedBp(arg(), f1.get(0)));
}
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -76,8 +77,7 @@ public class SELU extends BaseTransformStrictOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().seluDerivative(arg()).mul(i_v.get(0));
return Arrays.asList(ret);
return Collections.singletonList(f().seluBp(arg(), i_v.get(0)));
}
}

View File

@ -28,8 +28,11 @@ import java.util.List;
/**
* Sigmoid derivative
*
* @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative}
*
* @author Adam Gibson
*/
@Deprecated
public class SigmoidDerivative extends BaseTransformStrictOp {
public SigmoidDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -73,8 +74,7 @@ public class SoftSign extends BaseTransformStrictOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().softsignDerivative(arg()).mul(i_v.get(0));
return Arrays.asList(ret);
return Collections.singletonList(f().softsignBp(arg(), i_v.get(0)));
}
}

View File

@ -27,7 +27,10 @@ import java.util.List;
/**
* Tanh derivative
*
* @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative}.
*/
@Deprecated
public class TanhDerivative extends BaseTransformStrictOp {
public TanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);

View File

@ -42,7 +42,7 @@ public class LecunUniformInitScheme extends BaseWeightInitScheme {
@Override
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
double b = 3.0 / Math.sqrt(fanIn);
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-b, b));
return Nd4j.rand(Nd4j.getDistributions().createUniform(-b, b), shape);
}

View File

@ -43,7 +43,7 @@ public class ReluUniformInitScheme extends BaseWeightInitScheme {
@Override
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
double u = Math.sqrt(6.0 / fanIn);
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
return Nd4j.rand(Nd4j.getDistributions().createUniform(-u, u), shape); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
}

View File

@ -46,7 +46,7 @@ public class SigmoidUniformInitScheme extends BaseWeightInitScheme {
@Override
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut));
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-r, r));
return Nd4j.rand(Nd4j.getDistributions().createUniform(-r, r), shape);
}

View File

@ -43,7 +43,7 @@ public class UniformInitScheme extends BaseWeightInitScheme {
@Override
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
double a = 1.0 / Math.sqrt(fanIn);
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-a, a));
return Nd4j.rand(Nd4j.getDistributions().createUniform(-a, a), shape);
}

View File

@ -43,7 +43,7 @@ public class VarScalingNormalUniformFanInInitScheme extends BaseWeightInitScheme
@Override
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
double scalingFanIn = 3.0 / Math.sqrt(fanIn);
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn));
return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn), shape);
}

View File

@ -42,7 +42,7 @@ public class VarScalingNormalUniformFanOutInitScheme extends BaseWeightInitSchem
@Override
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
double scalingFanOut = 3.0 / Math.sqrt(fanOut);
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut));
return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut), shape);
}

View File

@ -46,7 +46,7 @@ public class VarScalingUniformFanAvgInitScheme extends BaseWeightInitScheme {
@Override
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2);
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg));
return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg), shape);
}

View File

@ -46,7 +46,7 @@ public class XavierUniformInitScheme extends BaseWeightInitScheme {
//As per Glorot and Bengio 2010: Uniform distribution U(-s,s) with s = sqrt(6/(fanIn + fanOut))
//Eq 16: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
double s = Math.sqrt(6.0) / Math.sqrt(fanIn + fanOut);
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-s, s));
return Nd4j.rand(Nd4j.getDistributions().createUniform(-s, s), shape);
}