parent
59a6e4e3ae
commit
b46f9827b8
|
@ -530,7 +530,7 @@ public abstract class DifferentialFunction {
|
||||||
public SDVariable arg(int num){
|
public SDVariable arg(int num){
|
||||||
SDVariable[] args = args();
|
SDVariable[] args = args();
|
||||||
Preconditions.checkNotNull(args, "Arguments are null for function %s", this.getOwnName());
|
Preconditions.checkNotNull(args, "Arguments are null for function %s", this.getOwnName());
|
||||||
Preconditions.checkArgument(num >= 0 && num < args.length, "Invalid index: must be 0 to numArgs (0 <= idx < %s)", args.length);
|
Preconditions.checkArgument(num >= 0 && num < args.length, "Invalid index: must be 0 to numArgs (0 <= idx < %s), got %s", args.length, num);
|
||||||
return args[num];
|
return args[num];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ public class LayerNorm extends DynamicCustomOp {
|
||||||
|
|
||||||
public LayerNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
|
public LayerNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
|
||||||
super(null, sameDiff, wrapFilterNull(input, gain, bias), false);
|
super(null, sameDiff, wrapFilterNull(input, gain, bias), false);
|
||||||
|
this.noBias = bias == null;
|
||||||
this.channelsFirst = channelsFirst;
|
this.channelsFirst = channelsFirst;
|
||||||
setDimensions(dimensions);
|
setDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
@ -56,6 +57,7 @@ public class LayerNorm extends DynamicCustomOp {
|
||||||
|
|
||||||
public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, boolean channelsFirst, int... dimensions) {
|
public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, boolean channelsFirst, int... dimensions) {
|
||||||
super("layer_norm", wrapFilterNull(input, gain, bias), wrapOrNull(result));
|
super("layer_norm", wrapFilterNull(input, gain, bias), wrapOrNull(result));
|
||||||
|
this.noBias = bias == null;
|
||||||
this.channelsFirst = channelsFirst;
|
this.channelsFirst = channelsFirst;
|
||||||
setDimensions(dimensions);
|
setDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
@ -115,4 +117,8 @@ public class LayerNorm extends DynamicCustomOp {
|
||||||
return Collections.singletonList(first);
|
return Collections.singletonList(first);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numOutputArguments() {
|
||||||
|
return noBias ? 2 : 3;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,12 +45,14 @@ public class LayerNormBp extends DynamicCustomOp {
|
||||||
|
|
||||||
public LayerNormBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, @NonNull SDVariable gradient, boolean channelsFirst, int... dimensions) {
|
public LayerNormBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, @NonNull SDVariable gradient, boolean channelsFirst, int... dimensions) {
|
||||||
super(null, sameDiff, wrapFilterNull(input, gain, bias, gradient), false);
|
super(null, sameDiff, wrapFilterNull(input, gain, bias, gradient), false);
|
||||||
|
this.noBias = bias == null;
|
||||||
this.channelsFirst = channelsFirst;
|
this.channelsFirst = channelsFirst;
|
||||||
setDimensions(dimensions);
|
setDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LayerNormBp(@NonNull INDArray input, @NonNull INDArray gain, INDArray bias, @NonNull INDArray grad, @NonNull INDArray dLdx, @NonNull INDArray dLdg, INDArray dLdb, boolean channelsFirst, int... dimensions) {
|
public LayerNormBp(@NonNull INDArray input, @NonNull INDArray gain, INDArray bias, @NonNull INDArray grad, @NonNull INDArray dLdx, @NonNull INDArray dLdg, INDArray dLdb, boolean channelsFirst, int... dimensions) {
|
||||||
super("layer_norm_bp", wrapFilterNull(input, gain, bias, grad), wrapFilterNull(dLdx, dLdg, dLdb));
|
super("layer_norm_bp", wrapFilterNull(input, gain, bias, grad), wrapFilterNull(dLdx, dLdg, dLdb));
|
||||||
|
this.noBias = bias == null;
|
||||||
this.channelsFirst = channelsFirst;
|
this.channelsFirst = channelsFirst;
|
||||||
setDimensions(dimensions);
|
setDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1112,12 +1112,12 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLayerNorm() {
|
public void testLayerNorm() {
|
||||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||||
final INDArray standardized = random.ulike();
|
final INDArray standardized = random.ulike();
|
||||||
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
||||||
|
|
||||||
final INDArray gain = Nd4j.rand(new int[]{1, 4});
|
final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4);
|
||||||
final INDArray bias = Nd4j.rand(new int[]{1, 4});
|
final INDArray bias = Nd4j.rand(DataType.DOUBLE, 4);
|
||||||
final INDArray res = standardized.mulRowVector(gain).addRowVector(bias);
|
final INDArray res = standardized.mulRowVector(gain).addRowVector(bias);
|
||||||
final INDArray expOut = res.norm1();
|
final INDArray expOut = res.norm1();
|
||||||
|
|
||||||
|
@ -1132,7 +1132,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
.expectedOutput("out", expOut)
|
.expectedOutput("out", expOut)
|
||||||
.gradientCheck(true));
|
.gradientCheck(true));
|
||||||
assertNull(err, err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -1141,9 +1141,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
int ch = 4;
|
int ch = 4;
|
||||||
for(boolean nchw : new boolean[]{true, false}) {
|
for(boolean nchw : new boolean[]{true, false}) {
|
||||||
double eps = 0.0;
|
double eps = 0.0;
|
||||||
INDArray x = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch});
|
INDArray x = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch});
|
||||||
INDArray gain4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
INDArray gain4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
||||||
INDArray bias4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
INDArray bias4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
||||||
INDArray mean = x.mean(true, 1, 2, 3);
|
INDArray mean = x.mean(true, 1, 2, 3);
|
||||||
INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1);
|
INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1);
|
||||||
|
|
||||||
|
@ -1169,12 +1169,12 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLayerNormOP() {
|
public void testLayerNormOP() {
|
||||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||||
final INDArray standardized = random.ulike();
|
final INDArray standardized = random.ulike();
|
||||||
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
||||||
|
|
||||||
final INDArray gain = Nd4j.rand(new int[]{1, 4});
|
final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4);
|
||||||
final INDArray bias = Nd4j.rand(new int[]{1, 4});
|
final INDArray bias = Nd4j.rand(DataType.DOUBLE, 4);
|
||||||
final INDArray res = standardized.mulRowVector(gain).addRowVector(bias);
|
final INDArray res = standardized.mulRowVector(gain).addRowVector(bias);
|
||||||
|
|
||||||
final INDArray output = Nd4j.zerosLike(res);
|
final INDArray output = Nd4j.zerosLike(res);
|
||||||
|
@ -1185,11 +1185,11 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLayerNormNoBias() {
|
public void testLayerNormNoBias() {
|
||||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||||
final INDArray standardized = random.ulike();
|
final INDArray standardized = random.ulike();
|
||||||
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
||||||
|
|
||||||
final INDArray gain = Nd4j.rand(new int[]{1, 4});
|
final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4);
|
||||||
final INDArray res = standardized.mulRowVector(gain);
|
final INDArray res = standardized.mulRowVector(gain);
|
||||||
final INDArray expOut = res.norm1();
|
final INDArray expOut = res.norm1();
|
||||||
|
|
||||||
|
@ -1208,11 +1208,11 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLayerNormOPNoBias() {
|
public void testLayerNormOPNoBias() {
|
||||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||||
final INDArray standardized = random.ulike();
|
final INDArray standardized = random.ulike();
|
||||||
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
||||||
|
|
||||||
final INDArray gain = Nd4j.rand(new int[]{1, 4});
|
final INDArray gain = Nd4j.rand(DataType.DOUBLE,4);
|
||||||
final INDArray res = standardized.mulRowVector(gain);
|
final INDArray res = standardized.mulRowVector(gain);
|
||||||
|
|
||||||
final INDArray output = Nd4j.zerosLike(res);
|
final INDArray output = Nd4j.zerosLike(res);
|
||||||
|
@ -1223,7 +1223,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLayerNormNoDeviation() {
|
public void testLayerNormNoDeviation() {
|
||||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
random.putScalar(1,i, 7);
|
random.putScalar(1,i, 7);
|
||||||
}
|
}
|
||||||
|
@ -1231,8 +1231,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
final INDArray standardized = random.ulike();
|
final INDArray standardized = random.ulike();
|
||||||
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
||||||
|
|
||||||
final INDArray gain = Nd4j.rand(new int[]{1, 4});
|
final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4);
|
||||||
final INDArray bias = Nd4j.rand(new int[]{1, 4});
|
final INDArray bias = Nd4j.rand(DataType.DOUBLE, 4);
|
||||||
final INDArray res = standardized.mulRowVector(gain).addRowVector(bias);
|
final INDArray res = standardized.mulRowVector(gain).addRowVector(bias);
|
||||||
final INDArray expOut = res.norm1();
|
final INDArray expOut = res.norm1();
|
||||||
|
|
||||||
|
@ -1332,8 +1332,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
public void testLayerNormMixedOrders(){
|
public void testLayerNormMixedOrders(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
|
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
|
||||||
INDArray gain = Nd4j.rand(DataType.DOUBLE, 1, 8).dup('f');
|
INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
|
||||||
INDArray bias = Nd4j.rand(DataType.DOUBLE, 1, 8).dup('f');
|
INDArray bias = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
|
||||||
|
|
||||||
INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f');
|
INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f');
|
||||||
INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c');
|
INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c');
|
||||||
|
|
|
@ -412,7 +412,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
.expectedOutput("dp0", expOut[0])
|
.expectedOutput("dp0", expOut[0])
|
||||||
.expectedOutput("dp1", expOut[1])
|
.expectedOutput("dp1", expOut[1])
|
||||||
.gradientCheck(true));
|
.gradientCheck(true));
|
||||||
assertNull(err, err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
Loading…
Reference in New Issue