Playing with some new code 2 - clean build/test

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-14 13:24:19 +02:00
parent 0f21ed9ec5
commit 1f2e82d3ef
73 changed files with 647 additions and 743 deletions

View File

@ -118,7 +118,7 @@ public class App {
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
//.weightInit(WeightInit.XAVIER)
.weightInitFn(new WeightInitXavier())
.weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY)
.layersFromArray(genLayers())
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))

View File

@ -74,7 +74,7 @@ public class LayerBuilderTest extends BaseDL4JTest {
checkSerialization(layer);
assertEquals(act, layer.getActivationFn());
assertEquals(weight.getWeightInitFunction(), layer.getWeightInitFn());
assertEquals(weight.getWeightInitFunction(), layer.getWeightInit());
assertEquals(new Dropout(dropOut), layer.getIDropout());
assertEquals(updater, layer.getIUpdater());
assertEquals(gradNorm, layer.getGradientNormalization());

View File

@ -99,8 +99,8 @@ public class LayerConfigTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn());
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn());
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInit());
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInit());
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
@ -117,8 +117,8 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf);
net.init();
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn());
assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn());
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInit());
assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInit());
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
assertEquals(0, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);

View File

@ -185,7 +185,7 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn());
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInit());
assertNull(TestUtils.getL1Reg(layerConf1.getRegularization()));
assertNull(TestUtils.getL2Reg(layerConf1.getRegularization()));

View File

@ -157,7 +157,7 @@ public class SameDiffConv extends SameDiffLayer {
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration clone = globalConfig.clone().build();
if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(clone.getActivationFn());
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
}
if (cm == null) {
cm = clone.getConvolutionMode();

View File

@ -119,7 +119,7 @@ public class SameDiffDense extends SameDiffLayer {
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration clone = globalConfig.clone().build();
if(activation == null){
activation = SameDiffLayerUtils.fromIActivation(clone.getActivationFn());
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
}
}

View File

@ -141,9 +141,9 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getLayer("layer0").getLayerConfiguration());
BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getLayer("layer1").getLayerConfiguration());
BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getLayer("layer3").getLayerConfiguration());
assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1)));
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
assertEquals(bl0.getWeightInit(), new WeightInitDistribution(new NormalDistribution(1, 1e-1)));
assertEquals(bl1.getWeightInit(), new WeightInitXavier());
assertEquals(bl1.getWeightInit(), new WeightInitXavier());
ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In")

View File

@ -163,14 +163,14 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(0).getLayer());
BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(1).getLayer());
BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(3).getLayer());
assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class);
assertEquals(bl0.getWeightInit().getClass(), WeightInitXavier.class);
try {
assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()),
assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInit()),
JsonMappers.getMapper().writeValueAsString(new WeightInitDistribution(new NormalDistribution(1, 1e-1))));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
assertEquals(bl3.getWeightInitFn(), new WeightInitXavier());
assertEquals(bl3.getWeightInit(), new WeightInitXavier());
//modelNow should have the same architecture as modelExpectedArch
assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
@ -506,13 +506,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
assertEquals(new Adam(1e-4), l0.getIUpdater());
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(new Adam(1e-4), l1.getIUpdater());
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
assertEquals(new WeightInitRelu(), l1.getWeightInit());
assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
assertEquals(BackpropType.Standard, conf.getBackpropType());
@ -521,13 +521,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
assertEquals(new Adam(2e-2), l0.getIUpdater());
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration();
assertEquals(new Adam(2e-2), l1.getIUpdater());
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
assertEquals(new WeightInitRelu(), l1.getWeightInit());
assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
assertEquals(BackpropType.TruncatedBPTT, net2.getNetConfiguration().getBackpropType());

View File

@ -37,6 +37,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
@ -940,7 +941,9 @@ public class TestUpdaters extends BaseDL4JTest {
List<UpdaterBlock> blocks;
NeuralNetConfiguration conf =
NeuralNetConfiguration.builder().updater(new Adam(0.5)).list()
NeuralNetConfiguration.builder()
.updater(new Adam(0.5))
.weightInit(WeightInit.NORMAL)
.layer(0, new VariationalAutoencoder.Builder().nIn(8).nOut(12)
.encoderLayerSizes(10, 11).decoderLayerSizes(13, 14).build())
.build();

View File

@ -72,7 +72,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals("relu", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
@ -81,7 +81,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMCXENT);
assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater());
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
@ -106,7 +106,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l0.getIDropout());
@ -118,7 +118,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMSE);
assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l1.getIDropout());
@ -145,7 +145,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals("tanh", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
@ -165,7 +165,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood);
assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);

View File

@ -74,7 +74,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals("relu", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
@ -83,7 +83,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMCXENT);
assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater());
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
@ -108,7 +108,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l0.getIDropout());
@ -122,7 +122,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMSE);
assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l1.getIDropout());
@ -151,7 +151,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals("tanh", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
@ -171,7 +171,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO
assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);

View File

@ -75,7 +75,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals("relu", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
@ -84,7 +84,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMCXENT);
assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
@ -109,7 +109,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l0.getIDropout());
@ -123,7 +123,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMSE);
assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l1.getIDropout());
@ -152,7 +152,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals("tanh", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
@ -172,7 +172,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO
assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);

View File

@ -74,7 +74,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationReLU);
assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertTrue(l0.getIUpdater() instanceof Nesterovs);
Nesterovs n = (Nesterovs) l0.getIUpdater();
assertEquals(0.9, n.getMomentum(), 1e-6);
@ -87,7 +87,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMCXENT);
assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertTrue(l1.getIUpdater() instanceof Nesterovs);
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
@ -113,7 +113,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertTrue(l0.getIUpdater() instanceof RmsProp);
RmsProp r = (RmsProp) l0.getIUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6);
@ -130,7 +130,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMSE);
assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInitFn());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInit());
assertTrue(l1.getIUpdater() instanceof RmsProp);
r = (RmsProp) l1.getIUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6);
@ -162,7 +162,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationTanH);
assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertTrue(l0.getIUpdater() instanceof RmsProp);
RmsProp r = (RmsProp) l0.getIUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6);
@ -185,7 +185,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood);
assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l2.getWeightInitFn());
assertEquals(new WeightInitRelu(), l2.getWeightInit());
assertTrue(l2.getIUpdater() instanceof RmsProp);
r = (RmsProp) l2.getIUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6);

View File

@ -89,21 +89,21 @@ public class RegressionTest100a extends BaseDL4JTest {
GravesLSTM l0 = (GravesLSTM) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new RmsProp(0.1), l0.getIUpdater());
GravesLSTM l1 = (GravesLSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l1));
assertEquals(new RmsProp(0.1), l1.getIUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new RmsProp(0.1), l0.getIUpdater());
@ -139,7 +139,7 @@ public class RegressionTest100a extends BaseDL4JTest {
assertEquals(32, l0.getNOut());
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(0.05), l0.getIUpdater());
@ -175,7 +175,7 @@ public class RegressionTest100a extends BaseDL4JTest {
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
assertEquals(new ActivationIdentity(), cl.getActivationFn());
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
assertEquals(new WeightInitXavier(), cl.getWeightInit());
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());

View File

@ -124,21 +124,21 @@ public class RegressionTest100b3 extends BaseDL4JTest {
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater());
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater());
@ -174,7 +174,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
assertEquals(32, l0.getNOut());
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(1e-3), l0.getIUpdater());
@ -210,7 +210,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
assertEquals(new ActivationIdentity(), cl.getActivationFn());
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
assertEquals(new WeightInitXavier(), cl.getWeightInit());
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());

View File

@ -142,21 +142,21 @@ public class RegressionTest100b4 extends BaseDL4JTest {
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater());
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
assertEquals(new Adam(0.005), l2.getIUpdater());
@ -192,7 +192,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(32, l0.getNOut());
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(1e-3), l0.getIUpdater());
@ -229,7 +229,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
assertEquals(new ActivationIdentity(), cl.getActivationFn());
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
assertEquals(new WeightInitXavier(), cl.getWeightInit());
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
INDArray outExp;
@ -260,7 +260,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationReLU(), l0.getActivationFn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
@ -271,7 +271,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationReLU(), l1.getActivationFn());
assertEquals(8, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
@ -297,7 +297,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration();
assertEquals(new ActivationReLU(), l5.getActivationFn());
assertEquals(16, l5.getNOut());
assertEquals(new WeightInitXavier(), l5.getWeightInitFn());
assertEquals(new WeightInitXavier(), l5.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
assertEquals(new Adam(0.005), l5.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
@ -318,7 +318,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration();
assertEquals(4, l8.getNOut());
assertEquals(new WeightInitXavier(), l8.getWeightInitFn());
assertEquals(new WeightInitXavier(), l8.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
assertEquals(new Adam(0.005), l8.getIUpdater());
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
@ -327,7 +327,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertArrayEquals(new int[]{0, 0}, l8.getPadding());
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
assertEquals(new WeightInitXavier(), l9.getWeightInitFn());
assertEquals(new WeightInitXavier(), l9.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
assertEquals(new Adam(0.005), l9.getIUpdater());
assertEquals(new LossMAE(), l9.getLossFn());

View File

@ -124,21 +124,21 @@ public class RegressionTest100b6 extends BaseDL4JTest {
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater());
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
assertEquals(new Adam(0.005), l2.getIUpdater());
@ -174,7 +174,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(32, l0.getNOut());
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(1e-3), l0.getIUpdater());
@ -210,7 +210,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
assertEquals(new ActivationIdentity(), cl.getActivationFn());
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
assertEquals(new WeightInitXavier(), cl.getWeightInit());
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
INDArray outExp;
@ -240,7 +240,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationReLU(), l0.getActivationFn());
assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
@ -251,7 +251,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationReLU(), l1.getActivationFn());
assertEquals(8, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
@ -277,7 +277,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration();
assertEquals(new ActivationReLU(), l5.getActivationFn());
assertEquals(16, l5.getNOut());
assertEquals(new WeightInitXavier(), l5.getWeightInitFn());
assertEquals(new WeightInitXavier(), l5.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
assertEquals(new Adam(0.005), l5.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
@ -298,7 +298,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration();
assertEquals(4, l8.getNOut());
assertEquals(new WeightInitXavier(), l8.getWeightInitFn());
assertEquals(new WeightInitXavier(), l8.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
assertEquals(new Adam(0.005), l8.getIUpdater());
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
@ -307,7 +307,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertArrayEquals(new int[]{0, 0}, l8.getPadding());
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
assertEquals(new WeightInitXavier(), l9.getWeightInitFn());
assertEquals(new WeightInitXavier(), l9.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
assertEquals(new Adam(0.005), l9.getIUpdater());
assertEquals(new LossMAE(), l9.getLossFn());

View File

@ -167,7 +167,7 @@ public class KerasInitilizationTest extends BaseDL4JTest {
layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion);
DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer();
assertEquals(dl4jInitializer, layer.getWeightInitFn());
assertEquals(dl4jInitializer, layer.getWeightInit());
}
}

View File

@ -79,7 +79,7 @@ public class KerasPReLUTest extends BaseDL4JTest {
PReLULayer layer = kerasPReLU.getPReLULayer();
assertArrayEquals(layer.getInputShape(), new long[] {3, 5, 4});
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(layerName, layer.getLayerName());
}

View File

@ -100,7 +100,7 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest {
Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -114,7 +114,7 @@ public class KerasAtrousConvolution2DTest extends BaseDL4JTest {
ConvolutionLayer layer = new KerasAtrousConvolution2D(layerConfig).getAtrousConvolution2D();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -122,7 +122,7 @@ public class KerasConvolution1DTest extends BaseDL4JTest {
Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -123,7 +123,7 @@ public class KerasConvolution2DTest extends BaseDL4JTest {
ConvolutionLayer layer = new KerasConvolution2D(layerConfig).getConvolution2DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -119,7 +119,7 @@ public class KerasConvolution3DTest extends BaseDL4JTest {
ConvolutionLayer layer = new KerasConvolution3D(layerConfig).getConvolution3DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -123,7 +123,7 @@ public class KerasDeconvolution2DTest extends BaseDL4JTest {
Deconvolution2D layer = new KerasDeconvolution2D(layerConfig).getDeconvolution2DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -128,7 +128,7 @@ public class KerasDepthwiseConvolution2DTest extends BaseDL4JTest {
DepthwiseConvolution2D layer = kerasLayer.getDepthwiseConvolution2DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);

View File

@ -130,7 +130,7 @@ public class KerasSeparableConvolution2DTest extends BaseDL4JTest {
SeparableConvolution2D layer = new KerasSeparableConvolution2D(layerConfig).getSeparableConvolution2DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier());

View File

@ -89,7 +89,7 @@ public class KerasDenseTest extends BaseDL4JTest {
DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -38,7 +38,6 @@ import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@ -131,7 +130,7 @@ public class KerasLSTMTest extends BaseDL4JTest {
}
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -101,7 +101,7 @@ public class KerasSimpleRnnTest extends BaseDL4JTest {
(SimpleRnn) ((LastTimeStep) new KerasSimpleRnn(layerConfig).getSimpleRnnLayer()).getUnderlying();
assertEquals(ACTIVATION, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn());
assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
INeuralNetworkConfiguration clone();
void init();
/**
@ -35,28 +36,4 @@ public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
* @return
*/
IModel getNet();
}
/**
/**
* Provides a flat list of all embedded layer configurations, this
* can only be called after the layer is initialized or {@link #getLayerConfigurations()} is
* called.
*
* @return unstacked layer configurations
List<ILayerConfiguration> getLayerConfigurations();
/**
* This uncollables any stacked layer configurations within building blocks like
* @link BuildingBlockLayer}
void calculateInnerLayerConfigurations();
/**
* An implementation should provide a method to validate the network
* @return true if no errors found; false otherwise
boolean isValid();
}
**/
}

View File

@ -259,7 +259,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
*/
private static void handleLegacyWeightInitFromJson(String json, LayerConfiguration layer, ObjectMapper mapper, JsonNode vertices) {
if (layer instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration) layer).getWeightInitFn() == null) {
&& ((BaseLayerConfiguration) layer).getWeightInit() == null) {
String layerName = layer.getLayerName();
try {
@ -291,7 +291,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
if (weightInit != null) {
final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist);
((BaseLayerConfiguration) layer).setWeightInitFn(wi);
((BaseLayerConfiguration) layer).setWeightInit(wi);
}
} catch (IOException e) {

View File

@ -35,15 +35,11 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.*;
import lombok.experimental.SuperBuilder;
import lombok.extern.jackson.Jacksonized;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
@ -67,9 +63,9 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.nn.conf.stepfunctions.DefaultStepFunction;
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.OutputLayerUtil;
@ -319,16 +315,14 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
private boolean validateTbpttConfig = true;
/**
* Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} or
* {@link org.nd4j.linalg.learning.config.Nesterovs}<br> Note: values set by this method will be
* applied to all applicable layers in the network, unless a different value is explicitly set on
* a given layer. In other words: values set via this method are used as the default value, and
* can be overridden on a per-layer basis.
* {@link org.nd4j.linalg.learning.config.Nesterovs}<br>
* Note: values set by this method will be applied to all applicable layers in the network, unless
* a different value is explicitly set on a given layer. In other words: values set via this
* method are used as the default value, and can be overridden on a per-layer basis.
*
* @param updater Updater to use
*/
@Getter
@Setter
private IUpdater updater;
@Getter @Setter @Builder.Default private IUpdater updater = new Sgd();
/**
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping
* etc. See {@link GradientNormalization} for details<br> Note: values set by this method will be
@ -357,19 +351,9 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
@Setter
private double gradientNormalizationThreshold;
/**
* Activation function / neuron non-linearity<br> Note: values set by this method will be applied
* to all applicable layers in the network, unless a different value is explicitly set on a given
* layer. In other words: values set via this method are used as the default value, and can be
* overridden on a per-layer basis.
*/
@Getter
@Setter
private IActivation activation;
//whether to constrain the gradient to unit norm or not
@Getter
@Setter
private StepFunction stepFunction;
// whether to constrain the gradient to unit norm or not
@Getter @Setter @Builder.Default private StepFunction stepFunction = new DefaultStepFunction();
@Getter
@Setter
@lombok.Builder.Default
@ -400,13 +384,10 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
@Getter
@lombok.Builder.Default
private List<Regularization> regularizationBias = new ArrayList<>();
@Getter
@Setter
@lombok.Builder.Default
private IUpdater iUpdater = new Sgd();
/**
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
* set by {@link #setIUpdater(IUpdater)}<br>
* set by {@link #setUpdater(IUpdater)}<br>
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
* value, and can be overridden on a per-layer basis.
@ -420,7 +401,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
@Getter
@Setter
@lombok.Builder.Default
private IActivation activationFn = new ActivationSigmoid();
private IActivation activation = new ActivationSigmoid();
/**
* Sets the convolution mode for convolutional layers, which impacts padding and output sizes.
@ -698,7 +679,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l,
ObjectMapper mapper,
JsonNode confs, int layerCount) {
if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInitFn() == null) {
if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInit() == null) {
try {
JsonNode jsonNode = mapper.readTree(json);
if (confs == null) {
@ -729,7 +710,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
if (weightInit != null) {
final IWeightInit wi = WeightInit.valueOf(weightInit.asText())
.getWeightInitFunction(dist);
((BaseLayerConfiguration) l).setWeightInitFn(wi);
((BaseLayerConfiguration) l).setWeightInit(wi);
}
}
@ -851,8 +832,8 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
* that do not have an individual setting (nor a default)
*/
for(LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) {
if(lconf.getActivationFn() == null ) lconf.setActivationFn(this.getActivationFn());
if(lconf.getIUpdater() == null ) lconf.setIUpdater( this.getIUpdater() );
if(lconf.getActivationFn() == null ) lconf.setActivationFn(this.getActivation());
if(lconf.getIUpdater() == null ) lconf.setIUpdater( this.getUpdater() );
if(lconf.getIDropout() == null ) lconf.setIDropout( this.getIdropOut() );
if(lconf.getWeightNoise() == null ) lconf.setWeightNoise( this.getWeightNoise());
@ -1108,29 +1089,27 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
*/
public List<LayerConfiguration> getFlattenedLayerConfigurations(NeuralNetConfiguration conf) {
List<LayerConfiguration> ret = new ArrayList<>(); //create the final return list
for( Object obj : conf.getInnerConfigurations().stream().skip(1) //don't include self
.collect(Collectors.toList())) {
//if Layer Config, include in list and inherit parameters from this conf
//else if neural net configuration, call self recursively to resolve layer configurations
if (obj instanceof LayerConfiguration)
ret.add((LayerConfiguration) obj);
else if (obj instanceof NeuralNetConfiguration)
ret.addAll(getFlattenedLayerConfigurations(
(NeuralNetConfiguration) obj));
else {
log.error(
"The list of layers and neural network configurations does contain an object of {}. Element will be ignored.",
obj.getClass().getSimpleName());
}
}
/**
LayerConfiguration lc = ((LayerConfiguration) lc).getType().getClazz().cast(obj);
switch(lc.getType()) {
case FC: { //fully connected layer
((FeedForwardLayer) lc).setWeightInitFn(this.getWeightInitFn());
}
if(lc instanceof FeedForwardLayer && ((FeedForwardLayer) lc).getWeightInitFn() == null) {
**/
//When properly initialized, _this_ configuration is set first in the list, however we
//can find cases where this is not true, thus the first configuration is another net or layer configuration
//and should not be skipped. In essence, skip first configuration if that is "this".
int iSkip = 0;
if(conf.getInnerConfigurations().size()>0 && conf.getInnerConfigurations().get(0).equals(this)) { iSkip=1;}
conf.getInnerConfigurations().stream().skip(iSkip)
.forEach(obj -> {
//if Layer Config, include in list and inherit parameters from this conf
//else if neural net configuration, call self recursively to resolve layer configurations
if (obj instanceof LayerConfiguration) {
((LayerConfiguration) obj).setNetConfiguration(conf);
ret.add((LayerConfiguration) obj);
} else if (obj instanceof NeuralNetConfiguration)
ret.addAll(getFlattenedLayerConfigurations(
(NeuralNetConfiguration) obj));
else {
log.error(
"The list of layers and neural network configurations does contain an object of {}. Element will be ignored.",
obj.getClass().getSimpleName());
}
});
return ret;
}
@ -1143,17 +1122,6 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
return getFlattenedLayerConfigurations(this);
}
/**
* Get the configuration of the first layer
* @return layer configuration
*/
/**
public LayerConfiguration getFirstLayer() {
return getFlattenedLayerConfigurations().get(0);
}
**/
/**
* Add a new layer to the first position
* @param layer configuration

View File

@ -23,6 +23,7 @@ package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.deeplearning4j.nn.api.ITraininableLayerConfiguration;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
@ -30,6 +31,7 @@ import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.deeplearning4j.util.NetworkUtils;
import org.jetbrains.annotations.NotNull;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
@ -52,7 +54,7 @@ import java.util.List;
public abstract class BaseLayerConfiguration extends LayerConfiguration implements ITraininableLayerConfiguration, Serializable, Cloneable {
@NonNull
protected IWeightInit weightInitFn;
protected IWeightInit weightInit;
protected double biasInit = 0.0;
protected double gainInit = 0.0;
protected List<Regularization> regularization;
@ -68,7 +70,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
public BaseLayerConfiguration(Builder builder) {
super(builder);
this.layerName = builder.layerName;
this.weightInitFn = builder.weightInitFn;
this.weightInit = builder.weightInit;
this.biasInit = builder.biasInit;
this.gainInit = builder.gainInit;
this.regularization = builder.regularization;
@ -89,7 +91,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
public void resetLayerDefaultConfig() {
//clear the learning related params for all layers in the origConf and set to defaults
this.setIUpdater(null);
this.setWeightInitFn(null);
this.setWeightInit(null);
this.setBiasInit(Double.NaN);
this.setGainInit(Double.NaN);
this.regularization = null;
@ -103,9 +105,6 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
@Override
public BaseLayerConfiguration clone() {
BaseLayerConfiguration clone = (BaseLayerConfiguration) super.clone();
if (clone.iDropout != null) {
clone.iDropout = clone.iDropout.clone();
}
if(regularization != null){
//Regularization fields are _usually_ thread safe and immutable, but let's clone to be sure
clone.regularization = new ArrayList<>(regularization.size());
@ -170,7 +169,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
*
* @see IWeightInit
*/
protected IWeightInit weightInitFn = null;
protected IWeightInit weightInit = null;
/**
* Bias initialization value, for layers with biases. Defaults to 0
@ -255,7 +254,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
* @see IWeightInit
*/
public T weightInit(IWeightInit weightInit) {
this.setWeightInitFn(weightInit);
this.setWeightInit(weightInit);
return (T) this;
}
@ -270,7 +269,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
"Not supported!, Use weightInit(Distribution distribution) instead!");
}
this.setWeightInitFn(weightInit.getWeightInitFunction());
this.setWeightInit(weightInit.getWeightInitFunction());
return (T) this;
}
@ -508,4 +507,19 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
}
}
/**
* Inherit setting from neural network for those settings, that are not already set or do have
* a layer(type) specific default.
* @param conf the neural net configration to inherit parameters from
*/
@Override
public void runInheritance(@NotNull NeuralNetConfiguration conf) {
super.runInheritance(conf);
if(this.biasUpdater == null ) this.biasUpdater = conf.getBiasUpdater();
if(this.iUpdater == null ) this.iUpdater = conf.getUpdater();
if(this.regularizationBias == null) this.regularizationBias = conf.getRegularizationBias();
if(this.regularization == null ) this.regularization = conf.getRegularization();
if(this.gradientNormalization == null) this.gradientNormalization = conf.getGradientNormalization();
}
}

View File

@ -172,6 +172,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
setNetConfiguration(conf);
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
lconf.runInheritance();
LayerValidation.assertNInNOutSet("ConvolutionLayer", getLayerName(), layerIndex, getNIn(), getNOut());
@ -404,9 +405,10 @@ public class ConvolutionLayer extends FeedForwardLayer {
/**
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
* Default is {@link ConvolutionMode}.Truncate.
*
*/
protected ConvolutionMode convolutionMode;
protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
/**
* Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated convolutions,

View File

@ -62,19 +62,18 @@ public class DenseLayer extends FeedForwardLayer {
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getLayerName(), layerIndex, getNIn(), getNOut());
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
lconf.runInheritance();
org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret =
new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType);
if(getWeightInitFn() == null) setWeightInitFn(new WeightInitXavier());
if(getWeightInit() == null) setWeightInit(new WeightInitXavier());
ret.addTrainingListeners(trainingListeners);
ret.setIndex(layerIndex);
ret.setParamsViewArray(layerParamsView);
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
ret.setParamTable(paramTable);
ret.setLayerConfiguration(lconf);
return ret;
}

View File

@ -217,14 +217,14 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
return this;
}
@Override
public void setWeightInitFn(IWeightInit weightInit){
if(weightInit instanceof WeightInitEmbedding){
long[] shape = ((WeightInitEmbedding) weightInit).shape();
nIn(shape[0]);
nOut(shape[1]);
}
this.weightInitFn = weightInit;
this.weightInit = weightInit;
}
/**

View File

@ -66,28 +66,29 @@ import org.nd4j.linalg.learning.regularization.Regularization;
@Slf4j
public abstract class LayerConfiguration implements ILayerConfiguration, Serializable, Cloneable { // ITraininableLayerConfiguration
protected String layerName = "noname";
protected String layerName;
@Getter
protected List<String> variables = new ArrayList<>();
public void addVariable(String s) {variables.add(s);}
protected IDropout iDropout;
protected List<LayerConstraint> constraints;
protected IWeightNoise weightNoise;
private IDropout iDropout;
/**
* The type of the layer, basically defines the base class and its properties
*/
@Getter @Setter @NonNull
private LayerType type = LayerType.UNKNOWN;
@Getter @Setter
private NeuralNetConfiguration netConfiguration;
@Getter @Setter
private IActivation activationFn;
public LayerConfiguration(Builder builder) {
this.layerName = builder.layerName;
this.iDropout = builder.iDropout;
}
public void addVariable(String s) {variables.add(s);}
public String toJson() {
throw new RuntimeException("toJson is not implemented for LayerConfiguration");
}
@ -151,6 +152,7 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
public LayerConfiguration getLayer() {
return this;
}
@Override
public LayerConfiguration clone() {
try {
@ -218,7 +220,6 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
*/
public abstract void setNIn(InputType inputType, boolean override);
/**
* For the given type of input to this layer, what preprocessor (if any) is required?<br>
* Returns null if no preprocessor is required, otherwise returns an appropriate {@link
@ -263,11 +264,11 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
"Not supported: all layers with parameters should override this method");
}
public IUpdater getIUpdater() {
throw new UnsupportedOperationException(
"Not supported: all layers with parameters should override this method");
}
public void setIUpdater(IUpdater iUpdater) {
log.warn("Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), getLayerName());
}
@ -285,15 +286,33 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
this.variables.clear();
}
@Getter @Setter
private IActivation activationFn;
/**
* Inherit setting from neural network for those settings, that are not already set or do have
* a layer(type) specific default. This implementation does not require the neural network configuration to be
* the same as the one returned from this layers {@link #getNetConfiguration()}.
*
* @param conf a neural net configration to inherit parameters from
*
*/
public void runInheritance(@NonNull NeuralNetConfiguration conf) {
if(this.activationFn == null ) this.activationFn = conf.getActivation();
if(this.iDropout == null ) this.iDropout = conf.getIdropOut();
if(this.weightNoise == null) this.weightNoise = conf.getWeightNoise();
}
/** Runs {@link #runInheritance(NeuralNetConfiguration)} using the layers configurations embedded neural net
* configuration (the one returned from {@link #getNetConfiguration()}.
*/
public void runInheritance() {
runInheritance(getNetConfiguration());
}
@SuppressWarnings("unchecked")
@Getter
@Setter
public abstract static class Builder<T extends Builder<T>> {
protected String layerName = "noname";
protected String layerName;
protected List<LayerConstraint> allParamConstraints;

View File

@ -215,7 +215,7 @@ public class LocallyConnected1D extends SameDiffLayer {
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration global_conf = globalConfig.build();
if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(global_conf.getActivationFn());
activation = SameDiffLayerUtils.fromIActivation(global_conf.getActivation());
}
if (cm == null) {
cm = global_conf.getConvolutionMode();

View File

@ -232,7 +232,7 @@ public class LocallyConnected2D extends SameDiffLayer {
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration gconf = globalConfig.build();
if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(gconf.getActivationFn());
activation = SameDiffLayerUtils.fromIActivation(gconf.getActivation());
}
if (cm == null) {
cm = gconf.getConvolutionMode();

View File

@ -117,7 +117,7 @@ public class PReLULayer extends BaseLayerConfiguration {
public Builder(){
//Default to 0s, and don't inherit global default
this.weightInitFn = new WeightInitConstant(0);
this.weightInit = new WeightInitConstant(0);
}
/**

View File

@ -152,7 +152,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
@Override
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(globalConfig.build().getActivationFn());
activation = SameDiffLayerUtils.fromIActivation(globalConfig.build().getActivation());
}
}

View File

@ -196,7 +196,7 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
regularizationBias = bConf.getRegularizationBias();
}
if (updater == null) {
updater = bConf.getIUpdater();
updater = bConf.getUpdater();
}
if (biasUpdater == null) {
biasUpdater = bConf.getBiasUpdater();

View File

@ -156,7 +156,7 @@ public abstract class SameDiffVertex extends GraphVertex implements ITraininable
regularizationBias = b_conf.getRegularizationBias();
}
if (updater == null) {
updater = b_conf.getIUpdater();
updater = b_conf.getUpdater();
}
if (biasUpdater == null) {
biasUpdater = b_conf.getBiasUpdater();

View File

@ -72,6 +72,7 @@ public class VariationalAutoencoder extends BasePretrainNetwork {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret =
new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(lconf, networkDataType);
lconf.runInheritance();
ret.addTrainingListeners(trainingListeners);
ret.setIndex(layerIndex);

View File

@ -98,7 +98,7 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
protected boolean requiresWeightInitFromLegacy(LayerConfiguration[] layers){
for(LayerConfiguration l : layers){
if(l instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration)l).getWeightInitFn() == null){
&& ((BaseLayerConfiguration)l).getWeightInit() == null){
return true;
}
}
@ -254,7 +254,7 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
d = NeuralNetConfiguration.mapper().readValue(dist, Distribution.class);
}
IWeightInit iwi = w.getWeightInitFunction(d);
baseLayerConfiguration.setWeightInitFn(iwi);
baseLayerConfiguration.setWeightInit(iwi);
} catch (Throwable t){
log.warn("Failed to infer weight initialization from legacy JSON format",t);
}

View File

@ -129,7 +129,7 @@ public class ComputationGraphConfigurationDeserializer
}
if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration)layers[layerIdx]).getWeightInitFn() == null){
&& ((BaseLayerConfiguration)layers[layerIdx]).getWeightInit() == null){
handleWeightInitBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next);
}
@ -160,7 +160,7 @@ public class ComputationGraphConfigurationDeserializer
layerIdx++;
} else if("org.deeplearning4j.nn.conf.graph.LayerVertex".equals(cls)){
if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration)layers[layerIdx]).getWeightInitFn() == null) {
&& ((BaseLayerConfiguration)layers[layerIdx]).getWeightInit() == null) {
//Post JSON format change for subclasses, but before WeightInit was made a class
confNode = (ObjectNode) next.get("layerConf");
next = confNode.get("layer");

View File

@ -141,7 +141,7 @@ public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserialize
}
if(requiresLegacyWeightInitHandling && layers[i] instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration) layers[i]).getWeightInitFn() == null) {
&& ((BaseLayerConfiguration) layers[i]).getWeightInit() == null) {
handleWeightInitBackwardCompatibility((BaseLayerConfiguration) layers[i], on);
}

View File

@ -88,14 +88,19 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
cacheMode = layerConfiguration.getNetConfiguration().getCacheMode();
}
this.dataType = dataType;
if (layerConfiguration.getNetConfiguration() == null) {
throw new RuntimeException("You cannot create a layer from a layer configuration, that is not part of any neural network configuration.");
}
this.net = layerConfiguration.getNetConfiguration().getNet();
}
public void addTrainingListeners(TrainingListener... listeners) {
if(listeners != null)
trainingListeners.addAll(List.of(listeners));
}
public void addTrainingListeners(Collection<TrainingListener> listeners) {
if(listeners != null)
trainingListeners.addAll(listeners);
}

View File

@ -77,7 +77,6 @@ public abstract class BaseLayer<LayerConfT extends BaseLayerConfiguration>
* INDArray params;
*/
public BaseLayer(LayerConfiguration conf, DataType dataType) {
super(conf, dataType);
}

View File

@ -21,7 +21,6 @@
package org.deeplearning4j.nn.layers.ocnn;
import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
@ -154,7 +153,7 @@ public class OCNNParamInitializer extends DefaultParamInitializer {
boolean initializeParameters) {
org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) configuration;
IWeightInit weightInit = ocnnOutputLayer.getWeightInitFn();
IWeightInit weightInit = ocnnOutputLayer.getWeightInit();
if (initializeParameters) {
INDArray ret = weightInit.init(weightParamView.size(0), //Fan in
weightParamView.size(1), //Fan out

View File

@ -92,7 +92,7 @@ public class VariationalAutoencoder implements Layer {
protected int epochCount;
@Getter @Setter @NonNull
private LayerConfiguration layerConfiguration;
private @Getter @Setter Collection<TrainingListener> trainingListeners;
private @Getter @Setter Collection<TrainingListener> trainingListeners = new HashSet<>();
public VariationalAutoencoder(@NonNull LayerConfiguration layerConfiguration, DataType dataType) {
this.layerConfiguration = layerConfiguration;
@ -113,6 +113,27 @@ public class VariationalAutoencoder implements Layer {
.getNumSamples();
}
/**
* Replace the TrainingListeners for this model
*
* @param listeners new listeners
*/
@Override
public void addTrainingListeners(TrainingListener... listeners) {
if(listeners != null)
trainingListeners.addAll(List.of(listeners));
}
/**
*
* @param listeners
*/
@Override
public void addTrainingListeners(Collection<TrainingListener> listeners) {
if(listeners != null)
trainingListeners.addAll(listeners);
}
/**
* Get a reference to the network this layer is part of.
*
@ -1214,24 +1235,6 @@ public class VariationalAutoencoder implements Layer {
//No-op for individual layers
}
/**
* Replace the TrainingListeners for this model
*
* @param listeners new listeners
*/
@Override
public void addTrainingListeners(TrainingListener... listeners) {
trainingListeners.addAll(List.of(listeners));
}
/**
*
* @param listeners
*/
@Override
public void addTrainingListeners(Collection<TrainingListener> listeners) {
trainingListeners.addAll(listeners);
}
@AllArgsConstructor
@Data

View File

@ -22,7 +22,6 @@ package org.deeplearning4j.nn.params;
import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -131,7 +130,7 @@ public class Convolution3DParamInitializer extends ConvolutionParamInitializer {
val weightsShape = new long[]{outputDepth, inputDepth, kernel[0], kernel[1], kernel[2]};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c',
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
weightView);
} else {
int[] kernel = layerConf.getKernelSize();

View File

@ -180,7 +180,7 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer {
val weightsShape = new long[] {outputDepth, inputDepth, kernel[0], kernel[1]};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', weightView);
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView);
} else {
int[] kernel = layerConf.getKernelSize();
return WeightInitUtil.reshapeWeights(

View File

@ -22,7 +22,6 @@ package org.deeplearning4j.nn.params;
import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Deconvolution3D;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -130,7 +129,7 @@ public class Deconvolution3DParamInitializer extends ConvolutionParamInitializer
//libnd4j: [kD, kH, kW, oC, iC]
val weightsShape = new long[]{kernel[0], kernel[1], kernel[2], outputDepth, inputDepth};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', weightView);
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView);
} else {
int[] kernel = layerConf.getKernelSize();
return WeightInitUtil.reshapeWeights(

View File

@ -21,7 +21,6 @@
package org.deeplearning4j.nn.params;
import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -61,7 +60,7 @@ public class DeconvolutionParamInitializer extends ConvolutionParamInitializer {
val weightsShape = new long[] {inputDepth, outputDepth, kernel[0], kernel[1]};
INDArray weights = layerConf.getWeightInitFn().init(
INDArray weights = layerConf.getWeightInit().init(
fanIn, fanOut, weightsShape, 'c', weightView);
return weights;

View File

@ -196,13 +196,13 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf;
if (initializeParameters) {
if( layerConf.getWeightInitFn() == null) {
if( layerConf.getWeightInit() == null) {
// set a default and set warning
layerConf.setWeightInitFn(new WeightInitXavier());
layerConf.setWeightInit(new WeightInitXavier());
log.warn("Weight Initializer function was not set on layer {} of class {}, it will default to {}", conf.getLayerName(),
conf.getClass().getSimpleName(), WeightInitXavier.class.getSimpleName());
}
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInitFn(),
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInit(),
weightParamView, true);
} else {
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), null, weightParamView, false);

View File

@ -23,8 +23,6 @@ package org.deeplearning4j.nn.params;
import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -193,7 +191,7 @@ public class DepthwiseConvolutionParamInitializer extends AbstractParamInitializ
val weightsShape = new long[] {kernel[0], kernel[1], inputDepth, depthMultiplier};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c',
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
weightView);
} else {
int[] kernel = layerConf.getKernelSize();

View File

@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params;
import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -159,14 +157,14 @@ public class GravesBidirectionalLSTMParamInitializer extends AbstractParamInitia
val inputWShape = new long[]{nLast, 4 * nL};
val recurrentWShape = new long[]{nL, 4 * nL + 3};
params.put(INPUT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape,
params.put(INPUT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwF));
params.put(RECURRENT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, recurrentWShape,
params.put(RECURRENT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInit().init(fanIn, fanOut, recurrentWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwF));
params.put(BIAS_KEY_FORWARDS, bF);
params.put(INPUT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape,
params.put(INPUT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwR));
params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, recurrentWShape,
params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInit().init(fanIn, fanOut, recurrentWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwR));
params.put(BIAS_KEY_BACKWARDS, bR);
} else {

View File

@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params;
import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -124,10 +122,10 @@ public class GravesLSTMParamInitializer extends AbstractParamInitializer {
if(layerConf.getWeightInitFnRecurrent() != null){
rwInit = layerConf.getWeightInitFnRecurrent();
} else {
rwInit = layerConf.getWeightInitFn();
rwInit = layerConf.getWeightInit();
}
params.put(INPUT_WEIGHT_KEY,layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape,
params.put(INPUT_WEIGHT_KEY,layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView));
params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView));

View File

@ -27,7 +27,6 @@ import java.util.List;
import java.util.Map;
import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.IWeightInit;
@ -132,10 +131,10 @@ public class LSTMParamInitializer extends AbstractParamInitializer {
if(layerConf.getWeightInitFnRecurrent() != null){
rwInit = layerConf.getWeightInitFnRecurrent();
} else {
rwInit = layerConf.getWeightInitFn();
rwInit = layerConf.getWeightInit();
}
params.put(INPUT_WEIGHT_KEY, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape,
params.put(INPUT_WEIGHT_KEY, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView));
params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView));
biasView.put(new INDArrayIndex[] {NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nL, 2 * nL)},

View File

@ -133,7 +133,7 @@ public class PReLUParamInitializer extends AbstractParamInitializer {
PReLULayer layerConf = (PReLULayer) conf;
if (initializeParameters) {
return layerConf.getWeightInitFn().init(layerConf.getNIn(), layerConf.getNOut(),
return layerConf.getWeightInit().init(layerConf.getNIn(), layerConf.getNOut(),
weightShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView);
} else {
return WeightInitUtil.reshapeWeights(weightShape, weightParamView);

View File

@ -23,8 +23,6 @@ package org.deeplearning4j.nn.params;
import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -220,7 +218,7 @@ public class SeparableConvolutionParamInitializer extends AbstractParamInitializ
val weightsShape = new long[] {depthMultiplier, inputDepth, kernel[0], kernel[1]};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c',
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
weightView);
} else {
int[] kernel = layerConf.getKernelSize();
@ -249,7 +247,7 @@ public class SeparableConvolutionParamInitializer extends AbstractParamInitializ
val weightsShape = new long[] {outputDepth, depthMultiplier * inputDepth, 1, 1};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c',
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
weightView);
} else {
return WeightInitUtil.reshapeWeights(

View File

@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params;
import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.weights.IWeightInit;
@ -102,14 +100,14 @@ public class SimpleRnnParamInitializer extends AbstractParamInitializer {
if (initializeParams) {
m = getSubsets(paramsView, nIn, nOut, false, hasLayerNorm(c));
INDArray w = c.getWeightInitFn().init(nIn, nOut, new long[]{nIn, nOut}, 'f', m.get(WEIGHT_KEY));
INDArray w = c.getWeightInit().init(nIn, nOut, new long[]{nIn, nOut}, 'f', m.get(WEIGHT_KEY));
m.put(WEIGHT_KEY, w);
IWeightInit rwInit;
if (c.getWeightInitFnRecurrent() != null) {
rwInit = c.getWeightInitFnRecurrent();
} else {
rwInit = c.getWeightInitFn();
rwInit = c.getWeightInit();
}
INDArray rw = rwInit.init(nOut, nOut, new long[]{nOut, nOut}, 'f', m.get(RECURRENT_WEIGHT_KEY));

View File

@ -21,7 +21,6 @@
package org.deeplearning4j.nn.params;
import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.weights.IWeightInit;
@ -200,7 +199,7 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali
int[] encoderLayerSizes = layer.getEncoderLayerSizes();
int[] decoderLayerSizes = layer.getDecoderLayerSizes();
IWeightInit weightInit = layer.getWeightInitFn();
IWeightInit weightInit = layer.getWeightInit();
int soFar = 0;
for (int i = 0; i < encoderLayerSizes.length; i++) {

View File

@ -164,7 +164,7 @@ public class FineTuneConfiguration {
bl.setActivationFn(activationFn);
}
if (weightInitFn != null) {
bl.setWeightInitFn(weightInitFn);
bl.setWeightInit(weightInitFn);
}
if (biasInit != null) {
bl.setBiasInit(biasInit);
@ -264,10 +264,10 @@ public class FineTuneConfiguration {
NeuralNetConfiguration.NeuralNetConfigurationBuilder confBuilder = NeuralNetConfiguration.builder();
if (activationFn != null) {
confBuilder.activationFn(activationFn);
confBuilder.activation(activationFn);
}
if (weightInitFn != null) {
confBuilder.weightInitFn(weightInitFn);
confBuilder.weightInit(weightInitFn);
}
if (biasInit != null) {
confBuilder.biasInit(biasInit);

View File

@ -462,7 +462,7 @@ public class TransferLearning {
Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nInReplace can only be applide on FeedForward layers;" +
"got layer of type %s", layerImpl.getClass().getSimpleName());
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(init);
layerImplF.setWeightInit(init);
layerImplF.setNIn(nIn);
long numParams = layerImpl.initializer().numParams(layerConf);
INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams);
@ -480,7 +480,7 @@ public class TransferLearning {
Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nOutReplace can only be applide on FeedForward layers;" +
"got layer of type %s", layerImpl.getClass().getSimpleName());
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(scheme);
layerImplF.setWeightInit(scheme);
layerImplF.setNOut(nOut);
long numParams = layerImpl.initializer().numParams(layerConf);
INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams);
@ -492,7 +492,7 @@ public class TransferLearning {
layerImpl = layerConf; //modify in place
if(layerImpl instanceof FeedForwardLayer) {
layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(schemeNext);
layerImplF.setWeightInit(schemeNext);
layerImplF.setNIn(nOut);
numParams = layerImpl.initializer().numParams(layerConf);
if (numParams > 0) {
@ -738,7 +738,7 @@ public class TransferLearning {
layerImpl.resetLayerDefaultConfig();
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(scheme);
layerImplF.setWeightInit(scheme);
layerImplF.setNIn(nIn);
if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex
@ -767,7 +767,7 @@ public class TransferLearning {
LayerConfiguration layerImpl = layerConf.clone();
layerImpl.resetLayerDefaultConfig();
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(scheme);
layerImplF.setWeightInit(scheme);
layerImplF.setNOut(nOut);
if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex
@ -806,7 +806,7 @@ public class TransferLearning {
continue;
layerImpl = layerConf.clone();
layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(schemeNext);
layerImplF.setWeightInit(schemeNext);
layerImplF.setNIn(nOut);
nInFromNewConfig.put(fanoutVertexName, nOut);

View File

@ -207,10 +207,11 @@ public abstract class BaseMultiLayerUpdater<T extends IModel> implements Updater
*/
public void setStateViewArray(INDArray viewArray) {
if(this.updaterStateViewArray == null){
if(viewArray == null)
if(viewArray == null || viewArray.length()==0)
return; //No op - for example, SGD and NoOp updater - i.e., no stored state
else {
throw new IllegalStateException("Attempting to set updater state view array with null value");
//this.updaterStateViewArray.set
// throw new IllegalStateException("Attempting to set updater state view array with null value");
}
}
if (this.updaterStateViewArray.length() != viewArray.length())
@ -296,7 +297,7 @@ public abstract class BaseMultiLayerUpdater<T extends IModel> implements Updater
//PRE apply (gradient clipping, etc): done on a per-layer basis
for (Map.Entry<String, Gradient> entry : layerGradients.entrySet()) {
String layerName = entry.getKey();
ITrainableLayer layer = layersByName.get(layerName);
ITrainableLayer layer = layersByName.get(layerName); //Todo Layers may have the same name!?
preApply(layer, layerGradients.get(layerName), iteration);
}

View File

@ -29,7 +29,6 @@ import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
@ -39,7 +38,6 @@ import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
import org.junit.jupiter.api.Test;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.learning.config.Adam;
@ -85,8 +83,8 @@ class dnnTest {
.updater(Adam.builder().learningRate(0.0002).beta1(0.5).build())
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100)
.weightInitFn(new WeightInitXavier())
.activationFn(new ActivationSigmoid())
.weightInit(new WeightInitXavier())
.activation(new ActivationSigmoid())
// .inputType(InputType.convolutional(28, 28, 1))
.layer(new DenseLayer.Builder().nIn(6).nOut(20).build())
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())

View File

@ -1182,7 +1182,7 @@ public class TrainModule implements UIModule {
String.valueOf(nParams)});
if (nParams > 0) {
try {
String str = JsonMappers.getMapper().writeValueAsString(bl.getWeightInitFn());
String str = JsonMappers.getMapper().writeValueAsString(bl.getWeightInit());
layerInfoRows.add(new String[]{
i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str});
} catch (JsonProcessingException e) {

View File

@ -29,6 +29,7 @@ import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.zoo.ModelMetaData;
import org.deeplearning4j.zoo.PretrainedType;

View File

@ -176,7 +176,7 @@ public class ResNet50 extends ZooModel {
.activation(Activation.IDENTITY)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(updater)
.weightInitFn(weightInit)
.weightInit(weightInit)
.l1(1e-7)
.l2(5e-5)
.miniBatch(true)