parent
efbfafe3f7
commit
fd22a8ecc7
|
@ -93,7 +93,7 @@ public abstract class BaseLayer<LayerConfT extends org.deeplearning4j.nn.conf.la
|
|||
INDArray g = getParam(DefaultParamInitializer.GAIN_KEY);
|
||||
|
||||
INDArray dldg = gradientViews.get(DefaultParamInitializer.GAIN_KEY);
|
||||
Nd4j.getExecutioner().exec(new LayerNormBp(preNorm, g, delta, delta, dldg, 1));
|
||||
Nd4j.getExecutioner().exec(new LayerNormBp(preNorm, g, delta, delta, dldg, true, 1));
|
||||
ret.gradientForVariable().put(DefaultParamInitializer.GAIN_KEY, dldg);
|
||||
|
||||
}
|
||||
|
@ -318,7 +318,7 @@ public abstract class BaseLayer<LayerConfT extends org.deeplearning4j.nn.conf.la
|
|||
INDArray preNorm = ret;
|
||||
if(hasLayerNorm()){
|
||||
preNorm = (forBackprop ? ret.dup(ret.ordering()) : ret);
|
||||
Nd4j.getExecutioner().exec(new LayerNorm(preNorm, g, ret, 1));
|
||||
Nd4j.getExecutioner().exec(new LayerNorm(preNorm, g, ret, true, 1));
|
||||
}
|
||||
|
||||
if(hasBias()){
|
||||
|
|
|
@ -158,7 +158,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
dldnCurrent = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape());
|
||||
INDArray ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), gxg.shape());
|
||||
INDArray bgCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, bg.dataType(), bg.shape());
|
||||
Nd4j.getExecutioner().exec(new LayerNormBp(nCurrent, gx, b, dldzCurrent, dldnCurrent, ggCur, bgCur, 1));
|
||||
Nd4j.getExecutioner().exec(new LayerNormBp(nCurrent, gx, b, dldzCurrent, dldnCurrent, ggCur, bgCur, true, 1));
|
||||
gxg.addi(ggCur);
|
||||
bg.addi(bgCur);
|
||||
}else{
|
||||
|
@ -177,7 +177,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
if(hasLayerNorm() && i > end){
|
||||
dldzNext = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape());
|
||||
INDArray ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), grg.shape());
|
||||
Nd4j.getExecutioner().exec(new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, 1));
|
||||
Nd4j.getExecutioner().exec(new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, true, 1));
|
||||
grg.addi(ggCur);
|
||||
}else{
|
||||
dldzNext = dldzCurrent;
|
||||
|
@ -256,7 +256,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
if(hasLayerNorm()){
|
||||
INDArray currOutPreNorm = (forBackprop ? outPreNorm : out).get(all(), all(), point(i));
|
||||
Nd4j.gemm(currIn, w, currOutPreNorm, false, false, 1.0, 0.0);
|
||||
Nd4j.getExecutioner().exec(new LayerNorm(currOutPreNorm, gx, b, currOut, 1));
|
||||
Nd4j.getExecutioner().exec(new LayerNorm(currOutPreNorm, gx, b, currOut, true, 1));
|
||||
}else{
|
||||
Nd4j.gemm(currIn, w, currOut, false, false, 1.0, 1.0); //beta = 1.0 to keep previous contents (bias)
|
||||
}
|
||||
|
@ -266,7 +266,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
INDArray currRecPreNorm = forBackprop ? recPreNorm.get(all(), all(), point(i)) : workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');;
|
||||
Nd4j.gemm(prevStepOut, rw, currRecPreNorm, false, false, 1.0, 0.0);
|
||||
INDArray recNorm = workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');
|
||||
Nd4j.getExecutioner().exec(new LayerNorm(currRecPreNorm, gr, recNorm, 1));
|
||||
Nd4j.getExecutioner().exec(new LayerNorm(currRecPreNorm, gr, recNorm, true, 1));
|
||||
currOut.addi(recNorm);
|
||||
}else {
|
||||
Nd4j.gemm(prevStepOut, rw, currOut, false, false, 1.0, 1.0); //beta = 1.0 to keep previous contents
|
||||
|
|
|
@ -54,17 +54,14 @@ public class LayerNorm extends DynamicCustomOp {
|
|||
this(sameDiff, input, gain, null, channelsFirst, dimensions);
|
||||
}
|
||||
|
||||
public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, int... dimensions) {
|
||||
super("layer_norm", new INDArray[]{input, gain, bias}, new INDArray[]{result});
|
||||
Preconditions.checkArgument(bias != null, "LayerNorm: Use different constructor if bias is null.");
|
||||
|
||||
public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, boolean channelsFirst, int... dimensions) {
|
||||
super("layer_norm", wrapFilterNull(input, gain, bias), wrapOrNull(result));
|
||||
this.channelsFirst = channelsFirst;
|
||||
setDimensions(dimensions);
|
||||
}
|
||||
|
||||
public LayerNorm(INDArray input, INDArray gain, INDArray result, int... dimensions) {
|
||||
super("layer_norm", new INDArray[]{input, gain}, new INDArray[]{result});
|
||||
noBias = true;
|
||||
setDimensions(dimensions);
|
||||
public LayerNorm(INDArray input, INDArray gain, INDArray result, boolean channelsFirst, int... dimensions) {
|
||||
this(input, gain, null, result, channelsFirst, dimensions);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
Loading…
Reference in New Issue