From fd22a8ecc753b4cbddc89e8eb8260fc439bbda90 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 27 Aug 2019 19:46:26 +1000 Subject: [PATCH] Small build fix, after last PR (#177) Signed-off-by: Alex Black --- .../org/deeplearning4j/nn/layers/BaseLayer.java | 4 ++-- .../nn/layers/recurrent/SimpleRnn.java | 8 ++++---- .../api/ops/impl/transforms/custom/LayerNorm.java | 13 +++++-------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java index 405889f0e..00ca7e7c4 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java @@ -93,7 +93,7 @@ public abstract class BaseLayer end){ dldzNext = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape()); INDArray ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), grg.shape()); - Nd4j.getExecutioner().exec(new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, 1)); + Nd4j.getExecutioner().exec(new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, true, 1)); grg.addi(ggCur); }else{ dldzNext = dldzCurrent; @@ -256,7 +256,7 @@ public class SimpleRnn extends BaseRecurrentLayer