Playing with some new code 2 - clean build/test

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-14 13:24:19 +02:00
parent 0f21ed9ec5
commit 1f2e82d3ef
73 changed files with 647 additions and 743 deletions

View File

@ -118,7 +118,7 @@ public class App {
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD) .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
//.weightInit(WeightInit.XAVIER) //.weightInit(WeightInit.XAVIER)
.weightInitFn(new WeightInitXavier()) .weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
.layersFromArray(genLayers()) .layersFromArray(genLayers())
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))

View File

@ -74,7 +74,7 @@ public class LayerBuilderTest extends BaseDL4JTest {
checkSerialization(layer); checkSerialization(layer);
assertEquals(act, layer.getActivationFn()); assertEquals(act, layer.getActivationFn());
assertEquals(weight.getWeightInitFunction(), layer.getWeightInitFn()); assertEquals(weight.getWeightInitFunction(), layer.getWeightInit());
assertEquals(new Dropout(dropOut), layer.getIDropout()); assertEquals(new Dropout(dropOut), layer.getIDropout());
assertEquals(updater, layer.getIUpdater()); assertEquals(updater, layer.getIUpdater());
assertEquals(gradNorm, layer.getGradientNormalization()); assertEquals(gradNorm, layer.getGradientNormalization());

View File

@ -99,8 +99,8 @@ public class LayerConfigTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInit());
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInit());
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0); assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0); assertEquals(1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
@ -117,8 +117,8 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInit());
assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInit());
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0); assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
assertEquals(0, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0); assertEquals(0, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);

View File

@ -185,7 +185,7 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3); assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3); assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn()); assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInit());
assertNull(TestUtils.getL1Reg(layerConf1.getRegularization())); assertNull(TestUtils.getL1Reg(layerConf1.getRegularization()));
assertNull(TestUtils.getL2Reg(layerConf1.getRegularization())); assertNull(TestUtils.getL2Reg(layerConf1.getRegularization()));

View File

@ -157,7 +157,7 @@ public class SameDiffConv extends SameDiffLayer {
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration clone = globalConfig.clone().build(); NeuralNetConfiguration clone = globalConfig.clone().build();
if (activation == null) { if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(clone.getActivationFn()); activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
} }
if (cm == null) { if (cm == null) {
cm = clone.getConvolutionMode(); cm = clone.getConvolutionMode();

View File

@ -119,7 +119,7 @@ public class SameDiffDense extends SameDiffLayer {
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration clone = globalConfig.clone().build(); NeuralNetConfiguration clone = globalConfig.clone().build();
if(activation == null){ if(activation == null){
activation = SameDiffLayerUtils.fromIActivation(clone.getActivationFn()); activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
} }
} }

View File

@ -141,9 +141,9 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getLayer("layer0").getLayerConfiguration()); BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getLayer("layer0").getLayerConfiguration());
BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getLayer("layer1").getLayerConfiguration()); BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getLayer("layer1").getLayerConfiguration());
BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getLayer("layer3").getLayerConfiguration()); BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getLayer("layer3").getLayerConfiguration());
assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1))); assertEquals(bl0.getWeightInit(), new WeightInitDistribution(new NormalDistribution(1, 1e-1)));
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); assertEquals(bl1.getWeightInit(), new WeightInitXavier());
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); assertEquals(bl1.getWeightInit(), new WeightInitXavier());
ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In") .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In")

View File

@ -163,14 +163,14 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(0).getLayer()); BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(0).getLayer());
BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(1).getLayer()); BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(1).getLayer());
BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(3).getLayer()); BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(3).getLayer());
assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class); assertEquals(bl0.getWeightInit().getClass(), WeightInitXavier.class);
try { try {
assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()), assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInit()),
JsonMappers.getMapper().writeValueAsString(new WeightInitDistribution(new NormalDistribution(1, 1e-1)))); JsonMappers.getMapper().writeValueAsString(new WeightInitDistribution(new NormalDistribution(1, 1e-1))));
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
assertEquals(bl3.getWeightInitFn(), new WeightInitXavier()); assertEquals(bl3.getWeightInit(), new WeightInitXavier());
//modelNow should have the same architecture as modelExpectedArch //modelNow should have the same architecture as modelExpectedArch
assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape()); assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
@ -506,13 +506,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration(); BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
assertEquals(new Adam(1e-4), l0.getIUpdater()); assertEquals(new Adam(1e-4), l0.getIUpdater());
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6); assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(new Adam(1e-4), l1.getIUpdater()); assertEquals(new Adam(1e-4), l1.getIUpdater());
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); assertEquals(new WeightInitRelu(), l1.getWeightInit());
assertEquals(0.2, TestUtils.getL2(l1), 1e-6); assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
assertEquals(BackpropType.Standard, conf.getBackpropType()); assertEquals(BackpropType.Standard, conf.getBackpropType());
@ -521,13 +521,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration(); l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
assertEquals(new Adam(2e-2), l0.getIUpdater()); assertEquals(new Adam(2e-2), l0.getIUpdater());
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6); assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration(); l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration();
assertEquals(new Adam(2e-2), l1.getIUpdater()); assertEquals(new Adam(2e-2), l1.getIUpdater());
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); assertEquals(new WeightInitRelu(), l1.getWeightInit());
assertEquals(0.2, TestUtils.getL2(l1), 1e-6); assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
assertEquals(BackpropType.TruncatedBPTT, net2.getNetConfiguration().getBackpropType()); assertEquals(BackpropType.TruncatedBPTT, net2.getNetConfiguration().getBackpropType());

View File

@ -37,6 +37,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.params.PretrainParamInitializer; import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater; import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -940,7 +941,9 @@ public class TestUpdaters extends BaseDL4JTest {
List<UpdaterBlock> blocks; List<UpdaterBlock> blocks;
NeuralNetConfiguration conf = NeuralNetConfiguration conf =
NeuralNetConfiguration.builder().updater(new Adam(0.5)).list() NeuralNetConfiguration.builder()
.updater(new Adam(0.5))
.weightInit(WeightInit.NORMAL)
.layer(0, new VariationalAutoencoder.Builder().nIn(8).nOut(12) .layer(0, new VariationalAutoencoder.Builder().nIn(8).nOut(12)
.encoderLayerSizes(10, 11).decoderLayerSizes(13, 14).build()) .encoderLayerSizes(10, 11).decoderLayerSizes(13, 14).build())
.build(); .build();

View File

@ -72,7 +72,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals("relu", l0.getActivationFn().toString()); assertEquals("relu", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater()); assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
@ -81,7 +81,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMCXENT); assertTrue(l1.getLossFn() instanceof LossMCXENT);
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater()); assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater());
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
@ -106,7 +106,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationLReLU); assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l0.getIDropout()); assertEquals(new Dropout(0.6), l0.getIDropout());
@ -118,7 +118,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMSE); assertTrue(l1.getLossFn() instanceof LossMSE);
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l1.getIDropout()); assertEquals(new Dropout(0.6), l1.getIDropout());
@ -145,7 +145,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals("tanh", l0.getActivationFn().toString()); assertEquals("tanh", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut()); assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
@ -165,7 +165,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood);
assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut()); assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);

View File

@ -74,7 +74,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals("relu", l0.getActivationFn().toString()); assertEquals("relu", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater()); assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
@ -83,7 +83,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMCXENT); assertTrue(l1.getLossFn() instanceof LossMCXENT);
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater()); assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater());
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
@ -108,7 +108,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationLReLU); assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l0.getIDropout()); assertEquals(new Dropout(0.6), l0.getIDropout());
@ -122,7 +122,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMSE); assertTrue(l1.getLossFn() instanceof LossMSE);
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l1.getIDropout()); assertEquals(new Dropout(0.6), l1.getIDropout());
@ -151,7 +151,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals("tanh", l0.getActivationFn().toString()); assertEquals("tanh", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut()); assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
@ -171,7 +171,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO
assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut()); assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);

View File

@ -75,7 +75,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals("relu", l0.getActivationFn().toString()); assertEquals("relu", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater()); assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
@ -84,7 +84,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMCXENT); assertTrue(l1.getLossFn() instanceof LossMCXENT);
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
@ -109,7 +109,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationLReLU); assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l0.getIDropout()); assertEquals(new Dropout(0.6), l0.getIDropout());
@ -123,7 +123,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMSE); assertTrue(l1.getLossFn() instanceof LossMSE);
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l1.getIDropout()); assertEquals(new Dropout(0.6), l1.getIDropout());
@ -152,7 +152,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals("tanh", l0.getActivationFn().toString()); assertEquals("tanh", l0.getActivationFn().toString());
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut()); assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
@ -172,7 +172,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO
assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut()); assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);

View File

