Layer norm 4d case fixes (#174)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-27 18:34:53 +10:00 committed by GitHub
parent 05d45ec050
commit dce4751fc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 96 additions and 44 deletions

View File

@ -790,20 +790,20 @@ public class DifferentialFunctionFactory {
return new StandardizeBp(sameDiff(), stdInput, gradient, dimensions).outputVariable();
}
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
return new LayerNorm(sameDiff(), input, gain, bias, dimensions).outputVariable();
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
return new LayerNorm(sameDiff(), input, gain, bias, channelsFirst, dimensions).outputVariable();
}
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, int... dimensions) {
return new LayerNormBp(sameDiff(), input, gain, bias, gradient, dimensions).outputVariables();
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, boolean channelsFirst, int... dimensions) {
return new LayerNormBp(sameDiff(), input, gain, bias, gradient, channelsFirst, dimensions).outputVariables();
}
public SDVariable layerNorm(SDVariable input, SDVariable gain, int... dimensions) {
return new LayerNorm(sameDiff(), input, gain, dimensions).outputVariable();
public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
return new LayerNorm(sameDiff(), input, gain, channelsFirst, dimensions).outputVariable();
}
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, int... dimensions) {
return new LayerNormBp(sameDiff(), input, gain, gradient, dimensions).outputVariables();
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) {
return new LayerNormBp(sameDiff(), input, gain, gradient, channelsFirst, dimensions).outputVariables();
}
public SDVariable squaredNorm(SDVariable input, boolean keepDims, int... dimensions) {

View File

@ -759,8 +759,8 @@ public class SDNN extends SDOps {
*
* @return Output variable
*/
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
return layerNorm(null, input, gain, bias, dimensions);
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
return layerNorm(null, input, gain, bias, channelsFirst, dimensions);
}
/**
@ -772,13 +772,15 @@ public class SDNN extends SDOps {
* @param input Input variable
* @param gain gain
* @param bias bias
* @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data
* @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs
* @return Output variable
*/
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
validateFloatingPoint("layerNorm", "input", input);
validateFloatingPoint("layerNorm", "gain", gain);
validateFloatingPoint("layerNorm", "bias", bias);
SDVariable result = f().layerNorm(input, gain, bias, dimensions);
SDVariable result = f().layerNorm(input, gain, bias, channelsFirst, dimensions);
return updateVariableNameAndReference(result, name);
}
@ -789,8 +791,8 @@ public class SDNN extends SDOps {
*
* @return Output variable
*/
public SDVariable layerNorm(SDVariable input, SDVariable gain, int... dimensions) {
return layerNorm((String)null, input, gain, dimensions);
public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
return layerNorm((String)null, input, gain, channelsFirst, dimensions);
}
/**
@ -803,10 +805,10 @@ public class SDNN extends SDOps {
* @param gain gain
* @return Output variable
*/
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, int... dimensions) {
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
validateFloatingPoint("layerNorm", "input", input);
validateFloatingPoint("layerNorm", "gain", gain);
SDVariable result = f().layerNorm(input, gain, dimensions);
SDVariable result = f().layerNorm(input, gain, channelsFirst, dimensions);
return updateVariableNameAndReference(result, name);
}

View File

