diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 3bf1754db..e16bd3dc2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -790,20 +790,20 @@ public class DifferentialFunctionFactory { return new StandardizeBp(sameDiff(), stdInput, gradient, dimensions).outputVariable(); } - public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) { - return new LayerNorm(sameDiff(), input, gain, bias, dimensions).outputVariable(); + public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { + return new LayerNorm(sameDiff(), input, gain, bias, channelsFirst, dimensions).outputVariable(); } - public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, int... dimensions) { - return new LayerNormBp(sameDiff(), input, gain, bias, gradient, dimensions).outputVariables(); + public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, boolean channelsFirst, int... dimensions) { + return new LayerNormBp(sameDiff(), input, gain, bias, gradient, channelsFirst, dimensions).outputVariables(); } - public SDVariable layerNorm(SDVariable input, SDVariable gain, int... dimensions) { - return new LayerNorm(sameDiff(), input, gain, dimensions).outputVariable(); + public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { + return new LayerNorm(sameDiff(), input, gain, channelsFirst, dimensions).outputVariable(); } - public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, int... dimensions) { - return new LayerNormBp(sameDiff(), input, gain, gradient, dimensions).outputVariables(); + public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) { + return new LayerNormBp(sameDiff(), input, gain, gradient, channelsFirst, dimensions).outputVariables(); } public SDVariable squaredNorm(SDVariable input, boolean keepDims, int... dimensions) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 928bf3e6e..eb89a0f3a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -759,8 +759,8 @@ public class SDNN extends SDOps { * * @return Output variable */ - public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) { - return layerNorm(null, input, gain, bias, dimensions); + public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { + return layerNorm(null, input, gain, bias, channelsFirst, dimensions); } /** @@ -772,13 +772,15 @@ public class SDNN extends SDOps { * @param input Input variable * @param gain gain * @param bias bias + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs * @return Output variable */ - public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) { + public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { validateFloatingPoint("layerNorm", "input", input); validateFloatingPoint("layerNorm", "gain", gain); validateFloatingPoint("layerNorm", "bias", bias); - SDVariable result = f().layerNorm(input, gain, bias, dimensions); + SDVariable result = f().layerNorm(input, gain, bias, channelsFirst, dimensions); return updateVariableNameAndReference(result, name); } @@ -789,8 +791,8 @@ public class SDNN extends SDOps { * * @return Output variable */ - public SDVariable layerNorm(SDVariable input, SDVariable gain, int... dimensions) { - return layerNorm((String)null, input, gain, dimensions); + public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { + return layerNorm((String)null, input, gain, channelsFirst, dimensions); } /** @@ -803,10 +805,10 @@ public class SDNN extends SDOps { * @param gain gain * @return Output variable */ - public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, int... dimensions) { + public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { validateFloatingPoint("layerNorm", "input", input); validateFloatingPoint("layerNorm", "gain", gain); - SDVariable result = f().layerNorm(input, gain, dimensions); + SDVariable result = f().layerNorm(input, gain, channelsFirst, dimensions); return updateVariableNameAndReference(result, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index f52450eee..27e8ae281 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -35,6 +35,7 @@ import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; +import java.lang.reflect.Array; import java.util.*; /** @@ -611,6 +612,21 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { return in == null ? null : new INDArray[]{in}; } + protected static T[] wrapFilterNull(T... in){ + int count = 0; + for( int i=0; i 0, "LayerNorm: You have to provide dimensions"); this.dimensions = dimensions; + this.iArguments.clear(); addIArgument(dimensions); + this.bArguments.clear(); + this.bArguments.add(channelsFirst); } @Override @@ -96,9 +99,9 @@ public class LayerNorm extends DynamicCustomOp { public List doDiff(List gradient) { SDVariable[] ret; if(noBias){ - ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), dimensions); + ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), channelsFirst, dimensions); }else{ - ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), dimensions); + ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), channelsFirst, dimensions); } return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java index cfd4fff65..2168fd165 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -39,33 +40,28 @@ import java.util.List; public class LayerNormBp extends DynamicCustomOp { private boolean noBias = false; + private boolean channelsFirst; - public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, int... dimensions) { - super(null, sameDiff, new SDVariable[] {input, gain, bias, gradient}, false); - Preconditions.checkArgument(bias != null, "LayerNormBp: Use constructor without bias argument if bias is null / not available."); - + public LayerNormBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, @NonNull SDVariable gradient, boolean channelsFirst, int... dimensions) { + super(null, sameDiff, wrapFilterNull(input, gain, bias, gradient), false); + this.channelsFirst = channelsFirst; setDimensions(dimensions); } - public LayerNormBp(INDArray input, INDArray gain, INDArray bias, INDArray grad, INDArray dLdx, INDArray dLdg, INDArray dLdb, int... dimensions) { - super("layer_norm_bp", new INDArray[]{input, gain, bias, grad}, new INDArray[]{dLdx, dLdg, dLdb}); - Preconditions.checkArgument(bias != null, "LayerNormBp: Use constructor without bias argument if bias is null / not available."); - + public LayerNormBp(@NonNull INDArray input, @NonNull INDArray gain, INDArray bias, @NonNull INDArray grad, @NonNull INDArray dLdx, @NonNull INDArray dLdg, INDArray dLdb, boolean channelsFirst, int... dimensions) { + super("layer_norm_bp", wrapFilterNull(input, gain, bias, grad), wrapFilterNull(dLdx, dLdg, dLdb)); + this.channelsFirst = channelsFirst; setDimensions(dimensions); } - public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable gradient, int... dimensions) { - super(null, sameDiff, new SDVariable[] {input, gain, gradient}, false); - noBias = true; - setDimensions(dimensions); + public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) { + this(sameDiff, input, gain, null, gradient, channelsFirst, dimensions); } - public LayerNormBp(INDArray input, INDArray gain, INDArray grad, INDArray dLdx, INDArray dLdg, int... dimensions) { - super("layer_norm_bp", new INDArray[]{input, gain, grad}, new INDArray[]{dLdx, dLdg}); - noBias = true; - setDimensions(dimensions); + public LayerNormBp(INDArray input, INDArray gain, INDArray grad, INDArray dLdx, INDArray dLdg, boolean channelsFirst, int... dimensions) { + this(input, gain, null, grad, dLdx, dLdg, null, channelsFirst, dimensions); } @Override @@ -74,7 +70,10 @@ public class LayerNormBp extends DynamicCustomOp { Preconditions.checkArgument(dimensions.length > 0, "LayerNormBp: You have to provide dimensions"); this.dimensions = dimensions; + this.iArguments.clear(); addIArgument(dimensions); + this.bArguments.clear(); + addBArgument(channelsFirst); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 84bd96ad6..fde2170a6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -1126,7 +1126,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable sdInput = sd.var("input", standardized); SDVariable sdGain = sd.var("gain", gain); SDVariable sdBias = sd.var("bias", bias); - SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, axis); + SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, true, axis); out.norm1("out"); String err = OpValidation.validate(new TestCase(sd) @@ -1135,6 +1135,38 @@ public class LayerOpValidation extends BaseOpValidation { assertNull(err, err); } + @Test + public void testLayerNorm4d() { + int mb = 3; + int ch = 4; + for(boolean nchw : new boolean[]{true, false}) { + double eps = 0.0; + INDArray x = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch}); + INDArray gain4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); + INDArray bias4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); + INDArray mean = x.mean(true, 1, 2, 3); + INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1); + + INDArray standardized = x.sub(mean).div(std); + INDArray exp = standardized.mul(gain4d).add(bias4d); + + final int[] axis = new int[]{1, 2, 3}; + SameDiff sd = SameDiff.create(); + SDVariable sdInput = sd.var("input", x); + SDVariable sdGain = sd.var("gain", gain4d.reshape(ch)); + SDVariable sdBias = sd.var("bias", bias4d.reshape(ch)); + SDVariable out = sd.nn.layerNorm("layernorm", sdInput, sdGain, sdBias, nchw, axis); + + SDVariable loss = sd.loss.l2Loss(out); + + String err = OpValidation.validate(new TestCase(sd) + .expectedOutput("layernorm", exp) + .gradientCheck(true)); + assertNull(err); + } + } + + @Test public void testLayerNormOP() { final INDArray random = Nd4j.rand(new int[]{10, 4}); @@ -1165,7 +1197,7 @@ public class LayerOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable sdInput = sd.var("input", standardized); SDVariable sdGain = sd.var("gain", gain); - SDVariable out = sd.nn.layerNorm(sdInput, sdGain, axis); + SDVariable out = sd.nn.layerNorm(sdInput, sdGain, true, axis); out.norm1("out"); String err = OpValidation.validate(new TestCase(sd) @@ -1209,7 +1241,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable sdInput = sd.var("input", standardized); SDVariable sdGain = sd.var("gain", gain); SDVariable sdBias = sd.var("bias", bias); - SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, axis); + SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, true, axis); out.norm1("out"); String err = OpValidation.validate(new TestCase(sd)