@ -74,7 +74,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationReLU); assertTrue(l0.getActivationFn() instanceof ActivationReLU);
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertTrue(l0.getIUpdater() instanceof Nesterovs); assertTrue(l0.getIUpdater() instanceof Nesterovs);
Nesterovs n = (Nesterovs) l0.getIUpdater(); Nesterovs n = (Nesterovs) l0.getIUpdater();
assertEquals(0.9, n.getMomentum(), 1e-6); assertEquals(0.9, n.getMomentum(), 1e-6);
@ -87,7 +87,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMCXENT); assertTrue(l1.getLossFn() instanceof LossMCXENT);
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertTrue(l1.getIUpdater() instanceof Nesterovs); assertTrue(l1.getIUpdater() instanceof Nesterovs);
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
@ -113,7 +113,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationLReLU); assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertTrue(l0.getIUpdater() instanceof RmsProp); assertTrue(l0.getIUpdater() instanceof RmsProp);
RmsProp r = (RmsProp) l0.getIUpdater(); RmsProp r = (RmsProp) l0.getIUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6); assertEquals(0.96, r.getRmsDecay(), 1e-6);
@ -130,7 +130,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l1.getLossFn() instanceof LossMSE); assertTrue(l1.getLossFn() instanceof LossMSE);
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInitFn()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInit());
assertTrue(l1.getIUpdater() instanceof RmsProp); assertTrue(l1.getIUpdater() instanceof RmsProp);
r = (RmsProp) l1.getIUpdater(); r = (RmsProp) l1.getIUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6); assertEquals(0.96, r.getRmsDecay(), 1e-6);
@ -162,7 +162,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l0.getActivationFn() instanceof ActivationTanH); assertTrue(l0.getActivationFn() instanceof ActivationTanH);
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut()); assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertTrue(l0.getIUpdater() instanceof RmsProp); assertTrue(l0.getIUpdater() instanceof RmsProp);
RmsProp r = (RmsProp) l0.getIUpdater(); RmsProp r = (RmsProp) l0.getIUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6); assertEquals(0.96, r.getRmsDecay(), 1e-6);
@ -185,7 +185,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood);
assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut()); assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l2.getWeightInitFn()); assertEquals(new WeightInitRelu(), l2.getWeightInit());
assertTrue(l2.getIUpdater() instanceof RmsProp); assertTrue(l2.getIUpdater() instanceof RmsProp);
r = (RmsProp) l2.getIUpdater(); r = (RmsProp) l2.getIUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6); assertEquals(0.96, r.getRmsDecay(), 1e-6);

View File

@ -89,21 +89,21 @@ public class RegressionTest100a extends BaseDL4JTest {
GravesLSTM l0 = (GravesLSTM) net.getLayer(0).getLayerConfiguration(); GravesLSTM l0 = (GravesLSTM) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(200, l0.getNOut()); assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new RmsProp(0.1), l0.getIUpdater()); assertEquals(new RmsProp(0.1), l0.getIUpdater());
GravesLSTM l1 = (GravesLSTM) net.getLayer(1).getLayerConfiguration(); GravesLSTM l1 = (GravesLSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut()); assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l1)); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l1));
assertEquals(new RmsProp(0.1), l1.getIUpdater()); assertEquals(new RmsProp(0.1), l1.getIUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut()); assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new RmsProp(0.1), l0.getIUpdater()); assertEquals(new RmsProp(0.1), l0.getIUpdater());
@ -139,7 +139,7 @@ public class RegressionTest100a extends BaseDL4JTest {
assertEquals(32, l0.getNOut()); assertEquals(32, l0.getNOut());
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(0.05), l0.getIUpdater()); assertEquals(new Adam(0.05), l0.getIUpdater());
@ -175,7 +175,7 @@ public class RegressionTest100a extends BaseDL4JTest {
assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(new ActivationIdentity(), cl.getActivationFn());
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
assertEquals(new WeightInitXavier(), cl.getWeightInitFn()); assertEquals(new WeightInitXavier(), cl.getWeightInit());
assertArrayEquals(new int[]{1,1}, cl.getKernelSize()); assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
assertArrayEquals(new int[]{1,1}, cl.getKernelSize()); assertArrayEquals(new int[]{1,1}, cl.getKernelSize());

View File

@ -124,21 +124,21 @@ public class RegressionTest100b3 extends BaseDL4JTest {
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration(); LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(200, l0.getNOut()); assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getIUpdater());
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut()); assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l1)); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater()); assertEquals(new Adam(0.005), l1.getIUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut()); assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getIUpdater());
@ -174,7 +174,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
assertEquals(32, l0.getNOut()); assertEquals(32, l0.getNOut());
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(1e-3), l0.getIUpdater()); assertEquals(new Adam(1e-3), l0.getIUpdater());
@ -210,7 +210,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(new ActivationIdentity(), cl.getActivationFn());
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
assertEquals(new WeightInitXavier(), cl.getWeightInitFn()); assertEquals(new WeightInitXavier(), cl.getWeightInit());
assertArrayEquals(new int[]{1,1}, cl.getKernelSize()); assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
assertArrayEquals(new int[]{1,1}, cl.getKernelSize()); assertArrayEquals(new int[]{1,1}, cl.getKernelSize());

View File

@ -142,21 +142,21 @@ public class RegressionTest100b4 extends BaseDL4JTest {
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration(); LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(200, l0.getNOut()); assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getIUpdater());
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut()); assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater()); assertEquals(new Adam(0.005), l1.getIUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut()); assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
assertEquals(new Adam(0.005), l2.getIUpdater()); assertEquals(new Adam(0.005), l2.getIUpdater());
@ -192,7 +192,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(32, l0.getNOut()); assertEquals(32, l0.getNOut());
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(1e-3), l0.getIUpdater()); assertEquals(new Adam(1e-3), l0.getIUpdater());
@ -229,7 +229,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(new ActivationIdentity(), cl.getActivationFn());
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
assertEquals(new WeightInitXavier(), cl.getWeightInitFn()); assertEquals(new WeightInitXavier(), cl.getWeightInit());
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize()); assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
INDArray outExp; INDArray outExp;
@ -260,7 +260,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration(); ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationReLU(), l0.getActivationFn()); assertEquals(new ActivationReLU(), l0.getActivationFn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
@ -271,7 +271,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration(); SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationReLU(), l1.getActivationFn()); assertEquals(new ActivationReLU(), l1.getActivationFn());
assertEquals(8, l1.getNOut()); assertEquals(8, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater()); assertEquals(new Adam(0.005), l1.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
@ -297,7 +297,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration(); DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration();
assertEquals(new ActivationReLU(), l5.getActivationFn()); assertEquals(new ActivationReLU(), l5.getActivationFn());
assertEquals(16, l5.getNOut()); assertEquals(16, l5.getNOut());
assertEquals(new WeightInitXavier(), l5.getWeightInitFn()); assertEquals(new WeightInitXavier(), l5.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
assertEquals(new Adam(0.005), l5.getIUpdater()); assertEquals(new Adam(0.005), l5.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
@ -318,7 +318,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration(); ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration();
assertEquals(4, l8.getNOut()); assertEquals(4, l8.getNOut());
assertEquals(new WeightInitXavier(), l8.getWeightInitFn()); assertEquals(new WeightInitXavier(), l8.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
assertEquals(new Adam(0.005), l8.getIUpdater()); assertEquals(new Adam(0.005), l8.getIUpdater());
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize()); assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
@ -327,7 +327,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertArrayEquals(new int[]{0, 0}, l8.getPadding()); assertArrayEquals(new int[]{0, 0}, l8.getPadding());
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration(); CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
assertEquals(new WeightInitXavier(), l9.getWeightInitFn()); assertEquals(new WeightInitXavier(), l9.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
assertEquals(new Adam(0.005), l9.getIUpdater()); assertEquals(new Adam(0.005), l9.getIUpdater());
assertEquals(new LossMAE(), l9.getLossFn()); assertEquals(new LossMAE(), l9.getLossFn());

View File

@ -124,21 +124,21 @@ public class RegressionTest100b6 extends BaseDL4JTest {
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration(); LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(200, l0.getNOut()); assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getIUpdater());
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut()); assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater()); assertEquals(new Adam(0.005), l1.getIUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut()); assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
assertEquals(new Adam(0.005), l2.getIUpdater()); assertEquals(new Adam(0.005), l2.getIUpdater());
@ -174,7 +174,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(32, l0.getNOut()); assertEquals(32, l0.getNOut());
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(1e-3), l0.getIUpdater()); assertEquals(new Adam(1e-3), l0.getIUpdater());
@ -210,7 +210,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(new ActivationIdentity(), cl.getActivationFn());
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
assertEquals(new WeightInitXavier(), cl.getWeightInitFn()); assertEquals(new WeightInitXavier(), cl.getWeightInit());
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize()); assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
INDArray outExp; INDArray outExp;
@ -240,7 +240,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration(); ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationReLU(), l0.getActivationFn()); assertEquals(new ActivationReLU(), l0.getActivationFn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
@ -251,7 +251,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration(); SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationReLU(), l1.getActivationFn()); assertEquals(new ActivationReLU(), l1.getActivationFn());
assertEquals(8, l1.getNOut()); assertEquals(8, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater()); assertEquals(new Adam(0.005), l1.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
@ -277,7 +277,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration(); DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration();
assertEquals(new ActivationReLU(), l5.getActivationFn()); assertEquals(new ActivationReLU(), l5.getActivationFn());
assertEquals(16, l5.getNOut()); assertEquals(16, l5.getNOut());
assertEquals(new WeightInitXavier(), l5.getWeightInitFn()); assertEquals(new WeightInitXavier(), l5.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
assertEquals(new Adam(0.005), l5.getIUpdater()); assertEquals(new Adam(0.005), l5.getIUpdater());
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
@ -298,7 +298,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration(); ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration();
assertEquals(4, l8.getNOut()); assertEquals(4, l8.getNOut());
assertEquals(new WeightInitXavier(), l8.getWeightInitFn()); assertEquals(new WeightInitXavier(), l8.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
assertEquals(new Adam(0.005), l8.getIUpdater()); assertEquals(new Adam(0.005), l8.getIUpdater());
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize()); assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
@ -307,7 +307,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertArrayEquals(new int[]{0, 0}, l8.getPadding()); assertArrayEquals(new int[]{0, 0}, l8.getPadding());
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration(); CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
assertEquals(new WeightInitXavier(), l9.getWeightInitFn()); assertEquals(new WeightInitXavier(), l9.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
assertEquals(new Adam(0.005), l9.getIUpdater()); assertEquals(new Adam(0.005), l9.getIUpdater());
assertEquals(new LossMAE(), l9.getLossFn()); assertEquals(new LossMAE(), l9.getLossFn());

View File

@ -167,7 +167,7 @@ public class KerasInitilizationTest extends BaseDL4JTest {
layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion);
DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer(); DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer();
assertEquals(dl4jInitializer, layer.getWeightInitFn()); assertEquals(dl4jInitializer, layer.getWeightInit());
} }
} }

View File

@ -79,7 +79,7 @@ public class KerasPReLUTest extends BaseDL4JTest {
PReLULayer layer = kerasPReLU.getPReLULayer(); PReLULayer layer = kerasPReLU.getPReLULayer();
assertArrayEquals(layer.getInputShape(), new long[] {3, 5, 4}); assertArrayEquals(layer.getInputShape(), new long[] {3, 5, 4});
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(layerName, layer.getLayerName()); assertEquals(layerName, layer.getLayerName());
} }

View File

@ -100,7 +100,7 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest {
Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D(); Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -114,7 +114,7 @@ public class KerasAtrousConvolution2DTest extends BaseDL4JTest {
ConvolutionLayer layer = new KerasAtrousConvolution2D(layerConfig).getAtrousConvolution2D(); ConvolutionLayer layer = new KerasAtrousConvolution2D(layerConfig).getAtrousConvolution2D();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -122,7 +122,7 @@ public class KerasConvolution1DTest extends BaseDL4JTest {
Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer(); Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -123,7 +123,7 @@ public class KerasConvolution2DTest extends BaseDL4JTest {
ConvolutionLayer layer = new KerasConvolution2D(layerConfig).getConvolution2DLayer(); ConvolutionLayer layer = new KerasConvolution2D(layerConfig).getConvolution2DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -119,7 +119,7 @@ public class KerasConvolution3DTest extends BaseDL4JTest {
ConvolutionLayer layer = new KerasConvolution3D(layerConfig).getConvolution3DLayer(); ConvolutionLayer layer = new KerasConvolution3D(layerConfig).getConvolution3DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -123,7 +123,7 @@ public class KerasDeconvolution2DTest extends BaseDL4JTest {
Deconvolution2D layer = new KerasDeconvolution2D(layerConfig).getDeconvolution2DLayer(); Deconvolution2D layer = new KerasDeconvolution2D(layerConfig).getDeconvolution2DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -128,7 +128,7 @@ public class KerasDepthwiseConvolution2DTest extends BaseDL4JTest {
DepthwiseConvolution2D layer = kerasLayer.getDepthwiseConvolution2DLayer(); DepthwiseConvolution2D layer = kerasLayer.getDepthwiseConvolution2DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier()); assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);

View File

@ -130,7 +130,7 @@ public class KerasSeparableConvolution2DTest extends BaseDL4JTest {
SeparableConvolution2D layer = new KerasSeparableConvolution2D(layerConfig).getSeparableConvolution2DLayer(); SeparableConvolution2D layer = new KerasSeparableConvolution2D(layerConfig).getSeparableConvolution2DLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier()); assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier());

View File

@ -89,7 +89,7 @@ public class KerasDenseTest extends BaseDL4JTest {
DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer(); DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer();
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -38,7 +38,6 @@ import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -131,7 +130,7 @@ public class KerasLSTMTest extends BaseDL4JTest {
} }
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -101,7 +101,7 @@ public class KerasSimpleRnnTest extends BaseDL4JTest {
(SimpleRnn) ((LastTimeStep) new KerasSimpleRnn(layerConfig).getSimpleRnnLayer()).getUnderlying(); (SimpleRnn) ((LastTimeStep) new KerasSimpleRnn(layerConfig).getSimpleRnnLayer()).getUnderlying();
assertEquals(ACTIVATION, layer.getActivationFn().toString()); assertEquals(ACTIVATION, layer.getActivationFn().toString());
assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(LAYER_NAME, layer.getLayerName());
assertEquals(INIT_DL4J, layer.getWeightInitFn()); assertEquals(INIT_DL4J, layer.getWeightInit());
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());

View File

@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
public interface INeuralNetworkConfiguration extends Serializable, Cloneable { public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
INeuralNetworkConfiguration clone(); INeuralNetworkConfiguration clone();
void init(); void init();
/** /**
@ -35,28 +36,4 @@ public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
* @return * @return
*/ */
IModel getNet(); IModel getNet();
} }
/**
/**
* Provides a flat list of all embedded layer configurations, this
* can only be called after the layer is initialized or {@link #getLayerConfigurations()} is
* called.
*
* @return unstacked layer configurations
List<ILayerConfiguration> getLayerConfigurations();
/**
* This uncollables any stacked layer configurations within building blocks like
* @link BuildingBlockLayer}
void calculateInnerLayerConfigurations();
/**
* An implementation should provide a method to validate the network
* @return true if no errors found; false otherwise
boolean isValid();
}
**/

View File

@ -259,7 +259,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
*/ */
private static void handleLegacyWeightInitFromJson(String json, LayerConfiguration layer, ObjectMapper mapper, JsonNode vertices) { private static void handleLegacyWeightInitFromJson(String json, LayerConfiguration layer, ObjectMapper mapper, JsonNode vertices) {
if (layer instanceof BaseLayerConfiguration if (layer instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration) layer).getWeightInitFn() == null) { && ((BaseLayerConfiguration) layer).getWeightInit() == null) {
String layerName = layer.getLayerName(); String layerName = layer.getLayerName();
try { try {
@ -291,7 +291,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
if (weightInit != null) { if (weightInit != null) {
final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist); final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist);
((BaseLayerConfiguration) layer).setWeightInitFn(wi); ((BaseLayerConfiguration) layer).setWeightInit(wi);
} }
} catch (IOException e) { } catch (IOException e) {

View File

@ -35,15 +35,11 @@ import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.*;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import lombok.extern.jackson.Jacksonized; import lombok.extern.jackson.Jacksonized;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
@ -67,9 +63,9 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.nn.conf.stepfunctions.DefaultStepFunction;
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction; import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.OutputLayerUtil; import org.deeplearning4j.util.OutputLayerUtil;
@ -319,16 +315,14 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
private boolean validateTbpttConfig = true; private boolean validateTbpttConfig = true;
/** /**
* Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} or * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} or
* {@link org.nd4j.linalg.learning.config.Nesterovs}<br> Note: values set by this method will be * {@link org.nd4j.linalg.learning.config.Nesterovs}<br>
* applied to all applicable layers in the network, unless a different value is explicitly set on * Note: values set by this method will be applied to all applicable layers in the network, unless
* a given layer. In other words: values set via this method are used as the default value, and * a different value is explicitly set on a given layer. In other words: values set via this
* can be overridden on a per-layer basis. * method are used as the default value, and can be overridden on a per-layer basis.
* *
* @param updater Updater to use * @param updater Updater to use
*/ */
@Getter @Getter @Setter @Builder.Default private IUpdater updater = new Sgd();
@Setter
private IUpdater updater;
/** /**
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping
* etc. See {@link GradientNormalization} for details<br> Note: values set by this method will be * etc. See {@link GradientNormalization} for details<br> Note: values set by this method will be
@ -357,19 +351,9 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
@Setter @Setter
private double gradientNormalizationThreshold; private double gradientNormalizationThreshold;
/** // whether to constrain the gradient to unit norm or not
* Activation function / neuron non-linearity<br> Note: values set by this method will be applied @Getter @Setter @Builder.Default private StepFunction stepFunction = new DefaultStepFunction();
* to all applicable layers in the network, unless a different value is explicitly set on a given
* layer. In other words: values set via this method are used as the default value, and can be
* overridden on a per-layer basis.
*/
@Getter
@Setter
private IActivation activation;
//whether to constrain the gradient to unit norm or not
@Getter
@Setter
private StepFunction stepFunction;
@Getter @Getter
@Setter @Setter
@lombok.Builder.Default @lombok.Builder.Default
@ -400,13 +384,10 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
@Getter @Getter
@lombok.Builder.Default @lombok.Builder.Default
private List<Regularization> regularizationBias = new ArrayList<>(); private List<Regularization> regularizationBias = new ArrayList<>();
@Getter
@Setter
@lombok.Builder.Default
private IUpdater iUpdater = new Sgd();
/** /**
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as * Gradient updater configuration, for the biases only. If not set, biases will use the updater as
* set by {@link #setIUpdater(IUpdater)}<br> * set by {@link #setUpdater(IUpdater)}<br>
* Note: values set by this method will be applied to all applicable layers in the network, unless a different * Note: values set by this method will be applied to all applicable layers in the network, unless a different
* value is explicitly set on a given layer. In other words: values set via this method are used as the default * value is explicitly set on a given layer. In other words: values set via this method are used as the default
* value, and can be overridden on a per-layer basis. * value, and can be overridden on a per-layer basis.
@ -420,7 +401,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
@Getter @Getter
@Setter @Setter
@lombok.Builder.Default @lombok.Builder.Default
private IActivation activationFn = new ActivationSigmoid(); private IActivation activation = new ActivationSigmoid();
/** /**
* Sets the convolution mode for convolutional layers, which impacts padding and output sizes. * Sets the convolution mode for convolutional layers, which impacts padding and output sizes.
@ -698,7 +679,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l, private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l,
ObjectMapper mapper, ObjectMapper mapper,
JsonNode confs, int layerCount) { JsonNode confs, int layerCount) {
if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInitFn() == null) { if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInit() == null) {
try { try {
JsonNode jsonNode = mapper.readTree(json); JsonNode jsonNode = mapper.readTree(json);
if (confs == null) { if (confs == null) {
@ -729,7 +710,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
if (weightInit != null) { if (weightInit != null) {
final IWeightInit wi = WeightInit.valueOf(weightInit.asText()) final IWeightInit wi = WeightInit.valueOf(weightInit.asText())
.getWeightInitFunction(dist); .getWeightInitFunction(dist);
((BaseLayerConfiguration) l).setWeightInitFn(wi); ((BaseLayerConfiguration) l).setWeightInit(wi);
} }
} }
@ -851,8 +832,8 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
* that do not have an individual setting (nor a default) * that do not have an individual setting (nor a default)
*/ */
for(LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) { for(LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) {
if(lconf.getActivationFn() == null ) lconf.setActivationFn(this.getActivationFn()); if(lconf.getActivationFn() == null ) lconf.setActivationFn(this.getActivation());
if(lconf.getIUpdater() == null ) lconf.setIUpdater( this.getIUpdater() ); if(lconf.getIUpdater() == null ) lconf.setIUpdater( this.getUpdater() );
if(lconf.getIDropout() == null ) lconf.setIDropout( this.getIdropOut() ); if(lconf.getIDropout() == null ) lconf.setIDropout( this.getIdropOut() );
if(lconf.getWeightNoise() == null ) lconf.setWeightNoise( this.getWeightNoise()); if(lconf.getWeightNoise() == null ) lconf.setWeightNoise( this.getWeightNoise());
@ -1108,13 +1089,19 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
*/ */
public List<LayerConfiguration> getFlattenedLayerConfigurations(NeuralNetConfiguration conf) { public List<LayerConfiguration> getFlattenedLayerConfigurations(NeuralNetConfiguration conf) {
List<LayerConfiguration> ret = new ArrayList<>(); //create the final return list List<LayerConfiguration> ret = new ArrayList<>(); //create the final return list
for( Object obj : conf.getInnerConfigurations().stream().skip(1) //don't include self //When properly initialized, _this_ configuration is set first in the list, however we
.collect(Collectors.toList())) { //can find cases where this is not true, thus the first configuration is another net or layer configuration
//and should not be skipped. In essence, skip first configuration if that is "this".
int iSkip = 0;
if(conf.getInnerConfigurations().size()>0 && conf.getInnerConfigurations().get(0).equals(this)) { iSkip=1;}
conf.getInnerConfigurations().stream().skip(iSkip)
.forEach(obj -> {
//if Layer Config, include in list and inherit parameters from this conf //if Layer Config, include in list and inherit parameters from this conf
//else if neural net configuration, call self recursively to resolve layer configurations //else if neural net configuration, call self recursively to resolve layer configurations
if (obj instanceof LayerConfiguration) if (obj instanceof LayerConfiguration) {
((LayerConfiguration) obj).setNetConfiguration(conf);
ret.add((LayerConfiguration) obj); ret.add((LayerConfiguration) obj);
else if (obj instanceof NeuralNetConfiguration) } else if (obj instanceof NeuralNetConfiguration)
ret.addAll(getFlattenedLayerConfigurations( ret.addAll(getFlattenedLayerConfigurations(
(NeuralNetConfiguration) obj)); (NeuralNetConfiguration) obj));
else { else {
@ -1122,15 +1109,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
"The list of layers and neural network configurations does contain an object of {}. Element will be ignored.", "The list of layers and neural network configurations does contain an object of {}. Element will be ignored.",
obj.getClass().getSimpleName()); obj.getClass().getSimpleName());
} }
} });
/**
LayerConfiguration lc = ((LayerConfiguration) lc).getType().getClazz().cast(obj);
switch(lc.getType()) {
case FC: { //fully connected layer
((FeedForwardLayer) lc).setWeightInitFn(this.getWeightInitFn());
}
if(lc instanceof FeedForwardLayer && ((FeedForwardLayer) lc).getWeightInitFn() == null) {
**/
return ret; return ret;
} }
@ -1143,17 +1122,6 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
return getFlattenedLayerConfigurations(this); return getFlattenedLayerConfigurations(this);
} }
/**
* Get the configuration of the first layer
* @return layer configuration
*/
/**
public LayerConfiguration getFirstLayer() {
return getFlattenedLayerConfigurations().get(0);
}
**/
/** /**
* Add a new layer to the first position * Add a new layer to the first position
* @param layer configuration * @param layer configuration

View File

@ -23,6 +23,7 @@ package org.deeplearning4j.nn.conf.layers;
import lombok.*; import lombok.*;
import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.api.ITraininableLayerConfiguration;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
@ -30,6 +31,7 @@ import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.deeplearning4j.util.NetworkUtils; import org.deeplearning4j.util.NetworkUtils;
import org.jetbrains.annotations.NotNull;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -52,7 +54,7 @@ import java.util.List;
public abstract class BaseLayerConfiguration extends LayerConfiguration implements ITraininableLayerConfiguration, Serializable, Cloneable { public abstract class BaseLayerConfiguration extends LayerConfiguration implements ITraininableLayerConfiguration, Serializable, Cloneable {
@NonNull @NonNull
protected IWeightInit weightInitFn; protected IWeightInit weightInit;
protected double biasInit = 0.0; protected double biasInit = 0.0;
protected double gainInit = 0.0; protected double gainInit = 0.0;
protected List<Regularization> regularization; protected List<Regularization> regularization;
@ -68,7 +70,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
public BaseLayerConfiguration(Builder builder) { public BaseLayerConfiguration(Builder builder) {
super(builder); super(builder);
this.layerName = builder.layerName; this.layerName = builder.layerName;
this.weightInitFn = builder.weightInitFn; this.weightInit = builder.weightInit;
this.biasInit = builder.biasInit; this.biasInit = builder.biasInit;
this.gainInit = builder.gainInit; this.gainInit = builder.gainInit;
this.regularization = builder.regularization; this.regularization = builder.regularization;
@ -89,7 +91,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
public void resetLayerDefaultConfig() { public void resetLayerDefaultConfig() {
//clear the learning related params for all layers in the origConf and set to defaults //clear the learning related params for all layers in the origConf and set to defaults
this.setIUpdater(null); this.setIUpdater(null);
this.setWeightInitFn(null); this.setWeightInit(null);
this.setBiasInit(Double.NaN); this.setBiasInit(Double.NaN);
this.setGainInit(Double.NaN); this.setGainInit(Double.NaN);
this.regularization = null; this.regularization = null;
@ -103,9 +105,6 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
@Override @Override
public BaseLayerConfiguration clone() { public BaseLayerConfiguration clone() {
BaseLayerConfiguration clone = (BaseLayerConfiguration) super.clone(); BaseLayerConfiguration clone = (BaseLayerConfiguration) super.clone();
if (clone.iDropout != null) {
clone.iDropout = clone.iDropout.clone();
}
if(regularization != null){ if(regularization != null){
//Regularization fields are _usually_ thread safe and immutable, but let's clone to be sure //Regularization fields are _usually_ thread safe and immutable, but let's clone to be sure
clone.regularization = new ArrayList<>(regularization.size()); clone.regularization = new ArrayList<>(regularization.size());
@ -170,7 +169,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
* *
* @see IWeightInit * @see IWeightInit
*/ */
protected IWeightInit weightInitFn = null; protected IWeightInit weightInit = null;
/** /**
* Bias initialization value, for layers with biases. Defaults to 0 * Bias initialization value, for layers with biases. Defaults to 0
@ -255,7 +254,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
* @see IWeightInit * @see IWeightInit
*/ */
public T weightInit(IWeightInit weightInit) { public T weightInit(IWeightInit weightInit) {
this.setWeightInitFn(weightInit); this.setWeightInit(weightInit);
return (T) this; return (T) this;
} }
@ -270,7 +269,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
"Not supported!, Use weightInit(Distribution distribution) instead!"); "Not supported!, Use weightInit(Distribution distribution) instead!");
} }
this.setWeightInitFn(weightInit.getWeightInitFunction()); this.setWeightInit(weightInit.getWeightInitFunction());
return (T) this; return (T) this;
} }
@ -508,4 +507,19 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
} }
} }
/**
* Inherit setting from neural network for those settings, that are not already set or do have
* a layer(type) specific default.
* @param conf the neural net configration to inherit parameters from
*/
@Override
public void runInheritance(@NotNull NeuralNetConfiguration conf) {
super.runInheritance(conf);
if(this.biasUpdater == null ) this.biasUpdater = conf.getBiasUpdater();
if(this.iUpdater == null ) this.iUpdater = conf.getUpdater();
if(this.regularizationBias == null) this.regularizationBias = conf.getRegularizationBias();
if(this.regularization == null ) this.regularization = conf.getRegularization();
if(this.gradientNormalization == null) this.gradientNormalization = conf.getGradientNormalization();
}
} }

View File

@ -172,6 +172,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
setNetConfiguration(conf); setNetConfiguration(conf);
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
lconf.runInheritance();
LayerValidation.assertNInNOutSet("ConvolutionLayer", getLayerName(), layerIndex, getNIn(), getNOut()); LayerValidation.assertNInNOutSet("ConvolutionLayer", getLayerName(), layerIndex, getNIn(), getNOut());
@ -404,9 +405,10 @@ public class ConvolutionLayer extends FeedForwardLayer {
/** /**
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
* Default is {@link ConvolutionMode}.Truncate.
* *
*/ */
protected ConvolutionMode convolutionMode; protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
/** /**
* Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated convolutions, * Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated convolutions,

View File

@ -62,19 +62,18 @@ public class DenseLayer extends FeedForwardLayer {
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getLayerName(), layerIndex, getNIn(), getNOut()); LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getLayerName(), layerIndex, getNIn(), getNOut());
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
lconf.runInheritance();
org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret = org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret =
new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType); new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType);
if(getWeightInitFn() == null) setWeightInitFn(new WeightInitXavier());
if(getWeightInit() == null) setWeightInit(new WeightInitXavier());
ret.addTrainingListeners(trainingListeners); ret.addTrainingListeners(trainingListeners);
ret.setIndex(layerIndex); ret.setIndex(layerIndex);
ret.setParamsViewArray(layerParamsView); ret.setParamsViewArray(layerParamsView);
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams); Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
ret.setParamTable(paramTable); ret.setParamTable(paramTable);
ret.setLayerConfiguration(lconf);
return ret; return ret;
} }

View File

@ -217,14 +217,14 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
return this; return this;
} }
@Override
public void setWeightInitFn(IWeightInit weightInit){ public void setWeightInitFn(IWeightInit weightInit){
if(weightInit instanceof WeightInitEmbedding){ if(weightInit instanceof WeightInitEmbedding){
long[] shape = ((WeightInitEmbedding) weightInit).shape(); long[] shape = ((WeightInitEmbedding) weightInit).shape();
nIn(shape[0]); nIn(shape[0]);
nOut(shape[1]); nOut(shape[1]);
} }
this.weightInitFn = weightInit; this.weightInit = weightInit;
} }
/** /**

View File

@ -66,28 +66,29 @@ import org.nd4j.linalg.learning.regularization.Regularization;
@Slf4j @Slf4j
public abstract class LayerConfiguration implements ILayerConfiguration, Serializable, Cloneable { // ITraininableLayerConfiguration public abstract class LayerConfiguration implements ILayerConfiguration, Serializable, Cloneable { // ITraininableLayerConfiguration
protected String layerName = "noname"; protected String layerName;
@Getter @Getter
protected List<String> variables = new ArrayList<>(); protected List<String> variables = new ArrayList<>();
public void addVariable(String s) {variables.add(s);}
protected IDropout iDropout;
protected List<LayerConstraint> constraints; protected List<LayerConstraint> constraints;
protected IWeightNoise weightNoise; protected IWeightNoise weightNoise;
private IDropout iDropout;
/** /**
* The type of the layer, basically defines the base class and its properties * The type of the layer, basically defines the base class and its properties
*/ */
@Getter @Setter @NonNull @Getter @Setter @NonNull
private LayerType type = LayerType.UNKNOWN; private LayerType type = LayerType.UNKNOWN;
@Getter @Setter @Getter @Setter
private NeuralNetConfiguration netConfiguration; private NeuralNetConfiguration netConfiguration;
@Getter @Setter
private IActivation activationFn;
public LayerConfiguration(Builder builder) { public LayerConfiguration(Builder builder) {
this.layerName = builder.layerName; this.layerName = builder.layerName;
this.iDropout = builder.iDropout; this.iDropout = builder.iDropout;
} }
public void addVariable(String s) {variables.add(s);}
public String toJson() { public String toJson() {
throw new RuntimeException("toJson is not implemented for LayerConfiguration"); throw new RuntimeException("toJson is not implemented for LayerConfiguration");
} }
@ -151,6 +152,7 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
public LayerConfiguration getLayer() { public LayerConfiguration getLayer() {
return this; return this;
} }
@Override @Override
public LayerConfiguration clone() { public LayerConfiguration clone() {
try { try {
@ -218,7 +220,6 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
*/ */
public abstract void setNIn(InputType inputType, boolean override); public abstract void setNIn(InputType inputType, boolean override);
/** /**
* For the given type of input to this layer, what preprocessor (if any) is required?<br> * For the given type of input to this layer, what preprocessor (if any) is required?<br>
* Returns null if no preprocessor is required, otherwise returns an appropriate {@link * Returns null if no preprocessor is required, otherwise returns an appropriate {@link
@ -263,11 +264,11 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
"Not supported: all layers with parameters should override this method"); "Not supported: all layers with parameters should override this method");
} }
public IUpdater getIUpdater() { public IUpdater getIUpdater() {
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Not supported: all layers with parameters should override this method"); "Not supported: all layers with parameters should override this method");
} }
public void setIUpdater(IUpdater iUpdater) { public void setIUpdater(IUpdater iUpdater) {
log.warn("Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), getLayerName()); log.warn("Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), getLayerName());
} }
@ -285,15 +286,33 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
this.variables.clear(); this.variables.clear();
} }
@Getter @Setter /**
private IActivation activationFn; * Inherit setting from neural network for those settings, that are not already set or do have
* a layer(type) specific default. This implementation does not require the neural network configuration to be
* the same as the one returned from this layers {@link #getNetConfiguration()}.
*
* @param conf a neural net configration to inherit parameters from
*
*/
public void runInheritance(@NonNull NeuralNetConfiguration conf) {
if(this.activationFn == null ) this.activationFn = conf.getActivation();
if(this.iDropout == null ) this.iDropout = conf.getIdropOut();
if(this.weightNoise == null) this.weightNoise = conf.getWeightNoise();
}
/** Runs {@link #runInheritance(NeuralNetConfiguration)} using the layers configurations embedded neural net
* configuration (the one returned from {@link #getNetConfiguration()}.
*/
public void runInheritance() {
runInheritance(getNetConfiguration());
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Getter @Getter
@Setter @Setter
public abstract static class Builder<T extends Builder<T>> { public abstract static class Builder<T extends Builder<T>> {
protected String layerName = "noname"; protected String layerName;
protected List<LayerConstraint> allParamConstraints; protected List<LayerConstraint> allParamConstraints;

View File

@ -215,7 +215,7 @@ public class LocallyConnected1D extends SameDiffLayer {
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration global_conf = globalConfig.build(); NeuralNetConfiguration global_conf = globalConfig.build();
if (activation == null) { if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(global_conf.getActivationFn()); activation = SameDiffLayerUtils.fromIActivation(global_conf.getActivation());
} }
if (cm == null) { if (cm == null) {
cm = global_conf.getConvolutionMode(); cm = global_conf.getConvolutionMode();

View File

@ -232,7 +232,7 @@ public class LocallyConnected2D extends SameDiffLayer {
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration gconf = globalConfig.build(); NeuralNetConfiguration gconf = globalConfig.build();
if (activation == null) { if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(gconf.getActivationFn()); activation = SameDiffLayerUtils.fromIActivation(gconf.getActivation());
} }
if (cm == null) { if (cm == null) {
cm = gconf.getConvolutionMode(); cm = gconf.getConvolutionMode();

View File

@ -117,7 +117,7 @@ public class PReLULayer extends BaseLayerConfiguration {
public Builder(){ public Builder(){
//Default to 0s, and don't inherit global default //Default to 0s, and don't inherit global default
this.weightInitFn = new WeightInitConstant(0); this.weightInit = new WeightInitConstant(0);
} }
/** /**

View File

@ -152,7 +152,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
@Override @Override
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
if (activation == null) { if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(globalConfig.build().getActivationFn()); activation = SameDiffLayerUtils.fromIActivation(globalConfig.build().getActivation());
} }
} }

View File

@ -196,7 +196,7 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
regularizationBias = bConf.getRegularizationBias(); regularizationBias = bConf.getRegularizationBias();
} }
if (updater == null) { if (updater == null) {
updater = bConf.getIUpdater(); updater = bConf.getUpdater();
} }
if (biasUpdater == null) { if (biasUpdater == null) {
biasUpdater = bConf.getBiasUpdater(); biasUpdater = bConf.getBiasUpdater();

View File

@ -156,7 +156,7 @@ public abstract class SameDiffVertex extends GraphVertex implements ITraininable
regularizationBias = b_conf.getRegularizationBias(); regularizationBias = b_conf.getRegularizationBias();
} }
if (updater == null) { if (updater == null) {
updater = b_conf.getIUpdater(); updater = b_conf.getUpdater();
} }
if (biasUpdater == null) { if (biasUpdater == null) {
biasUpdater = b_conf.getBiasUpdater(); biasUpdater = b_conf.getBiasUpdater();

View File

@ -72,6 +72,7 @@ public class VariationalAutoencoder extends BasePretrainNetwork {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret = org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret =
new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(lconf, networkDataType); new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(lconf, networkDataType);
lconf.runInheritance();
ret.addTrainingListeners(trainingListeners); ret.addTrainingListeners(trainingListeners);
ret.setIndex(layerIndex); ret.setIndex(layerIndex);

View File

@ -98,7 +98,7 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
protected boolean requiresWeightInitFromLegacy(LayerConfiguration[] layers){ protected boolean requiresWeightInitFromLegacy(LayerConfiguration[] layers){
for(LayerConfiguration l : layers){ for(LayerConfiguration l : layers){
if(l instanceof BaseLayerConfiguration if(l instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration)l).getWeightInitFn() == null){ && ((BaseLayerConfiguration)l).getWeightInit() == null){
return true; return true;
} }
} }
@ -254,7 +254,7 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
d = NeuralNetConfiguration.mapper().readValue(dist, Distribution.class); d = NeuralNetConfiguration.mapper().readValue(dist, Distribution.class);
} }
IWeightInit iwi = w.getWeightInitFunction(d); IWeightInit iwi = w.getWeightInitFunction(d);
baseLayerConfiguration.setWeightInitFn(iwi); baseLayerConfiguration.setWeightInit(iwi);
} catch (Throwable t){ } catch (Throwable t){
log.warn("Failed to infer weight initialization from legacy JSON format",t); log.warn("Failed to infer weight initialization from legacy JSON format",t);
} }

View File

@ -129,7 +129,7 @@ public class ComputationGraphConfigurationDeserializer
} }
if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration)layers[layerIdx]).getWeightInitFn() == null){ && ((BaseLayerConfiguration)layers[layerIdx]).getWeightInit() == null){
handleWeightInitBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next); handleWeightInitBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next);
} }
@ -160,7 +160,7 @@ public class ComputationGraphConfigurationDeserializer
layerIdx++; layerIdx++;
} else if("org.deeplearning4j.nn.conf.graph.LayerVertex".equals(cls)){ } else if("org.deeplearning4j.nn.conf.graph.LayerVertex".equals(cls)){
if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration)layers[layerIdx]).getWeightInitFn() == null) { && ((BaseLayerConfiguration)layers[layerIdx]).getWeightInit() == null) {
//Post JSON format change for subclasses, but before WeightInit was made a class //Post JSON format change for subclasses, but before WeightInit was made a class
confNode = (ObjectNode) next.get("layerConf"); confNode = (ObjectNode) next.get("layerConf");
next = confNode.get("layer"); next = confNode.get("layer");

View File

@ -141,7 +141,7 @@ public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserialize
} }
if(requiresLegacyWeightInitHandling && layers[i] instanceof BaseLayerConfiguration if(requiresLegacyWeightInitHandling && layers[i] instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration) layers[i]).getWeightInitFn() == null) { && ((BaseLayerConfiguration) layers[i]).getWeightInit() == null) {
handleWeightInitBackwardCompatibility((BaseLayerConfiguration) layers[i], on); handleWeightInitBackwardCompatibility((BaseLayerConfiguration) layers[i], on);
} }

View File

@ -88,14 +88,19 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
cacheMode = layerConfiguration.getNetConfiguration().getCacheMode(); cacheMode = layerConfiguration.getNetConfiguration().getCacheMode();
} }
this.dataType = dataType; this.dataType = dataType;
if (layerConfiguration.getNetConfiguration() == null) {
throw new RuntimeException("You cannot create a layer from a layer configuration, that is not part of any neural network configuration.");
}
this.net = layerConfiguration.getNetConfiguration().getNet(); this.net = layerConfiguration.getNetConfiguration().getNet();
} }
public void addTrainingListeners(TrainingListener... listeners) { public void addTrainingListeners(TrainingListener... listeners) {
if(listeners != null)
trainingListeners.addAll(List.of(listeners)); trainingListeners.addAll(List.of(listeners));
} }
public void addTrainingListeners(Collection<TrainingListener> listeners) { public void addTrainingListeners(Collection<TrainingListener> listeners) {
if(listeners != null)
trainingListeners.addAll(listeners); trainingListeners.addAll(listeners);
} }

View File

@ -77,7 +77,6 @@ public abstract class BaseLayer<LayerConfT extends BaseLayerConfiguration>
* INDArray params; * INDArray params;
*/ */
public BaseLayer(LayerConfiguration conf, DataType dataType) { public BaseLayer(LayerConfiguration conf, DataType dataType) {
super(conf, dataType); super(conf, dataType);
} }

View File

@ -21,7 +21,6 @@
package org.deeplearning4j.nn.layers.ocnn; package org.deeplearning4j.nn.layers.ocnn;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
@ -154,7 +153,7 @@ public class OCNNParamInitializer extends DefaultParamInitializer {
boolean initializeParameters) { boolean initializeParameters) {
org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) configuration; org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) configuration;
IWeightInit weightInit = ocnnOutputLayer.getWeightInitFn(); IWeightInit weightInit = ocnnOutputLayer.getWeightInit();
if (initializeParameters) { if (initializeParameters) {
INDArray ret = weightInit.init(weightParamView.size(0), //Fan in INDArray ret = weightInit.init(weightParamView.size(0), //Fan in
weightParamView.size(1), //Fan out weightParamView.size(1), //Fan out

View File

@ -92,7 +92,7 @@ public class VariationalAutoencoder implements Layer {
protected int epochCount; protected int epochCount;
@Getter @Setter @NonNull @Getter @Setter @NonNull
private LayerConfiguration layerConfiguration; private LayerConfiguration layerConfiguration;
private @Getter @Setter Collection<TrainingListener> trainingListeners; private @Getter @Setter Collection<TrainingListener> trainingListeners = new HashSet<>();
public VariationalAutoencoder(@NonNull LayerConfiguration layerConfiguration, DataType dataType) { public VariationalAutoencoder(@NonNull LayerConfiguration layerConfiguration, DataType dataType) {
this.layerConfiguration = layerConfiguration; this.layerConfiguration = layerConfiguration;
@ -113,6 +113,27 @@ public class VariationalAutoencoder implements Layer {
.getNumSamples(); .getNumSamples();
} }
/**
* Replace the TrainingListeners for this model
*
* @param listeners new listeners
*/
@Override
public void addTrainingListeners(TrainingListener... listeners) {
if(listeners != null)
trainingListeners.addAll(List.of(listeners));
}
/**
*
* @param listeners
*/
@Override
public void addTrainingListeners(Collection<TrainingListener> listeners) {
if(listeners != null)
trainingListeners.addAll(listeners);
}
/** /**
* Get a reference to the network this layer is part of. * Get a reference to the network this layer is part of.
* *
@ -1214,24 +1235,6 @@ public class VariationalAutoencoder implements Layer {
//No-op for individual layers //No-op for individual layers
} }
/**
* Replace the TrainingListeners for this model
*
* @param listeners new listeners
*/
@Override
public void addTrainingListeners(TrainingListener... listeners) {
trainingListeners.addAll(List.of(listeners));
}
/**
*
* @param listeners
*/
@Override
public void addTrainingListeners(Collection<TrainingListener> listeners) {
trainingListeners.addAll(listeners);
}
@AllArgsConstructor @AllArgsConstructor
@Data @Data

View File

@ -22,7 +22,6 @@ package org.deeplearning4j.nn.params;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.WeightInitUtil; import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -131,7 +130,7 @@ public class Convolution3DParamInitializer extends ConvolutionParamInitializer {
val weightsShape = new long[]{outputDepth, inputDepth, kernel[0], kernel[1], kernel[2]}; val weightsShape = new long[]{outputDepth, inputDepth, kernel[0], kernel[1], kernel[2]};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
weightView); weightView);
} else { } else {
int[] kernel = layerConf.getKernelSize(); int[] kernel = layerConf.getKernelSize();

View File

@ -180,7 +180,7 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer {
val weightsShape = new long[] {outputDepth, inputDepth, kernel[0], kernel[1]}; val weightsShape = new long[] {outputDepth, inputDepth, kernel[0], kernel[1]};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', weightView); return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView);
} else { } else {
int[] kernel = layerConf.getKernelSize(); int[] kernel = layerConf.getKernelSize();
return WeightInitUtil.reshapeWeights( return WeightInitUtil.reshapeWeights(

View File

@ -22,7 +22,6 @@ package org.deeplearning4j.nn.params;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Deconvolution3D; import org.deeplearning4j.nn.conf.layers.Deconvolution3D;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.WeightInitUtil; import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -130,7 +129,7 @@ public class Deconvolution3DParamInitializer extends ConvolutionParamInitializer
//libnd4j: [kD, kH, kW, oC, iC] //libnd4j: [kD, kH, kW, oC, iC]
val weightsShape = new long[]{kernel[0], kernel[1], kernel[2], outputDepth, inputDepth}; val weightsShape = new long[]{kernel[0], kernel[1], kernel[2], outputDepth, inputDepth};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', weightView); return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView);
} else { } else {
int[] kernel = layerConf.getKernelSize(); int[] kernel = layerConf.getKernelSize();
return WeightInitUtil.reshapeWeights( return WeightInitUtil.reshapeWeights(

View File

@ -21,7 +21,6 @@
package org.deeplearning4j.nn.params; package org.deeplearning4j.nn.params;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.WeightInitUtil; import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -61,7 +60,7 @@ public class DeconvolutionParamInitializer extends ConvolutionParamInitializer {
val weightsShape = new long[] {inputDepth, outputDepth, kernel[0], kernel[1]}; val weightsShape = new long[] {inputDepth, outputDepth, kernel[0], kernel[1]};
INDArray weights = layerConf.getWeightInitFn().init( INDArray weights = layerConf.getWeightInit().init(
fanIn, fanOut, weightsShape, 'c', weightView); fanIn, fanOut, weightsShape, 'c', weightView);
return weights; return weights;

View File

@ -196,13 +196,13 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf;
if (initializeParameters) { if (initializeParameters) {
if( layerConf.getWeightInitFn() == null) { if( layerConf.getWeightInit() == null) {
// set a default and set warning // set a default and set warning
layerConf.setWeightInitFn(new WeightInitXavier()); layerConf.setWeightInit(new WeightInitXavier());
log.warn("Weight Initializer function was not set on layer {} of class {}, it will default to {}", conf.getLayerName(), log.warn("Weight Initializer function was not set on layer {} of class {}, it will default to {}", conf.getLayerName(),
conf.getClass().getSimpleName(), WeightInitXavier.class.getSimpleName()); conf.getClass().getSimpleName(), WeightInitXavier.class.getSimpleName());
} }
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInitFn(), return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInit(),
weightParamView, true); weightParamView, true);
} else { } else {
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), null, weightParamView, false); return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), null, weightParamView, false);

View File

@ -23,8 +23,6 @@ package org.deeplearning4j.nn.params;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D; import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.WeightInitUtil; import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -193,7 +191,7 @@ public class DepthwiseConvolutionParamInitializer extends AbstractParamInitializ
val weightsShape = new long[] {kernel[0], kernel[1], inputDepth, depthMultiplier}; val weightsShape = new long[] {kernel[0], kernel[1], inputDepth, depthMultiplier};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
weightView); weightView);
} else { } else {
int[] kernel = layerConf.getKernelSize(); int[] kernel = layerConf.getKernelSize();

View File

@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil; import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -159,14 +157,14 @@ public class GravesBidirectionalLSTMParamInitializer extends AbstractParamInitia
val inputWShape = new long[]{nLast, 4 * nL}; val inputWShape = new long[]{nLast, 4 * nL};
val recurrentWShape = new long[]{nL, 4 * nL + 3}; val recurrentWShape = new long[]{nL, 4 * nL + 3};
params.put(INPUT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, params.put(INPUT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwF)); IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwF));
params.put(RECURRENT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, recurrentWShape, params.put(RECURRENT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInit().init(fanIn, fanOut, recurrentWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwF)); IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwF));
params.put(BIAS_KEY_FORWARDS, bF); params.put(BIAS_KEY_FORWARDS, bF);
params.put(INPUT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, params.put(INPUT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwR)); IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwR));
params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, recurrentWShape, params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInit().init(fanIn, fanOut, recurrentWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwR)); IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwR));
params.put(BIAS_KEY_BACKWARDS, bR); params.put(BIAS_KEY_BACKWARDS, bR);
} else { } else {

View File

@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil; import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -124,10 +122,10 @@ public class GravesLSTMParamInitializer extends AbstractParamInitializer {
if(layerConf.getWeightInitFnRecurrent() != null){ if(layerConf.getWeightInitFnRecurrent() != null){
rwInit = layerConf.getWeightInitFnRecurrent(); rwInit = layerConf.getWeightInitFnRecurrent();
} else { } else {
rwInit = layerConf.getWeightInitFn(); rwInit = layerConf.getWeightInit();
} }
params.put(INPUT_WEIGHT_KEY,layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, params.put(INPUT_WEIGHT_KEY,layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView)); IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView));
params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape, params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView)); IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView));

View File

@ -27,7 +27,6 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
@ -132,10 +131,10 @@ public class LSTMParamInitializer extends AbstractParamInitializer {
if(layerConf.getWeightInitFnRecurrent() != null){ if(layerConf.getWeightInitFnRecurrent() != null){
rwInit = layerConf.getWeightInitFnRecurrent(); rwInit = layerConf.getWeightInitFnRecurrent();
} else { } else {
rwInit = layerConf.getWeightInitFn(); rwInit = layerConf.getWeightInit();
} }
params.put(INPUT_WEIGHT_KEY, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, params.put(INPUT_WEIGHT_KEY, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView)); IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView));
params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView)); params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView));
biasView.put(new INDArrayIndex[] {NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nL, 2 * nL)}, biasView.put(new INDArrayIndex[] {NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nL, 2 * nL)},

View File

@ -133,7 +133,7 @@ public class PReLUParamInitializer extends AbstractParamInitializer {
PReLULayer layerConf = (PReLULayer) conf; PReLULayer layerConf = (PReLULayer) conf;
if (initializeParameters) { if (initializeParameters) {
return layerConf.getWeightInitFn().init(layerConf.getNIn(), layerConf.getNOut(), return layerConf.getWeightInit().init(layerConf.getNIn(), layerConf.getNOut(),
weightShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView); weightShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView);
} else { } else {
return WeightInitUtil.reshapeWeights(weightShape, weightParamView); return WeightInitUtil.reshapeWeights(weightShape, weightParamView);

View File

@ -23,8 +23,6 @@ package org.deeplearning4j.nn.params;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D; import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
import org.deeplearning4j.nn.weights.WeightInitUtil; import org.deeplearning4j.nn.weights.WeightInitUtil;
@ -220,7 +218,7 @@ public class SeparableConvolutionParamInitializer extends AbstractParamInitializ
val weightsShape = new long[] {depthMultiplier, inputDepth, kernel[0], kernel[1]}; val weightsShape = new long[] {depthMultiplier, inputDepth, kernel[0], kernel[1]};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
weightView); weightView);
} else { } else {
int[] kernel = layerConf.getKernelSize(); int[] kernel = layerConf.getKernelSize();
@ -249,7 +247,7 @@ public class SeparableConvolutionParamInitializer extends AbstractParamInitializ
val weightsShape = new long[] {outputDepth, depthMultiplier * inputDepth, 1, 1}; val weightsShape = new long[] {outputDepth, depthMultiplier * inputDepth, 1, 1};
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
weightView); weightView);
} else { } else {
return WeightInitUtil.reshapeWeights( return WeightInitUtil.reshapeWeights(

View File

@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.AbstractParamInitializer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
@ -102,14 +100,14 @@ public class SimpleRnnParamInitializer extends AbstractParamInitializer {
if (initializeParams) { if (initializeParams) {
m = getSubsets(paramsView, nIn, nOut, false, hasLayerNorm(c)); m = getSubsets(paramsView, nIn, nOut, false, hasLayerNorm(c));
INDArray w = c.getWeightInitFn().init(nIn, nOut, new long[]{nIn, nOut}, 'f', m.get(WEIGHT_KEY)); INDArray w = c.getWeightInit().init(nIn, nOut, new long[]{nIn, nOut}, 'f', m.get(WEIGHT_KEY));
m.put(WEIGHT_KEY, w); m.put(WEIGHT_KEY, w);
IWeightInit rwInit; IWeightInit rwInit;
if (c.getWeightInitFnRecurrent() != null) { if (c.getWeightInitFnRecurrent() != null) {
rwInit = c.getWeightInitFnRecurrent(); rwInit = c.getWeightInitFnRecurrent();
} else { } else {
rwInit = c.getWeightInitFn(); rwInit = c.getWeightInit();
} }
INDArray rw = rwInit.init(nOut, nOut, new long[]{nOut, nOut}, 'f', m.get(RECURRENT_WEIGHT_KEY)); INDArray rw = rwInit.init(nOut, nOut, new long[]{nOut, nOut}, 'f', m.get(RECURRENT_WEIGHT_KEY));

View File

@ -21,7 +21,6 @@
package org.deeplearning4j.nn.params; package org.deeplearning4j.nn.params;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
@ -200,7 +199,7 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali
int[] encoderLayerSizes = layer.getEncoderLayerSizes(); int[] encoderLayerSizes = layer.getEncoderLayerSizes();
int[] decoderLayerSizes = layer.getDecoderLayerSizes(); int[] decoderLayerSizes = layer.getDecoderLayerSizes();
IWeightInit weightInit = layer.getWeightInitFn(); IWeightInit weightInit = layer.getWeightInit();
int soFar = 0; int soFar = 0;
for (int i = 0; i < encoderLayerSizes.length; i++) { for (int i = 0; i < encoderLayerSizes.length; i++) {

View File

@ -164,7 +164,7 @@ public class FineTuneConfiguration {
bl.setActivationFn(activationFn); bl.setActivationFn(activationFn);
} }
if (weightInitFn != null) { if (weightInitFn != null) {
bl.setWeightInitFn(weightInitFn); bl.setWeightInit(weightInitFn);
} }
if (biasInit != null) { if (biasInit != null) {
bl.setBiasInit(biasInit); bl.setBiasInit(biasInit);
@ -264,10 +264,10 @@ public class FineTuneConfiguration {
NeuralNetConfiguration.NeuralNetConfigurationBuilder confBuilder = NeuralNetConfiguration.builder(); NeuralNetConfiguration.NeuralNetConfigurationBuilder confBuilder = NeuralNetConfiguration.builder();
if (activationFn != null) { if (activationFn != null) {
confBuilder.activationFn(activationFn); confBuilder.activation(activationFn);
} }
if (weightInitFn != null) { if (weightInitFn != null) {
confBuilder.weightInitFn(weightInitFn); confBuilder.weightInit(weightInitFn);
} }
if (biasInit != null) { if (biasInit != null) {
confBuilder.biasInit(biasInit); confBuilder.biasInit(biasInit);

View File

@ -462,7 +462,7 @@ public class TransferLearning {
Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nInReplace can only be applide on FeedForward layers;" + Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nInReplace can only be applide on FeedForward layers;" +
"got layer of type %s", layerImpl.getClass().getSimpleName()); "got layer of type %s", layerImpl.getClass().getSimpleName());
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl; FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(init); layerImplF.setWeightInit(init);
layerImplF.setNIn(nIn); layerImplF.setNIn(nIn);
long numParams = layerImpl.initializer().numParams(layerConf); long numParams = layerImpl.initializer().numParams(layerConf);
INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams);
@ -480,7 +480,7 @@ public class TransferLearning {
Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nOutReplace can only be applide on FeedForward layers;" + Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nOutReplace can only be applide on FeedForward layers;" +
"got layer of type %s", layerImpl.getClass().getSimpleName()); "got layer of type %s", layerImpl.getClass().getSimpleName());
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl; FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(scheme); layerImplF.setWeightInit(scheme);
layerImplF.setNOut(nOut); layerImplF.setNOut(nOut);
long numParams = layerImpl.initializer().numParams(layerConf); long numParams = layerImpl.initializer().numParams(layerConf);
INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams);
@ -492,7 +492,7 @@ public class TransferLearning {
layerImpl = layerConf; //modify in place layerImpl = layerConf; //modify in place
if(layerImpl instanceof FeedForwardLayer) { if(layerImpl instanceof FeedForwardLayer) {
layerImplF = (FeedForwardLayer) layerImpl; layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(schemeNext); layerImplF.setWeightInit(schemeNext);
layerImplF.setNIn(nOut); layerImplF.setNIn(nOut);
numParams = layerImpl.initializer().numParams(layerConf); numParams = layerImpl.initializer().numParams(layerConf);
if (numParams > 0) { if (numParams > 0) {
@ -738,7 +738,7 @@ public class TransferLearning {
layerImpl.resetLayerDefaultConfig(); layerImpl.resetLayerDefaultConfig();
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl; FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(scheme); layerImplF.setWeightInit(scheme);
layerImplF.setNIn(nIn); layerImplF.setNIn(nIn);
if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex
@ -767,7 +767,7 @@ public class TransferLearning {
LayerConfiguration layerImpl = layerConf.clone(); LayerConfiguration layerImpl = layerConf.clone();
layerImpl.resetLayerDefaultConfig(); layerImpl.resetLayerDefaultConfig();
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl; FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(scheme); layerImplF.setWeightInit(scheme);
layerImplF.setNOut(nOut); layerImplF.setNOut(nOut);
if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex
@ -806,7 +806,7 @@ public class TransferLearning {
continue; continue;
layerImpl = layerConf.clone(); layerImpl = layerConf.clone();
layerImplF = (FeedForwardLayer) layerImpl; layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(schemeNext); layerImplF.setWeightInit(schemeNext);
layerImplF.setNIn(nOut); layerImplF.setNIn(nOut);
nInFromNewConfig.put(fanoutVertexName, nOut); nInFromNewConfig.put(fanoutVertexName, nOut);

View File

@ -207,10 +207,11 @@ public abstract class BaseMultiLayerUpdater<T extends IModel> implements Updater
*/ */
public void setStateViewArray(INDArray viewArray) { public void setStateViewArray(INDArray viewArray) {
if(this.updaterStateViewArray == null){ if(this.updaterStateViewArray == null){
if(viewArray == null) if(viewArray == null || viewArray.length()==0)
return; //No op - for example, SGD and NoOp updater - i.e., no stored state return; //No op - for example, SGD and NoOp updater - i.e., no stored state
else { else {
throw new IllegalStateException("Attempting to set updater state view array with null value"); //this.updaterStateViewArray.set
// throw new IllegalStateException("Attempting to set updater state view array with null value");
} }
} }
if (this.updaterStateViewArray.length() != viewArray.length()) if (this.updaterStateViewArray.length() != viewArray.length())
@ -296,7 +297,7 @@ public abstract class BaseMultiLayerUpdater<T extends IModel> implements Updater
//PRE apply (gradient clipping, etc): done on a per-layer basis //PRE apply (gradient clipping, etc): done on a per-layer basis
for (Map.Entry<String, Gradient> entry : layerGradients.entrySet()) { for (Map.Entry<String, Gradient> entry : layerGradients.entrySet()) {
String layerName = entry.getKey(); String layerName = entry.getKey();
ITrainableLayer layer = layersByName.get(layerName); ITrainableLayer layer = layersByName.get(layerName); //Todo Layers may have the same name!?
preApply(layer, layerGradients.get(layerName), iteration); preApply(layer, layerGradients.get(layerName), iteration);
} }

View File

@ -29,7 +29,6 @@ import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator; import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
@ -39,7 +38,6 @@ import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
@ -85,8 +83,8 @@ class dnnTest {
.updater(Adam.builder().learningRate(0.0002).beta1(0.5).build()) .updater(Adam.builder().learningRate(0.0002).beta1(0.5).build())
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(100) .gradientNormalizationThreshold(100)
.weightInitFn(new WeightInitXavier()) .weightInit(new WeightInitXavier())
.activationFn(new ActivationSigmoid()) .activation(new ActivationSigmoid())
// .inputType(InputType.convolutional(28, 28, 1)) // .inputType(InputType.convolutional(28, 28, 1))
.layer(new DenseLayer.Builder().nIn(6).nOut(20).build()) .layer(new DenseLayer.Builder().nIn(6).nOut(20).build())
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())

View File

@ -1182,7 +1182,7 @@ public class TrainModule implements UIModule {
String.valueOf(nParams)}); String.valueOf(nParams)});
if (nParams > 0) { if (nParams > 0) {
try { try {
String str = JsonMappers.getMapper().writeValueAsString(bl.getWeightInitFn()); String str = JsonMappers.getMapper().writeValueAsString(bl.getWeightInit());
layerInfoRows.add(new String[]{ layerInfoRows.add(new String[]{
i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str}); i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str});
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {

View File

@ -29,6 +29,7 @@ import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.zoo.ModelMetaData; import org.deeplearning4j.zoo.ModelMetaData;
import org.deeplearning4j.zoo.PretrainedType; import org.deeplearning4j.zoo.PretrainedType;

View File

@ -176,7 +176,7 @@ public class ResNet50 extends ZooModel {
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(updater) .updater(updater)
.weightInitFn(weightInit) .weightInit(weightInit)
.l1(1e-7) .l1(1e-7)
.l2(5e-5) .l2(5e-5)
.miniBatch(true) .miniBatch(true)