diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java index 55713cc08..2a5f16be6 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java @@ -26,6 +26,8 @@ import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; +import org.nd4j.linalg.learning.regularization.L1Regularization; +import org.nd4j.linalg.learning.regularization.L2Regularization; import org.nd4j.linalg.learning.regularization.Regularization; import java.util.ArrayList; @@ -126,17 +128,84 @@ public class LayerValidation { } } - private static void configureBaseLayer(String layerName, BaseLayer bLayer, IDropout iDropout, List regularization, - List regularizationBias) { - if (regularization != null && !regularization.isEmpty()) { - bLayer.setRegularization(regularization); - } - if (regularizationBias != null && !regularizationBias.isEmpty()) { - bLayer.setRegularizationBias(regularizationBias); - } + private static void configureBaseLayer(String layerName, BaseLayer bLayer, IDropout iDropout, + List regularization, List regularizationBias) { + if (regularization != null && !regularization.isEmpty()) { - if (bLayer.getIDropout() == null) { - bLayer.setIDropout(iDropout); - } - } + final List bLayerRegs = bLayer.getRegularization(); + if (bLayerRegs == null || bLayerRegs.isEmpty()) { + + bLayer.setRegularization(regularization); + } else { + + boolean hasL1 = false; + boolean hasL2 = false; + final List regContext = regularization; + for (final Regularization reg : bLayerRegs) { + + if (reg instanceof L1Regularization) { + + hasL1 = true; + } else if (reg instanceof L2Regularization) { + + hasL2 = true; + } + } + for (final Regularization reg : regContext) { + + if (reg instanceof L1Regularization) { + + if (!hasL1) + bLayerRegs.add(reg); + } else if (reg instanceof L2Regularization) { + + if (!hasL2) + bLayerRegs.add(reg); + } else + bLayerRegs.add(reg); + } + } + } + if (regularizationBias != null && !regularizationBias.isEmpty()) { + + final List bLayerRegs = bLayer.getRegularizationBias(); + if (bLayerRegs == null || bLayerRegs.isEmpty()) { + + bLayer.setRegularizationBias(regularizationBias); + } else { + + boolean hasL1 = false; + boolean hasL2 = false; + final List regContext = regularizationBias; + for (final Regularization reg : bLayerRegs) { + + if (reg instanceof L1Regularization) { + + hasL1 = true; + } else if (reg instanceof L2Regularization) { + + hasL2 = true; + } + } + for (final Regularization reg : regContext) { + + if (reg instanceof L1Regularization) { + + if (!hasL1) + bLayerRegs.add(reg); + } else if (reg instanceof L2Regularization) { + + if (!hasL2) + bLayerRegs.add(reg); + } else + bLayerRegs.add(reg); + } + } + } + + if (bLayer.getIDropout() == null) { + + bLayer.setIDropout(iDropout); + } + } }