parent
efbfafe3f7
commit
fd22a8ecc7
|
@ -93,7 +93,7 @@ public abstract class BaseLayer<LayerConfT extends org.deeplearning4j.nn.conf.la
|
||||||
INDArray g = getParam(DefaultParamInitializer.GAIN_KEY);
|
INDArray g = getParam(DefaultParamInitializer.GAIN_KEY);
|
||||||
|
|
||||||
INDArray dldg = gradientViews.get(DefaultParamInitializer.GAIN_KEY);
|
INDArray dldg = gradientViews.get(DefaultParamInitializer.GAIN_KEY);
|
||||||
Nd4j.getExecutioner().exec(new LayerNormBp(preNorm, g, delta, delta, dldg, 1));
|
Nd4j.getExecutioner().exec(new LayerNormBp(preNorm, g, delta, delta, dldg, true, 1));
|
||||||
ret.gradientForVariable().put(DefaultParamInitializer.GAIN_KEY, dldg);
|
ret.gradientForVariable().put(DefaultParamInitializer.GAIN_KEY, dldg);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -318,7 +318,7 @@ public abstract class BaseLayer<LayerConfT extends org.deeplearning4j.nn.conf.la
|
||||||
INDArray preNorm = ret;
|
INDArray preNorm = ret;
|
||||||
if(hasLayerNorm()){
|
if(hasLayerNorm()){
|
||||||
preNorm = (forBackprop ? ret.dup(ret.ordering()) : ret);
|
preNorm = (forBackprop ? ret.dup(ret.ordering()) : ret);
|
||||||
Nd4j.getExecutioner().exec(new LayerNorm(preNorm, g, ret, 1));
|
Nd4j.getExecutioner().exec(new LayerNorm(preNorm, g, ret, true, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
if(hasBias()){
|
if(hasBias()){
|
||||||
|
|
|
@ -158,7 +158,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
||||||
dldnCurrent = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape());
|
dldnCurrent = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape());
|
||||||
INDArray ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), gxg.shape());
|
INDArray ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), gxg.shape());
|
||||||
INDArray bgCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, bg.dataType(), bg.shape());
|
INDArray bgCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, bg.dataType(), bg.shape());
|
||||||
Nd4j.getExecutioner().exec(new LayerNormBp(nCurrent, gx, b, dldzCurrent, dldnCurrent, ggCur, bgCur, 1));
|
Nd4j.getExecutioner().exec(new LayerNormBp(nCurrent, gx, b, dldzCurrent, dldnCurrent, ggCur, bgCur, true, 1));
|
||||||
gxg.addi(ggCur);
|
gxg.addi(ggCur);
|
||||||
bg.addi(bgCur);
|
bg.addi(bgCur);
|
||||||
}else{
|
}else{
|
||||||
|
@ -177,7 +177,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
||||||
if(hasLayerNorm() && i > end){
|
if(hasLayerNorm() && i > end){
|
||||||
dldzNext = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape());
|
dldzNext = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape());
|
||||||
INDArray ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), grg.shape());
|
INDArray ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), grg.shape());
|
||||||
Nd4j.getExecutioner().exec(new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, 1));
|
Nd4j.getExecutioner().exec(new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, true, 1));
|
||||||
grg.addi(ggCur);
|
grg.addi(ggCur);
|
||||||
}else{
|
}else{
|
||||||
dldzNext = dldzCurrent;
|
dldzNext = dldzCurrent;
|
||||||
|
@ -256,7 +256,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
||||||
if(hasLayerNorm()){
|
if(hasLayerNorm()){
|
||||||
INDArray currOutPreNorm = (forBackprop ? outPreNorm : out).get(all(), all(), point(i));
|
INDArray currOutPreNorm = (forBackprop ? outPreNorm : out).get(all(), all(), point(i));
|
||||||
Nd4j.gemm(currIn, w, currOutPreNorm, false, false, 1.0, 0.0);
|
Nd4j.gemm(currIn, w, currOutPreNorm, false, false, 1.0, 0.0);
|
||||||
Nd4j.getExecutioner().exec(new LayerNorm(currOutPreNorm, gx, b, currOut, 1));
|
Nd4j.getExecutioner().exec(new LayerNorm(currOutPreNorm, gx, b, currOut, true, 1));
|
||||||
}else{
|
}else{
|
||||||
Nd4j.gemm(currIn, w, currOut, false, false, 1.0, 1.0); //beta = 1.0 to keep previous contents (bias)
|
Nd4j.gemm(currIn, w, currOut, false, false, 1.0, 1.0); //beta = 1.0 to keep previous contents (bias)
|
||||||
}
|
}
|
||||||
|
@ -266,7 +266,7 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
||||||
INDArray currRecPreNorm = forBackprop ? recPreNorm.get(all(), all(), point(i)) : workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');;
|
INDArray currRecPreNorm = forBackprop ? recPreNorm.get(all(), all(), point(i)) : workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');;
|
||||||
Nd4j.gemm(prevStepOut, rw, currRecPreNorm, false, false, 1.0, 0.0);
|
Nd4j.gemm(prevStepOut, rw, currRecPreNorm, false, false, 1.0, 0.0);
|
||||||
INDArray recNorm = workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');
|
INDArray recNorm = workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');
|
||||||
Nd4j.getExecutioner().exec(new LayerNorm(currRecPreNorm, gr, recNorm, 1));
|
Nd4j.getExecutioner().exec(new LayerNorm(currRecPreNorm, gr, recNorm, true, 1));
|
||||||
currOut.addi(recNorm);
|
currOut.addi(recNorm);
|
||||||
}else {
|
}else {
|
||||||
Nd4j.gemm(prevStepOut, rw, currOut, false, false, 1.0, 1.0); //beta = 1.0 to keep previous contents
|
Nd4j.gemm(prevStepOut, rw, currOut, false, false, 1.0, 1.0); //beta = 1.0 to keep previous contents
|
||||||
|
|
|
@ -54,17 +54,14 @@ public class LayerNorm extends DynamicCustomOp {
|
||||||
this(sameDiff, input, gain, null, channelsFirst, dimensions);
|
this(sameDiff, input, gain, null, channelsFirst, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, int... dimensions) {
|
public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, boolean channelsFirst, int... dimensions) {
|
||||||
super("layer_norm", new INDArray[]{input, gain, bias}, new INDArray[]{result});
|
super("layer_norm", wrapFilterNull(input, gain, bias), wrapOrNull(result));
|
||||||
Preconditions.checkArgument(bias != null, "LayerNorm: Use different constructor if bias is null.");
|
this.channelsFirst = channelsFirst;
|
||||||
|
|
||||||
setDimensions(dimensions);
|
setDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LayerNorm(INDArray input, INDArray gain, INDArray result, int... dimensions) {
|
public LayerNorm(INDArray input, INDArray gain, INDArray result, boolean channelsFirst, int... dimensions) {
|
||||||
super("layer_norm", new INDArray[]{input, gain}, new INDArray[]{result});
|
this(input, gain, null, result, channelsFirst, dimensions);
|
||||||
noBias = true;
|
|
||||||
setDimensions(dimensions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue