Playing with some new code 2 - clean build/test
Signed-off-by: brian <brian@brutex.de>master
parent
0f21ed9ec5
commit
1f2e82d3ef
|
@ -118,7 +118,7 @@ public class App {
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
//.weightInit(WeightInit.XAVIER)
|
//.weightInit(WeightInit.XAVIER)
|
||||||
.weightInitFn(new WeightInitXavier())
|
.weightInit(WeightInit.XAVIER)
|
||||||
.activation(Activation.IDENTITY)
|
.activation(Activation.IDENTITY)
|
||||||
.layersFromArray(genLayers())
|
.layersFromArray(genLayers())
|
||||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||||
|
|
|
@ -74,7 +74,7 @@ public class LayerBuilderTest extends BaseDL4JTest {
|
||||||
checkSerialization(layer);
|
checkSerialization(layer);
|
||||||
|
|
||||||
assertEquals(act, layer.getActivationFn());
|
assertEquals(act, layer.getActivationFn());
|
||||||
assertEquals(weight.getWeightInitFunction(), layer.getWeightInitFn());
|
assertEquals(weight.getWeightInitFunction(), layer.getWeightInit());
|
||||||
assertEquals(new Dropout(dropOut), layer.getIDropout());
|
assertEquals(new Dropout(dropOut), layer.getIDropout());
|
||||||
assertEquals(updater, layer.getIUpdater());
|
assertEquals(updater, layer.getIUpdater());
|
||||||
assertEquals(gradNorm, layer.getGradientNormalization());
|
assertEquals(gradNorm, layer.getGradientNormalization());
|
||||||
|
|
|
@ -99,8 +99,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn());
|
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInit());
|
||||||
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn());
|
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInit());
|
||||||
|
|
||||||
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
|
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
|
||||||
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
|
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
|
||||||
|
@ -117,8 +117,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn());
|
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInit());
|
||||||
assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn());
|
assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInit());
|
||||||
|
|
||||||
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
|
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
|
||||||
assertEquals(0, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
|
assertEquals(0, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
|
||||||
|
|
|
@ -185,7 +185,7 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
|
||||||
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
|
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
|
||||||
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
|
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
|
||||||
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn());
|
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInit());
|
||||||
assertNull(TestUtils.getL1Reg(layerConf1.getRegularization()));
|
assertNull(TestUtils.getL1Reg(layerConf1.getRegularization()));
|
||||||
assertNull(TestUtils.getL2Reg(layerConf1.getRegularization()));
|
assertNull(TestUtils.getL2Reg(layerConf1.getRegularization()));
|
||||||
|
|
||||||
|
|
|
@ -157,7 +157,7 @@ public class SameDiffConv extends SameDiffLayer {
|
||||||
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
||||||
NeuralNetConfiguration clone = globalConfig.clone().build();
|
NeuralNetConfiguration clone = globalConfig.clone().build();
|
||||||
if (activation == null) {
|
if (activation == null) {
|
||||||
activation = SameDiffLayerUtils.fromIActivation(clone.getActivationFn());
|
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
|
||||||
}
|
}
|
||||||
if (cm == null) {
|
if (cm == null) {
|
||||||
cm = clone.getConvolutionMode();
|
cm = clone.getConvolutionMode();
|
||||||
|
|
|
@ -119,7 +119,7 @@ public class SameDiffDense extends SameDiffLayer {
|
||||||
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
||||||
NeuralNetConfiguration clone = globalConfig.clone().build();
|
NeuralNetConfiguration clone = globalConfig.clone().build();
|
||||||
if(activation == null){
|
if(activation == null){
|
||||||
activation = SameDiffLayerUtils.fromIActivation(clone.getActivationFn());
|
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -141,9 +141,9 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getLayer("layer0").getLayerConfiguration());
|
BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getLayer("layer0").getLayerConfiguration());
|
||||||
BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getLayer("layer1").getLayerConfiguration());
|
BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getLayer("layer1").getLayerConfiguration());
|
||||||
BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getLayer("layer3").getLayerConfiguration());
|
BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getLayer("layer3").getLayerConfiguration());
|
||||||
assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1)));
|
assertEquals(bl0.getWeightInit(), new WeightInitDistribution(new NormalDistribution(1, 1e-1)));
|
||||||
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
|
assertEquals(bl1.getWeightInit(), new WeightInitXavier());
|
||||||
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
|
assertEquals(bl1.getWeightInit(), new WeightInitXavier());
|
||||||
|
|
||||||
ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
|
ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
|
||||||
.addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In")
|
.addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In")
|
||||||
|
|
|
@ -163,14 +163,14 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(0).getLayer());
|
BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(0).getLayer());
|
||||||
BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(1).getLayer());
|
BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(1).getLayer());
|
||||||
BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(3).getLayer());
|
BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(3).getLayer());
|
||||||
assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class);
|
assertEquals(bl0.getWeightInit().getClass(), WeightInitXavier.class);
|
||||||
try {
|
try {
|
||||||
assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()),
|
assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInit()),
|
||||||
JsonMappers.getMapper().writeValueAsString(new WeightInitDistribution(new NormalDistribution(1, 1e-1))));
|
JsonMappers.getMapper().writeValueAsString(new WeightInitDistribution(new NormalDistribution(1, 1e-1))));
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
assertEquals(bl3.getWeightInitFn(), new WeightInitXavier());
|
assertEquals(bl3.getWeightInit(), new WeightInitXavier());
|
||||||
|
|
||||||
//modelNow should have the same architecture as modelExpectedArch
|
//modelNow should have the same architecture as modelExpectedArch
|
||||||
assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
|
assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
|
||||||
|
@ -506,13 +506,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new Adam(1e-4), l0.getIUpdater());
|
assertEquals(new Adam(1e-4), l0.getIUpdater());
|
||||||
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
||||||
|
|
||||||
BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new Adam(1e-4), l1.getIUpdater());
|
assertEquals(new Adam(1e-4), l1.getIUpdater());
|
||||||
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l1.getWeightInit());
|
||||||
assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
|
assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
|
||||||
|
|
||||||
assertEquals(BackpropType.Standard, conf.getBackpropType());
|
assertEquals(BackpropType.Standard, conf.getBackpropType());
|
||||||
|
@ -521,13 +521,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
|
l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new Adam(2e-2), l0.getIUpdater());
|
assertEquals(new Adam(2e-2), l0.getIUpdater());
|
||||||
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
||||||
|
|
||||||
l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration();
|
l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new Adam(2e-2), l1.getIUpdater());
|
assertEquals(new Adam(2e-2), l1.getIUpdater());
|
||||||
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l1.getWeightInit());
|
||||||
assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
|
assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
|
||||||
|
|
||||||
assertEquals(BackpropType.TruncatedBPTT, net2.getNetConfiguration().getBackpropType());
|
assertEquals(BackpropType.TruncatedBPTT, net2.getNetConfiguration().getBackpropType());
|
||||||
|
|
|
@ -37,6 +37,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||||
import org.deeplearning4j.nn.params.PretrainParamInitializer;
|
import org.deeplearning4j.nn.params.PretrainParamInitializer;
|
||||||
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
|
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -940,7 +941,9 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
|
|
||||||
List<UpdaterBlock> blocks;
|
List<UpdaterBlock> blocks;
|
||||||
NeuralNetConfiguration conf =
|
NeuralNetConfiguration conf =
|
||||||
NeuralNetConfiguration.builder().updater(new Adam(0.5)).list()
|
NeuralNetConfiguration.builder()
|
||||||
|
.updater(new Adam(0.5))
|
||||||
|
.weightInit(WeightInit.NORMAL)
|
||||||
.layer(0, new VariationalAutoencoder.Builder().nIn(8).nOut(12)
|
.layer(0, new VariationalAutoencoder.Builder().nIn(8).nOut(12)
|
||||||
.encoderLayerSizes(10, 11).decoderLayerSizes(13, 14).build())
|
.encoderLayerSizes(10, 11).decoderLayerSizes(13, 14).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertEquals("relu", l0.getActivationFn().toString());
|
assertEquals("relu", l0.getActivationFn().toString());
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
|
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertTrue(l1.getLossFn() instanceof LossMCXENT);
|
assertTrue(l1.getLossFn() instanceof LossMCXENT);
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater());
|
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater());
|
||||||
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
||||||
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
@ -106,7 +106,7 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
|
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l0.getIDropout());
|
assertEquals(new Dropout(0.6), l0.getIDropout());
|
||||||
|
@ -118,7 +118,7 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertTrue(l1.getLossFn() instanceof LossMSE);
|
assertTrue(l1.getLossFn() instanceof LossMSE);
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l1.getIDropout());
|
assertEquals(new Dropout(0.6), l1.getIDropout());
|
||||||
|
@ -145,7 +145,7 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertEquals("tanh", l0.getActivationFn().toString());
|
assertEquals("tanh", l0.getActivationFn().toString());
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(3, l0.getNOut());
|
assertEquals(3, l0.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
||||||
|
@ -165,7 +165,7 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood);
|
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood);
|
||||||
assertEquals(26 * 26 * 3, l2.getNIn());
|
assertEquals(26 * 26 * 3, l2.getNIn());
|
||||||
assertEquals(5, l2.getNOut());
|
assertEquals(5, l2.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertEquals("relu", l0.getActivationFn().toString());
|
assertEquals("relu", l0.getActivationFn().toString());
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
|
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertTrue(l1.getLossFn() instanceof LossMCXENT);
|
assertTrue(l1.getLossFn() instanceof LossMCXENT);
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater());
|
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater());
|
||||||
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
||||||
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
@ -108,7 +108,7 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
|
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l0.getIDropout());
|
assertEquals(new Dropout(0.6), l0.getIDropout());
|
||||||
|
@ -122,7 +122,7 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertTrue(l1.getLossFn() instanceof LossMSE);
|
assertTrue(l1.getLossFn() instanceof LossMSE);
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l1.getIDropout());
|
assertEquals(new Dropout(0.6), l1.getIDropout());
|
||||||
|
@ -151,7 +151,7 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertEquals("tanh", l0.getActivationFn().toString());
|
assertEquals("tanh", l0.getActivationFn().toString());
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(3, l0.getNOut());
|
assertEquals(3, l0.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
||||||
|
@ -171,7 +171,7 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO
|
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO
|
||||||
assertEquals(26 * 26 * 3, l2.getNIn());
|
assertEquals(26 * 26 * 3, l2.getNIn());
|
||||||
assertEquals(5, l2.getNOut());
|
assertEquals(5, l2.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,7 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertEquals("relu", l0.getActivationFn().toString());
|
assertEquals("relu", l0.getActivationFn().toString());
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
|
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertTrue(l1.getLossFn() instanceof LossMCXENT);
|
assertTrue(l1.getLossFn() instanceof LossMCXENT);
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
||||||
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
||||||
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
@ -109,7 +109,7 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
|
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l0.getIDropout());
|
assertEquals(new Dropout(0.6), l0.getIDropout());
|
||||||
|
@ -123,7 +123,7 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertTrue(l1.getLossFn() instanceof LossMSE);
|
assertTrue(l1.getLossFn() instanceof LossMSE);
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l1.getIDropout());
|
assertEquals(new Dropout(0.6), l1.getIDropout());
|
||||||
|
@ -152,7 +152,7 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertEquals("tanh", l0.getActivationFn().toString());
|
assertEquals("tanh", l0.getActivationFn().toString());
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(3, l0.getNOut());
|
assertEquals(3, l0.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
||||||
|
@ -172,7 +172,7 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO
|
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO
|
||||||
assertEquals(26 * 26 * 3, l2.getNIn());
|
assertEquals(26 * 26 * 3, l2.getNIn());
|
||||||
assertEquals(5, l2.getNOut());
|
assertEquals(5, l2.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertTrue(l0.getActivationFn() instanceof ActivationReLU);
|
assertTrue(l0.getActivationFn() instanceof ActivationReLU);
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertTrue(l0.getIUpdater() instanceof Nesterovs);
|
assertTrue(l0.getIUpdater() instanceof Nesterovs);
|
||||||
Nesterovs n = (Nesterovs) l0.getIUpdater();
|
Nesterovs n = (Nesterovs) l0.getIUpdater();
|
||||||
assertEquals(0.9, n.getMomentum(), 1e-6);
|
assertEquals(0.9, n.getMomentum(), 1e-6);
|
||||||
|
@ -87,7 +87,7 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertTrue(l1.getLossFn() instanceof LossMCXENT);
|
assertTrue(l1.getLossFn() instanceof LossMCXENT);
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertTrue(l1.getIUpdater() instanceof Nesterovs);
|
assertTrue(l1.getIUpdater() instanceof Nesterovs);
|
||||||
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
||||||
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
@ -113,7 +113,7 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
|
assertTrue(l0.getActivationFn() instanceof ActivationLReLU);
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertTrue(l0.getIUpdater() instanceof RmsProp);
|
assertTrue(l0.getIUpdater() instanceof RmsProp);
|
||||||
RmsProp r = (RmsProp) l0.getIUpdater();
|
RmsProp r = (RmsProp) l0.getIUpdater();
|
||||||
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
||||||
|
@ -130,7 +130,7 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertTrue(l1.getLossFn() instanceof LossMSE);
|
assertTrue(l1.getLossFn() instanceof LossMSE);
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInitFn());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInit());
|
||||||
assertTrue(l1.getIUpdater() instanceof RmsProp);
|
assertTrue(l1.getIUpdater() instanceof RmsProp);
|
||||||
r = (RmsProp) l1.getIUpdater();
|
r = (RmsProp) l1.getIUpdater();
|
||||||
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
||||||
|
@ -162,7 +162,7 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertTrue(l0.getActivationFn() instanceof ActivationTanH);
|
assertTrue(l0.getActivationFn() instanceof ActivationTanH);
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(3, l0.getNOut());
|
assertEquals(3, l0.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertTrue(l0.getIUpdater() instanceof RmsProp);
|
assertTrue(l0.getIUpdater() instanceof RmsProp);
|
||||||
RmsProp r = (RmsProp) l0.getIUpdater();
|
RmsProp r = (RmsProp) l0.getIUpdater();
|
||||||
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
||||||
|
@ -185,7 +185,7 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood);
|
assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood);
|
||||||
assertEquals(26 * 26 * 3, l2.getNIn());
|
assertEquals(26 * 26 * 3, l2.getNIn());
|
||||||
assertEquals(5, l2.getNOut());
|
assertEquals(5, l2.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l2.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l2.getWeightInit());
|
||||||
assertTrue(l2.getIUpdater() instanceof RmsProp);
|
assertTrue(l2.getIUpdater() instanceof RmsProp);
|
||||||
r = (RmsProp) l2.getIUpdater();
|
r = (RmsProp) l2.getIUpdater();
|
||||||
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
||||||
|
|
|
@ -89,21 +89,21 @@ public class RegressionTest100a extends BaseDL4JTest {
|
||||||
GravesLSTM l0 = (GravesLSTM) net.getLayer(0).getLayerConfiguration();
|
GravesLSTM l0 = (GravesLSTM) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
||||||
assertEquals(200, l0.getNOut());
|
assertEquals(200, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new RmsProp(0.1), l0.getIUpdater());
|
assertEquals(new RmsProp(0.1), l0.getIUpdater());
|
||||||
|
|
||||||
GravesLSTM l1 = (GravesLSTM) net.getLayer(1).getLayerConfiguration();
|
GravesLSTM l1 = (GravesLSTM) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
assertEquals(200, l1.getNOut());
|
assertEquals(200, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l1));
|
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l1));
|
||||||
assertEquals(new RmsProp(0.1), l1.getIUpdater());
|
assertEquals(new RmsProp(0.1), l1.getIUpdater());
|
||||||
|
|
||||||
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
||||||
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
||||||
assertEquals(77, l2.getNOut());
|
assertEquals(77, l2.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new RmsProp(0.1), l0.getIUpdater());
|
assertEquals(new RmsProp(0.1), l0.getIUpdater());
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ public class RegressionTest100a extends BaseDL4JTest {
|
||||||
assertEquals(32, l0.getNOut());
|
assertEquals(32, l0.getNOut());
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new Adam(0.05), l0.getIUpdater());
|
assertEquals(new Adam(0.05), l0.getIUpdater());
|
||||||
|
|
||||||
|
@ -175,7 +175,7 @@ public class RegressionTest100a extends BaseDL4JTest {
|
||||||
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
||||||
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
||||||
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
||||||
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), cl.getWeightInit());
|
||||||
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
|
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
|
||||||
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
|
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
|
||||||
|
|
||||||
|
|
|
@ -124,21 +124,21 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
|
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
||||||
assertEquals(200, l0.getNOut());
|
assertEquals(200, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getIUpdater());
|
||||||
|
|
||||||
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
assertEquals(200, l1.getNOut());
|
assertEquals(200, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l1));
|
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l1));
|
||||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
assertEquals(new Adam(0.005), l1.getIUpdater());
|
||||||
|
|
||||||
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
||||||
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
||||||
assertEquals(77, l2.getNOut());
|
assertEquals(77, l2.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getIUpdater());
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
assertEquals(32, l0.getNOut());
|
assertEquals(32, l0.getNOut());
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new Adam(1e-3), l0.getIUpdater());
|
assertEquals(new Adam(1e-3), l0.getIUpdater());
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
||||||
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
||||||
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
||||||
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), cl.getWeightInit());
|
||||||
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
|
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
|
||||||
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
|
assertArrayEquals(new int[]{1,1}, cl.getKernelSize());
|
||||||
|
|
||||||
|
|
|
@ -142,21 +142,21 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
|
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
||||||
assertEquals(200, l0.getNOut());
|
assertEquals(200, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getIUpdater());
|
||||||
|
|
||||||
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
assertEquals(200, l1.getNOut());
|
assertEquals(200, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
assertEquals(new Adam(0.005), l1.getIUpdater());
|
||||||
|
|
||||||
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
||||||
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
||||||
assertEquals(77, l2.getNOut());
|
assertEquals(77, l2.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
||||||
assertEquals(new Adam(0.005), l2.getIUpdater());
|
assertEquals(new Adam(0.005), l2.getIUpdater());
|
||||||
|
|
||||||
|
@ -192,7 +192,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
assertEquals(32, l0.getNOut());
|
assertEquals(32, l0.getNOut());
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(1e-3), l0.getIUpdater());
|
assertEquals(new Adam(1e-3), l0.getIUpdater());
|
||||||
|
|
||||||
|
@ -229,7 +229,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
||||||
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
||||||
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
||||||
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), cl.getWeightInit());
|
||||||
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
|
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
|
@ -260,7 +260,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration();
|
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new ActivationReLU(), l0.getActivationFn());
|
assertEquals(new ActivationReLU(), l0.getActivationFn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getIUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
|
||||||
|
@ -271,7 +271,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration();
|
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationReLU(), l1.getActivationFn());
|
assertEquals(new ActivationReLU(), l1.getActivationFn());
|
||||||
assertEquals(8, l1.getNOut());
|
assertEquals(8, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
assertEquals(new Adam(0.005), l1.getIUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
|
||||||
|
@ -297,7 +297,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration();
|
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration();
|
||||||
assertEquals(new ActivationReLU(), l5.getActivationFn());
|
assertEquals(new ActivationReLU(), l5.getActivationFn());
|
||||||
assertEquals(16, l5.getNOut());
|
assertEquals(16, l5.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l5.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l5.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
||||||
assertEquals(new Adam(0.005), l5.getIUpdater());
|
assertEquals(new Adam(0.005), l5.getIUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
|
||||||
|
@ -318,7 +318,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
|
|
||||||
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration();
|
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration();
|
||||||
assertEquals(4, l8.getNOut());
|
assertEquals(4, l8.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l8.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l8.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
|
||||||
assertEquals(new Adam(0.005), l8.getIUpdater());
|
assertEquals(new Adam(0.005), l8.getIUpdater());
|
||||||
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
|
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
|
||||||
|
@ -327,7 +327,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
assertArrayEquals(new int[]{0, 0}, l8.getPadding());
|
assertArrayEquals(new int[]{0, 0}, l8.getPadding());
|
||||||
|
|
||||||
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
|
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
|
||||||
assertEquals(new WeightInitXavier(), l9.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l9.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
|
||||||
assertEquals(new Adam(0.005), l9.getIUpdater());
|
assertEquals(new Adam(0.005), l9.getIUpdater());
|
||||||
assertEquals(new LossMAE(), l9.getLossFn());
|
assertEquals(new LossMAE(), l9.getLossFn());
|
||||||
|
|
|
@ -124,21 +124,21 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
|
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
||||||
assertEquals(200, l0.getNOut());
|
assertEquals(200, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getIUpdater());
|
||||||
|
|
||||||
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
assertEquals(200, l1.getNOut());
|
assertEquals(200, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
assertEquals(new Adam(0.005), l1.getIUpdater());
|
||||||
|
|
||||||
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
||||||
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
||||||
assertEquals(77, l2.getNOut());
|
assertEquals(77, l2.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l2.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
||||||
assertEquals(new Adam(0.005), l2.getIUpdater());
|
assertEquals(new Adam(0.005), l2.getIUpdater());
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertEquals(32, l0.getNOut());
|
assertEquals(32, l0.getNOut());
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(1e-3), l0.getIUpdater());
|
assertEquals(new Adam(1e-3), l0.getIUpdater());
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
||||||
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
||||||
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
||||||
assertEquals(new WeightInitXavier(), cl.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), cl.getWeightInit());
|
||||||
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
|
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
|
@ -240,7 +240,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration();
|
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new ActivationReLU(), l0.getActivationFn());
|
assertEquals(new ActivationReLU(), l0.getActivationFn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getIUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
|
||||||
|
@ -251,7 +251,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration();
|
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationReLU(), l1.getActivationFn());
|
assertEquals(new ActivationReLU(), l1.getActivationFn());
|
||||||
assertEquals(8, l1.getNOut());
|
assertEquals(8, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
assertEquals(new Adam(0.005), l1.getIUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
|
||||||
|
@ -277,7 +277,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration();
|
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration();
|
||||||
assertEquals(new ActivationReLU(), l5.getActivationFn());
|
assertEquals(new ActivationReLU(), l5.getActivationFn());
|
||||||
assertEquals(16, l5.getNOut());
|
assertEquals(16, l5.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l5.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l5.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
||||||
assertEquals(new Adam(0.005), l5.getIUpdater());
|
assertEquals(new Adam(0.005), l5.getIUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
|
||||||
|
@ -298,7 +298,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
|
|
||||||
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration();
|
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration();
|
||||||
assertEquals(4, l8.getNOut());
|
assertEquals(4, l8.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l8.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l8.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
|
||||||
assertEquals(new Adam(0.005), l8.getIUpdater());
|
assertEquals(new Adam(0.005), l8.getIUpdater());
|
||||||
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
|
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
|
||||||
|
@ -307,7 +307,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertArrayEquals(new int[]{0, 0}, l8.getPadding());
|
assertArrayEquals(new int[]{0, 0}, l8.getPadding());
|
||||||
|
|
||||||
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
|
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
|
||||||
assertEquals(new WeightInitXavier(), l9.getWeightInitFn());
|
assertEquals(new WeightInitXavier(), l9.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
|
||||||
assertEquals(new Adam(0.005), l9.getIUpdater());
|
assertEquals(new Adam(0.005), l9.getIUpdater());
|
||||||
assertEquals(new LossMAE(), l9.getLossFn());
|
assertEquals(new LossMAE(), l9.getLossFn());
|
||||||
|
|
|
@ -167,7 +167,7 @@ public class KerasInitilizationTest extends BaseDL4JTest {
|
||||||
layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion);
|
layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion);
|
||||||
|
|
||||||
DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer();
|
DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer();
|
||||||
assertEquals(dl4jInitializer, layer.getWeightInitFn());
|
assertEquals(dl4jInitializer, layer.getWeightInit());
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,7 @@ public class KerasPReLUTest extends BaseDL4JTest {
|
||||||
|
|
||||||
PReLULayer layer = kerasPReLU.getPReLULayer();
|
PReLULayer layer = kerasPReLU.getPReLULayer();
|
||||||
assertArrayEquals(layer.getInputShape(), new long[] {3, 5, 4});
|
assertArrayEquals(layer.getInputShape(), new long[] {3, 5, 4});
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
|
|
||||||
assertEquals(layerName, layer.getLayerName());
|
assertEquals(layerName, layer.getLayerName());
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,7 +100,7 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest {
|
||||||
Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
|
Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getLayerName());
|
assertEquals(LAYER_NAME, layer.getLayerName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
||||||
|
|
|
@ -114,7 +114,7 @@ public class KerasAtrousConvolution2DTest extends BaseDL4JTest {
|
||||||
ConvolutionLayer layer = new KerasAtrousConvolution2D(layerConfig).getAtrousConvolution2D();
|
ConvolutionLayer layer = new KerasAtrousConvolution2D(layerConfig).getAtrousConvolution2D();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getLayerName());
|
assertEquals(LAYER_NAME, layer.getLayerName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
||||||
|
|
|
@ -122,7 +122,7 @@ public class KerasConvolution1DTest extends BaseDL4JTest {
|
||||||
Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
|
Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getLayerName());
|
assertEquals(LAYER_NAME, layer.getLayerName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
||||||
|
|
|
@ -123,7 +123,7 @@ public class KerasConvolution2DTest extends BaseDL4JTest {
|
||||||
ConvolutionLayer layer = new KerasConvolution2D(layerConfig).getConvolution2DLayer();
|
ConvolutionLayer layer = new KerasConvolution2D(layerConfig).getConvolution2DLayer();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getLayerName());
|
assertEquals(LAYER_NAME, layer.getLayerName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
||||||
|
|
|
@ -119,7 +119,7 @@ public class KerasConvolution3DTest extends BaseDL4JTest {
|
||||||
ConvolutionLayer layer = new KerasConvolution3D(layerConfig).getConvolution3DLayer();
|
ConvolutionLayer layer = new KerasConvolution3D(layerConfig).getConvolution3DLayer();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getLayerName());
|
assertEquals(LAYER_NAME, layer.getLayerName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
||||||
|
|
|
@ -123,7 +123,7 @@ public class KerasDeconvolution2DTest extends BaseDL4JTest {
|
||||||
Deconvolution2D layer = new KerasDeconvolution2D(layerConfig).getDeconvolution2DLayer();
|
Deconvolution2D layer = new KerasDeconvolution2D(layerConfig).getDeconvolution2DLayer();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getLayerName());
|
assertEquals(LAYER_NAME, layer.getLayerName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
||||||
|
|
|
@ -128,7 +128,7 @@ public class KerasDepthwiseConvolution2DTest extends BaseDL4JTest {
|
||||||
DepthwiseConvolution2D layer = kerasLayer.getDepthwiseConvolution2DLayer();
|
DepthwiseConvolution2D layer = kerasLayer.getDepthwiseConvolution2DLayer();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getLayerName());
|
assertEquals(LAYER_NAME, layer.getLayerName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier());
|
assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier());
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
|
|
|
@ -130,7 +130,7 @@ public class KerasSeparableConvolution2DTest extends BaseDL4JTest {
|
||||||
SeparableConvolution2D layer = new KerasSeparableConvolution2D(layerConfig).getSeparableConvolution2DLayer();
|
SeparableConvolution2D layer = new KerasSeparableConvolution2D(layerConfig).getSeparableConvolution2DLayer();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getLayerName());
|
assertEquals(LAYER_NAME, layer.getLayerName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier());
|
assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier());
|
||||||
|
|
|
@ -89,7 +89,7 @@ public class KerasDenseTest extends BaseDL4JTest {
|
||||||
DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer();
|
DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getLayerName());
|
assertEquals(LAYER_NAME, layer.getLayerName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
||||||
|
|
|
@ -38,7 +38,6 @@ import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
import org.junit.jupiter.api.Assertions;
|
import org.junit.jupiter.api.Assertions;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -131,7 +130,7 @@ public class KerasLSTMTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getLayerName());
|
assertEquals(LAYER_NAME, layer.getLayerName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
||||||
|
|
|
@ -101,7 +101,7 @@ public class KerasSimpleRnnTest extends BaseDL4JTest {
|
||||||
(SimpleRnn) ((LastTimeStep) new KerasSimpleRnn(layerConfig).getSimpleRnnLayer()).getUnderlying();
|
(SimpleRnn) ((LastTimeStep) new KerasSimpleRnn(layerConfig).getSimpleRnnLayer()).getUnderlying();
|
||||||
assertEquals(ACTIVATION, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getLayerName());
|
assertEquals(LAYER_NAME, layer.getLayerName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInitFn());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout());
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
|
public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
|
||||||
|
|
||||||
INeuralNetworkConfiguration clone();
|
INeuralNetworkConfiguration clone();
|
||||||
|
|
||||||
void init();
|
void init();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -35,28 +36,4 @@ public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
IModel getNet();
|
IModel getNet();
|
||||||
}
|
}
|
||||||
/**
|
|
||||||
/**
|
|
||||||
* Provides a flat list of all embedded layer configurations, this
|
|
||||||
* can only be called after the layer is initialized or {@link #getLayerConfigurations()} is
|
|
||||||
* called.
|
|
||||||
*
|
|
||||||
* @return unstacked layer configurations
|
|
||||||
|
|
||||||
List<ILayerConfiguration> getLayerConfigurations();
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This uncollables any stacked layer configurations within building blocks like
|
|
||||||
* @link BuildingBlockLayer}
|
|
||||||
|
|
||||||
void calculateInnerLayerConfigurations();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An implementation should provide a method to validate the network
|
|
||||||
* @return true if no errors found; false otherwise
|
|
||||||
|
|
||||||
boolean isValid();
|
|
||||||
}
|
|
||||||
**/
|
|
|
@ -259,7 +259,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
*/
|
*/
|
||||||
private static void handleLegacyWeightInitFromJson(String json, LayerConfiguration layer, ObjectMapper mapper, JsonNode vertices) {
|
private static void handleLegacyWeightInitFromJson(String json, LayerConfiguration layer, ObjectMapper mapper, JsonNode vertices) {
|
||||||
if (layer instanceof BaseLayerConfiguration
|
if (layer instanceof BaseLayerConfiguration
|
||||||
&& ((BaseLayerConfiguration) layer).getWeightInitFn() == null) {
|
&& ((BaseLayerConfiguration) layer).getWeightInit() == null) {
|
||||||
String layerName = layer.getLayerName();
|
String layerName = layer.getLayerName();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
@ -291,7 +291,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
|
|
||||||
if (weightInit != null) {
|
if (weightInit != null) {
|
||||||
final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist);
|
final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist);
|
||||||
((BaseLayerConfiguration) layer).setWeightInitFn(wi);
|
((BaseLayerConfiguration) layer).setWeightInit(wi);
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -35,15 +35,11 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.Data;
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.*;
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.NonNull;
|
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import lombok.extern.jackson.Jacksonized;
|
import lombok.extern.jackson.Jacksonized;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
|
@ -67,9 +63,9 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
|
import org.deeplearning4j.nn.conf.stepfunctions.DefaultStepFunction;
|
||||||
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
|
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
|
||||||
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
||||||
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
|
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.util.OutputLayerUtil;
|
import org.deeplearning4j.util.OutputLayerUtil;
|
||||||
|
@ -319,16 +315,14 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
private boolean validateTbpttConfig = true;
|
private boolean validateTbpttConfig = true;
|
||||||
/**
|
/**
|
||||||
* Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} or
|
* Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} or
|
||||||
* {@link org.nd4j.linalg.learning.config.Nesterovs}<br> Note: values set by this method will be
|
* {@link org.nd4j.linalg.learning.config.Nesterovs}<br>
|
||||||
* applied to all applicable layers in the network, unless a different value is explicitly set on
|
* Note: values set by this method will be applied to all applicable layers in the network, unless
|
||||||
* a given layer. In other words: values set via this method are used as the default value, and
|
* a different value is explicitly set on a given layer. In other words: values set via this
|
||||||
* can be overridden on a per-layer basis.
|
* method are used as the default value, and can be overridden on a per-layer basis.
|
||||||
*
|
*
|
||||||
* @param updater Updater to use
|
* @param updater Updater to use
|
||||||
*/
|
*/
|
||||||
@Getter
|
@Getter @Setter @Builder.Default private IUpdater updater = new Sgd();
|
||||||
@Setter
|
|
||||||
private IUpdater updater;
|
|
||||||
/**
|
/**
|
||||||
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping
|
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping
|
||||||
* etc. See {@link GradientNormalization} for details<br> Note: values set by this method will be
|
* etc. See {@link GradientNormalization} for details<br> Note: values set by this method will be
|
||||||
|
@ -357,19 +351,9 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
@Setter
|
@Setter
|
||||||
private double gradientNormalizationThreshold;
|
private double gradientNormalizationThreshold;
|
||||||
|
|
||||||
/**
|
// whether to constrain the gradient to unit norm or not
|
||||||
* Activation function / neuron non-linearity<br> Note: values set by this method will be applied
|
@Getter @Setter @Builder.Default private StepFunction stepFunction = new DefaultStepFunction();
|
||||||
* to all applicable layers in the network, unless a different value is explicitly set on a given
|
|
||||||
* layer. In other words: values set via this method are used as the default value, and can be
|
|
||||||
* overridden on a per-layer basis.
|
|
||||||
*/
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
private IActivation activation;
|
|
||||||
//whether to constrain the gradient to unit norm or not
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
private StepFunction stepFunction;
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
@lombok.Builder.Default
|
@lombok.Builder.Default
|
||||||
|
@ -400,13 +384,10 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
@Getter
|
@Getter
|
||||||
@lombok.Builder.Default
|
@lombok.Builder.Default
|
||||||
private List<Regularization> regularizationBias = new ArrayList<>();
|
private List<Regularization> regularizationBias = new ArrayList<>();
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@lombok.Builder.Default
|
|
||||||
private IUpdater iUpdater = new Sgd();
|
|
||||||
/**
|
/**
|
||||||
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
|
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
|
||||||
* set by {@link #setIUpdater(IUpdater)}<br>
|
* set by {@link #setUpdater(IUpdater)}<br>
|
||||||
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
||||||
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
||||||
* value, and can be overridden on a per-layer basis.
|
* value, and can be overridden on a per-layer basis.
|
||||||
|
@ -420,7 +401,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
@lombok.Builder.Default
|
@lombok.Builder.Default
|
||||||
private IActivation activationFn = new ActivationSigmoid();
|
private IActivation activation = new ActivationSigmoid();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the convolution mode for convolutional layers, which impacts padding and output sizes.
|
* Sets the convolution mode for convolutional layers, which impacts padding and output sizes.
|
||||||
|
@ -698,7 +679,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l,
|
private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l,
|
||||||
ObjectMapper mapper,
|
ObjectMapper mapper,
|
||||||
JsonNode confs, int layerCount) {
|
JsonNode confs, int layerCount) {
|
||||||
if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInitFn() == null) {
|
if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInit() == null) {
|
||||||
try {
|
try {
|
||||||
JsonNode jsonNode = mapper.readTree(json);
|
JsonNode jsonNode = mapper.readTree(json);
|
||||||
if (confs == null) {
|
if (confs == null) {
|
||||||
|
@ -729,7 +710,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
if (weightInit != null) {
|
if (weightInit != null) {
|
||||||
final IWeightInit wi = WeightInit.valueOf(weightInit.asText())
|
final IWeightInit wi = WeightInit.valueOf(weightInit.asText())
|
||||||
.getWeightInitFunction(dist);
|
.getWeightInitFunction(dist);
|
||||||
((BaseLayerConfiguration) l).setWeightInitFn(wi);
|
((BaseLayerConfiguration) l).setWeightInit(wi);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -851,8 +832,8 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
* that do not have an individual setting (nor a default)
|
* that do not have an individual setting (nor a default)
|
||||||
*/
|
*/
|
||||||
for(LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) {
|
for(LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) {
|
||||||
if(lconf.getActivationFn() == null ) lconf.setActivationFn(this.getActivationFn());
|
if(lconf.getActivationFn() == null ) lconf.setActivationFn(this.getActivation());
|
||||||
if(lconf.getIUpdater() == null ) lconf.setIUpdater( this.getIUpdater() );
|
if(lconf.getIUpdater() == null ) lconf.setIUpdater( this.getUpdater() );
|
||||||
if(lconf.getIDropout() == null ) lconf.setIDropout( this.getIdropOut() );
|
if(lconf.getIDropout() == null ) lconf.setIDropout( this.getIdropOut() );
|
||||||
if(lconf.getWeightNoise() == null ) lconf.setWeightNoise( this.getWeightNoise());
|
if(lconf.getWeightNoise() == null ) lconf.setWeightNoise( this.getWeightNoise());
|
||||||
|
|
||||||
|
@ -1108,13 +1089,19 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
*/
|
*/
|
||||||
public List<LayerConfiguration> getFlattenedLayerConfigurations(NeuralNetConfiguration conf) {
|
public List<LayerConfiguration> getFlattenedLayerConfigurations(NeuralNetConfiguration conf) {
|
||||||
List<LayerConfiguration> ret = new ArrayList<>(); //create the final return list
|
List<LayerConfiguration> ret = new ArrayList<>(); //create the final return list
|
||||||
for( Object obj : conf.getInnerConfigurations().stream().skip(1) //don't include self
|
//When properly initialized, _this_ configuration is set first in the list, however we
|
||||||
.collect(Collectors.toList())) {
|
//can find cases where this is not true, thus the first configuration is another net or layer configuration
|
||||||
|
//and should not be skipped. In essence, skip first configuration if that is "this".
|
||||||
|
int iSkip = 0;
|
||||||
|
if(conf.getInnerConfigurations().size()>0 && conf.getInnerConfigurations().get(0).equals(this)) { iSkip=1;}
|
||||||
|
conf.getInnerConfigurations().stream().skip(iSkip)
|
||||||
|
.forEach(obj -> {
|
||||||
//if Layer Config, include in list and inherit parameters from this conf
|
//if Layer Config, include in list and inherit parameters from this conf
|
||||||
//else if neural net configuration, call self recursively to resolve layer configurations
|
//else if neural net configuration, call self recursively to resolve layer configurations
|
||||||
if (obj instanceof LayerConfiguration)
|
if (obj instanceof LayerConfiguration) {
|
||||||
|
((LayerConfiguration) obj).setNetConfiguration(conf);
|
||||||
ret.add((LayerConfiguration) obj);
|
ret.add((LayerConfiguration) obj);
|
||||||
else if (obj instanceof NeuralNetConfiguration)
|
} else if (obj instanceof NeuralNetConfiguration)
|
||||||
ret.addAll(getFlattenedLayerConfigurations(
|
ret.addAll(getFlattenedLayerConfigurations(
|
||||||
(NeuralNetConfiguration) obj));
|
(NeuralNetConfiguration) obj));
|
||||||
else {
|
else {
|
||||||
|
@ -1122,15 +1109,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
"The list of layers and neural network configurations does contain an object of {}. Element will be ignored.",
|
"The list of layers and neural network configurations does contain an object of {}. Element will be ignored.",
|
||||||
obj.getClass().getSimpleName());
|
obj.getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
/**
|
|
||||||
LayerConfiguration lc = ((LayerConfiguration) lc).getType().getClazz().cast(obj);
|
|
||||||
switch(lc.getType()) {
|
|
||||||
case FC: { //fully connected layer
|
|
||||||
((FeedForwardLayer) lc).setWeightInitFn(this.getWeightInitFn());
|
|
||||||
}
|
|
||||||
if(lc instanceof FeedForwardLayer && ((FeedForwardLayer) lc).getWeightInitFn() == null) {
|
|
||||||
**/
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1143,17 +1122,6 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
return getFlattenedLayerConfigurations(this);
|
return getFlattenedLayerConfigurations(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the configuration of the first layer
|
|
||||||
* @return layer configuration
|
|
||||||
*/
|
|
||||||
/**
|
|
||||||
public LayerConfiguration getFirstLayer() {
|
|
||||||
return getFlattenedLayerConfigurations().get(0);
|
|
||||||
}
|
|
||||||
**/
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a new layer to the first position
|
* Add a new layer to the first position
|
||||||
* @param layer configuration
|
* @param layer configuration
|
||||||
|
|
|
@ -23,6 +23,7 @@ package org.deeplearning4j.nn.conf.layers;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import org.deeplearning4j.nn.api.ITraininableLayerConfiguration;
|
import org.deeplearning4j.nn.api.ITraininableLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.Updater;
|
import org.deeplearning4j.nn.conf.Updater;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
||||||
|
@ -30,6 +31,7 @@ import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||||
import org.deeplearning4j.util.NetworkUtils;
|
import org.deeplearning4j.util.NetworkUtils;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -52,7 +54,7 @@ import java.util.List;
|
||||||
public abstract class BaseLayerConfiguration extends LayerConfiguration implements ITraininableLayerConfiguration, Serializable, Cloneable {
|
public abstract class BaseLayerConfiguration extends LayerConfiguration implements ITraininableLayerConfiguration, Serializable, Cloneable {
|
||||||
|
|
||||||
@NonNull
|
@NonNull
|
||||||
protected IWeightInit weightInitFn;
|
protected IWeightInit weightInit;
|
||||||
protected double biasInit = 0.0;
|
protected double biasInit = 0.0;
|
||||||
protected double gainInit = 0.0;
|
protected double gainInit = 0.0;
|
||||||
protected List<Regularization> regularization;
|
protected List<Regularization> regularization;
|
||||||
|
@ -68,7 +70,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
|
||||||
public BaseLayerConfiguration(Builder builder) {
|
public BaseLayerConfiguration(Builder builder) {
|
||||||
super(builder);
|
super(builder);
|
||||||
this.layerName = builder.layerName;
|
this.layerName = builder.layerName;
|
||||||
this.weightInitFn = builder.weightInitFn;
|
this.weightInit = builder.weightInit;
|
||||||
this.biasInit = builder.biasInit;
|
this.biasInit = builder.biasInit;
|
||||||
this.gainInit = builder.gainInit;
|
this.gainInit = builder.gainInit;
|
||||||
this.regularization = builder.regularization;
|
this.regularization = builder.regularization;
|
||||||
|
@ -89,7 +91,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
|
||||||
public void resetLayerDefaultConfig() {
|
public void resetLayerDefaultConfig() {
|
||||||
//clear the learning related params for all layers in the origConf and set to defaults
|
//clear the learning related params for all layers in the origConf and set to defaults
|
||||||
this.setIUpdater(null);
|
this.setIUpdater(null);
|
||||||
this.setWeightInitFn(null);
|
this.setWeightInit(null);
|
||||||
this.setBiasInit(Double.NaN);
|
this.setBiasInit(Double.NaN);
|
||||||
this.setGainInit(Double.NaN);
|
this.setGainInit(Double.NaN);
|
||||||
this.regularization = null;
|
this.regularization = null;
|
||||||
|
@ -103,9 +105,6 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
|
||||||
@Override
|
@Override
|
||||||
public BaseLayerConfiguration clone() {
|
public BaseLayerConfiguration clone() {
|
||||||
BaseLayerConfiguration clone = (BaseLayerConfiguration) super.clone();
|
BaseLayerConfiguration clone = (BaseLayerConfiguration) super.clone();
|
||||||
if (clone.iDropout != null) {
|
|
||||||
clone.iDropout = clone.iDropout.clone();
|
|
||||||
}
|
|
||||||
if(regularization != null){
|
if(regularization != null){
|
||||||
//Regularization fields are _usually_ thread safe and immutable, but let's clone to be sure
|
//Regularization fields are _usually_ thread safe and immutable, but let's clone to be sure
|
||||||
clone.regularization = new ArrayList<>(regularization.size());
|
clone.regularization = new ArrayList<>(regularization.size());
|
||||||
|
@ -170,7 +169,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
|
||||||
*
|
*
|
||||||
* @see IWeightInit
|
* @see IWeightInit
|
||||||
*/
|
*/
|
||||||
protected IWeightInit weightInitFn = null;
|
protected IWeightInit weightInit = null;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bias initialization value, for layers with biases. Defaults to 0
|
* Bias initialization value, for layers with biases. Defaults to 0
|
||||||
|
@ -255,7 +254,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
|
||||||
* @see IWeightInit
|
* @see IWeightInit
|
||||||
*/
|
*/
|
||||||
public T weightInit(IWeightInit weightInit) {
|
public T weightInit(IWeightInit weightInit) {
|
||||||
this.setWeightInitFn(weightInit);
|
this.setWeightInit(weightInit);
|
||||||
return (T) this;
|
return (T) this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -270,7 +269,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
|
||||||
"Not supported!, Use weightInit(Distribution distribution) instead!");
|
"Not supported!, Use weightInit(Distribution distribution) instead!");
|
||||||
}
|
}
|
||||||
|
|
||||||
this.setWeightInitFn(weightInit.getWeightInitFunction());
|
this.setWeightInit(weightInit.getWeightInitFunction());
|
||||||
return (T) this;
|
return (T) this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -508,4 +507,19 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inherit setting from neural network for those settings, that are not already set or do have
|
||||||
|
* a layer(type) specific default.
|
||||||
|
* @param conf the neural net configration to inherit parameters from
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void runInheritance(@NotNull NeuralNetConfiguration conf) {
|
||||||
|
super.runInheritance(conf);
|
||||||
|
if(this.biasUpdater == null ) this.biasUpdater = conf.getBiasUpdater();
|
||||||
|
if(this.iUpdater == null ) this.iUpdater = conf.getUpdater();
|
||||||
|
if(this.regularizationBias == null) this.regularizationBias = conf.getRegularizationBias();
|
||||||
|
if(this.regularization == null ) this.regularization = conf.getRegularization();
|
||||||
|
if(this.gradientNormalization == null) this.gradientNormalization = conf.getGradientNormalization();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -172,6 +172,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
||||||
setNetConfiguration(conf);
|
setNetConfiguration(conf);
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
|
lconf.runInheritance();
|
||||||
|
|
||||||
LayerValidation.assertNInNOutSet("ConvolutionLayer", getLayerName(), layerIndex, getNIn(), getNOut());
|
LayerValidation.assertNInNOutSet("ConvolutionLayer", getLayerName(), layerIndex, getNIn(), getNOut());
|
||||||
|
|
||||||
|
@ -404,9 +405,10 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
|
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details
|
||||||
|
* Default is {@link ConvolutionMode}.Truncate.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
protected ConvolutionMode convolutionMode;
|
protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated convolutions,
|
* Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated convolutions,
|
||||||
|
|
|
@ -62,19 +62,18 @@ public class DenseLayer extends FeedForwardLayer {
|
||||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
||||||
|
|
||||||
LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getLayerName(), layerIndex, getNIn(), getNOut());
|
LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getLayerName(), layerIndex, getNIn(), getNOut());
|
||||||
|
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
|
lconf.runInheritance();
|
||||||
|
|
||||||
org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret =
|
org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret =
|
||||||
new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType);
|
new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType);
|
||||||
if(getWeightInitFn() == null) setWeightInitFn(new WeightInitXavier());
|
|
||||||
|
if(getWeightInit() == null) setWeightInit(new WeightInitXavier());
|
||||||
ret.addTrainingListeners(trainingListeners);
|
ret.addTrainingListeners(trainingListeners);
|
||||||
ret.setIndex(layerIndex);
|
ret.setIndex(layerIndex);
|
||||||
ret.setParamsViewArray(layerParamsView);
|
ret.setParamsViewArray(layerParamsView);
|
||||||
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
||||||
ret.setParamTable(paramTable);
|
ret.setParamTable(paramTable);
|
||||||
ret.setLayerConfiguration(lconf);
|
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -217,14 +217,14 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setWeightInitFn(IWeightInit weightInit){
|
public void setWeightInitFn(IWeightInit weightInit){
|
||||||
if(weightInit instanceof WeightInitEmbedding){
|
if(weightInit instanceof WeightInitEmbedding){
|
||||||
long[] shape = ((WeightInitEmbedding) weightInit).shape();
|
long[] shape = ((WeightInitEmbedding) weightInit).shape();
|
||||||
nIn(shape[0]);
|
nIn(shape[0]);
|
||||||
nOut(shape[1]);
|
nOut(shape[1]);
|
||||||
}
|
}
|
||||||
this.weightInitFn = weightInit;
|
this.weightInit = weightInit;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -66,28 +66,29 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class LayerConfiguration implements ILayerConfiguration, Serializable, Cloneable { // ITraininableLayerConfiguration
|
public abstract class LayerConfiguration implements ILayerConfiguration, Serializable, Cloneable { // ITraininableLayerConfiguration
|
||||||
|
|
||||||
protected String layerName = "noname";
|
protected String layerName;
|
||||||
@Getter
|
@Getter
|
||||||
protected List<String> variables = new ArrayList<>();
|
protected List<String> variables = new ArrayList<>();
|
||||||
public void addVariable(String s) {variables.add(s);}
|
|
||||||
|
|
||||||
protected IDropout iDropout;
|
|
||||||
protected List<LayerConstraint> constraints;
|
protected List<LayerConstraint> constraints;
|
||||||
protected IWeightNoise weightNoise;
|
protected IWeightNoise weightNoise;
|
||||||
|
private IDropout iDropout;
|
||||||
/**
|
/**
|
||||||
* The type of the layer, basically defines the base class and its properties
|
* The type of the layer, basically defines the base class and its properties
|
||||||
*/
|
*/
|
||||||
@Getter @Setter @NonNull
|
@Getter @Setter @NonNull
|
||||||
private LayerType type = LayerType.UNKNOWN;
|
private LayerType type = LayerType.UNKNOWN;
|
||||||
|
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private NeuralNetConfiguration netConfiguration;
|
private NeuralNetConfiguration netConfiguration;
|
||||||
|
@Getter @Setter
|
||||||
|
private IActivation activationFn;
|
||||||
|
|
||||||
public LayerConfiguration(Builder builder) {
|
public LayerConfiguration(Builder builder) {
|
||||||
this.layerName = builder.layerName;
|
this.layerName = builder.layerName;
|
||||||
this.iDropout = builder.iDropout;
|
this.iDropout = builder.iDropout;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void addVariable(String s) {variables.add(s);}
|
||||||
|
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
throw new RuntimeException("toJson is not implemented for LayerConfiguration");
|
throw new RuntimeException("toJson is not implemented for LayerConfiguration");
|
||||||
}
|
}
|
||||||
|
@ -151,6 +152,7 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
|
||||||
public LayerConfiguration getLayer() {
|
public LayerConfiguration getLayer() {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LayerConfiguration clone() {
|
public LayerConfiguration clone() {
|
||||||
try {
|
try {
|
||||||
|
@ -218,7 +220,6 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
|
||||||
*/
|
*/
|
||||||
public abstract void setNIn(InputType inputType, boolean override);
|
public abstract void setNIn(InputType inputType, boolean override);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For the given type of input to this layer, what preprocessor (if any) is required?<br>
|
* For the given type of input to this layer, what preprocessor (if any) is required?<br>
|
||||||
* Returns null if no preprocessor is required, otherwise returns an appropriate {@link
|
* Returns null if no preprocessor is required, otherwise returns an appropriate {@link
|
||||||
|
@ -263,11 +264,11 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
|
||||||
"Not supported: all layers with parameters should override this method");
|
"Not supported: all layers with parameters should override this method");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public IUpdater getIUpdater() {
|
public IUpdater getIUpdater() {
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"Not supported: all layers with parameters should override this method");
|
"Not supported: all layers with parameters should override this method");
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setIUpdater(IUpdater iUpdater) {
|
public void setIUpdater(IUpdater iUpdater) {
|
||||||
log.warn("Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), getLayerName());
|
log.warn("Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), getLayerName());
|
||||||
}
|
}
|
||||||
|
@ -285,15 +286,33 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali
|
||||||
this.variables.clear();
|
this.variables.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Getter @Setter
|
/**
|
||||||
private IActivation activationFn;
|
* Inherit setting from neural network for those settings, that are not already set or do have
|
||||||
|
* a layer(type) specific default. This implementation does not require the neural network configuration to be
|
||||||
|
* the same as the one returned from this layers {@link #getNetConfiguration()}.
|
||||||
|
*
|
||||||
|
* @param conf a neural net configration to inherit parameters from
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public void runInheritance(@NonNull NeuralNetConfiguration conf) {
|
||||||
|
if(this.activationFn == null ) this.activationFn = conf.getActivation();
|
||||||
|
if(this.iDropout == null ) this.iDropout = conf.getIdropOut();
|
||||||
|
if(this.weightNoise == null) this.weightNoise = conf.getWeightNoise();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Runs {@link #runInheritance(NeuralNetConfiguration)} using the layers configurations embedded neural net
|
||||||
|
* configuration (the one returned from {@link #getNetConfiguration()}.
|
||||||
|
*/
|
||||||
|
public void runInheritance() {
|
||||||
|
runInheritance(getNetConfiguration());
|
||||||
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
public abstract static class Builder<T extends Builder<T>> {
|
public abstract static class Builder<T extends Builder<T>> {
|
||||||
|
|
||||||
protected String layerName = "noname";
|
protected String layerName;
|
||||||
|
|
||||||
protected List<LayerConstraint> allParamConstraints;
|
protected List<LayerConstraint> allParamConstraints;
|
||||||
|
|
||||||
|
|
|
@ -215,7 +215,7 @@ public class LocallyConnected1D extends SameDiffLayer {
|
||||||
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
||||||
NeuralNetConfiguration global_conf = globalConfig.build();
|
NeuralNetConfiguration global_conf = globalConfig.build();
|
||||||
if (activation == null) {
|
if (activation == null) {
|
||||||
activation = SameDiffLayerUtils.fromIActivation(global_conf.getActivationFn());
|
activation = SameDiffLayerUtils.fromIActivation(global_conf.getActivation());
|
||||||
}
|
}
|
||||||
if (cm == null) {
|
if (cm == null) {
|
||||||
cm = global_conf.getConvolutionMode();
|
cm = global_conf.getConvolutionMode();
|
||||||
|
|
|
@ -232,7 +232,7 @@ public class LocallyConnected2D extends SameDiffLayer {
|
||||||
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
||||||
NeuralNetConfiguration gconf = globalConfig.build();
|
NeuralNetConfiguration gconf = globalConfig.build();
|
||||||
if (activation == null) {
|
if (activation == null) {
|
||||||
activation = SameDiffLayerUtils.fromIActivation(gconf.getActivationFn());
|
activation = SameDiffLayerUtils.fromIActivation(gconf.getActivation());
|
||||||
}
|
}
|
||||||
if (cm == null) {
|
if (cm == null) {
|
||||||
cm = gconf.getConvolutionMode();
|
cm = gconf.getConvolutionMode();
|
||||||
|
|
|
@ -117,7 +117,7 @@ public class PReLULayer extends BaseLayerConfiguration {
|
||||||
|
|
||||||
public Builder(){
|
public Builder(){
|
||||||
//Default to 0s, and don't inherit global default
|
//Default to 0s, and don't inherit global default
|
||||||
this.weightInitFn = new WeightInitConstant(0);
|
this.weightInit = new WeightInitConstant(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -152,7 +152,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
|
||||||
@Override
|
@Override
|
||||||
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
||||||
if (activation == null) {
|
if (activation == null) {
|
||||||
activation = SameDiffLayerUtils.fromIActivation(globalConfig.build().getActivationFn());
|
activation = SameDiffLayerUtils.fromIActivation(globalConfig.build().getActivation());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -196,7 +196,7 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
regularizationBias = bConf.getRegularizationBias();
|
regularizationBias = bConf.getRegularizationBias();
|
||||||
}
|
}
|
||||||
if (updater == null) {
|
if (updater == null) {
|
||||||
updater = bConf.getIUpdater();
|
updater = bConf.getUpdater();
|
||||||
}
|
}
|
||||||
if (biasUpdater == null) {
|
if (biasUpdater == null) {
|
||||||
biasUpdater = bConf.getBiasUpdater();
|
biasUpdater = bConf.getBiasUpdater();
|
||||||
|
|
|
@ -156,7 +156,7 @@ public abstract class SameDiffVertex extends GraphVertex implements ITraininable
|
||||||
regularizationBias = b_conf.getRegularizationBias();
|
regularizationBias = b_conf.getRegularizationBias();
|
||||||
}
|
}
|
||||||
if (updater == null) {
|
if (updater == null) {
|
||||||
updater = b_conf.getIUpdater();
|
updater = b_conf.getUpdater();
|
||||||
}
|
}
|
||||||
if (biasUpdater == null) {
|
if (biasUpdater == null) {
|
||||||
biasUpdater = b_conf.getBiasUpdater();
|
biasUpdater = b_conf.getBiasUpdater();
|
||||||
|
|
|
@ -72,6 +72,7 @@ public class VariationalAutoencoder extends BasePretrainNetwork {
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret =
|
org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret =
|
||||||
new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(lconf, networkDataType);
|
new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(lconf, networkDataType);
|
||||||
|
lconf.runInheritance();
|
||||||
|
|
||||||
ret.addTrainingListeners(trainingListeners);
|
ret.addTrainingListeners(trainingListeners);
|
||||||
ret.setIndex(layerIndex);
|
ret.setIndex(layerIndex);
|
||||||
|
|
|
@ -98,7 +98,7 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
|
||||||
protected boolean requiresWeightInitFromLegacy(LayerConfiguration[] layers){
|
protected boolean requiresWeightInitFromLegacy(LayerConfiguration[] layers){
|
||||||
for(LayerConfiguration l : layers){
|
for(LayerConfiguration l : layers){
|
||||||
if(l instanceof BaseLayerConfiguration
|
if(l instanceof BaseLayerConfiguration
|
||||||
&& ((BaseLayerConfiguration)l).getWeightInitFn() == null){
|
&& ((BaseLayerConfiguration)l).getWeightInit() == null){
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -254,7 +254,7 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
|
||||||
d = NeuralNetConfiguration.mapper().readValue(dist, Distribution.class);
|
d = NeuralNetConfiguration.mapper().readValue(dist, Distribution.class);
|
||||||
}
|
}
|
||||||
IWeightInit iwi = w.getWeightInitFunction(d);
|
IWeightInit iwi = w.getWeightInitFunction(d);
|
||||||
baseLayerConfiguration.setWeightInitFn(iwi);
|
baseLayerConfiguration.setWeightInit(iwi);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
log.warn("Failed to infer weight initialization from legacy JSON format",t);
|
log.warn("Failed to infer weight initialization from legacy JSON format",t);
|
||||||
}
|
}
|
||||||
|
|
|
@ -129,7 +129,7 @@ public class ComputationGraphConfigurationDeserializer
|
||||||
}
|
}
|
||||||
|
|
||||||
if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration
|
if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration
|
||||||
&& ((BaseLayerConfiguration)layers[layerIdx]).getWeightInitFn() == null){
|
&& ((BaseLayerConfiguration)layers[layerIdx]).getWeightInit() == null){
|
||||||
handleWeightInitBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next);
|
handleWeightInitBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ public class ComputationGraphConfigurationDeserializer
|
||||||
layerIdx++;
|
layerIdx++;
|
||||||
} else if("org.deeplearning4j.nn.conf.graph.LayerVertex".equals(cls)){
|
} else if("org.deeplearning4j.nn.conf.graph.LayerVertex".equals(cls)){
|
||||||
if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration
|
if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration
|
||||||
&& ((BaseLayerConfiguration)layers[layerIdx]).getWeightInitFn() == null) {
|
&& ((BaseLayerConfiguration)layers[layerIdx]).getWeightInit() == null) {
|
||||||
//Post JSON format change for subclasses, but before WeightInit was made a class
|
//Post JSON format change for subclasses, but before WeightInit was made a class
|
||||||
confNode = (ObjectNode) next.get("layerConf");
|
confNode = (ObjectNode) next.get("layerConf");
|
||||||
next = confNode.get("layer");
|
next = confNode.get("layer");
|
||||||
|
|
|
@ -141,7 +141,7 @@ public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserialize
|
||||||
}
|
}
|
||||||
|
|
||||||
if(requiresLegacyWeightInitHandling && layers[i] instanceof BaseLayerConfiguration
|
if(requiresLegacyWeightInitHandling && layers[i] instanceof BaseLayerConfiguration
|
||||||
&& ((BaseLayerConfiguration) layers[i]).getWeightInitFn() == null) {
|
&& ((BaseLayerConfiguration) layers[i]).getWeightInit() == null) {
|
||||||
handleWeightInitBackwardCompatibility((BaseLayerConfiguration) layers[i], on);
|
handleWeightInitBackwardCompatibility((BaseLayerConfiguration) layers[i], on);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -88,14 +88,19 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
|
||||||
cacheMode = layerConfiguration.getNetConfiguration().getCacheMode();
|
cacheMode = layerConfiguration.getNetConfiguration().getCacheMode();
|
||||||
}
|
}
|
||||||
this.dataType = dataType;
|
this.dataType = dataType;
|
||||||
|
if (layerConfiguration.getNetConfiguration() == null) {
|
||||||
|
throw new RuntimeException("You cannot create a layer from a layer configuration, that is not part of any neural network configuration.");
|
||||||
|
}
|
||||||
this.net = layerConfiguration.getNetConfiguration().getNet();
|
this.net = layerConfiguration.getNetConfiguration().getNet();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addTrainingListeners(TrainingListener... listeners) {
|
public void addTrainingListeners(TrainingListener... listeners) {
|
||||||
|
if(listeners != null)
|
||||||
trainingListeners.addAll(List.of(listeners));
|
trainingListeners.addAll(List.of(listeners));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addTrainingListeners(Collection<TrainingListener> listeners) {
|
public void addTrainingListeners(Collection<TrainingListener> listeners) {
|
||||||
|
if(listeners != null)
|
||||||
trainingListeners.addAll(listeners);
|
trainingListeners.addAll(listeners);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,6 @@ public abstract class BaseLayer<LayerConfT extends BaseLayerConfiguration>
|
||||||
* INDArray params;
|
* INDArray params;
|
||||||
*/
|
*/
|
||||||
public BaseLayer(LayerConfiguration conf, DataType dataType) {
|
public BaseLayer(LayerConfiguration conf, DataType dataType) {
|
||||||
|
|
||||||
super(conf, dataType);
|
super(conf, dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
package org.deeplearning4j.nn.layers.ocnn;
|
package org.deeplearning4j.nn.layers.ocnn;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
@ -154,7 +153,7 @@ public class OCNNParamInitializer extends DefaultParamInitializer {
|
||||||
boolean initializeParameters) {
|
boolean initializeParameters) {
|
||||||
|
|
||||||
org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) configuration;
|
org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) configuration;
|
||||||
IWeightInit weightInit = ocnnOutputLayer.getWeightInitFn();
|
IWeightInit weightInit = ocnnOutputLayer.getWeightInit();
|
||||||
if (initializeParameters) {
|
if (initializeParameters) {
|
||||||
INDArray ret = weightInit.init(weightParamView.size(0), //Fan in
|
INDArray ret = weightInit.init(weightParamView.size(0), //Fan in
|
||||||
weightParamView.size(1), //Fan out
|
weightParamView.size(1), //Fan out
|
||||||
|
|
|
@ -92,7 +92,7 @@ public class VariationalAutoencoder implements Layer {
|
||||||
protected int epochCount;
|
protected int epochCount;
|
||||||
@Getter @Setter @NonNull
|
@Getter @Setter @NonNull
|
||||||
private LayerConfiguration layerConfiguration;
|
private LayerConfiguration layerConfiguration;
|
||||||
private @Getter @Setter Collection<TrainingListener> trainingListeners;
|
private @Getter @Setter Collection<TrainingListener> trainingListeners = new HashSet<>();
|
||||||
|
|
||||||
public VariationalAutoencoder(@NonNull LayerConfiguration layerConfiguration, DataType dataType) {
|
public VariationalAutoencoder(@NonNull LayerConfiguration layerConfiguration, DataType dataType) {
|
||||||
this.layerConfiguration = layerConfiguration;
|
this.layerConfiguration = layerConfiguration;
|
||||||
|
@ -113,6 +113,27 @@ public class VariationalAutoencoder implements Layer {
|
||||||
.getNumSamples();
|
.getNumSamples();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Replace the TrainingListeners for this model
|
||||||
|
*
|
||||||
|
* @param listeners new listeners
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void addTrainingListeners(TrainingListener... listeners) {
|
||||||
|
if(listeners != null)
|
||||||
|
trainingListeners.addAll(List.of(listeners));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param listeners
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void addTrainingListeners(Collection<TrainingListener> listeners) {
|
||||||
|
if(listeners != null)
|
||||||
|
trainingListeners.addAll(listeners);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get a reference to the network this layer is part of.
|
* Get a reference to the network this layer is part of.
|
||||||
*
|
*
|
||||||
|
@ -1214,24 +1235,6 @@ public class VariationalAutoencoder implements Layer {
|
||||||
//No-op for individual layers
|
//No-op for individual layers
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Replace the TrainingListeners for this model
|
|
||||||
*
|
|
||||||
* @param listeners new listeners
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void addTrainingListeners(TrainingListener... listeners) {
|
|
||||||
trainingListeners.addAll(List.of(listeners));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param listeners
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void addTrainingListeners(Collection<TrainingListener> listeners) {
|
|
||||||
trainingListeners.addAll(listeners);
|
|
||||||
}
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -22,7 +22,6 @@ package org.deeplearning4j.nn.params;
|
||||||
|
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||||
|
@ -131,7 +130,7 @@ public class Convolution3DParamInitializer extends ConvolutionParamInitializer {
|
||||||
|
|
||||||
val weightsShape = new long[]{outputDepth, inputDepth, kernel[0], kernel[1], kernel[2]};
|
val weightsShape = new long[]{outputDepth, inputDepth, kernel[0], kernel[1], kernel[2]};
|
||||||
|
|
||||||
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c',
|
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
|
||||||
weightView);
|
weightView);
|
||||||
} else {
|
} else {
|
||||||
int[] kernel = layerConf.getKernelSize();
|
int[] kernel = layerConf.getKernelSize();
|
||||||
|
|
|
@ -180,7 +180,7 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer {
|
||||||
|
|
||||||
val weightsShape = new long[] {outputDepth, inputDepth, kernel[0], kernel[1]};
|
val weightsShape = new long[] {outputDepth, inputDepth, kernel[0], kernel[1]};
|
||||||
|
|
||||||
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', weightView);
|
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView);
|
||||||
} else {
|
} else {
|
||||||
int[] kernel = layerConf.getKernelSize();
|
int[] kernel = layerConf.getKernelSize();
|
||||||
return WeightInitUtil.reshapeWeights(
|
return WeightInitUtil.reshapeWeights(
|
||||||
|
|
|
@ -22,7 +22,6 @@ package org.deeplearning4j.nn.params;
|
||||||
|
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.Deconvolution3D;
|
import org.deeplearning4j.nn.conf.layers.Deconvolution3D;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||||
|
@ -130,7 +129,7 @@ public class Deconvolution3DParamInitializer extends ConvolutionParamInitializer
|
||||||
//libnd4j: [kD, kH, kW, oC, iC]
|
//libnd4j: [kD, kH, kW, oC, iC]
|
||||||
val weightsShape = new long[]{kernel[0], kernel[1], kernel[2], outputDepth, inputDepth};
|
val weightsShape = new long[]{kernel[0], kernel[1], kernel[2], outputDepth, inputDepth};
|
||||||
|
|
||||||
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', weightView);
|
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView);
|
||||||
} else {
|
} else {
|
||||||
int[] kernel = layerConf.getKernelSize();
|
int[] kernel = layerConf.getKernelSize();
|
||||||
return WeightInitUtil.reshapeWeights(
|
return WeightInitUtil.reshapeWeights(
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
package org.deeplearning4j.nn.params;
|
package org.deeplearning4j.nn.params;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -61,7 +60,7 @@ public class DeconvolutionParamInitializer extends ConvolutionParamInitializer {
|
||||||
|
|
||||||
val weightsShape = new long[] {inputDepth, outputDepth, kernel[0], kernel[1]};
|
val weightsShape = new long[] {inputDepth, outputDepth, kernel[0], kernel[1]};
|
||||||
|
|
||||||
INDArray weights = layerConf.getWeightInitFn().init(
|
INDArray weights = layerConf.getWeightInit().init(
|
||||||
fanIn, fanOut, weightsShape, 'c', weightView);
|
fanIn, fanOut, weightsShape, 'c', weightView);
|
||||||
|
|
||||||
return weights;
|
return weights;
|
||||||
|
|
|
@ -196,13 +196,13 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
|
||||||
(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf;
|
(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf;
|
||||||
|
|
||||||
if (initializeParameters) {
|
if (initializeParameters) {
|
||||||
if( layerConf.getWeightInitFn() == null) {
|
if( layerConf.getWeightInit() == null) {
|
||||||
// set a default and set warning
|
// set a default and set warning
|
||||||
layerConf.setWeightInitFn(new WeightInitXavier());
|
layerConf.setWeightInit(new WeightInitXavier());
|
||||||
log.warn("Weight Initializer function was not set on layer {} of class {}, it will default to {}", conf.getLayerName(),
|
log.warn("Weight Initializer function was not set on layer {} of class {}, it will default to {}", conf.getLayerName(),
|
||||||
conf.getClass().getSimpleName(), WeightInitXavier.class.getSimpleName());
|
conf.getClass().getSimpleName(), WeightInitXavier.class.getSimpleName());
|
||||||
}
|
}
|
||||||
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInitFn(),
|
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInit(),
|
||||||
weightParamView, true);
|
weightParamView, true);
|
||||||
} else {
|
} else {
|
||||||
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), null, weightParamView, false);
|
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), null, weightParamView, false);
|
||||||
|
|
|
@ -23,8 +23,6 @@ package org.deeplearning4j.nn.params;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
|
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||||
|
@ -193,7 +191,7 @@ public class DepthwiseConvolutionParamInitializer extends AbstractParamInitializ
|
||||||
|
|
||||||
val weightsShape = new long[] {kernel[0], kernel[1], inputDepth, depthMultiplier};
|
val weightsShape = new long[] {kernel[0], kernel[1], inputDepth, depthMultiplier};
|
||||||
|
|
||||||
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c',
|
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
|
||||||
weightView);
|
weightView);
|
||||||
} else {
|
} else {
|
||||||
int[] kernel = layerConf.getKernelSize();
|
int[] kernel = layerConf.getKernelSize();
|
||||||
|
|
|
@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||||
|
@ -159,14 +157,14 @@ public class GravesBidirectionalLSTMParamInitializer extends AbstractParamInitia
|
||||||
val inputWShape = new long[]{nLast, 4 * nL};
|
val inputWShape = new long[]{nLast, 4 * nL};
|
||||||
val recurrentWShape = new long[]{nL, 4 * nL + 3};
|
val recurrentWShape = new long[]{nL, 4 * nL + 3};
|
||||||
|
|
||||||
params.put(INPUT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape,
|
params.put(INPUT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
|
||||||
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwF));
|
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwF));
|
||||||
params.put(RECURRENT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, recurrentWShape,
|
params.put(RECURRENT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInit().init(fanIn, fanOut, recurrentWShape,
|
||||||
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwF));
|
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwF));
|
||||||
params.put(BIAS_KEY_FORWARDS, bF);
|
params.put(BIAS_KEY_FORWARDS, bF);
|
||||||
params.put(INPUT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape,
|
params.put(INPUT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
|
||||||
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwR));
|
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwR));
|
||||||
params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, recurrentWShape,
|
params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInit().init(fanIn, fanOut, recurrentWShape,
|
||||||
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwR));
|
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwR));
|
||||||
params.put(BIAS_KEY_BACKWARDS, bR);
|
params.put(BIAS_KEY_BACKWARDS, bR);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||||
|
@ -124,10 +122,10 @@ public class GravesLSTMParamInitializer extends AbstractParamInitializer {
|
||||||
if(layerConf.getWeightInitFnRecurrent() != null){
|
if(layerConf.getWeightInitFnRecurrent() != null){
|
||||||
rwInit = layerConf.getWeightInitFnRecurrent();
|
rwInit = layerConf.getWeightInitFnRecurrent();
|
||||||
} else {
|
} else {
|
||||||
rwInit = layerConf.getWeightInitFn();
|
rwInit = layerConf.getWeightInit();
|
||||||
}
|
}
|
||||||
|
|
||||||
params.put(INPUT_WEIGHT_KEY,layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape,
|
params.put(INPUT_WEIGHT_KEY,layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
|
||||||
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView));
|
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView));
|
||||||
params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape,
|
params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape,
|
||||||
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView));
|
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView));
|
||||||
|
|
|
@ -27,7 +27,6 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
@ -132,10 +131,10 @@ public class LSTMParamInitializer extends AbstractParamInitializer {
|
||||||
if(layerConf.getWeightInitFnRecurrent() != null){
|
if(layerConf.getWeightInitFnRecurrent() != null){
|
||||||
rwInit = layerConf.getWeightInitFnRecurrent();
|
rwInit = layerConf.getWeightInitFnRecurrent();
|
||||||
} else {
|
} else {
|
||||||
rwInit = layerConf.getWeightInitFn();
|
rwInit = layerConf.getWeightInit();
|
||||||
}
|
}
|
||||||
|
|
||||||
params.put(INPUT_WEIGHT_KEY, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape,
|
params.put(INPUT_WEIGHT_KEY, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape,
|
||||||
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView));
|
IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView));
|
||||||
params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView));
|
params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView));
|
||||||
biasView.put(new INDArrayIndex[] {NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nL, 2 * nL)},
|
biasView.put(new INDArrayIndex[] {NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nL, 2 * nL)},
|
||||||
|
|
|
@ -133,7 +133,7 @@ public class PReLUParamInitializer extends AbstractParamInitializer {
|
||||||
|
|
||||||
PReLULayer layerConf = (PReLULayer) conf;
|
PReLULayer layerConf = (PReLULayer) conf;
|
||||||
if (initializeParameters) {
|
if (initializeParameters) {
|
||||||
return layerConf.getWeightInitFn().init(layerConf.getNIn(), layerConf.getNOut(),
|
return layerConf.getWeightInit().init(layerConf.getNIn(), layerConf.getNOut(),
|
||||||
weightShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView);
|
weightShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView);
|
||||||
} else {
|
} else {
|
||||||
return WeightInitUtil.reshapeWeights(weightShape, weightParamView);
|
return WeightInitUtil.reshapeWeights(weightShape, weightParamView);
|
||||||
|
|
|
@ -23,8 +23,6 @@ package org.deeplearning4j.nn.params;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
|
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||||
|
@ -220,7 +218,7 @@ public class SeparableConvolutionParamInitializer extends AbstractParamInitializ
|
||||||
|
|
||||||
val weightsShape = new long[] {depthMultiplier, inputDepth, kernel[0], kernel[1]};
|
val weightsShape = new long[] {depthMultiplier, inputDepth, kernel[0], kernel[1]};
|
||||||
|
|
||||||
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c',
|
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
|
||||||
weightView);
|
weightView);
|
||||||
} else {
|
} else {
|
||||||
int[] kernel = layerConf.getKernelSize();
|
int[] kernel = layerConf.getKernelSize();
|
||||||
|
@ -249,7 +247,7 @@ public class SeparableConvolutionParamInitializer extends AbstractParamInitializ
|
||||||
|
|
||||||
val weightsShape = new long[] {outputDepth, depthMultiplier * inputDepth, 1, 1};
|
val weightsShape = new long[] {outputDepth, depthMultiplier * inputDepth, 1, 1};
|
||||||
|
|
||||||
return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c',
|
return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c',
|
||||||
weightView);
|
weightView);
|
||||||
} else {
|
} else {
|
||||||
return WeightInitUtil.reshapeWeights(
|
return WeightInitUtil.reshapeWeights(
|
||||||
|
|
|
@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
@ -102,14 +100,14 @@ public class SimpleRnnParamInitializer extends AbstractParamInitializer {
|
||||||
|
|
||||||
if (initializeParams) {
|
if (initializeParams) {
|
||||||
m = getSubsets(paramsView, nIn, nOut, false, hasLayerNorm(c));
|
m = getSubsets(paramsView, nIn, nOut, false, hasLayerNorm(c));
|
||||||
INDArray w = c.getWeightInitFn().init(nIn, nOut, new long[]{nIn, nOut}, 'f', m.get(WEIGHT_KEY));
|
INDArray w = c.getWeightInit().init(nIn, nOut, new long[]{nIn, nOut}, 'f', m.get(WEIGHT_KEY));
|
||||||
m.put(WEIGHT_KEY, w);
|
m.put(WEIGHT_KEY, w);
|
||||||
|
|
||||||
IWeightInit rwInit;
|
IWeightInit rwInit;
|
||||||
if (c.getWeightInitFnRecurrent() != null) {
|
if (c.getWeightInitFnRecurrent() != null) {
|
||||||
rwInit = c.getWeightInitFnRecurrent();
|
rwInit = c.getWeightInitFnRecurrent();
|
||||||
} else {
|
} else {
|
||||||
rwInit = c.getWeightInitFn();
|
rwInit = c.getWeightInit();
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray rw = rwInit.init(nOut, nOut, new long[]{nOut, nOut}, 'f', m.get(RECURRENT_WEIGHT_KEY));
|
INDArray rw = rwInit.init(nOut, nOut, new long[]{nOut, nOut}, 'f', m.get(RECURRENT_WEIGHT_KEY));
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
package org.deeplearning4j.nn.params;
|
package org.deeplearning4j.nn.params;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
@ -200,7 +199,7 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali
|
||||||
int[] encoderLayerSizes = layer.getEncoderLayerSizes();
|
int[] encoderLayerSizes = layer.getEncoderLayerSizes();
|
||||||
int[] decoderLayerSizes = layer.getDecoderLayerSizes();
|
int[] decoderLayerSizes = layer.getDecoderLayerSizes();
|
||||||
|
|
||||||
IWeightInit weightInit = layer.getWeightInitFn();
|
IWeightInit weightInit = layer.getWeightInit();
|
||||||
|
|
||||||
int soFar = 0;
|
int soFar = 0;
|
||||||
for (int i = 0; i < encoderLayerSizes.length; i++) {
|
for (int i = 0; i < encoderLayerSizes.length; i++) {
|
||||||
|
|
|
@ -164,7 +164,7 @@ public class FineTuneConfiguration {
|
||||||
bl.setActivationFn(activationFn);
|
bl.setActivationFn(activationFn);
|
||||||
}
|
}
|
||||||
if (weightInitFn != null) {
|
if (weightInitFn != null) {
|
||||||
bl.setWeightInitFn(weightInitFn);
|
bl.setWeightInit(weightInitFn);
|
||||||
}
|
}
|
||||||
if (biasInit != null) {
|
if (biasInit != null) {
|
||||||
bl.setBiasInit(biasInit);
|
bl.setBiasInit(biasInit);
|
||||||
|
@ -264,10 +264,10 @@ public class FineTuneConfiguration {
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder confBuilder = NeuralNetConfiguration.builder();
|
NeuralNetConfiguration.NeuralNetConfigurationBuilder confBuilder = NeuralNetConfiguration.builder();
|
||||||
|
|
||||||
if (activationFn != null) {
|
if (activationFn != null) {
|
||||||
confBuilder.activationFn(activationFn);
|
confBuilder.activation(activationFn);
|
||||||
}
|
}
|
||||||
if (weightInitFn != null) {
|
if (weightInitFn != null) {
|
||||||
confBuilder.weightInitFn(weightInitFn);
|
confBuilder.weightInit(weightInitFn);
|
||||||
}
|
}
|
||||||
if (biasInit != null) {
|
if (biasInit != null) {
|
||||||
confBuilder.biasInit(biasInit);
|
confBuilder.biasInit(biasInit);
|
||||||
|
|
|
@ -462,7 +462,7 @@ public class TransferLearning {
|
||||||
Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nInReplace can only be applide on FeedForward layers;" +
|
Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nInReplace can only be applide on FeedForward layers;" +
|
||||||
"got layer of type %s", layerImpl.getClass().getSimpleName());
|
"got layer of type %s", layerImpl.getClass().getSimpleName());
|
||||||
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
|
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
|
||||||
layerImplF.setWeightInitFn(init);
|
layerImplF.setWeightInit(init);
|
||||||
layerImplF.setNIn(nIn);
|
layerImplF.setNIn(nIn);
|
||||||
long numParams = layerImpl.initializer().numParams(layerConf);
|
long numParams = layerImpl.initializer().numParams(layerConf);
|
||||||
INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams);
|
INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams);
|
||||||
|
@ -480,7 +480,7 @@ public class TransferLearning {
|
||||||
Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nOutReplace can only be applide on FeedForward layers;" +
|
Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nOutReplace can only be applide on FeedForward layers;" +
|
||||||
"got layer of type %s", layerImpl.getClass().getSimpleName());
|
"got layer of type %s", layerImpl.getClass().getSimpleName());
|
||||||
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
|
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
|
||||||
layerImplF.setWeightInitFn(scheme);
|
layerImplF.setWeightInit(scheme);
|
||||||
layerImplF.setNOut(nOut);
|
layerImplF.setNOut(nOut);
|
||||||
long numParams = layerImpl.initializer().numParams(layerConf);
|
long numParams = layerImpl.initializer().numParams(layerConf);
|
||||||
INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams);
|
INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams);
|
||||||
|
@ -492,7 +492,7 @@ public class TransferLearning {
|
||||||
layerImpl = layerConf; //modify in place
|
layerImpl = layerConf; //modify in place
|
||||||
if(layerImpl instanceof FeedForwardLayer) {
|
if(layerImpl instanceof FeedForwardLayer) {
|
||||||
layerImplF = (FeedForwardLayer) layerImpl;
|
layerImplF = (FeedForwardLayer) layerImpl;
|
||||||
layerImplF.setWeightInitFn(schemeNext);
|
layerImplF.setWeightInit(schemeNext);
|
||||||
layerImplF.setNIn(nOut);
|
layerImplF.setNIn(nOut);
|
||||||
numParams = layerImpl.initializer().numParams(layerConf);
|
numParams = layerImpl.initializer().numParams(layerConf);
|
||||||
if (numParams > 0) {
|
if (numParams > 0) {
|
||||||
|
@ -738,7 +738,7 @@ public class TransferLearning {
|
||||||
|
|
||||||
layerImpl.resetLayerDefaultConfig();
|
layerImpl.resetLayerDefaultConfig();
|
||||||
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
|
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
|
||||||
layerImplF.setWeightInitFn(scheme);
|
layerImplF.setWeightInit(scheme);
|
||||||
layerImplF.setNIn(nIn);
|
layerImplF.setNIn(nIn);
|
||||||
|
|
||||||
if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex
|
if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex
|
||||||
|
@ -767,7 +767,7 @@ public class TransferLearning {
|
||||||
LayerConfiguration layerImpl = layerConf.clone();
|
LayerConfiguration layerImpl = layerConf.clone();
|
||||||
layerImpl.resetLayerDefaultConfig();
|
layerImpl.resetLayerDefaultConfig();
|
||||||
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
|
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
|
||||||
layerImplF.setWeightInitFn(scheme);
|
layerImplF.setWeightInit(scheme);
|
||||||
layerImplF.setNOut(nOut);
|
layerImplF.setNOut(nOut);
|
||||||
|
|
||||||
if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex
|
if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex
|
||||||
|
@ -806,7 +806,7 @@ public class TransferLearning {
|
||||||
continue;
|
continue;
|
||||||
layerImpl = layerConf.clone();
|
layerImpl = layerConf.clone();
|
||||||
layerImplF = (FeedForwardLayer) layerImpl;
|
layerImplF = (FeedForwardLayer) layerImpl;
|
||||||
layerImplF.setWeightInitFn(schemeNext);
|
layerImplF.setWeightInit(schemeNext);
|
||||||
layerImplF.setNIn(nOut);
|
layerImplF.setNIn(nOut);
|
||||||
|
|
||||||
nInFromNewConfig.put(fanoutVertexName, nOut);
|
nInFromNewConfig.put(fanoutVertexName, nOut);
|
||||||
|
|
|
@ -207,10 +207,11 @@ public abstract class BaseMultiLayerUpdater<T extends IModel> implements Updater
|
||||||
*/
|
*/
|
||||||
public void setStateViewArray(INDArray viewArray) {
|
public void setStateViewArray(INDArray viewArray) {
|
||||||
if(this.updaterStateViewArray == null){
|
if(this.updaterStateViewArray == null){
|
||||||
if(viewArray == null)
|
if(viewArray == null || viewArray.length()==0)
|
||||||
return; //No op - for example, SGD and NoOp updater - i.e., no stored state
|
return; //No op - for example, SGD and NoOp updater - i.e., no stored state
|
||||||
else {
|
else {
|
||||||
throw new IllegalStateException("Attempting to set updater state view array with null value");
|
//this.updaterStateViewArray.set
|
||||||
|
// throw new IllegalStateException("Attempting to set updater state view array with null value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (this.updaterStateViewArray.length() != viewArray.length())
|
if (this.updaterStateViewArray.length() != viewArray.length())
|
||||||
|
@ -296,7 +297,7 @@ public abstract class BaseMultiLayerUpdater<T extends IModel> implements Updater
|
||||||
//PRE apply (gradient clipping, etc): done on a per-layer basis
|
//PRE apply (gradient clipping, etc): done on a per-layer basis
|
||||||
for (Map.Entry<String, Gradient> entry : layerGradients.entrySet()) {
|
for (Map.Entry<String, Gradient> entry : layerGradients.entrySet()) {
|
||||||
String layerName = entry.getKey();
|
String layerName = entry.getKey();
|
||||||
ITrainableLayer layer = layersByName.get(layerName);
|
ITrainableLayer layer = layersByName.get(layerName); //Todo Layers may have the same name!?
|
||||||
|
|
||||||
preApply(layer, layerGradients.get(layerName), iteration);
|
preApply(layer, layerGradients.get(layerName), iteration);
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,6 @@ import org.apache.commons.lang3.RandomUtils;
|
||||||
import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
@ -39,7 +38,6 @@ import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
|
||||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
@ -85,8 +83,8 @@ class dnnTest {
|
||||||
.updater(Adam.builder().learningRate(0.0002).beta1(0.5).build())
|
.updater(Adam.builder().learningRate(0.0002).beta1(0.5).build())
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.gradientNormalizationThreshold(100)
|
.gradientNormalizationThreshold(100)
|
||||||
.weightInitFn(new WeightInitXavier())
|
.weightInit(new WeightInitXavier())
|
||||||
.activationFn(new ActivationSigmoid())
|
.activation(new ActivationSigmoid())
|
||||||
// .inputType(InputType.convolutional(28, 28, 1))
|
// .inputType(InputType.convolutional(28, 28, 1))
|
||||||
.layer(new DenseLayer.Builder().nIn(6).nOut(20).build())
|
.layer(new DenseLayer.Builder().nIn(6).nOut(20).build())
|
||||||
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||||
|
|
|
@ -1182,7 +1182,7 @@ public class TrainModule implements UIModule {
|
||||||
String.valueOf(nParams)});
|
String.valueOf(nParams)});
|
||||||
if (nParams > 0) {
|
if (nParams > 0) {
|
||||||
try {
|
try {
|
||||||
String str = JsonMappers.getMapper().writeValueAsString(bl.getWeightInitFn());
|
String str = JsonMappers.getMapper().writeValueAsString(bl.getWeightInit());
|
||||||
layerInfoRows.add(new String[]{
|
layerInfoRows.add(new String[]{
|
||||||
i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str});
|
i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str});
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.zoo.ModelMetaData;
|
import org.deeplearning4j.zoo.ModelMetaData;
|
||||||
import org.deeplearning4j.zoo.PretrainedType;
|
import org.deeplearning4j.zoo.PretrainedType;
|
||||||
|
|
|
@ -176,7 +176,7 @@ public class ResNet50 extends ZooModel {
|
||||||
.activation(Activation.IDENTITY)
|
.activation(Activation.IDENTITY)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(updater)
|
.updater(updater)
|
||||||
.weightInitFn(weightInit)
|
.weightInit(weightInit)
|
||||||
.l1(1e-7)
|
.l1(1e-7)
|
||||||
.l2(5e-5)
|
.l2(5e-5)
|
||||||
.miniBatch(true)
|
.miniBatch(true)
|
||||||
|
|
Loading…
Reference in New Issue