From 06e4f5f96e3e84b2608b3a1d4c1762615d232f02 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 20 Jul 2019 12:37:34 +1000 Subject: [PATCH] Small DL4J/SameDiff fixes (#70) * More mask fixes + remove debugging println Signed-off-by: AlexDBlack * Small batch norm derivative fixe Signed-off-by: AlexDBlack --- .../nn/conf/layers/samediff/AbstractSameDiffLayer.java | 8 +++++++- .../deeplearning4j/nn/layers/samediff/SameDiffLayer.java | 1 - .../linalg/api/ops/impl/layers/convolution/BatchNorm.java | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java index 6cf4ae810..49453dd1f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java @@ -216,8 +216,14 @@ public abstract class AbstractSameDiffLayer extends Layer { return Nd4j.ones(input.dataType(), input.size(0), 1); } else if(input.rank() == 3){ return Nd4j.ones(input.dataType(), input.size(0), input.size(2)); //mask: [mb, length] vs. input [mb, nIn, length] + } else if(input.rank() == 4){ + //CNN style - return [mb, 1, 1, 1] for broadcast... + return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1); + } else if(input.rank() == 5){ + //CNN3D style - return [mb, 1, 1, 1, 1] for broadcast... + return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1, 1); } else { - throw new IllegalStateException("When using masking with rank 4+ inputs, the onesMaskForInput method must be implemented, " + + throw new IllegalStateException("When using masking with rank 1 or 6+ inputs, the onesMaskForInput method must be implemented, " + "in order to determine the correct mask shape for this layer"); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java index 7ead0850c..912bc45a8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java @@ -166,7 +166,6 @@ public class SameDiffLayer extends AbstractLayer { sameDiff.clearPlaceholders(true); sameDiff.clearOpInputs(); - System.out.println(dLdIn); return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java index c03dc919e..67fc9f3a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java @@ -162,11 +162,11 @@ public class BatchNorm extends DynamicCustomOp { @Override public List doDiff(List f1) { List ret = new ArrayList<>(); - List inputs = new ArrayList<>(); - inputs.addAll(Arrays.asList(args())); + List inputs = new ArrayList<>(Arrays.asList(args())); inputs.add(f1.get(0)); BatchNormDerivative batchNormDerivative = BatchNormDerivative.derivativeBuilder() .sameDiff(sameDiff) + .inputFunctions(inputs.toArray(new SDVariable[inputs.size()])) .applyGamma(applyGamma) .applyBeta(applyBeta) .epsilon(epsilon)