Fixed #9050 regularization loss/override bug
Signed-off-by: jljljl <jijiji95@bk.ru> (cherry picked from commit 819f3b8c9d5377ed8c3031b4c519f0a3c13e65d3)master
parent
a002461812
commit
cefec591b0
|
@ -26,6 +26,8 @@ import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
|||
import org.deeplearning4j.nn.conf.dropout.IDropout;
|
||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -126,16 +128,83 @@ public class LayerValidation {
|
|||
}
|
||||
}
|
||||
|
||||
private static void configureBaseLayer(String layerName, BaseLayer bLayer, IDropout iDropout, List<Regularization> regularization,
|
||||
List<Regularization> regularizationBias) {
|
||||
private static void configureBaseLayer(String layerName, BaseLayer bLayer, IDropout iDropout,
|
||||
List<Regularization> regularization, List<Regularization> regularizationBias) {
|
||||
if (regularization != null && !regularization.isEmpty()) {
|
||||
|
||||
final List<Regularization> bLayerRegs = bLayer.getRegularization();
|
||||
if (bLayerRegs == null || bLayerRegs.isEmpty()) {
|
||||
|
||||
bLayer.setRegularization(regularization);
|
||||
} else {
|
||||
|
||||
boolean hasL1 = false;
|
||||
boolean hasL2 = false;
|
||||
final List<Regularization> regContext = regularization;
|
||||
for (final Regularization reg : bLayerRegs) {
|
||||
|
||||
if (reg instanceof L1Regularization) {
|
||||
|
||||
hasL1 = true;
|
||||
} else if (reg instanceof L2Regularization) {
|
||||
|
||||
hasL2 = true;
|
||||
}
|
||||
}
|
||||
for (final Regularization reg : regContext) {
|
||||
|
||||
if (reg instanceof L1Regularization) {
|
||||
|
||||
if (!hasL1)
|
||||
bLayerRegs.add(reg);
|
||||
} else if (reg instanceof L2Regularization) {
|
||||
|
||||
if (!hasL2)
|
||||
bLayerRegs.add(reg);
|
||||
} else
|
||||
bLayerRegs.add(reg);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (regularizationBias != null && !regularizationBias.isEmpty()) {
|
||||
|
||||
final List<Regularization> bLayerRegs = bLayer.getRegularizationBias();
|
||||
if (bLayerRegs == null || bLayerRegs.isEmpty()) {
|
||||
|
||||
bLayer.setRegularizationBias(regularizationBias);
|
||||
} else {
|
||||
|
||||
boolean hasL1 = false;
|
||||
boolean hasL2 = false;
|
||||
final List<Regularization> regContext = regularizationBias;
|
||||
for (final Regularization reg : bLayerRegs) {
|
||||
|
||||
if (reg instanceof L1Regularization) {
|
||||
|
||||
hasL1 = true;
|
||||
} else if (reg instanceof L2Regularization) {
|
||||
|
||||
hasL2 = true;
|
||||
}
|
||||
}
|
||||
for (final Regularization reg : regContext) {
|
||||
|
||||
if (reg instanceof L1Regularization) {
|
||||
|
||||
if (!hasL1)
|
||||
bLayerRegs.add(reg);
|
||||
} else if (reg instanceof L2Regularization) {
|
||||
|
||||
if (!hasL2)
|
||||
bLayerRegs.add(reg);
|
||||
} else
|
||||
bLayerRegs.add(reg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (bLayer.getIDropout() == null) {
|
||||
|
||||
bLayer.setIDropout(iDropout);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue