Fixing tests

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-05-08 09:22:38 +02:00
parent 35ea21e436
commit 871073e4a4
60 changed files with 306 additions and 338 deletions

View File

@ -217,14 +217,14 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone(); MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
netCopy.fit(data); netCopy.fit(data);
IUpdater expectedUpdater = ((BaseLayer) netCopy.conf().getLayer()).getIUpdater(); IUpdater expectedUpdater = ((BaseLayer) netCopy.conf().getLayer()).getUpdater();
double expectedLR = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getIUpdater()).getLearningRate(); double expectedLR = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getUpdater()).getLearningRate();
double expectedMomentum = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getIUpdater()).getMomentum(); double expectedMomentum = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getUpdater()).getMomentum();
IUpdater actualUpdater = ((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater(); IUpdater actualUpdater = ((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getUpdater();
sparkNet.fit(sparkData); sparkNet.fit(sparkData);
double actualLR = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater()).getLearningRate(); double actualLR = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getUpdater()).getLearningRate();
double actualMomentum = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater()).getMomentum(); double actualMomentum = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getUpdater()).getMomentum();
assertEquals(expectedUpdater, actualUpdater); assertEquals(expectedUpdater, actualUpdater);
assertEquals(expectedLR, actualLR, 0.01); assertEquals(expectedLR, actualLR, 0.01);

View File

@ -47,6 +47,7 @@ import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.ActivationLayer;
@ -77,10 +78,10 @@ public class App {
private static final double LEARNING_RATE = 0.000002; private static final double LEARNING_RATE = 0.000002;
private static final double GRADIENT_THRESHOLD = 100.0; private static final double GRADIENT_THRESHOLD = 100.0;
private static final int X_DIM = 28; private static final int X_DIM = 20 ;
private static final int Y_DIM = 28; private static final int Y_DIM = 20;
private static final int CHANNELS = 1; private static final int CHANNELS = 1;
private static final int batchSize = 9; private static final int batchSize = 10;
private static final int INPUT = 128; private static final int INPUT = 128;
private static final int OUTPUT_PER_PANEL = 4; private static final int OUTPUT_PER_PANEL = 4;
@ -97,12 +98,13 @@ public class App {
return new LayerConfiguration[] { return new LayerConfiguration[] {
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
ActivationLayer.builder(Activation.LEAKYRELU).build(), ActivationLayer.builder(Activation.LEAKYRELU).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(), DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH)
.build() DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build()
}; };
} }
@ -131,30 +133,34 @@ public class App {
private static LayerConfiguration[] disLayers() { private static LayerConfiguration[] disLayers() {
return new LayerConfiguration[]{ return new LayerConfiguration[]{
DenseLayer.builder().nOut(X_DIM*Y_DIM*CHANNELS*2).build(), //input is set by setInputType on the network DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(), DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().nIn(X_DIM * Y_DIM*CHANNELS*2).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(), DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(), DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(), DropoutLayer.builder(1 - 0.5).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DropoutLayer.builder(1 - 0.5).build(), DropoutLayer.builder(1 - 0.5).build(),
OutputLayer.builder().lossFunction(LossFunction.XENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
OutputLayer.builder().name("dis-output").lossFunction(LossFunction.XENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
}; };
} }
private static NeuralNetConfiguration discriminator() { private static NeuralNetConfiguration discriminator() {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf =
NeuralNetConfiguration.builder()
.seed(42) .seed(42)
.updater(UPDATER) .updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(GRADIENT_THRESHOLD) .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightNoise(null)
// .weightInitFn(new WeightInitXavier()) // .weightInitFn(new WeightInitXavier())
// .activationFn(new ActivationIdentity()) // .activationFn(new ActivationIdentity())
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
@ -171,7 +177,7 @@ public class App {
LayerConfiguration[] disLayers = Arrays.stream(disLayers()) LayerConfiguration[] disLayers = Arrays.stream(disLayers())
.map((layer) -> { .map((layer) -> {
if (layer instanceof DenseLayer || layer instanceof OutputLayer) { if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
return FrozenLayerWithBackprop.builder(layer); return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
} else { } else {
return layer; return layer;
} }
@ -204,7 +210,7 @@ public class App {
public static void main(String... args) throws Exception { public static void main(String... args) throws Exception {
log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m "); log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m ");
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); Nd4j.getMemoryManager().setAutoGcWindow(500);
// MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45); // MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
// FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS()); // FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
@ -236,10 +242,10 @@ public class App {
copyParams(gen, dis, gan); copyParams(gen, dis, gan);
gen.addTrainingListeners(new PerformanceListener(10, true)); gen.addTrainingListeners(new PerformanceListener(15, true));
dis.addTrainingListeners(new PerformanceListener(10, true)); //dis.addTrainingListeners(new PerformanceListener(10, true));
gan.addTrainingListeners(new PerformanceListener(10, true)); //gan.addTrainingListeners(new PerformanceListener(10, true));
gan.addTrainingListeners(new ScoreToChartListener("gan")); //gan.addTrainingListeners(new ScoreToChartListener("gan"));
//dis.setListeners(new ScoreToChartListener("dis")); //dis.setListeners(new ScoreToChartListener("dis"));
System.out.println(gan.toString()); System.out.println(gan.toString());

View File

@ -107,6 +107,9 @@ public class DeallocatorService {
boolean canRun = true; boolean canRun = true;
long cnt = 0; long cnt = 0;
while (canRun) { while (canRun) {
log.trace("Starting deallocator threat with name '{}'. isPeriodicGc: {}, AutoGcWindow: {}. Current allocated memory: {}"
,this.getName(), Nd4j.getMemoryManager().isPeriodicGcActive()
, Nd4j.getMemoryManager().getAutoGcWindow(), Nd4j.getMemoryManager().allocatedMemory(deviceId));
// if periodicGc is enabled, only first thread will call for it // if periodicGc is enabled, only first thread will call for it
if (Nd4j.getMemoryManager().isPeriodicGcActive() && threadIdx == 0 && Nd4j.getMemoryManager().getAutoGcWindow() > 0) { if (Nd4j.getMemoryManager().isPeriodicGcActive() && threadIdx == 0 && Nd4j.getMemoryManager().getAutoGcWindow() > 0) {
val reference = (DeallocatableReference) queue.poll(); val reference = (DeallocatableReference) queue.poll();
@ -120,6 +123,7 @@ public class DeallocatorService {
} }
} else { } else {
// invoking deallocator // invoking deallocator
log.trace("Deallocate reference {}", reference.getId());
reference.getDeallocator().deallocate(); reference.getDeallocator().deallocate();
referenceMap.remove(reference.getId()); referenceMap.remove(reference.getId());
} }

View File

@ -498,7 +498,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo()); assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo());
BaseLayerConfiguration bl = (BaseLayerConfiguration) net.getLayerConfiguration(); BaseLayerConfiguration bl = (BaseLayerConfiguration) net.getLayerConfiguration();
assertEquals(bl.getActivationFn().toString(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getActivationFn().toString()); assertEquals(bl.getActivationFn().toString(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getActivationFn().toString());
assertEquals(bl.getIUpdater(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getIUpdater()); assertEquals(bl.getUpdater(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getUpdater());
} }
@Test @Test

View File

@ -306,7 +306,7 @@ public class AttentionLayerTest extends BaseDL4JTest {
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
.updater(new NoOp()) .updater(new NoOp())
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.list()
.layer(LSTM.builder().nOut(layerSize).build()) .layer(LSTM.builder().nOut(layerSize).build())
.layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()) .layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build()) .layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())

View File

@ -76,7 +76,7 @@ public class LayerBuilderTest extends BaseDL4JTest {
assertEquals(act, layer.getActivationFn()); assertEquals(act, layer.getActivationFn());
assertEquals(weight.getWeightInitFunction(), layer.getWeightInit()); assertEquals(weight.getWeightInitFunction(), layer.getWeightInit());
assertEquals(new Dropout(dropOut), layer.getDropOut()); assertEquals(new Dropout(dropOut), layer.getDropOut());
assertEquals(updater, layer.getIUpdater()); assertEquals(updater, layer.getUpdater());
assertEquals(gradNorm, layer.getGradientNormalization()); assertEquals(gradNorm, layer.getGradientNormalization());
assertEquals(gradNormThreshold, layer.getGradientNormalizationThreshold(), 0.0); assertEquals(gradNormThreshold, layer.getGradientNormalizationThreshold(), 0.0);
} }

View File

@ -213,8 +213,8 @@ public class LayerConfigTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
Map<Integer, Double> testMomentumAfter2 = new HashMap<>(); Map<Integer, Double> testMomentumAfter2 = new HashMap<>();
testMomentumAfter2.put(0, 0.2); testMomentumAfter2.put(0, 0.2);
@ -227,8 +227,8 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
assertEquals(0.2, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); assertEquals(0.2, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
} }
@Test @Test
@ -239,10 +239,10 @@ public class LayerConfigTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta); assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater() instanceof AdaDelta);
assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater() instanceof AdaDelta);
assertEquals(0.5, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0); assertEquals(0.5, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getRho(), 0.0);
assertEquals(0.01, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); assertEquals(0.01, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getRho(), 0.0);
conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON)) conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON))
.layer(0, DenseLayer.builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()) .layer(0, DenseLayer.builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build())
@ -252,10 +252,10 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp); assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater() instanceof RmsProp);
assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater() instanceof AdaDelta);
assertEquals(1.0, ((RmsProp) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0); assertEquals(1.0, ((RmsProp) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getRmsDecay(), 0.0);
assertEquals(0.5, ((AdaDelta) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); assertEquals(0.5, ((AdaDelta) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getRho(), 0.0);
} }
@ -270,10 +270,10 @@ public class LayerConfigTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0); assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getBeta1(), 0.0);
assertEquals(0.6, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0); assertEquals(0.6, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getBeta1(), 0.0);
assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0); assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getBeta2(), 0.0);
assertEquals(0.7, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta2(), 0.0); assertEquals(0.7, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getBeta2(), 0.0);
} }
@Test @Test

View File

@ -163,12 +163,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
net.init(); net.init();
BaseLayerConfiguration layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration(); BaseLayerConfiguration layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3); assertEquals(expectedMomentum, ((Nesterovs) layerConf.getUpdater()).getMomentum(), 1e-3);
assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
BaseLayerConfiguration layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); BaseLayerConfiguration layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3); assertEquals(0.4, ((Nesterovs) layerConf1.getUpdater()).getMomentum(), 1e-3);
// Adam Updater // Adam Updater
conf = NeuralNetConfiguration.builder().updater(new Adam(0.3)) conf = NeuralNetConfiguration.builder().updater(new Adam(0.3))
@ -183,8 +183,8 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3); assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getUpdater()).getBeta1(), 1e-3);
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3); assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getUpdater()).getBeta2(), 1e-3);
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInit()); assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInit());
assertNull(TestUtils.getL1Reg(layerConf1.getRegularization())); assertNull(TestUtils.getL1Reg(layerConf1.getRegularization()));
assertNull(TestUtils.getL2Reg(layerConf1.getRegularization())); assertNull(TestUtils.getL2Reg(layerConf1.getRegularization()));
@ -197,12 +197,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
net.init(); net.init();
layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration(); layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3); assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getUpdater()).getRmsDecay(), 1e-3);
assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
assertNull(TestUtils.getL2Reg(layerConf.getRegularization())); assertNull(TestUtils.getL2Reg(layerConf.getRegularization()));
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3); assertEquals(0.4, ((RmsProp) layerConf1.getUpdater()).getRmsDecay(), 1e-3);
} }

View File

@ -164,16 +164,4 @@ public class SameDiffConv extends SameDiffLayer {
return activation.asSameDiff("out", sameDiff, conv); return activation.asSameDiff("out", sameDiff, conv);
} }
@Override
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration clone = globalConfig.clone().build();
if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
}
if (convolutionMode == null) {
convolutionMode = clone.getConvolutionMode();
}
}
} }

View File

@ -114,14 +114,6 @@ public class SameDiffDense extends SameDiffLayer {
return activation.asSameDiff("out", sd, z); return activation.asSameDiff("out", sd, z);
} }
@Override
public void applyGlobalConfigToLayer(
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration clone = globalConfig.clone().build();
if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
}
}
public char paramReshapeOrder(String param) { public char paramReshapeOrder(String param) {
// To match DL4J for easy comparison // To match DL4J for easy comparison

View File

@ -21,7 +21,6 @@
package org.deeplearning4j.nn.layers.samediff.testlayers; package org.deeplearning4j.nn.layers.samediff.testlayers;
import java.util.Map; import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams; import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer;
@ -87,8 +86,4 @@ public class SameDiffMSEOutputLayer extends SameDiffOutputLayer {
// To match DL4J for easy comparison // To match DL4J for easy comparison
return 'f'; return 'f';
} }
@Override
public void applyGlobalConfigToLayer(
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {}
} }

View File

@ -99,7 +99,7 @@ public class TransferLearningComplex extends BaseDL4JTest {
//Also check config: //Also check config:
BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration()); BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
assertEquals(new Adam(2e-2), bl.getIUpdater()); assertEquals(new Adam(2e-2), bl.getUpdater());
assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn()); assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn());
} }
assertTrue(cFound); assertTrue(cFound);

View File

@ -92,7 +92,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) { for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) {
BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration()); BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
assertEquals(new RmsProp(0.5), bl.getIUpdater()); assertEquals(new RmsProp(0.5), bl.getUpdater());
} }
@ -504,13 +504,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
//Check original net isn't modified: //Check original net isn't modified:
BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration(); BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
assertEquals(new Adam(1e-4), l0.getIUpdater()); assertEquals(new Adam(1e-4), l0.getUpdater());
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6); assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(new Adam(1e-4), l1.getIUpdater()); assertEquals(new Adam(1e-4), l1.getUpdater());
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
assertEquals(new WeightInitRelu(), l1.getWeightInit()); assertEquals(new WeightInitRelu(), l1.getWeightInit());
assertEquals(0.2, TestUtils.getL2(l1), 1e-6); assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
@ -519,13 +519,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
//Check new net has only the appropriate things modified (i.e., LR) //Check new net has only the appropriate things modified (i.e., LR)
l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration(); l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
assertEquals(new Adam(2e-2), l0.getIUpdater()); assertEquals(new Adam(2e-2), l0.getUpdater());
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6); assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration(); l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration();
assertEquals(new Adam(2e-2), l1.getIUpdater()); assertEquals(new Adam(2e-2), l1.getUpdater());
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
assertEquals(new WeightInitRelu(), l1.getWeightInit()); assertEquals(new WeightInitRelu(), l1.getWeightInit());
assertEquals(0.2, TestUtils.getL2(l1), 1e-6); assertEquals(0.2, TestUtils.getL2(l1), 1e-6);

View File

@ -100,7 +100,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -145,7 +145,7 @@ public class TestUpdaters extends BaseDL4JTest {
msdx.put(key, msdxTmp); msdx.put(key, msdxTmp);
count++; count++;
} }
assertEquals(rho, ((AdaDelta)layer.getTypedLayerConfiguration().getIUpdater()).getRho(), 1e-4); assertEquals(rho, ((AdaDelta)layer.getTypedLayerConfiguration().getUpdater()).getRho(), 1e-4);
} }
assertEquals(4, count); assertEquals(4, count);
@ -166,7 +166,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -186,7 +186,7 @@ public class TestUpdaters extends BaseDL4JTest {
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
count++; count++;
} }
assertEquals(lr, ((AdaGrad)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4); assertEquals(lr, ((AdaGrad)layer.getTypedLayerConfiguration().getUpdater()).getLearningRate(), 1e-4);
assertEquals(2, count); assertEquals(2, count);
} }
@ -210,7 +210,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -246,8 +246,8 @@ public class TestUpdaters extends BaseDL4JTest {
count++; count++;
} }
assertEquals(beta1, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4); assertEquals(beta1, ((Adam)layer.getTypedLayerConfiguration().getUpdater()).getBeta1(), 1e-4);
assertEquals(beta2, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4); assertEquals(beta2, ((Adam)layer.getTypedLayerConfiguration().getUpdater()).getBeta2(), 1e-4);
assertEquals(2, count); assertEquals(2, count);
} }
@ -274,7 +274,7 @@ public class TestUpdaters extends BaseDL4JTest {
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -363,7 +363,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -399,8 +399,8 @@ public class TestUpdaters extends BaseDL4JTest {
count++; count++;
} }
assertEquals(beta1, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4); assertEquals(beta1, ((AdaMax)layer.getTypedLayerConfiguration().getUpdater()).getBeta1(), 1e-4);
assertEquals(beta2, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4); assertEquals(beta2, ((AdaMax)layer.getTypedLayerConfiguration().getUpdater()).getBeta2(), 1e-4);
assertEquals(2, count); assertEquals(2, count);
} }
@ -419,7 +419,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -444,7 +444,7 @@ public class TestUpdaters extends BaseDL4JTest {
count++; count++;
} }
assertEquals(mu, ((Nesterovs)layer.getTypedLayerConfiguration().getIUpdater()).getMomentum(), 1e-4); assertEquals(mu, ((Nesterovs)layer.getTypedLayerConfiguration().getUpdater()).getMomentum(), 1e-4);
assertEquals(2, count); assertEquals(2, count);
} }
@ -466,7 +466,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -496,7 +496,7 @@ public class TestUpdaters extends BaseDL4JTest {
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
lastG.put(key, lastGTmp); lastG.put(key, lastGTmp);
} }
assertEquals(rmsDecay, ((RmsProp)layer.getTypedLayerConfiguration().getIUpdater()).getRmsDecay(), 1e-4); assertEquals(rmsDecay, ((RmsProp)layer.getTypedLayerConfiguration().getUpdater()).getRmsDecay(), 1e-4);
} }
@Test @Test
@ -528,7 +528,7 @@ public class TestUpdaters extends BaseDL4JTest {
gradExpected = val.mul(lr); gradExpected = val.mul(lr);
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
} }
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4); assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getUpdater()).getLearningRate(), 1e-4);
} }
@ -770,7 +770,7 @@ public class TestUpdaters extends BaseDL4JTest {
gradExpected = val.mul(lr); gradExpected = val.mul(lr);
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
} }
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4); assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getUpdater()).getLearningRate(), 1e-4);
//Test with pretrain == false //Test with pretrain == false
@ -798,7 +798,7 @@ public class TestUpdaters extends BaseDL4JTest {
layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
updater = UpdaterCreator.getUpdater(layer); updater = UpdaterCreator.getUpdater(layer);
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4); assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getUpdater()).getLearningRate(), 1e-4);
} }
@Test @Test

View File

@ -61,18 +61,18 @@ public class TestCustomUpdater extends BaseDL4JTest {
.build(); .build();
//First: Check updater config //First: Check updater config
assertTrue(((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater() instanceof CustomIUpdater); assertTrue(((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getUpdater() instanceof CustomIUpdater);
assertTrue(((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater() instanceof CustomIUpdater); assertTrue(((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getUpdater() instanceof CustomIUpdater);
assertTrue(((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater() instanceof Sgd); assertTrue(((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getUpdater() instanceof Sgd);
assertTrue(((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater() instanceof Sgd); assertTrue(((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getUpdater() instanceof Sgd);
CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater(); CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getUpdater();
CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater(); CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getUpdater();
assertEquals(lr, u0_0.getLearningRate(), 1e-6); assertEquals(lr, u0_0.getLearningRate(), 1e-6);
assertEquals(lr, u0_1.getLearningRate(), 1e-6); assertEquals(lr, u0_1.getLearningRate(), 1e-6);
Sgd u1_0 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater(); Sgd u1_0 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getUpdater();
Sgd u1_1 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater(); Sgd u1_1 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getUpdater();
assertEquals(lr, u1_0.getLearningRate(), 1e-6); assertEquals(lr, u1_0.getLearningRate(), 1e-6);
assertEquals(lr, u1_1.getLearningRate(), 1e-6); assertEquals(lr, u1_1.getLearningRate(), 1e-6);

View File

@ -73,8 +73,8 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater()); assertEquals(new Nesterovs(0.15, 0.9), l0.getUpdater());
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l0.getUpdater()).getLearningRate(), 1e-6);
OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer(); OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
assertEquals("softmax", l1.getActivationFn().toString()); assertEquals("softmax", l1.getActivationFn().toString());
@ -82,9 +82,9 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater()); assertEquals(new Nesterovs(0.15, 0.9), l1.getUpdater());
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getUpdater()).getLearningRate(), 1e-6);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
@ -107,8 +107,8 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l0.getDropOut()); assertEquals(new Dropout(0.6), l0.getDropOut());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6); assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l0));
@ -119,8 +119,8 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getUpdater());
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l1.getUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l1.getDropOut()); assertEquals(new Dropout(0.6), l1.getDropOut());
assertEquals(0.1, TestUtils.getL1(l1), 1e-6); assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1)); assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));
@ -146,8 +146,8 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut()); assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
@ -166,8 +166,8 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut()); assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());

View File

@ -75,8 +75,8 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater()); assertEquals(new Nesterovs(0.15, 0.9), l0.getUpdater());
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l0.getUpdater()).getLearningRate(), 1e-6);
OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer(); OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
assertEquals("softmax", l1.getActivationFn().toString()); assertEquals("softmax", l1.getActivationFn().toString());
@ -84,9 +84,9 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater()); assertEquals(new Nesterovs(0.15, 0.9), l1.getUpdater());
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getUpdater()).getLearningRate(), 1e-6);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
@ -109,8 +109,8 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l0.getDropOut()); assertEquals(new Dropout(0.6), l0.getDropOut());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6); assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l0));
@ -123,8 +123,8 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getUpdater());
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l1.getUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l1.getDropOut()); assertEquals(new Dropout(0.6), l1.getDropOut());
assertEquals(0.1, TestUtils.getL1(l1), 1e-6); assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l1)); assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l1));
@ -152,8 +152,8 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut()); assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
@ -172,8 +172,8 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut()); assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);

View File

@ -76,8 +76,8 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater()); assertEquals(new Nesterovs(0.15, 0.9), l0.getUpdater());
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l0.getUpdater()).getLearningRate(), 1e-6);
OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer(); OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
assertEquals("softmax", l1.getActivationFn().toString()); assertEquals("softmax", l1.getActivationFn().toString());
@ -85,9 +85,9 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getUpdater()).getMomentum(), 1e-6);
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getUpdater()).getLearningRate(), 1e-6);
long numParams = (int)net.numParams(); long numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams()); assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
@ -110,8 +110,8 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l0.getDropOut()); assertEquals(new Dropout(0.6), l0.getDropOut());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6); assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l0));
@ -124,8 +124,8 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l1.getDropOut()); assertEquals(new Dropout(0.6), l1.getDropOut());
assertEquals(0.1, TestUtils.getL1(l1), 1e-6); assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l1)); assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l1));
@ -153,8 +153,8 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut()); assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
@ -173,8 +173,8 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut()); assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);

View File

@ -75,10 +75,10 @@ public class RegressionTest080 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertTrue(l0.getIUpdater() instanceof Nesterovs); assertTrue(l0.getUpdater() instanceof Nesterovs);
Nesterovs n = (Nesterovs) l0.getIUpdater(); Nesterovs n = (Nesterovs) l0.getUpdater();
assertEquals(0.9, n.getMomentum(), 1e-6); assertEquals(0.9, n.getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l0.getUpdater()).getLearningRate(), 1e-6);
assertEquals(0.15, n.getLearningRate(), 1e-6); assertEquals(0.15, n.getLearningRate(), 1e-6);
@ -88,9 +88,9 @@ public class RegressionTest080 extends BaseDL4JTest {
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertTrue(l1.getIUpdater() instanceof Nesterovs); assertTrue(l1.getUpdater() instanceof Nesterovs);
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getUpdater()).getMomentum(), 1e-6);
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getUpdater()).getLearningRate(), 1e-6);
assertEquals(0.15, n.getLearningRate(), 1e-6); assertEquals(0.15, n.getLearningRate(), 1e-6);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
@ -114,11 +114,11 @@ public class RegressionTest080 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
assertTrue(l0.getIUpdater() instanceof RmsProp); assertTrue(l0.getUpdater() instanceof RmsProp);
RmsProp r = (RmsProp) l0.getIUpdater(); RmsProp r = (RmsProp) l0.getUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6); assertEquals(0.96, r.getRmsDecay(), 1e-6);
assertEquals(0.15, r.getLearningRate(), 1e-6); assertEquals(0.15, r.getLearningRate(), 1e-6);
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l0.getDropOut()); assertEquals(new Dropout(0.6), l0.getDropOut());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6); assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l0));
@ -131,11 +131,11 @@ public class RegressionTest080 extends BaseDL4JTest {
assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNIn());
assertEquals(5, l1.getNOut()); assertEquals(5, l1.getNOut());
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInit()); assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInit());
assertTrue(l1.getIUpdater() instanceof RmsProp); assertTrue(l1.getUpdater() instanceof RmsProp);
r = (RmsProp) l1.getIUpdater(); r = (RmsProp) l1.getUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6); assertEquals(0.96, r.getRmsDecay(), 1e-6);
assertEquals(0.15, r.getLearningRate(), 1e-6); assertEquals(0.15, r.getLearningRate(), 1e-6);
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertEquals(new Dropout(0.6), l1.getDropOut()); assertEquals(new Dropout(0.6), l1.getDropOut());
assertEquals(0.1, TestUtils.getL1(l1), 1e-6); assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1)); assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));
@ -163,11 +163,11 @@ public class RegressionTest080 extends BaseDL4JTest {
assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNIn());
assertEquals(3, l0.getNOut()); assertEquals(3, l0.getNOut());
assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new WeightInitRelu(), l0.getWeightInit());
assertTrue(l0.getIUpdater() instanceof RmsProp); assertTrue(l0.getUpdater() instanceof RmsProp);
RmsProp r = (RmsProp) l0.getIUpdater(); RmsProp r = (RmsProp) l0.getUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6); assertEquals(0.96, r.getRmsDecay(), 1e-6);
assertEquals(0.15, r.getLearningRate(), 1e-6); assertEquals(0.15, r.getLearningRate(), 1e-6);
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
@ -186,8 +186,8 @@ public class RegressionTest080 extends BaseDL4JTest {
assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(26 * 26 * 3, l2.getNIn());
assertEquals(5, l2.getNOut()); assertEquals(5, l2.getNOut());
assertEquals(new WeightInitRelu(), l2.getWeightInit()); assertEquals(new WeightInitRelu(), l2.getWeightInit());
assertTrue(l2.getIUpdater() instanceof RmsProp); assertTrue(l2.getUpdater() instanceof RmsProp);
r = (RmsProp) l2.getIUpdater(); r = (RmsProp) l2.getUpdater();
assertEquals(0.96, r.getRmsDecay(), 1e-6); assertEquals(0.96, r.getRmsDecay(), 1e-6);
assertEquals(0.15, r.getLearningRate(), 1e-6); assertEquals(0.15, r.getLearningRate(), 1e-6);

View File

@ -91,21 +91,21 @@ public class RegressionTest100a extends BaseDL4JTest {
assertEquals(200, l0.getNOut()); assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new RmsProp(0.1), l0.getIUpdater()); assertEquals(new RmsProp(0.1), l0.getUpdater());
GravesLSTM l1 = (GravesLSTM) net.getLayer(1).getLayerConfiguration(); GravesLSTM l1 = (GravesLSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut()); assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l1)); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l1));
assertEquals(new RmsProp(0.1), l1.getIUpdater()); assertEquals(new RmsProp(0.1), l1.getUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut()); assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInit()); assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new RmsProp(0.1), l0.getIUpdater()); assertEquals(new RmsProp(0.1), l0.getUpdater());
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType()); assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
assertEquals(50, net.getNetConfiguration().getTbpttBackLength()); assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
@ -141,7 +141,7 @@ public class RegressionTest100a extends BaseDL4JTest {
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(0.05), l0.getIUpdater()); assertEquals(new Adam(0.05), l0.getUpdater());
INDArray outExp; INDArray outExp;
File f2 = Resources.asFile("regression_testing/100a/VaeMNISTAnomaly_Output_100a.bin"); File f2 = Resources.asFile("regression_testing/100a/VaeMNISTAnomaly_Output_100a.bin");

