Fixed #9050 regularization loss/override bug
Signed-off-by: jljljl <jijiji95@bk.ru> (cherry picked from commit 819f3b8c9d5377ed8c3031b4c519f0a3c13e65d3)master
parent
a002461812
commit
cefec591b0
|
@ -26,6 +26,8 @@ import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.dropout.IDropout;
|
import org.deeplearning4j.nn.conf.dropout.IDropout;
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
||||||
|
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||||
|
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -126,17 +128,84 @@ public class LayerValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void configureBaseLayer(String layerName, BaseLayer bLayer, IDropout iDropout, List<Regularization> regularization,
|
private static void configureBaseLayer(String layerName, BaseLayer bLayer, IDropout iDropout,
|
||||||
List<Regularization> regularizationBias) {
|
List<Regularization> regularization, List<Regularization> regularizationBias) {
|
||||||
if (regularization != null && !regularization.isEmpty()) {
|
if (regularization != null && !regularization.isEmpty()) {
|
||||||
bLayer.setRegularization(regularization);
|
|
||||||
}
|
|
||||||
if (regularizationBias != null && !regularizationBias.isEmpty()) {
|
|
||||||
bLayer.setRegularizationBias(regularizationBias);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bLayer.getIDropout() == null) {
|
final List<Regularization> bLayerRegs = bLayer.getRegularization();
|
||||||
bLayer.setIDropout(iDropout);
|
if (bLayerRegs == null || bLayerRegs.isEmpty()) {
|
||||||
}
|
|
||||||
}
|
bLayer.setRegularization(regularization);
|
||||||
|
} else {
|
||||||
|
|
||||||
|
boolean hasL1 = false;
|
||||||
|
boolean hasL2 = false;
|
||||||
|
final List<Regularization> regContext = regularization;
|
||||||
|
for (final Regularization reg : bLayerRegs) {
|
||||||
|
|
||||||
|
if (reg instanceof L1Regularization) {
|
||||||
|
|
||||||
|
hasL1 = true;
|
||||||
|
} else if (reg instanceof L2Regularization) {
|
||||||
|
|
||||||
|
hasL2 = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (final Regularization reg : regContext) {
|
||||||
|
|
||||||
|
if (reg instanceof L1Regularization) {
|
||||||
|
|
||||||
|
if (!hasL1)
|
||||||
|
bLayerRegs.add(reg);
|
||||||
|
} else if (reg instanceof L2Regularization) {
|
||||||
|
|
||||||
|
if (!hasL2)
|
||||||
|
bLayerRegs.add(reg);
|
||||||
|
} else
|
||||||
|
bLayerRegs.add(reg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (regularizationBias != null && !regularizationBias.isEmpty()) {
|
||||||
|
|
||||||
|
final List<Regularization> bLayerRegs = bLayer.getRegularizationBias();
|
||||||
|
if (bLayerRegs == null || bLayerRegs.isEmpty()) {
|
||||||
|
|
||||||
|
bLayer.setRegularizationBias(regularizationBias);
|
||||||
|
} else {
|
||||||
|
|
||||||
|
boolean hasL1 = false;
|
||||||
|
boolean hasL2 = false;
|
||||||
|
final List<Regularization> regContext = regularizationBias;
|
||||||
|
for (final Regularization reg : bLayerRegs) {
|
||||||
|
|
||||||
|
if (reg instanceof L1Regularization) {
|
||||||
|
|
||||||
|
hasL1 = true;
|
||||||
|
} else if (reg instanceof L2Regularization) {
|
||||||
|
|
||||||
|
hasL2 = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (final Regularization reg : regContext) {
|
||||||
|
|
||||||
|
if (reg instanceof L1Regularization) {
|
||||||
|
|
||||||
|
if (!hasL1)
|
||||||
|
bLayerRegs.add(reg);
|
||||||
|
} else if (reg instanceof L2Regularization) {
|
||||||
|
|
||||||
|
if (!hasL2)
|
||||||
|
bLayerRegs.add(reg);
|
||||||
|
} else
|
||||||
|
bLayerRegs.add(reg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bLayer.getIDropout() == null) {
|
||||||
|
|
||||||
|
bLayer.setIDropout(iDropout);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue