diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java index 405889f0e..00ca7e7c4 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java @@ -93,7 +93,7 @@ public abstract class BaseLayer end){ dldzNext = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape()); INDArray ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), grg.shape()); - Nd4j.getExecutioner().exec(new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, 1)); + Nd4j.getExecutioner().exec(new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, true, 1)); grg.addi(ggCur); }else{ dldzNext = dldzCurrent; @@ -256,7 +256,7 @@ public class SimpleRnn extends BaseRecurrentLayer