View File

@ -75,12 +75,12 @@ public class RegressionTest100b3 extends BaseDL4JTest {
DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration(); DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(new WeightDecay(0.03, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.03, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new RmsProp(0.95), l0.getIUpdater()); assertEquals(new RmsProp(0.95), l0.getUpdater());
CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration(); CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction()); assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
assertEquals(new RmsProp(0.95), l1.getIUpdater()); assertEquals(new RmsProp(0.95), l1.getUpdater());
INDArray outExp; INDArray outExp;
@ -126,21 +126,21 @@ public class RegressionTest100b3 extends BaseDL4JTest {
assertEquals(200, l0.getNOut()); assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getUpdater());
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut()); assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l1)); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater()); assertEquals(new Adam(0.005), l1.getUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut()); assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInit()); assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getUpdater());
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType()); assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
assertEquals(50, net.getNetConfiguration().getTbpttBackLength()); assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
@ -176,7 +176,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(1e-3), l0.getIUpdater()); assertEquals(new Adam(1e-3), l0.getUpdater());
INDArray outExp; INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b3/VaeMNISTAnomaly_Output_100b3.bin"); File f2 = Resources.asFile("regression_testing/100b3/VaeMNISTAnomaly_Output_100b3.bin");

View File

@ -94,12 +94,12 @@ public class RegressionTest100b4 extends BaseDL4JTest {
DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration(); DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0));
assertEquals(new RmsProp(0.95), l0.getIUpdater()); assertEquals(new RmsProp(0.95), l0.getUpdater());
CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration(); CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction()); assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
assertEquals(new RmsProp(0.95), l1.getIUpdater()); assertEquals(new RmsProp(0.95), l1.getUpdater());
INDArray outExp; INDArray outExp;
File f2 = Resources File f2 = Resources
@ -144,21 +144,21 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(200, l0.getNOut()); assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getUpdater());
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut()); assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater()); assertEquals(new Adam(0.005), l1.getUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut()); assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInit()); assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
assertEquals(new Adam(0.005), l2.getIUpdater()); assertEquals(new Adam(0.005), l2.getUpdater());
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType()); assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
assertEquals(50, net.getNetConfiguration().getTbpttBackLength()); assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
@ -194,7 +194,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(1e-3), l0.getIUpdater()); assertEquals(new Adam(1e-3), l0.getUpdater());
INDArray outExp; INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b4/VaeMNISTAnomaly_Output_100b4.bin"); File f2 = Resources.asFile("regression_testing/100b4/VaeMNISTAnomaly_Output_100b4.bin");
@ -262,7 +262,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getUpdater());
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
assertArrayEquals(new int[]{2, 1}, l0.getStride()); assertArrayEquals(new int[]{2, 1}, l0.getStride());
assertArrayEquals(new int[]{1, 1}, l0.getDilation()); assertArrayEquals(new int[]{1, 1}, l0.getDilation());
@ -273,7 +273,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(8, l1.getNOut()); assertEquals(8, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater()); assertEquals(new Adam(0.005), l1.getUpdater());
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l1.getStride()); assertArrayEquals(new int[]{1, 1}, l1.getStride());
assertArrayEquals(new int[]{1, 1}, l1.getDilation()); assertArrayEquals(new int[]{1, 1}, l1.getDilation());
@ -299,7 +299,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(16, l5.getNOut()); assertEquals(16, l5.getNOut());
assertEquals(new WeightInitXavier(), l5.getWeightInit()); assertEquals(new WeightInitXavier(), l5.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
assertEquals(new Adam(0.005), l5.getIUpdater()); assertEquals(new Adam(0.005), l5.getUpdater());
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l5.getStride()); assertArrayEquals(new int[]{1, 1}, l5.getStride());
assertArrayEquals(new int[]{1, 1}, l5.getDilation()); assertArrayEquals(new int[]{1, 1}, l5.getDilation());
@ -320,7 +320,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(4, l8.getNOut()); assertEquals(4, l8.getNOut());
assertEquals(new WeightInitXavier(), l8.getWeightInit()); assertEquals(new WeightInitXavier(), l8.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
assertEquals(new Adam(0.005), l8.getIUpdater()); assertEquals(new Adam(0.005), l8.getUpdater());
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize()); assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l8.getStride()); assertArrayEquals(new int[]{1, 1}, l8.getStride());
assertArrayEquals(new int[]{1, 1}, l8.getDilation()); assertArrayEquals(new int[]{1, 1}, l8.getDilation());
@ -329,7 +329,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration(); CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
assertEquals(new WeightInitXavier(), l9.getWeightInit()); assertEquals(new WeightInitXavier(), l9.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
assertEquals(new Adam(0.005), l9.getIUpdater()); assertEquals(new Adam(0.005), l9.getUpdater());
assertEquals(new LossMAE(), l9.getLossFunction()); assertEquals(new LossMAE(), l9.getLossFunction());
INDArray outExp; INDArray outExp;

View File

@ -76,12 +76,12 @@ public class RegressionTest100b6 extends BaseDL4JTest {
DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration(); DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration();
assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(new ActivationTanH(), l0.getActivationFn());
assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0));
assertEquals(new RmsProp(0.95), l0.getIUpdater()); assertEquals(new RmsProp(0.95), l0.getUpdater());
CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration(); CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction()); assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
assertEquals(new RmsProp(0.95), l1.getIUpdater()); assertEquals(new RmsProp(0.95), l1.getUpdater());
INDArray outExp; INDArray outExp;
File f2 = Resources File f2 = Resources
@ -126,21 +126,21 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(200, l0.getNOut()); assertEquals(200, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getUpdater());
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationTanH(), l1.getActivationFn());
assertEquals(200, l1.getNOut()); assertEquals(200, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater()); assertEquals(new Adam(0.005), l1.getUpdater());
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(new ActivationSoftmax(), l2.getActivationFn());
assertEquals(77, l2.getNOut()); assertEquals(77, l2.getNOut());
assertEquals(new WeightInitXavier(), l2.getWeightInit()); assertEquals(new WeightInitXavier(), l2.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
assertEquals(new Adam(0.005), l2.getIUpdater()); assertEquals(new Adam(0.005), l2.getUpdater());
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType()); assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
assertEquals(50, net.getNetConfiguration().getTbpttBackLength()); assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
@ -176,7 +176,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(1e-3), l0.getIUpdater()); assertEquals(new Adam(1e-3), l0.getUpdater());
INDArray outExp; INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Output_100b6.bin"); File f2 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Output_100b6.bin");
@ -242,7 +242,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(4, l0.getNOut()); assertEquals(4, l0.getNOut());
assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightInitXavier(), l0.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getUpdater());
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
assertArrayEquals(new int[]{2, 1}, l0.getStride()); assertArrayEquals(new int[]{2, 1}, l0.getStride());
assertArrayEquals(new int[]{1, 1}, l0.getDilation()); assertArrayEquals(new int[]{1, 1}, l0.getDilation());
@ -253,7 +253,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(8, l1.getNOut()); assertEquals(8, l1.getNOut());
assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightInitXavier(), l1.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
assertEquals(new Adam(0.005), l1.getIUpdater()); assertEquals(new Adam(0.005), l1.getUpdater());
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l1.getStride()); assertArrayEquals(new int[]{1, 1}, l1.getStride());
assertArrayEquals(new int[]{1, 1}, l1.getDilation()); assertArrayEquals(new int[]{1, 1}, l1.getDilation());
@ -279,7 +279,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(16, l5.getNOut()); assertEquals(16, l5.getNOut());
assertEquals(new WeightInitXavier(), l5.getWeightInit()); assertEquals(new WeightInitXavier(), l5.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
assertEquals(new Adam(0.005), l5.getIUpdater()); assertEquals(new Adam(0.005), l5.getUpdater());
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize()); assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l5.getStride()); assertArrayEquals(new int[]{1, 1}, l5.getStride());
assertArrayEquals(new int[]{1, 1}, l5.getDilation()); assertArrayEquals(new int[]{1, 1}, l5.getDilation());
@ -300,7 +300,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(4, l8.getNOut()); assertEquals(4, l8.getNOut());
assertEquals(new WeightInitXavier(), l8.getWeightInit()); assertEquals(new WeightInitXavier(), l8.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
assertEquals(new Adam(0.005), l8.getIUpdater()); assertEquals(new Adam(0.005), l8.getUpdater());
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize()); assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
assertArrayEquals(new int[]{1, 1}, l8.getStride()); assertArrayEquals(new int[]{1, 1}, l8.getStride());
assertArrayEquals(new int[]{1, 1}, l8.getDilation()); assertArrayEquals(new int[]{1, 1}, l8.getDilation());
@ -309,7 +309,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration(); CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
assertEquals(new WeightInitXavier(), l9.getWeightInit()); assertEquals(new WeightInitXavier(), l9.getWeightInit());
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
assertEquals(new Adam(0.005), l9.getIUpdater()); assertEquals(new Adam(0.005), l9.getUpdater());
assertEquals(new LossMAE(), l9.getLossFunction()); assertEquals(new LossMAE(), l9.getLossFunction());
INDArray outExp; INDArray outExp;

