Fixed #9050 regularization loss/override bug

Signed-off-by: jljljl <jijiji95@bk.ru>
(cherry picked from commit 819f3b8c9d5377ed8c3031b4c519f0a3c13e65d3)
master
jljljl 2021-04-02 21:40:57 +05:00 committed by brian
parent a002461812
commit cefec591b0
1 changed files with 81 additions and 12 deletions

View File

@ -26,6 +26,8 @@ import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.ArrayList; import java.util.ArrayList;
@ -126,16 +128,83 @@ public class LayerValidation {
} }
} }
private static void configureBaseLayer(String layerName, BaseLayer bLayer, IDropout iDropout, List<Regularization> regularization, private static void configureBaseLayer(String layerName, BaseLayer bLayer, IDropout iDropout,
List<Regularization> regularizationBias) { List<Regularization> regularization, List<Regularization> regularizationBias) {
if (regularization != null && !regularization.isEmpty()) { if (regularization != null && !regularization.isEmpty()) {
final List<Regularization> bLayerRegs = bLayer.getRegularization();
if (bLayerRegs == null || bLayerRegs.isEmpty()) {
bLayer.setRegularization(regularization); bLayer.setRegularization(regularization);
} else {
boolean hasL1 = false;
boolean hasL2 = false;
final List<Regularization> regContext = regularization;
for (final Regularization reg : bLayerRegs) {
if (reg instanceof L1Regularization) {
hasL1 = true;
} else if (reg instanceof L2Regularization) {
hasL2 = true;
}
}
for (final Regularization reg : regContext) {
if (reg instanceof L1Regularization) {
if (!hasL1)
bLayerRegs.add(reg);
} else if (reg instanceof L2Regularization) {
if (!hasL2)
bLayerRegs.add(reg);
} else
bLayerRegs.add(reg);
}
}
} }
if (regularizationBias != null && !regularizationBias.isEmpty()) { if (regularizationBias != null && !regularizationBias.isEmpty()) {
final List<Regularization> bLayerRegs = bLayer.getRegularizationBias();
if (bLayerRegs == null || bLayerRegs.isEmpty()) {
bLayer.setRegularizationBias(regularizationBias); bLayer.setRegularizationBias(regularizationBias);
} else {
boolean hasL1 = false;
boolean hasL2 = false;
final List<Regularization> regContext = regularizationBias;
for (final Regularization reg : bLayerRegs) {
if (reg instanceof L1Regularization) {
hasL1 = true;
} else if (reg instanceof L2Regularization) {
hasL2 = true;
}
}
for (final Regularization reg : regContext) {
if (reg instanceof L1Regularization) {
if (!hasL1)
bLayerRegs.add(reg);
} else if (reg instanceof L2Regularization) {
if (!hasL2)
bLayerRegs.add(reg);
} else
bLayerRegs.add(reg);
}
}
} }
if (bLayer.getIDropout() == null) { if (bLayer.getIDropout() == null) {
bLayer.setIDropout(iDropout); bLayer.setIDropout(iDropout);
} }
} }