parent
05d45ec050
commit
dce4751fc1
|
@ -790,20 +790,20 @@ public class DifferentialFunctionFactory {
|
|||
return new StandardizeBp(sameDiff(), stdInput, gradient, dimensions).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
|
||||
return new LayerNorm(sameDiff(), input, gain, bias, dimensions).outputVariable();
|
||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
|
||||
return new LayerNorm(sameDiff(), input, gain, bias, channelsFirst, dimensions).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, int... dimensions) {
|
||||
return new LayerNormBp(sameDiff(), input, gain, bias, gradient, dimensions).outputVariables();
|
||||
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, boolean channelsFirst, int... dimensions) {
|
||||
return new LayerNormBp(sameDiff(), input, gain, bias, gradient, channelsFirst, dimensions).outputVariables();
|
||||
}
|
||||
|
||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, int... dimensions) {
|
||||
return new LayerNorm(sameDiff(), input, gain, dimensions).outputVariable();
|
||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
|
||||
return new LayerNorm(sameDiff(), input, gain, channelsFirst, dimensions).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, int... dimensions) {
|
||||
return new LayerNormBp(sameDiff(), input, gain, gradient, dimensions).outputVariables();
|
||||
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) {
|
||||
return new LayerNormBp(sameDiff(), input, gain, gradient, channelsFirst, dimensions).outputVariables();
|
||||
}
|
||||
|
||||
public SDVariable squaredNorm(SDVariable input, boolean keepDims, int... dimensions) {
|
||||
|
|
|
@ -759,8 +759,8 @@ public class SDNN extends SDOps {
|
|||
*
|
||||
* @return Output variable
|
||||
*/
|
||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
|
||||
return layerNorm(null, input, gain, bias, dimensions);
|
||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
|
||||
return layerNorm(null, input, gain, bias, channelsFirst, dimensions);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -772,13 +772,15 @@ public class SDNN extends SDOps {
|
|||
* @param input Input variable
|
||||
* @param gain gain
|
||||
* @param bias bias
|
||||
* @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data
|
||||
* @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs
|
||||
* @return Output variable
|
||||
*/
|
||||
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
|
||||
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
|
||||
validateFloatingPoint("layerNorm", "input", input);
|
||||
validateFloatingPoint("layerNorm", "gain", gain);
|
||||
validateFloatingPoint("layerNorm", "bias", bias);
|
||||
SDVariable result = f().layerNorm(input, gain, bias, dimensions);
|
||||
SDVariable result = f().layerNorm(input, gain, bias, channelsFirst, dimensions);
|
||||
return updateVariableNameAndReference(result, name);
|
||||
}
|
||||
|
||||
|
@ -789,8 +791,8 @@ public class SDNN extends SDOps {
|
|||
*
|
||||
* @return Output variable
|
||||
*/
|
||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, int... dimensions) {
|
||||
return layerNorm((String)null, input, gain, dimensions);
|
||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
|
||||
return layerNorm((String)null, input, gain, channelsFirst, dimensions);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -803,10 +805,10 @@ public class SDNN extends SDOps {
|
|||
* @param gain gain
|
||||
* @return Output variable
|
||||
*/
|
||||
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, int... dimensions) {
|
||||
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
|
||||
validateFloatingPoint("layerNorm", "input", input);
|
||||
validateFloatingPoint("layerNorm", "gain", gain);
|
||||
SDVariable result = f().layerNorm(input, gain, dimensions);
|
||||
SDVariable result = f().layerNorm(input, gain, channelsFirst, dimensions);
|
||||
return updateVariableNameAndReference(result, name);
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ import org.tensorflow.framework.AttrValue;
|
|||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.lang.reflect.Array;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
|
@ -611,6 +612,21 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
return in == null ? null : new INDArray[]{in};
|
||||
}
|
||||
|
||||
protected static <T> T[] wrapFilterNull(T... in){
|
||||
int count = 0;
|
||||
for( int i=0; i<in.length; i++ ) {
|
||||
if (in[i] != null) count++;
|
||||
}
|
||||
T[] out = (T[]) Array.newInstance(in.getClass().getComponentType(), count);
|
||||
int j=0;
|
||||
for( int i=0; i<in.length; i++ ){
|
||||
if(in[i] != null){
|
||||
out[j++] = in[i];
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
public static class DynamicCustomOpsBuilder {
|
||||
protected String opName;
|
||||
protected int numInputs;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -41,17 +42,16 @@ import java.util.List;
|
|||
public class LayerNorm extends DynamicCustomOp {
|
||||
|
||||
private boolean noBias = false;
|
||||
private boolean channelsFirst;
|
||||
|
||||
public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
|
||||
super(null, sameDiff, new SDVariable[] {input, gain, bias}, false);
|
||||
Preconditions.checkArgument(bias != null, "LayerNorm: Use constructor without bias argument if bias is null / not available.");
|
||||
public LayerNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
|
||||
super(null, sameDiff, wrapFilterNull(input, gain, bias), false);
|
||||
this.channelsFirst = channelsFirst;
|
||||
setDimensions(dimensions);
|
||||
}
|
||||
|
||||
public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, int... dimensions) {
|
||||
super(null, sameDiff, new SDVariable[] {input, gain}, false);
|
||||
noBias = true;
|
||||
setDimensions(dimensions);
|
||||
public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
|
||||
this(sameDiff, input, gain, null, channelsFirst, dimensions);
|
||||
}
|
||||
|
||||
public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, int... dimensions) {
|
||||
|
@ -73,7 +73,10 @@ public class LayerNorm extends DynamicCustomOp {
|
|||
Preconditions.checkArgument(dimensions.length > 0, "LayerNorm: You have to provide dimensions");
|
||||
|
||||
this.dimensions = dimensions;
|
||||
this.iArguments.clear();
|
||||
addIArgument(dimensions);
|
||||
this.bArguments.clear();
|
||||
this.bArguments.add(channelsFirst);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -96,9 +99,9 @@ public class LayerNorm extends DynamicCustomOp {
|
|||
public List<SDVariable> doDiff(List<SDVariable> gradient) {
|
||||
SDVariable[] ret;
|
||||
if(noBias){
|
||||
ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), dimensions);
|
||||
ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), channelsFirst, dimensions);
|
||||
}else{
|
||||
ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), dimensions);
|
||||
ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), channelsFirst, dimensions);
|
||||
}
|
||||
return Arrays.asList(ret);
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -39,33 +40,28 @@ import java.util.List;
|
|||
public class LayerNormBp extends DynamicCustomOp {
|
||||
|
||||
private boolean noBias = false;
|
||||
private boolean channelsFirst;
|
||||
|
||||
|
||||
public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, int... dimensions) {
|
||||
super(null, sameDiff, new SDVariable[] {input, gain, bias, gradient}, false);
|
||||
Preconditions.checkArgument(bias != null, "LayerNormBp: Use constructor without bias argument if bias is null / not available.");
|
||||
|
||||
public LayerNormBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, @NonNull SDVariable gradient, boolean channelsFirst, int... dimensions) {
|
||||
super(null, sameDiff, wrapFilterNull(input, gain, bias, gradient), false);
|
||||
this.channelsFirst = channelsFirst;
|
||||
setDimensions(dimensions);
|
||||
}
|
||||
|
||||
public LayerNormBp(INDArray input, INDArray gain, INDArray bias, INDArray grad, INDArray dLdx, INDArray dLdg, INDArray dLdb, int... dimensions) {
|
||||
super("layer_norm_bp", new INDArray[]{input, gain, bias, grad}, new INDArray[]{dLdx, dLdg, dLdb});
|
||||
Preconditions.checkArgument(bias != null, "LayerNormBp: Use constructor without bias argument if bias is null / not available.");
|
||||
|
||||
public LayerNormBp(@NonNull INDArray input, @NonNull INDArray gain, INDArray bias, @NonNull INDArray grad, @NonNull INDArray dLdx, @NonNull INDArray dLdg, INDArray dLdb, boolean channelsFirst, int... dimensions) {
|
||||
super("layer_norm_bp", wrapFilterNull(input, gain, bias, grad), wrapFilterNull(dLdx, dLdg, dLdb));
|
||||
this.channelsFirst = channelsFirst;
|
||||
setDimensions(dimensions);
|
||||
}
|
||||
|
||||
|
||||
public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable gradient, int... dimensions) {
|
||||
super(null, sameDiff, new SDVariable[] {input, gain, gradient}, false);
|
||||
noBias = true;
|
||||
setDimensions(dimensions);
|
||||
public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) {
|
||||
this(sameDiff, input, gain, null, gradient, channelsFirst, dimensions);
|
||||
}
|
||||
|
||||
public LayerNormBp(INDArray input, INDArray gain, INDArray grad, INDArray dLdx, INDArray dLdg, int... dimensions) {
|
||||
super("layer_norm_bp", new INDArray[]{input, gain, grad}, new INDArray[]{dLdx, dLdg});
|
||||
noBias = true;
|
||||
setDimensions(dimensions);
|
||||
public LayerNormBp(INDArray input, INDArray gain, INDArray grad, INDArray dLdx, INDArray dLdg, boolean channelsFirst, int... dimensions) {
|
||||
this(input, gain, null, grad, dLdx, dLdg, null, channelsFirst, dimensions);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -74,7 +70,10 @@ public class LayerNormBp extends DynamicCustomOp {
|
|||
Preconditions.checkArgument(dimensions.length > 0, "LayerNormBp: You have to provide dimensions");
|
||||
|
||||
this.dimensions = dimensions;
|
||||
this.iArguments.clear();
|
||||
addIArgument(dimensions);
|
||||
this.bArguments.clear();
|
||||
addBArgument(channelsFirst);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1126,7 +1126,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
SDVariable sdInput = sd.var("input", standardized);
|
||||
SDVariable sdGain = sd.var("gain", gain);
|
||||
SDVariable sdBias = sd.var("bias", bias);
|
||||
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, axis);
|
||||
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, true, axis);
|
||||
out.norm1("out");
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
|
@ -1135,6 +1135,38 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
assertNull(err, err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLayerNorm4d() {
|
||||
int mb = 3;
|
||||
int ch = 4;
|
||||
for(boolean nchw : new boolean[]{true, false}) {
|
||||
double eps = 0.0;
|
||||
INDArray x = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch});
|
||||
INDArray gain4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
||||
INDArray bias4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
||||
INDArray mean = x.mean(true, 1, 2, 3);
|
||||
INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1);
|
||||
|
||||
INDArray standardized = x.sub(mean).div(std);
|
||||
INDArray exp = standardized.mul(gain4d).add(bias4d);
|
||||
|
||||
final int[] axis = new int[]{1, 2, 3};
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable sdInput = sd.var("input", x);
|
||||
SDVariable sdGain = sd.var("gain", gain4d.reshape(ch));
|
||||
SDVariable sdBias = sd.var("bias", bias4d.reshape(ch));
|
||||
SDVariable out = sd.nn.layerNorm("layernorm", sdInput, sdGain, sdBias, nchw, axis);
|
||||
|
||||
SDVariable loss = sd.loss.l2Loss(out);
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.expectedOutput("layernorm", exp)
|
||||
.gradientCheck(true));
|
||||
assertNull(err);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testLayerNormOP() {
|
||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
||||
|
@ -1165,7 +1197,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
SameDiff sd = SameDiff.create();
|
||||
SDVariable sdInput = sd.var("input", standardized);
|
||||
SDVariable sdGain = sd.var("gain", gain);
|
||||
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, axis);
|
||||
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, true, axis);
|
||||
out.norm1("out");
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
|
@ -1209,7 +1241,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
SDVariable sdInput = sd.var("input", standardized);
|
||||
SDVariable sdGain = sd.var("gain", gain);
|
||||
SDVariable sdBias = sd.var("bias", bias);
|
||||
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, axis);
|
||||
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, true, axis);
|
||||
out.norm1("out");
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
|
|
Loading…
Reference in New Issue