View File

@ -113,7 +113,7 @@ public class CustomLayer extends FeedForwardLayer {
InputType outputType = getOutputType(-1, inputType); InputType outputType = getOutputType(-1, inputType);
val numParams = initializer().numParams(this); val numParams = initializer().numParams(this);
int updaterStateSize = (int) getIUpdater().stateSize(numParams); int updaterStateSize = (int) getUpdater().stateSize(numParams);
int trainSizeFixed = 0; int trainSizeFixed = 0;
int trainSizeVariable = 0; int trainSizeVariable = 0;

View File

@ -235,7 +235,7 @@ public class GradientCheckUtil {
for (LayerConfiguration n : c.net.getNetConfiguration().getFlattenedLayerConfigurations()) { for (LayerConfiguration n : c.net.getNetConfiguration().getFlattenedLayerConfigurations()) {
if (n instanceof BaseLayerConfiguration) { if (n instanceof BaseLayerConfiguration) {
BaseLayerConfiguration bl = (BaseLayerConfiguration) n; BaseLayerConfiguration bl = (BaseLayerConfiguration) n;
IUpdater u = bl.getIUpdater(); IUpdater u = bl.getUpdater();
if (u instanceof Sgd) { if (u instanceof Sgd) {
// Must have LR of 1.0 // Must have LR of 1.0
double lr = ((Sgd) u).getLearningRate(); double lr = ((Sgd) u).getLearningRate();
@ -540,7 +540,7 @@ public class GradientCheckUtil {
if (lv.getLayerConfiguration() instanceof BaseLayerConfiguration) { if (lv.getLayerConfiguration() instanceof BaseLayerConfiguration) {
BaseLayerConfiguration bl = (BaseLayerConfiguration) lv.getLayerConfiguration(); BaseLayerConfiguration bl = (BaseLayerConfiguration) lv.getLayerConfiguration();
IUpdater u = bl.getIUpdater(); IUpdater u = bl.getUpdater();
if (u instanceof Sgd) { if (u instanceof Sgd) {
// Must have LR of 1.0 // Must have LR of 1.0
double lr = ((Sgd) u).getLearningRate(); double lr = ((Sgd) u).getLearningRate();

View File

@ -322,7 +322,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
@Getter @Setter @lombok.Builder.Default @JsonIgnore private IUpdater iUpdater = new Sgd(); @Getter @Setter @lombok.Builder.Default @JsonIgnore private IUpdater iUpdater = new Sgd();
/** /**
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as * Gradient updater configuration, for the biases only. If not set, biases will use the updater as
* set by {@link #setIUpdater(IUpdater)}<br> * set by {@link #setUpdater(IUpdater)}<br>
* Note: values set by this method will be applied to all applicable layers in the network, unless * Note: values set by this method will be applied to all applicable layers in the network, unless
* a different value is explicitly set on a given layer. In other words: values set via this * a different value is explicitly set on a given layer. In other words: values set via this
* method are used as the default value, and can be overridden on a per-layer basis. * method are used as the default value, and can be overridden on a per-layer basis.

View File

@ -65,7 +65,7 @@ public class ActivationLayer extends NoParamLayer {
} }
@Override @Override
public IUpdater getIUpdater() { public IUpdater getUpdater() {
return null; return null;
} }

View File

@ -88,7 +88,7 @@ public class AutoEncoder extends BasePretrainNetwork {
val actElementsPerEx = outputType.arrayElementsPerExample() + inputType.arrayElementsPerExample(); val actElementsPerEx = outputType.arrayElementsPerExample() + inputType.arrayElementsPerExample();
val numParams = initializer().numParams(this); val numParams = initializer().numParams(this);
val updaterStateSize = (int) getIUpdater().stateSize(numParams); val updaterStateSize = (int) getUpdater().stateSize(numParams);
int trainSizePerEx = 0; int trainSizePerEx = 0;
if (getDropOut() != null) { if (getDropOut() != null) {

View File

@ -95,7 +95,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link * Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
* org.nd4j.linalg.learning.config.Nesterovs} * org.nd4j.linalg.learning.config.Nesterovs}
*/ */
@Getter @Setter @Getter
protected IUpdater updater; protected IUpdater updater;
/** /**
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as * Gradient updater configuration, for the biases only. If not set, biases will use the updater as
@ -134,7 +134,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
*/ */
public void resetLayerDefaultConfig() { public void resetLayerDefaultConfig() {
// clear the learning related params for all layers in the origConf and set to defaults // clear the learning related params for all layers in the origConf and set to defaults
this.setIUpdater(null); this.setUpdater( (IUpdater) null);
this.setWeightInit(null); this.setWeightInit(null);
this.setBiasInit(Double.NaN); this.setBiasInit(Double.NaN);
this.setGainInit(Double.NaN); this.setGainInit(Double.NaN);
@ -142,10 +142,16 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
this.regularizationBias = null; this.regularizationBias = null;
this.setGradientNormalization(GradientNormalization.None); this.setGradientNormalization(GradientNormalization.None);
this.setGradientNormalizationThreshold(1.0); this.setGradientNormalizationThreshold(1.0);
this.updater = null;
this.biasUpdater = null; this.biasUpdater = null;
} }
public void setUpdater(Updater updater) {
setUpdater(updater.getIUpdaterWithDefaultConfig());
}
public void setUpdater(IUpdater iUpdater) {
this.updater=iUpdater;
}
@Override @Override
public BaseLayerConfiguration clone() { public BaseLayerConfiguration clone() {
BaseLayerConfiguration clone = (BaseLayerConfiguration) super.clone(); BaseLayerConfiguration clone = (BaseLayerConfiguration) super.clone();
@ -203,6 +209,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
if (this.updater == null) this.updater = conf.getUpdater(); if (this.updater == null) this.updater = conf.getUpdater();
if (this.regularizationBias == null) this.regularizationBias = conf.getRegularizationBias(); if (this.regularizationBias == null) this.regularizationBias = conf.getRegularizationBias();
if (this.regularization == null) this.regularization = conf.getRegularization(); if (this.regularization == null) this.regularization = conf.getRegularization();
if( this.weightInit == null) this.weightInit = conf.getWeightInit();
if (this.gradientNormalization == null) if (this.gradientNormalization == null)
this.gradientNormalization = conf.getGradientNormalization(); this.gradientNormalization = conf.getGradientNormalization();
// if(this.weightInit == null) this.weightInit = conf.getWeightInit(); // if(this.weightInit == null) this.weightInit = conf.getWeightInit();

View File

@ -56,7 +56,7 @@ public abstract class BaseOutputLayer extends FeedForwardLayer {
InputType outputType = getOutputType(-1, inputType); InputType outputType = getOutputType(-1, inputType);
val numParams = initializer().numParams(this); val numParams = initializer().numParams(this);
val updaterStateSize = (int) getIUpdater().stateSize(numParams); val updaterStateSize = (int) getUpdater().stateSize(numParams);
int trainSizeFixed = 0; int trainSizeFixed = 0;
int trainSizeVariable = 0; int trainSizeVariable = 0;

View File

@ -235,7 +235,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
@Override @Override
public LayerMemoryReport getMemoryReport(InputType inputType) { public LayerMemoryReport getMemoryReport(InputType inputType) {
val paramSize = initializer().numParams(this); val paramSize = initializer().numParams(this);
val updaterStateSize = (int) getIUpdater().stateSize(paramSize); val updaterStateSize = (int) getUpdater().stateSize(paramSize);
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
InputType.InputTypeConvolutional outputType = InputType.InputTypeConvolutional outputType =

View File

@ -60,7 +60,7 @@ public class DenseLayer extends FeedForwardLayer {
LayerValidation.assertNInNOutSet( LayerValidation.assertNInNOutSet(
"DenseLayerConfiguration", getName(), layerIndex, getNIn(), getNOut()); "DenseLayerConfiguration", getName(), layerIndex, getNIn(), getNOut());
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
lconf.runInheritance(); runInheritance();
org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret = org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret =
new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType); new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType);
@ -84,7 +84,7 @@ public class DenseLayer extends FeedForwardLayer {
InputType outputType = getOutputType(-1, inputType); InputType outputType = getOutputType(-1, inputType);
val numParams = initializer().numParams(this); val numParams = initializer().numParams(this);
val updaterStateSize = (int) getIUpdater().stateSize(numParams); val updaterStateSize = (int) getUpdater().stateSize(numParams);
int trainSizeFixed = 0; int trainSizeFixed = 0;
int trainSizeVariable = 0; int trainSizeVariable = 0;

View File

@ -96,7 +96,7 @@ public class EmbeddingLayer extends FeedForwardLayer {
val actElementsPerEx = outputType.arrayElementsPerExample(); val actElementsPerEx = outputType.arrayElementsPerExample();
val numParams = initializer().numParams(this); val numParams = initializer().numParams(this);
val updaterStateSize = (int) getIUpdater().stateSize(numParams); val updaterStateSize = (int) getUpdater().stateSize(numParams);
// Embedding layer does not use caching. // Embedding layer does not use caching.
// Inference: no working memory - just activations (pullRows) // Inference: no working memory - just activations (pullRows)

View File

@ -162,7 +162,7 @@ extends FeedForwardLayerBuilder<C, B> {
val actElementsPerEx = outputType.arrayElementsPerExample(); val actElementsPerEx = outputType.arrayElementsPerExample();
val numParams = initializer().numParams(this); val numParams = initializer().numParams(this);
val updaterStateSize = (int) getIUpdater().stateSize(numParams); val updaterStateSize = (int) getUpdater().stateSize(numParams);
return new LayerMemoryReport.Builder(name, EmbeddingSequenceLayer.class, inputType, outputType) return new LayerMemoryReport.Builder(name, EmbeddingSequenceLayer.class, inputType, outputType)
.standardMemory(numParams, updaterStateSize).workingMemory(0, 0, 0, actElementsPerEx) .standardMemory(numParams, updaterStateSize).workingMemory(0, 0, 0, actElementsPerEx)

View File

@ -34,6 +34,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
@ -281,16 +282,6 @@ public abstract class LayerConfiguration
"Not supported: all layers with parameters should override this method"); "Not supported: all layers with parameters should override this method");
} }
public IUpdater getIUpdater() {
throw new UnsupportedOperationException(
"Not supported: all layers with parameters should override this method");
}
public void setIUpdater(IUpdater iUpdater) {
log.warn(
"Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), name);
}
/** /**
* This is a report of the estimated memory consumption for the given layer * This is a report of the estimated memory consumption for the given layer
* *
@ -316,6 +307,7 @@ public abstract class LayerConfiguration
if (this.activation == null) this.activation = conf.getActivation(); if (this.activation == null) this.activation = conf.getActivation();
if (this.dropOut == null) this.dropOut = conf.getIdropOut(); if (this.dropOut == null) this.dropOut = conf.getIdropOut();
if (this.weightNoise == null) this.weightNoise = conf.getWeightNoise(); if (this.weightNoise == null) this.weightNoise = conf.getWeightNoise();
} }
/** /**
@ -326,6 +318,24 @@ public abstract class LayerConfiguration
runInheritance(getNetConfiguration()); runInheritance(getNetConfiguration());
} }
/**
* This will always return no-Op updater.
* @return
*/
public IUpdater getUpdater() {
log.warn("Calling getUpdater() in {} will always return no-Op Updater.", LayerConfiguration.class.getSimpleName());
return Updater.NONE.getIUpdaterWithDefaultConfig();
}
@Deprecated
public void setUpdater(Updater updater) {
setUpdater(updater.getIUpdaterWithDefaultConfig());
}
public void setUpdater(IUpdater iUpdater) {
throw new RuntimeException("When " + this.getName() + " wants to you an Updater, it needs to override the "
+ "Getter/ Setter for the Updater and not rely on LayerConfiguration class.");
}
public abstract static class LayerConfigurationBuilder< public abstract static class LayerConfigurationBuilder<
C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> { C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {

View File

@ -249,17 +249,6 @@ public class LocallyConnected1D extends SameDiffLayer {
} }
} }
@Override
public void applyGlobalConfigToLayer(
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration global_conf = globalConfig.build();
if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(global_conf.getActivation());
}
if (convolutionMode == null) {
convolutionMode = global_conf.getConvolutionMode();
}
}
private static final class LocallyConnected1DBuilderImpl private static final class LocallyConnected1DBuilderImpl
extends LocallyConnected1DBuilder<LocallyConnected1D, LocallyConnected1DBuilderImpl> { extends LocallyConnected1DBuilder<LocallyConnected1D, LocallyConnected1DBuilderImpl> {

View File

@ -305,17 +305,6 @@ public class LocallyConnected2D extends SameDiffLayer {
} }
} }
@Override
public void applyGlobalConfigToLayer(
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
NeuralNetConfiguration gconf = globalConfig.build();
if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(gconf.getActivation());
}
if (convolutionMode == null) {
convolutionMode = gconf.getConvolutionMode();
}
}
private static final class LocallyConnected2DBuilderImpl private static final class LocallyConnected2DBuilderImpl
extends LocallyConnected2DBuilder<LocallyConnected2D, LocallyConnected2DBuilderImpl> { extends LocallyConnected2DBuilder<LocallyConnected2D, LocallyConnected2DBuilderImpl> {

View File

@ -56,10 +56,11 @@ public abstract class NoParamLayer extends LayerConfiguration {
} }
/** /**
* Will always return no-Op updater.
* @return * @return
*/ */
@Override @Override
public IUpdater getIUpdater() { public IUpdater getUpdater() {
return Updater.NONE.getIUpdaterWithDefaultConfig(); return Updater.NONE.getIUpdaterWithDefaultConfig();
} }

View File

@ -40,9 +40,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class OutputLayer extends BaseOutputLayer { public class OutputLayer extends BaseOutputLayer {
{ // Set default activation function to softmax (to match default loss function MCXENT)
setActivation(Activation.SOFTMAX.getActivationFunction());
}
public static OutputLayerBuilder<?, ?> builder() { public static OutputLayerBuilder<?, ?> builder() {
return innerBuilder(); return innerBuilder();

View File

@ -120,7 +120,7 @@ public class PReLULayer extends BaseLayerConfiguration {
InputType outputType = getOutputType(-1, inputType); InputType outputType = getOutputType(-1, inputType);
val numParams = initializer().numParams(this); val numParams = initializer().numParams(this);
val updaterStateSize = (int) getIUpdater().stateSize(numParams); val updaterStateSize = (int) getUpdater().stateSize(numParams);
return new LayerMemoryReport.Builder(name, PReLULayer.class, inputType, outputType) return new LayerMemoryReport.Builder(name, PReLULayer.class, inputType, outputType)
.standardMemory(numParams, updaterStateSize) .standardMemory(numParams, updaterStateSize)

View File

@ -22,6 +22,7 @@ package org.deeplearning4j.nn.conf.layers;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.ITraininableLayerConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
@ -35,6 +36,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -44,7 +46,9 @@ import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder
public class RecurrentAttentionLayer extends SameDiffLayer { public class RecurrentAttentionLayer extends SameDiffLayer implements ITraininableLayerConfiguration {
private DataType dataType;
private static final class RecurrentAttentionLayerBuilderImpl extends RecurrentAttentionLayerBuilder<RecurrentAttentionLayer, RecurrentAttentionLayerBuilderImpl> { private static final class RecurrentAttentionLayerBuilderImpl extends RecurrentAttentionLayerBuilder<RecurrentAttentionLayer, RecurrentAttentionLayerBuilderImpl> {
public RecurrentAttentionLayer build() { public RecurrentAttentionLayer build() {
@ -190,13 +194,6 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
} }
} }
@Override
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(globalConfig.build().getActivation());
}
}
@Override @Override
public void validateInput(INDArray input) { public void validateInput(INDArray input) {
final long inputLength = input.size(2); final long inputLength = input.size(2);

View File

@ -91,7 +91,7 @@ public class ElementWiseMultiplicationLayer extends org.deeplearning4j.nn.conf.l
InputType outputType = getOutputType(-1, inputType); InputType outputType = getOutputType(-1, inputType);
val numParams = initializer().numParams(this); val numParams = initializer().numParams(this);
val updaterStateSize = (int) getIUpdater().stateSize(numParams); val updaterStateSize = (int) getUpdater().stateSize(numParams);
int trainSizeFixed = 0; int trainSizeFixed = 0;
int trainSizeVariable = 0; int trainSizeVariable = 0;

View File

@ -82,6 +82,7 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayerConfiguration {
org.deeplearning4j.nn.api.Layer newUnderlyingLayer = underlying.instantiate(conf, trainingListeners, org.deeplearning4j.nn.api.Layer newUnderlyingLayer = underlying.instantiate(conf, trainingListeners,
layerIndex, layerParamsView, initializeParams, networkDataType); layerIndex, layerParamsView, initializeParams, networkDataType);
runInheritance();
newUnderlyingLayer.setLayerConfiguration(underlying); //Fix a problem, where the embedded layer gets the conf of the frozen layer, rather than its own newUnderlyingLayer.setLayerConfiguration(underlying); //Fix a problem, where the embedded layer gets the conf of the frozen layer, rather than its own
NeuralNetConfiguration nncUnderlying = underlying.getNetConfiguration(); NeuralNetConfiguration nncUnderlying = underlying.getNetConfiguration();
@ -130,4 +131,6 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayerConfiguration {
this.constraints = constraints; this.constraints = constraints;
this.underlying.setConstraints(constraints); this.underlying.setConstraints(constraints);
} }
} }

View File

@ -38,6 +38,7 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil; import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.NetworkUtils; import org.deeplearning4j.util.NetworkUtils;
import org.jetbrains.annotations.NotNull;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -93,6 +94,21 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
@Getter @Setter @Getter @Setter
private SDLayerParams layerParams; private SDLayerParams layerParams;
@Override
public void runInheritance(@NotNull NeuralNetConfiguration conf) {
super.runInheritance(conf);
if (this.biasUpdater == null ) this.biasUpdater = conf.getBiasUpdater();
if (this.updater == null) this.updater = conf.getUpdater();
if (this.regularizationBias == null || regularizationBias.isEmpty()) this.regularizationBias = conf.getRegularizationBias();
if (this.regularization == null || regularization.isEmpty()) this.regularization = conf.getRegularization();
// if( this.weightInit == null) this.weightInit = conf.getWeightInit();
if (this.gradientNormalization == null)
this.gradientNormalization = conf.getGradientNormalization();
// if(this.weightInit == null) this.weightInit = conf.getWeightInit();
if (Double.isNaN(gradientNormalizationThreshold)) {
this.gradientNormalizationThreshold = conf.getGradientNormalizationThreshold();
}
}
@Override @Override
public List<Regularization> getRegularizationByParam(String paramName) { public List<Regularization> getRegularizationByParam(String paramName) {
if (layerParams.isWeightParam(paramName)) { if (layerParams.isWeightParam(paramName)) {
@ -122,10 +138,6 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
return null; return null;
} }
public void applyGlobalConfigToLayer(
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
// Default implementation: no op
}
/** /**
* Define the parameters for the network. Use {@link SDLayerParams#addWeightParam(String, * Define the parameters for the network. Use {@link SDLayerParams#addWeightParam(String,
@ -195,29 +207,6 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
fanIn, fanOut, array.shape(), weightInit, null, paramReshapeOrder(null), array); fanIn, fanOut, array.shape(), weightInit, null, paramReshapeOrder(null), array);
} }
public void applyGlobalConfig(NeuralNetConfiguration.NeuralNetConfigurationBuilder b) {
NeuralNetConfiguration bConf = b.build();
if (regularization == null || regularization.isEmpty()) {
regularization = bConf.getRegularization();
}
if (regularizationBias == null || regularizationBias.isEmpty()) {
regularizationBias = bConf.getRegularizationBias();
}
if (updater == null) {
updater = bConf.getUpdater();
}
if (biasUpdater == null) {
biasUpdater = bConf.getBiasUpdater();
}
if (gradientNormalization == null) {
gradientNormalization = bConf.getGradientNormalization();
}
if (Double.isNaN(gradientNormalizationThreshold)) {
gradientNormalizationThreshold = bConf.getGradientNormalizationThreshold();
}
applyGlobalConfigToLayer(b);
}
/** /**
* This method generates an "all ones" mask array for use in the SameDiff model when none is * This method generates an "all ones" mask array for use in the SameDiff model when none is

View File

@ -137,7 +137,7 @@ public class VariationalAutoencoder extends BasePretrainNetwork {
val actElementsPerEx = outputType.arrayElementsPerExample(); val actElementsPerEx = outputType.arrayElementsPerExample();
val numParams = initializer().numParams(this); val numParams = initializer().numParams(this);
int updaterStateSize = (int) getIUpdater().stateSize(numParams); int updaterStateSize = (int) getUpdater().stateSize(numParams);
int inferenceWorkingMemSizePerEx = 0; int inferenceWorkingMemSizePerEx = 0;
// Forward pass size through the encoder: // Forward pass size through the encoder:

View File

@ -123,17 +123,17 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
/** /**
* @return * @return
*/ */
@Override
public IUpdater getIUpdater() { public IUpdater getUpdater() {
return underlying.getIUpdater(); return underlying.getUpdater();
} }
/** /**
* @param iUpdater * @param iUpdater
*/ */
@Override
public void setIUpdater(IUpdater iUpdater) { public void setUpdater(IUpdater iUpdater) {
underlying.setIUpdater(iUpdater); underlying.setUpdater(iUpdater);
} }
@Override @Override

View File

@ -68,7 +68,7 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
for(LayerConfiguration l : layers){ for(LayerConfiguration l : layers){
if(l instanceof BaseLayerConfiguration){ if(l instanceof BaseLayerConfiguration){
BaseLayerConfiguration bl = (BaseLayerConfiguration)l; BaseLayerConfiguration bl = (BaseLayerConfiguration)l;
if(bl.getIUpdater() == null && bl.initializer().numParams(bl) > 0){ if(bl.getUpdater() == null && bl.initializer().numParams(bl) > 0){
return true; return true;
} }
} }
@ -200,7 +200,7 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
break; break;
} }
layer.setIUpdater(iu); layer.setUpdater(iu);
} }
} }
} }

View File

@ -116,7 +116,7 @@ public class ComputationGraphConfigurationDeserializer
} }
if(attemptIUpdaterFromLegacy && layers[layerIdx] instanceof BaseLayerConfiguration if(attemptIUpdaterFromLegacy && layers[layerIdx] instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration)layers[layerIdx]).getIUpdater() == null){ && ((BaseLayerConfiguration)layers[layerIdx]).getUpdater() == null){
handleUpdaterBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next); handleUpdaterBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next);
} }

View File

@ -87,7 +87,7 @@ public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserialize
ObjectNode on = (ObjectNode) confsNode.get(i); ObjectNode on = (ObjectNode) confsNode.get(i);
ObjectNode confNode = null; ObjectNode confNode = null;
if(layers[i] instanceof BaseLayerConfiguration if(layers[i] instanceof BaseLayerConfiguration
&& ((BaseLayerConfiguration)layers[i]).getIUpdater() == null){ && ((BaseLayerConfiguration)layers[i]).getUpdater() == null){
//layer -> (first/only child) -> updater //layer -> (first/only child) -> updater
if(on.has("layer")){ if(on.has("layer")){
confNode = on; confNode = on;

View File

@ -77,6 +77,7 @@ public class WeightNoise implements IWeightNoise {
(applyToBias && init.isBiasParam(layer.getLayerConfiguration(), paramKey))) { (applyToBias && init.isBiasParam(layer.getLayerConfiguration(), paramKey))) {
org.nd4j.linalg.api.rng.distribution.Distribution dist = Distributions.createDistribution(distribution); org.nd4j.linalg.api.rng.distribution.Distribution dist = Distributions.createDistribution(distribution);
INDArray noise = dist.sample(param.ulike()); INDArray noise = dist.sample(param.ulike());
INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering()); INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering());

View File

@ -577,7 +577,8 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
*/ */
@Override @Override
public Map<String, INDArray> getParamTable(boolean isBackprop) { public Map<String, INDArray> getParamTable(boolean isBackprop) {
throw new RuntimeException("Not implemented"); // throw new RuntimeException("Not implemented");
return null;
} }
/** /**

View File

@ -340,6 +340,6 @@ public class GlobalPoolingLayer extends AbstractLayer<org.deeplearning4j.nn.conf
*/ */
@Override @Override
public void setParamTable(Map<String, INDArray> paramTable) { public void setParamTable(Map<String, INDArray> paramTable) {
throw new RuntimeException("Not implemented.");
} }
} }

View File

@ -768,7 +768,7 @@ public class LSTMHelpers {
InputType outputType = lstmLayer.getOutputType(-1, inputType); InputType outputType = lstmLayer.getOutputType(-1, inputType);
val numParams = lstmLayer.initializer().numParams(lstmLayer); val numParams = lstmLayer.initializer().numParams(lstmLayer);
int updaterSize = (int) lstmLayer.getIUpdater().stateSize(numParams); int updaterSize = (int) lstmLayer.getUpdater().stateSize(numParams);
//Memory use during forward pass: //Memory use during forward pass:
//ifogActivations: nTimeSteps * [minibatch,4*layerSize] (not cached during inference fwd pass) //ifogActivations: nTimeSteps * [minibatch,4*layerSize] (not cached during inference fwd pass)

View File

@ -565,7 +565,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork
* inference/backprop are excluded from the returned list. * inference/backprop are excluded from the returned list.
* *
* @param backpropParamsOnly If true, return backprop params only. If false: return all params * @param backpropParamsOnly If true, return backprop params only. If false: return all params
* @return Parameters for the network * @return Parameters for the network, empty Map if no parameters present in the neural network
*/ */
@Override @Override
public Map<String, INDArray> getParamTable(boolean backpropParamsOnly) { public Map<String, INDArray> getParamTable(boolean backpropParamsOnly) {
@ -573,10 +573,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork
Map<String, INDArray> allParams = new LinkedHashMap<>(); Map<String, INDArray> allParams = new LinkedHashMap<>();
for (int i = 0; i < layers.length; i++) { for (int i = 0; i < layers.length; i++) {
Map<String, INDArray> paramMap = layers[i].getParamTable(backpropParamsOnly); Map<String, INDArray> paramMap = layers[i].getParamTable(backpropParamsOnly);
if(paramMap!=null){
for (Map.Entry<String, INDArray> entry : paramMap.entrySet()) { for (Map.Entry<String, INDArray> entry : paramMap.entrySet()) {
String newKey = i + "_" + entry.getKey(); String newKey = i + "_" + entry.getKey();
allParams.put(newKey, entry.getValue()); allParams.put(newKey, entry.getValue());
} }}
} }
return allParams; return allParams;
} }

View File

@ -94,6 +94,7 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
if (!(conf instanceof org.deeplearning4j.nn.conf.layers.FeedForwardLayer)) if (!(conf instanceof org.deeplearning4j.nn.conf.layers.FeedForwardLayer))
throw new IllegalArgumentException("unsupported layer type: " + conf.getClass().getName()); throw new IllegalArgumentException("unsupported layer type: " + conf.getClass().getName());
INDArray reshapedParamsView = paramsView.reshape(paramsView.length());
Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap<String, INDArray>()); Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap<String, INDArray>());
val length = numParams(conf); val length = numParams(conf);
@ -107,14 +108,14 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
val nOut = layerConf.getNOut(); val nOut = layerConf.getNOut();
val nWeightParams = nIn * nOut; val nWeightParams = nIn * nOut;
INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams)); INDArray weightView = reshapedParamsView.get(NDArrayIndex.interval(0, nWeightParams));
params.put(WEIGHT_KEY, createWeightMatrix(layerConf, weightView, initializeParams)); params.put(WEIGHT_KEY, createWeightMatrix(layerConf, weightView, initializeParams));
layerConf.addVariable(WEIGHT_KEY); layerConf.addVariable(WEIGHT_KEY);
long offset = nWeightParams; long offset = nWeightParams;
if(hasBias(layerConf)){ if(hasBias(layerConf)){
INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), INDArray biasView = reshapedParamsView.get(
NDArrayIndex.interval(offset, offset + nOut)); NDArrayIndex.interval(offset, offset + nOut));
params.put(BIAS_KEY, createBias(layerConf, biasView, initializeParams)); params.put(BIAS_KEY, createBias(layerConf, biasView, initializeParams));
layerConf.addVariable(BIAS_KEY); layerConf.addVariable(BIAS_KEY);
@ -122,7 +123,7 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
} }
if(hasLayerNorm(layerConf)){ if(hasLayerNorm(layerConf)){
INDArray gainView = paramsView.get(NDArrayIndex.interval(0,0,true), INDArray gainView = reshapedParamsView.get(
NDArrayIndex.interval(offset, offset + nOut)); NDArrayIndex.interval(offset, offset + nOut));
params.put(GAIN_KEY, createGain(conf, gainView, initializeParams)); params.put(GAIN_KEY, createGain(conf, gainView, initializeParams));
conf.addVariable(GAIN_KEY); conf.addVariable(GAIN_KEY);
@ -138,8 +139,9 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
val nIn = layerConf.getNIn(); val nIn = layerConf.getNIn();
val nOut = layerConf.getNOut(); val nOut = layerConf.getNOut();
val nWeightParams = nIn * nOut; val nWeightParams = nIn * nOut;
INDArray gradientViewReshaped = gradientView.reshape(gradientView.length());
INDArray weightGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams)) INDArray weightGradientView = gradientViewReshaped.get(NDArrayIndex.interval(0, nWeightParams))
.reshape('f', nIn, nOut); .reshape('f', nIn, nOut);
Map<String, INDArray> out = new LinkedHashMap<>(); Map<String, INDArray> out = new LinkedHashMap<>();
@ -147,14 +149,14 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
long offset = nWeightParams; long offset = nWeightParams;
if(hasBias(layerConf)){ if(hasBias(layerConf)){
INDArray biasView = gradientView.get(NDArrayIndex.interval(0,0,true), INDArray biasView = gradientViewReshaped.get(
NDArrayIndex.interval(offset, offset + nOut)); //Already a row vector NDArrayIndex.interval(offset, offset + nOut)); //Already a row vector
out.put(BIAS_KEY, biasView); out.put(BIAS_KEY, biasView);
offset += nOut; offset += nOut;
} }
if(hasLayerNorm(layerConf)){ if(hasLayerNorm(layerConf)){
INDArray gainView = gradientView.get(NDArrayIndex.interval(0,0,true), INDArray gainView = gradientViewReshaped.get(
NDArrayIndex.interval(offset, offset + nOut)); //Already a row vector NDArrayIndex.interval(offset, offset + nOut)); //Already a row vector
out.put(GAIN_KEY, gainView); out.put(GAIN_KEY, gainView);
} }
@ -196,12 +198,6 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf;
if (initializeParameters) { if (initializeParameters) {
if( layerConf.getWeightInit() == null) {
// set a default and set warning
layerConf.setWeightInit(new WeightInitXavier());
log.warn("Weight Initializer function was not set on layer {} of class {}, it will default to {}", conf.getName(),
conf.getClass().getSimpleName(), WeightInitXavier.class.getSimpleName());
}
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInit(), return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInit(),
weightParamView, true); weightParamView, true);
} else { } else {

View File

@ -200,7 +200,7 @@ public class FineTuneConfiguration {
bl.setGradientNormalizationThreshold(gradientNormalizationThreshold); bl.setGradientNormalizationThreshold(gradientNormalizationThreshold);
} }
if (updater != null) { if (updater != null) {
bl.setIUpdater(updater); bl.setUpdater(updater);
} }
if (biasUpdater != null) { if (biasUpdater != null) {
bl.setBiasUpdater(biasUpdater); bl.setBiasUpdater(biasUpdater);

View File

@ -125,7 +125,7 @@ public class NetworkUtils {
LayerConfiguration l = net.getLayer(layerNumber).getLayerConfiguration(); LayerConfiguration l = net.getLayer(layerNumber).getLayerConfiguration();
if (l instanceof BaseLayerConfiguration) { if (l instanceof BaseLayerConfiguration) {
BaseLayerConfiguration bl = (BaseLayerConfiguration) l; BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
IUpdater u = bl.getIUpdater(); IUpdater u = bl.getUpdater();
if (u != null && u.hasLearningRate()) { if (u != null && u.hasLearningRate()) {
if (newLrSchedule != null) { if (newLrSchedule != null) {
u.setLrAndSchedule(Double.NaN, newLrSchedule); u.setLrAndSchedule(Double.NaN, newLrSchedule);
@ -207,7 +207,7 @@ public class NetworkUtils {
int epoch = net.getEpochCount(); int epoch = net.getEpochCount();
if (l instanceof BaseLayerConfiguration) { if (l instanceof BaseLayerConfiguration) {
BaseLayerConfiguration bl = (BaseLayerConfiguration) l; BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
IUpdater u = bl.getIUpdater(); IUpdater u = bl.getUpdater();
if (u != null && u.hasLearningRate()) { if (u != null && u.hasLearningRate()) {
double d = u.getLearningRate(iter, epoch); double d = u.getLearningRate(iter, epoch);
if (Double.isNaN(d)) { if (Double.isNaN(d)) {
@ -247,7 +247,7 @@ public class NetworkUtils {
LayerConfiguration l = net.getLayer(layerName).getLayerConfiguration(); LayerConfiguration l = net.getLayer(layerName).getLayerConfiguration();
if (l instanceof BaseLayerConfiguration) { if (l instanceof BaseLayerConfiguration) {
BaseLayerConfiguration bl = (BaseLayerConfiguration) l; BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
IUpdater u = bl.getIUpdater(); IUpdater u = bl.getUpdater();
if (u != null && u.hasLearningRate()) { if (u != null && u.hasLearningRate()) {
if (newLrSchedule != null) { if (newLrSchedule != null) {
u.setLrAndSchedule(Double.NaN, newLrSchedule); u.setLrAndSchedule(Double.NaN, newLrSchedule);
@ -329,7 +329,7 @@ public class NetworkUtils {
int epoch = net.getComputationGraphConfiguration().getEpochCount(); int epoch = net.getComputationGraphConfiguration().getEpochCount();
if (l instanceof BaseLayerConfiguration) { if (l instanceof BaseLayerConfiguration) {
BaseLayerConfiguration bl = (BaseLayerConfiguration) l; BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
IUpdater u = bl.getIUpdater(); IUpdater u = bl.getUpdater();
if (u != null && u.hasLearningRate()) { if (u != null && u.hasLearningRate()) {
double d = u.getLearningRate(iter, epoch); double d = u.getLearningRate(iter, epoch);
if (Double.isNaN(d)) { if (Double.isNaN(d)) {

View File

@ -210,14 +210,14 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone(); MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
netCopy.fit(data); netCopy.fit(data);
IUpdater expectedUpdater = ((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getIUpdater(); IUpdater expectedUpdater = ((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getUpdater();
double expectedLR = ((Nesterovs)((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getIUpdater()).getLearningRate(); double expectedLR = ((Nesterovs)((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getUpdater()).getLearningRate();
double expectedMomentum = ((Nesterovs)((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getIUpdater()).getMomentum(); double expectedMomentum = ((Nesterovs)((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getUpdater()).getMomentum();
IUpdater actualUpdater = ((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater(); IUpdater actualUpdater = ((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getUpdater();
sparkNet.fit(sparkData); sparkNet.fit(sparkData);
double actualLR = ((Nesterovs)((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater()).getLearningRate(); double actualLR = ((Nesterovs)((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getUpdater()).getLearningRate();
double actualMomentum = ((Nesterovs)((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater()).getMomentum(); double actualMomentum = ((Nesterovs)((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getUpdater()).getMomentum();
assertEquals(expectedUpdater, actualUpdater); assertEquals(expectedUpdater, actualUpdater);
assertEquals(expectedLR, actualLR, 0.01); assertEquals(expectedLR, actualLR, 0.01);

View File

@ -1580,10 +1580,11 @@ public class CudaExecutioner extends DefaultOpExecutioner {
zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context), zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context),
AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context)); AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context));
} }
int errorCode = nativeOps.lastErrorCode();
if (nativeOps.lastErrorCode() != 0) if (errorCode != 0) {
throw new RuntimeException(nativeOps.lastErrorMessage() + " error code: " + nativeOps.lastErrorCode()); throw new RuntimeException(
nativeOps.lastErrorMessage() + " error code: " + errorCode);
}
profilingConfigurableHookOut(op, oc, st); profilingConfigurableHookOut(op, oc, st);
return z; return z;

View File

@ -1189,7 +1189,7 @@ public class TrainModule implements UIModule {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
IUpdater u = bl.getIUpdater(); IUpdater u = bl.getUpdater();
String us = (u == null ? "" : u.getClass().getSimpleName()); String us = (u == null ? "" : u.getClass().getSimpleName());
layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerUpdater"), layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerUpdater"),
us}); us});