@ -35,6 +35,7 @@ import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Array;
import java.util.*;
/**
@ -611,6 +612,21 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
return in == null ? null : new INDArray[]{in};
}
protected static <T> T[] wrapFilterNull(T... in){
int count = 0;
for( int i=0; i<in.length; i++ ) {
if (in[i] != null) count++;
}
T[] out = (T[]) Array.newInstance(in.getClass().getComponentType(), count);
int j=0;
for( int i=0; i<in.length; i++ ){
if(in[i] != null){
out[j++] = in[i];
}
}
return out;
}
public static class DynamicCustomOpsBuilder {
protected String opName;
protected int numInputs;

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -41,17 +42,16 @@ import java.util.List;
public class LayerNorm extends DynamicCustomOp {
private boolean noBias = false;
private boolean channelsFirst;
public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
super(null, sameDiff, new SDVariable[] {input, gain, bias}, false);
Preconditions.checkArgument(bias != null, "LayerNorm: Use constructor without bias argument if bias is null / not available.");
public LayerNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
super(null, sameDiff, wrapFilterNull(input, gain, bias), false);
this.channelsFirst = channelsFirst;
setDimensions(dimensions);
}
public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, int... dimensions) {
super(null, sameDiff, new SDVariable[] {input, gain}, false);
noBias = true;
setDimensions(dimensions);
public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
this(sameDiff, input, gain, null, channelsFirst, dimensions);
}
public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, int... dimensions) {
@ -73,7 +73,10 @@ public class LayerNorm extends DynamicCustomOp {
Preconditions.checkArgument(dimensions.length > 0, "LayerNorm: You have to provide dimensions");
this.dimensions = dimensions;
this.iArguments.clear();
addIArgument(dimensions);
this.bArguments.clear();
this.bArguments.add(channelsFirst);
}
@Override
@ -96,9 +99,9 @@ public class LayerNorm extends DynamicCustomOp {
public List<SDVariable> doDiff(List<SDVariable> gradient) {
SDVariable[] ret;
if(noBias){
ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), dimensions);
ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), channelsFirst, dimensions);
}else{
ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), dimensions);
ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), channelsFirst, dimensions);
}
return Arrays.asList(ret);
}

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -39,33 +40,28 @@ import java.util.List;
public class LayerNormBp extends DynamicCustomOp {
private boolean noBias = false;
private boolean channelsFirst;
public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, int... dimensions) {
super(null, sameDiff, new SDVariable[] {input, gain, bias, gradient}, false);
Preconditions.checkArgument(bias != null, "LayerNormBp: Use constructor without bias argument if bias is null / not available.");
public LayerNormBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, @NonNull SDVariable gradient, boolean channelsFirst, int... dimensions) {
super(null, sameDiff, wrapFilterNull(input, gain, bias, gradient), false);
this.channelsFirst = channelsFirst;
setDimensions(dimensions);
}
public LayerNormBp(INDArray input, INDArray gain, INDArray bias, INDArray grad, INDArray dLdx, INDArray dLdg, INDArray dLdb, int... dimensions) {
super("layer_norm_bp", new INDArray[]{input, gain, bias, grad}, new INDArray[]{dLdx, dLdg, dLdb});
Preconditions.checkArgument(bias != null, "LayerNormBp: Use constructor without bias argument if bias is null / not available.");
public LayerNormBp(@NonNull INDArray input, @NonNull INDArray gain, INDArray bias, @NonNull INDArray grad, @NonNull INDArray dLdx, @NonNull INDArray dLdg, INDArray dLdb, boolean channelsFirst, int... dimensions) {
super("layer_norm_bp", wrapFilterNull(input, gain, bias, grad), wrapFilterNull(dLdx, dLdg, dLdb));
this.channelsFirst = channelsFirst;
setDimensions(dimensions);
}
public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable gradient, int... dimensions) {
super(null, sameDiff, new SDVariable[] {input, gain, gradient}, false);
noBias = true;
setDimensions(dimensions);
public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) {
this(sameDiff, input, gain, null, gradient, channelsFirst, dimensions);
}
public LayerNormBp(INDArray input, INDArray gain, INDArray grad, INDArray dLdx, INDArray dLdg, int... dimensions) {
super("layer_norm_bp", new INDArray[]{input, gain, grad}, new INDArray[]{dLdx, dLdg});
noBias = true;
setDimensions(dimensions);
public LayerNormBp(INDArray input, INDArray gain, INDArray grad, INDArray dLdx, INDArray dLdg, boolean channelsFirst, int... dimensions) {
this(input, gain, null, grad, dLdx, dLdg, null, channelsFirst, dimensions);
}
@Override
@ -74,7 +70,10 @@ public class LayerNormBp extends DynamicCustomOp {
Preconditions.checkArgument(dimensions.length > 0, "LayerNormBp: You have to provide dimensions");
this.dimensions = dimensions;
this.iArguments.clear();
addIArgument(dimensions);
this.bArguments.clear();
addBArgument(channelsFirst);
}
@Override

View File

@ -1126,7 +1126,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable sdInput = sd.var("input", standardized);
SDVariable sdGain = sd.var("gain", gain);
SDVariable sdBias = sd.var("bias", bias);
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, axis);
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, true, axis);
out.norm1("out");
String err = OpValidation.validate(new TestCase(sd)
@ -1135,6 +1135,38 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err, err);
}
@Test
public void testLayerNorm4d() {
int mb = 3;
int ch = 4;
for(boolean nchw : new boolean[]{true, false}) {
double eps = 0.0;
INDArray x = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch});
INDArray gain4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
INDArray bias4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
INDArray mean = x.mean(true, 1, 2, 3);
INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1);
INDArray standardized = x.sub(mean).div(std);
INDArray exp = standardized.mul(gain4d).add(bias4d);
final int[] axis = new int[]{1, 2, 3};
SameDiff sd = SameDiff.create();
SDVariable sdInput = sd.var("input", x);
SDVariable sdGain = sd.var("gain", gain4d.reshape(ch));
SDVariable sdBias = sd.var("bias", bias4d.reshape(ch));
SDVariable out = sd.nn.layerNorm("layernorm", sdInput, sdGain, sdBias, nchw, axis);
SDVariable loss = sd.loss.l2Loss(out);
String err = OpValidation.validate(new TestCase(sd)
.expectedOutput("layernorm", exp)
.gradientCheck(true));
assertNull(err);
}
}
@Test
public void testLayerNormOP() {
final INDArray random = Nd4j.rand(new int[]{10, 4});
@ -1165,7 +1197,7 @@ public class LayerOpValidation extends BaseOpValidation {
SameDiff sd = SameDiff.create();
SDVariable sdInput = sd.var("input", standardized);
SDVariable sdGain = sd.var("gain", gain);
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, axis);
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, true, axis);
out.norm1("out");
String err = OpValidation.validate(new TestCase(sd)
@ -1209,7 +1241,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable sdInput = sd.var("input", standardized);
SDVariable sdGain = sd.var("gain", gain);
SDVariable sdBias = sd.var("bias", bias);
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, axis);
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, true, axis);
out.norm1("out");
String err = OpValidation.validate(new TestCase(sd)