parent
05d45ec050
commit
dce4751fc1
|
@ -790,20 +790,20 @@ public class DifferentialFunctionFactory {
|
||||||
return new StandardizeBp(sameDiff(), stdInput, gradient, dimensions).outputVariable();
|
return new StandardizeBp(sameDiff(), stdInput, gradient, dimensions).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
|
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
|
||||||
return new LayerNorm(sameDiff(), input, gain, bias, dimensions).outputVariable();
|
return new LayerNorm(sameDiff(), input, gain, bias, channelsFirst, dimensions).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, int... dimensions) {
|
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, boolean channelsFirst, int... dimensions) {
|
||||||
return new LayerNormBp(sameDiff(), input, gain, bias, gradient, dimensions).outputVariables();
|
return new LayerNormBp(sameDiff(), input, gain, bias, gradient, channelsFirst, dimensions).outputVariables();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, int... dimensions) {
|
public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
|
||||||
return new LayerNorm(sameDiff(), input, gain, dimensions).outputVariable();
|
return new LayerNorm(sameDiff(), input, gain, channelsFirst, dimensions).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, int... dimensions) {
|
public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) {
|
||||||
return new LayerNormBp(sameDiff(), input, gain, gradient, dimensions).outputVariables();
|
return new LayerNormBp(sameDiff(), input, gain, gradient, channelsFirst, dimensions).outputVariables();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable squaredNorm(SDVariable input, boolean keepDims, int... dimensions) {
|
public SDVariable squaredNorm(SDVariable input, boolean keepDims, int... dimensions) {
|
||||||
|
|
|
@ -759,8 +759,8 @@ public class SDNN extends SDOps {
|
||||||
*
|
*
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
|
public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
|
||||||
return layerNorm(null, input, gain, bias, dimensions);
|
return layerNorm(null, input, gain, bias, channelsFirst, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -772,13 +772,15 @@ public class SDNN extends SDOps {
|
||||||
* @param input Input variable
|
* @param input Input variable
|
||||||
* @param gain gain
|
* @param gain gain
|
||||||
* @param bias bias
|
* @param bias bias
|
||||||
|
* @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data
|
||||||
|
* @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
|
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
|
||||||
validateFloatingPoint("layerNorm", "input", input);
|
validateFloatingPoint("layerNorm", "input", input);
|
||||||
validateFloatingPoint("layerNorm", "gain", gain);
|
validateFloatingPoint("layerNorm", "gain", gain);
|
||||||
validateFloatingPoint("layerNorm", "bias", bias);
|
validateFloatingPoint("layerNorm", "bias", bias);
|
||||||
SDVariable result = f().layerNorm(input, gain, bias, dimensions);
|
SDVariable result = f().layerNorm(input, gain, bias, channelsFirst, dimensions);
|
||||||
return updateVariableNameAndReference(result, name);
|
return updateVariableNameAndReference(result, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -789,8 +791,8 @@ public class SDNN extends SDOps {
|
||||||
*
|
*
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable layerNorm(SDVariable input, SDVariable gain, int... dimensions) {
|
public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
|
||||||
return layerNorm((String)null, input, gain, dimensions);
|
return layerNorm((String)null, input, gain, channelsFirst, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -803,10 +805,10 @@ public class SDNN extends SDOps {
|
||||||
* @param gain gain
|
* @param gain gain
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, int... dimensions) {
|
public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
|
||||||
validateFloatingPoint("layerNorm", "input", input);
|
validateFloatingPoint("layerNorm", "input", input);
|
||||||
validateFloatingPoint("layerNorm", "gain", gain);
|
validateFloatingPoint("layerNorm", "gain", gain);
|
||||||
SDVariable result = f().layerNorm(input, gain, dimensions);
|
SDVariable result = f().layerNorm(input, gain, channelsFirst, dimensions);
|
||||||
return updateVariableNameAndReference(result, name);
|
return updateVariableNameAndReference(result, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
|
import java.lang.reflect.Array;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -611,6 +612,21 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
return in == null ? null : new INDArray[]{in};
|
return in == null ? null : new INDArray[]{in};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected static <T> T[] wrapFilterNull(T... in){
|
||||||
|
int count = 0;
|
||||||
|
for( int i=0; i<in.length; i++ ) {
|
||||||
|
if (in[i] != null) count++;
|
||||||
|
}
|
||||||
|
T[] out = (T[]) Array.newInstance(in.getClass().getComponentType(), count);
|
||||||
|
int j=0;
|
||||||
|
for( int i=0; i<in.length; i++ ){
|
||||||
|
if(in[i] != null){
|
||||||
|
out[j++] = in[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
public static class DynamicCustomOpsBuilder {
|
public static class DynamicCustomOpsBuilder {
|
||||||
protected String opName;
|
protected String opName;
|
||||||
protected int numInputs;
|
protected int numInputs;
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -41,17 +42,16 @@ import java.util.List;
|
||||||
public class LayerNorm extends DynamicCustomOp {
|
public class LayerNorm extends DynamicCustomOp {
|
||||||
|
|
||||||
private boolean noBias = false;
|
private boolean noBias = false;
|
||||||
|
private boolean channelsFirst;
|
||||||
|
|
||||||
public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) {
|
public LayerNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) {
|
||||||
super(null, sameDiff, new SDVariable[] {input, gain, bias}, false);
|
super(null, sameDiff, wrapFilterNull(input, gain, bias), false);
|
||||||
Preconditions.checkArgument(bias != null, "LayerNorm: Use constructor without bias argument if bias is null / not available.");
|
this.channelsFirst = channelsFirst;
|
||||||
setDimensions(dimensions);
|
setDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, int... dimensions) {
|
public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) {
|
||||||
super(null, sameDiff, new SDVariable[] {input, gain}, false);
|
this(sameDiff, input, gain, null, channelsFirst, dimensions);
|
||||||
noBias = true;
|
|
||||||
setDimensions(dimensions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, int... dimensions) {
|
public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, int... dimensions) {
|
||||||
|
@ -73,7 +73,10 @@ public class LayerNorm extends DynamicCustomOp {
|
||||||
Preconditions.checkArgument(dimensions.length > 0, "LayerNorm: You have to provide dimensions");
|
Preconditions.checkArgument(dimensions.length > 0, "LayerNorm: You have to provide dimensions");
|
||||||
|
|
||||||
this.dimensions = dimensions;
|
this.dimensions = dimensions;
|
||||||
|
this.iArguments.clear();
|
||||||
addIArgument(dimensions);
|
addIArgument(dimensions);
|
||||||
|
this.bArguments.clear();
|
||||||
|
this.bArguments.add(channelsFirst);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -96,9 +99,9 @@ public class LayerNorm extends DynamicCustomOp {
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradient) {
|
public List<SDVariable> doDiff(List<SDVariable> gradient) {
|
||||||
SDVariable[] ret;
|
SDVariable[] ret;
|
||||||
if(noBias){
|
if(noBias){
|
||||||
ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), dimensions);
|
ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), channelsFirst, dimensions);
|
||||||
}else{
|
}else{
|
||||||
ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), dimensions);
|
ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), channelsFirst, dimensions);
|
||||||
}
|
}
|
||||||
return Arrays.asList(ret);
|
return Arrays.asList(ret);
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -39,33 +40,28 @@ import java.util.List;
|
||||||
public class LayerNormBp extends DynamicCustomOp {
|
public class LayerNormBp extends DynamicCustomOp {
|
||||||
|
|
||||||
private boolean noBias = false;
|
private boolean noBias = false;
|
||||||
|
private boolean channelsFirst;
|
||||||
|
|
||||||
|
|
||||||
public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, int... dimensions) {
|
public LayerNormBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, @NonNull SDVariable gradient, boolean channelsFirst, int... dimensions) {
|
||||||
super(null, sameDiff, new SDVariable[] {input, gain, bias, gradient}, false);
|
super(null, sameDiff, wrapFilterNull(input, gain, bias, gradient), false);
|
||||||
Preconditions.checkArgument(bias != null, "LayerNormBp: Use constructor without bias argument if bias is null / not available.");
|
this.channelsFirst = channelsFirst;
|
||||||
|
|
||||||
setDimensions(dimensions);
|
setDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LayerNormBp(INDArray input, INDArray gain, INDArray bias, INDArray grad, INDArray dLdx, INDArray dLdg, INDArray dLdb, int... dimensions) {
|
public LayerNormBp(@NonNull INDArray input, @NonNull INDArray gain, INDArray bias, @NonNull INDArray grad, @NonNull INDArray dLdx, @NonNull INDArray dLdg, INDArray dLdb, boolean channelsFirst, int... dimensions) {
|
||||||
super("layer_norm_bp", new INDArray[]{input, gain, bias, grad}, new INDArray[]{dLdx, dLdg, dLdb});
|
super("layer_norm_bp", wrapFilterNull(input, gain, bias, grad), wrapFilterNull(dLdx, dLdg, dLdb));
|
||||||
Preconditions.checkArgument(bias != null, "LayerNormBp: Use constructor without bias argument if bias is null / not available.");
|
this.channelsFirst = channelsFirst;
|
||||||
|
|
||||||
setDimensions(dimensions);
|
setDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable gradient, int... dimensions) {
|
public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) {
|
||||||
super(null, sameDiff, new SDVariable[] {input, gain, gradient}, false);
|
this(sameDiff, input, gain, null, gradient, channelsFirst, dimensions);
|
||||||
noBias = true;
|
|
||||||
setDimensions(dimensions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public LayerNormBp(INDArray input, INDArray gain, INDArray grad, INDArray dLdx, INDArray dLdg, int... dimensions) {
|
public LayerNormBp(INDArray input, INDArray gain, INDArray grad, INDArray dLdx, INDArray dLdg, boolean channelsFirst, int... dimensions) {
|
||||||
super("layer_norm_bp", new INDArray[]{input, gain, grad}, new INDArray[]{dLdx, dLdg});
|
this(input, gain, null, grad, dLdx, dLdg, null, channelsFirst, dimensions);
|
||||||
noBias = true;
|
|
||||||
setDimensions(dimensions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -74,7 +70,10 @@ public class LayerNormBp extends DynamicCustomOp {
|
||||||
Preconditions.checkArgument(dimensions.length > 0, "LayerNormBp: You have to provide dimensions");
|
Preconditions.checkArgument(dimensions.length > 0, "LayerNormBp: You have to provide dimensions");
|
||||||
|
|
||||||
this.dimensions = dimensions;
|
this.dimensions = dimensions;
|
||||||
|
this.iArguments.clear();
|
||||||
addIArgument(dimensions);
|
addIArgument(dimensions);
|
||||||
|
this.bArguments.clear();
|
||||||
|
addBArgument(channelsFirst);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1126,7 +1126,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
SDVariable sdInput = sd.var("input", standardized);
|
SDVariable sdInput = sd.var("input", standardized);
|
||||||
SDVariable sdGain = sd.var("gain", gain);
|
SDVariable sdGain = sd.var("gain", gain);
|
||||||
SDVariable sdBias = sd.var("bias", bias);
|
SDVariable sdBias = sd.var("bias", bias);
|
||||||
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, axis);
|
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, true, axis);
|
||||||
out.norm1("out");
|
out.norm1("out");
|
||||||
|
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
|
@ -1135,6 +1135,38 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLayerNorm4d() {
|
||||||
|
int mb = 3;
|
||||||
|
int ch = 4;
|
||||||
|
for(boolean nchw : new boolean[]{true, false}) {
|
||||||
|
double eps = 0.0;
|
||||||
|
INDArray x = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch});
|
||||||
|
INDArray gain4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
||||||
|
INDArray bias4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch});
|
||||||
|
INDArray mean = x.mean(true, 1, 2, 3);
|
||||||
|
INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1);
|
||||||
|
|
||||||
|
INDArray standardized = x.sub(mean).div(std);
|
||||||
|
INDArray exp = standardized.mul(gain4d).add(bias4d);
|
||||||
|
|
||||||
|
final int[] axis = new int[]{1, 2, 3};
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable sdInput = sd.var("input", x);
|
||||||
|
SDVariable sdGain = sd.var("gain", gain4d.reshape(ch));
|
||||||
|
SDVariable sdBias = sd.var("bias", bias4d.reshape(ch));
|
||||||
|
SDVariable out = sd.nn.layerNorm("layernorm", sdInput, sdGain, sdBias, nchw, axis);
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss.l2Loss(out);
|
||||||
|
|
||||||
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
|
.expectedOutput("layernorm", exp)
|
||||||
|
.gradientCheck(true));
|
||||||
|
assertNull(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLayerNormOP() {
|
public void testLayerNormOP() {
|
||||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
||||||
|
@ -1165,7 +1197,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable sdInput = sd.var("input", standardized);
|
SDVariable sdInput = sd.var("input", standardized);
|
||||||
SDVariable sdGain = sd.var("gain", gain);
|
SDVariable sdGain = sd.var("gain", gain);
|
||||||
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, axis);
|
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, true, axis);
|
||||||
out.norm1("out");
|
out.norm1("out");
|
||||||
|
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
|
@ -1209,7 +1241,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
SDVariable sdInput = sd.var("input", standardized);
|
SDVariable sdInput = sd.var("input", standardized);
|
||||||
SDVariable sdGain = sd.var("gain", gain);
|
SDVariable sdGain = sd.var("gain", gain);
|
||||||
SDVariable sdBias = sd.var("bias", bias);
|
SDVariable sdBias = sd.var("bias", bias);
|
||||||
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, axis);
|
SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, true, axis);
|
||||||
out.norm1("out");
|
out.norm1("out");
|
||||||
|
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
|
|
Loading…
Reference in New Issue