parent
35ea21e436
commit
871073e4a4
|
@ -217,14 +217,14 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
||||||
|
|
||||||
netCopy.fit(data);
|
netCopy.fit(data);
|
||||||
IUpdater expectedUpdater = ((BaseLayer) netCopy.conf().getLayer()).getIUpdater();
|
IUpdater expectedUpdater = ((BaseLayer) netCopy.conf().getLayer()).getUpdater();
|
||||||
double expectedLR = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getIUpdater()).getLearningRate();
|
double expectedLR = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getUpdater()).getLearningRate();
|
||||||
double expectedMomentum = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getIUpdater()).getMomentum();
|
double expectedMomentum = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getUpdater()).getMomentum();
|
||||||
|
|
||||||
IUpdater actualUpdater = ((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater();
|
IUpdater actualUpdater = ((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getUpdater();
|
||||||
sparkNet.fit(sparkData);
|
sparkNet.fit(sparkData);
|
||||||
double actualLR = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater()).getLearningRate();
|
double actualLR = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getUpdater()).getLearningRate();
|
||||||
double actualMomentum = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater()).getMomentum();
|
double actualMomentum = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getUpdater()).getMomentum();
|
||||||
|
|
||||||
assertEquals(expectedUpdater, actualUpdater);
|
assertEquals(expectedUpdater, actualUpdater);
|
||||||
assertEquals(expectedLR, actualLR, 0.01);
|
assertEquals(expectedLR, actualLR, 0.01);
|
||||||
|
|
|
@ -47,6 +47,7 @@ import org.datavec.image.transform.ShowImageTransform;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||||
|
@ -77,10 +78,10 @@ public class App {
|
||||||
private static final double LEARNING_RATE = 0.000002;
|
private static final double LEARNING_RATE = 0.000002;
|
||||||
private static final double GRADIENT_THRESHOLD = 100.0;
|
private static final double GRADIENT_THRESHOLD = 100.0;
|
||||||
|
|
||||||
private static final int X_DIM = 28;
|
private static final int X_DIM = 20 ;
|
||||||
private static final int Y_DIM = 28;
|
private static final int Y_DIM = 20;
|
||||||
private static final int CHANNELS = 1;
|
private static final int CHANNELS = 1;
|
||||||
private static final int batchSize = 9;
|
private static final int batchSize = 10;
|
||||||
private static final int INPUT = 128;
|
private static final int INPUT = 128;
|
||||||
|
|
||||||
private static final int OUTPUT_PER_PANEL = 4;
|
private static final int OUTPUT_PER_PANEL = 4;
|
||||||
|
@ -97,12 +98,13 @@ public class App {
|
||||||
return new LayerConfiguration[] {
|
return new LayerConfiguration[] {
|
||||||
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
|
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
|
||||||
ActivationLayer.builder(Activation.LEAKYRELU).build(),
|
ActivationLayer.builder(Activation.LEAKYRELU).build(),
|
||||||
|
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
|
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH)
|
|
||||||
.build()
|
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH).build()
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,36 +133,40 @@ public class App {
|
||||||
|
|
||||||
private static LayerConfiguration[] disLayers() {
|
private static LayerConfiguration[] disLayers() {
|
||||||
return new LayerConfiguration[]{
|
return new LayerConfiguration[]{
|
||||||
DenseLayer.builder().nOut(X_DIM*Y_DIM*CHANNELS*2).build(), //input is set by setInputType on the network
|
DenseLayer.builder().name("1.Dense").nOut(X_DIM*Y_DIM*CHANNELS).build(), //input is set by setInputType on the network
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
DenseLayer.builder().nIn(X_DIM * Y_DIM*CHANNELS*2).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
|
DenseLayer.builder().name("2.Dense").nIn(X_DIM * Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
|
DenseLayer.builder().name("3.Dense").nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(),
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
DenseLayer.builder().name("4.Dense").nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
DropoutLayer.builder(1 - 0.5).build(),
|
DropoutLayer.builder(1 - 0.5).build(),
|
||||||
OutputLayer.builder().lossFunction(LossFunction.XENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
|
|
||||||
|
OutputLayer.builder().name("dis-output").lossFunction(LossFunction.XENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build()
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
private static NeuralNetConfiguration discriminator() {
|
private static NeuralNetConfiguration discriminator() {
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf =
|
||||||
.seed(42)
|
NeuralNetConfiguration.builder()
|
||||||
.updater(UPDATER)
|
.seed(42)
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.updater(UPDATER)
|
||||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
||||||
//.weightInitFn(new WeightInitXavier())
|
.weightInit(WeightInit.XAVIER)
|
||||||
//.activationFn(new ActivationIdentity())
|
//.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||||
.activation(Activation.IDENTITY)
|
.weightNoise(null)
|
||||||
.layersFromArray(disLayers())
|
// .weightInitFn(new WeightInitXavier())
|
||||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
// .activationFn(new ActivationIdentity())
|
||||||
.build();
|
.activation(Activation.IDENTITY)
|
||||||
|
.layersFromArray(disLayers())
|
||||||
|
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||||
|
.build();
|
||||||
((NeuralNetConfiguration) conf).init();
|
((NeuralNetConfiguration) conf).init();
|
||||||
|
|
||||||
return conf;
|
return conf;
|
||||||
|
@ -171,7 +177,7 @@ public class App {
|
||||||
LayerConfiguration[] disLayers = Arrays.stream(disLayers())
|
LayerConfiguration[] disLayers = Arrays.stream(disLayers())
|
||||||
.map((layer) -> {
|
.map((layer) -> {
|
||||||
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
if (layer instanceof DenseLayer || layer instanceof OutputLayer) {
|
||||||
return FrozenLayerWithBackprop.builder(layer);
|
return FrozenLayerWithBackprop.builder(layer).name("frozen-for-"+layer.getName()).build();
|
||||||
} else {
|
} else {
|
||||||
return layer;
|
return layer;
|
||||||
}
|
}
|
||||||
|
@ -204,7 +210,7 @@ public class App {
|
||||||
public static void main(String... args) throws Exception {
|
public static void main(String... args) throws Exception {
|
||||||
|
|
||||||
log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m ");
|
log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m ");
|
||||||
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
Nd4j.getMemoryManager().setAutoGcWindow(500);
|
||||||
|
|
||||||
// MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
|
// MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45);
|
||||||
// FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
|
// FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS());
|
||||||
|
@ -236,10 +242,10 @@ public class App {
|
||||||
|
|
||||||
copyParams(gen, dis, gan);
|
copyParams(gen, dis, gan);
|
||||||
|
|
||||||
gen.addTrainingListeners(new PerformanceListener(10, true));
|
gen.addTrainingListeners(new PerformanceListener(15, true));
|
||||||
dis.addTrainingListeners(new PerformanceListener(10, true));
|
//dis.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
gan.addTrainingListeners(new PerformanceListener(10, true));
|
//gan.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
gan.addTrainingListeners(new ScoreToChartListener("gan"));
|
//gan.addTrainingListeners(new ScoreToChartListener("gan"));
|
||||||
//dis.setListeners(new ScoreToChartListener("dis"));
|
//dis.setListeners(new ScoreToChartListener("dis"));
|
||||||
|
|
||||||
System.out.println(gan.toString());
|
System.out.println(gan.toString());
|
||||||
|
|
|
@ -107,6 +107,9 @@ public class DeallocatorService {
|
||||||
boolean canRun = true;
|
boolean canRun = true;
|
||||||
long cnt = 0;
|
long cnt = 0;
|
||||||
while (canRun) {
|
while (canRun) {
|
||||||
|
log.trace("Starting deallocator threat with name '{}'. isPeriodicGc: {}, AutoGcWindow: {}. Current allocated memory: {}"
|
||||||
|
,this.getName(), Nd4j.getMemoryManager().isPeriodicGcActive()
|
||||||
|
, Nd4j.getMemoryManager().getAutoGcWindow(), Nd4j.getMemoryManager().allocatedMemory(deviceId));
|
||||||
// if periodicGc is enabled, only first thread will call for it
|
// if periodicGc is enabled, only first thread will call for it
|
||||||
if (Nd4j.getMemoryManager().isPeriodicGcActive() && threadIdx == 0 && Nd4j.getMemoryManager().getAutoGcWindow() > 0) {
|
if (Nd4j.getMemoryManager().isPeriodicGcActive() && threadIdx == 0 && Nd4j.getMemoryManager().getAutoGcWindow() > 0) {
|
||||||
val reference = (DeallocatableReference) queue.poll();
|
val reference = (DeallocatableReference) queue.poll();
|
||||||
|
@ -120,6 +123,7 @@ public class DeallocatorService {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// invoking deallocator
|
// invoking deallocator
|
||||||
|
log.trace("Deallocate reference {}", reference.getId());
|
||||||
reference.getDeallocator().deallocate();
|
reference.getDeallocator().deallocate();
|
||||||
referenceMap.remove(reference.getId());
|
referenceMap.remove(reference.getId());
|
||||||
}
|
}
|
||||||
|
|
|
@ -498,7 +498,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo());
|
assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo());
|
||||||
BaseLayerConfiguration bl = (BaseLayerConfiguration) net.getLayerConfiguration();
|
BaseLayerConfiguration bl = (BaseLayerConfiguration) net.getLayerConfiguration();
|
||||||
assertEquals(bl.getActivationFn().toString(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getActivationFn().toString());
|
assertEquals(bl.getActivationFn().toString(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getActivationFn().toString());
|
||||||
assertEquals(bl.getIUpdater(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getIUpdater());
|
assertEquals(bl.getUpdater(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getUpdater());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -306,7 +306,7 @@ public class AttentionLayerTest extends BaseDL4JTest {
|
||||||
.activation(Activation.IDENTITY)
|
.activation(Activation.IDENTITY)
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.list()
|
|
||||||
.layer(LSTM.builder().nOut(layerSize).build())
|
.layer(LSTM.builder().nOut(layerSize).build())
|
||||||
.layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
.layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
||||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
||||||
|
|
|
@ -76,7 +76,7 @@ public class LayerBuilderTest extends BaseDL4JTest {
|
||||||
assertEquals(act, layer.getActivationFn());
|
assertEquals(act, layer.getActivationFn());
|
||||||
assertEquals(weight.getWeightInitFunction(), layer.getWeightInit());
|
assertEquals(weight.getWeightInitFunction(), layer.getWeightInit());
|
||||||
assertEquals(new Dropout(dropOut), layer.getDropOut());
|
assertEquals(new Dropout(dropOut), layer.getDropOut());
|
||||||
assertEquals(updater, layer.getIUpdater());
|
assertEquals(updater, layer.getUpdater());
|
||||||
assertEquals(gradNorm, layer.getGradientNormalization());
|
assertEquals(gradNorm, layer.getGradientNormalization());
|
||||||
assertEquals(gradNormThreshold, layer.getGradientNormalizationThreshold(), 0.0);
|
assertEquals(gradNormThreshold, layer.getGradientNormalizationThreshold(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -213,8 +213,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
||||||
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
||||||
|
|
||||||
Map<Integer, Double> testMomentumAfter2 = new HashMap<>();
|
Map<Integer, Double> testMomentumAfter2 = new HashMap<>();
|
||||||
testMomentumAfter2.put(0, 0.2);
|
testMomentumAfter2.put(0, 0.2);
|
||||||
|
@ -227,8 +227,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
|
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
||||||
assertEquals(0.2, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
assertEquals(0.2, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -239,10 +239,10 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta);
|
assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater() instanceof AdaDelta);
|
||||||
assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
|
assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater() instanceof AdaDelta);
|
||||||
assertEquals(0.5, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0);
|
assertEquals(0.5, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getRho(), 0.0);
|
||||||
assertEquals(0.01, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
|
assertEquals(0.01, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getRho(), 0.0);
|
||||||
|
|
||||||
conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON))
|
conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON))
|
||||||
.layer(0, DenseLayer.builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build())
|
.layer(0, DenseLayer.builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build())
|
||||||
|
@ -252,10 +252,10 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp);
|
assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater() instanceof RmsProp);
|
||||||
assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
|
assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater() instanceof AdaDelta);
|
||||||
assertEquals(1.0, ((RmsProp) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0);
|
assertEquals(1.0, ((RmsProp) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getRmsDecay(), 0.0);
|
||||||
assertEquals(0.5, ((AdaDelta) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
|
assertEquals(0.5, ((AdaDelta) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getRho(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -270,10 +270,10 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0);
|
assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getBeta1(), 0.0);
|
||||||
assertEquals(0.6, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0);
|
assertEquals(0.6, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getBeta1(), 0.0);
|
||||||
assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0);
|
assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getUpdater()).getBeta2(), 0.0);
|
||||||
assertEquals(0.7, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta2(), 0.0);
|
assertEquals(0.7, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getUpdater()).getBeta2(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -163,12 +163,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
BaseLayerConfiguration layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
BaseLayerConfiguration layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3);
|
assertEquals(expectedMomentum, ((Nesterovs) layerConf.getUpdater()).getMomentum(), 1e-3);
|
||||||
assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
|
assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
|
||||||
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
|
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
|
||||||
|
|
||||||
BaseLayerConfiguration layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
BaseLayerConfiguration layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3);
|
assertEquals(0.4, ((Nesterovs) layerConf1.getUpdater()).getMomentum(), 1e-3);
|
||||||
|
|
||||||
// Adam Updater
|
// Adam Updater
|
||||||
conf = NeuralNetConfiguration.builder().updater(new Adam(0.3))
|
conf = NeuralNetConfiguration.builder().updater(new Adam(0.3))
|
||||||
|
@ -183,8 +183,8 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
|
||||||
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
|
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
|
||||||
|
|
||||||
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
|
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getUpdater()).getBeta1(), 1e-3);
|
||||||
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
|
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getUpdater()).getBeta2(), 1e-3);
|
||||||
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInit());
|
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInit());
|
||||||
assertNull(TestUtils.getL1Reg(layerConf1.getRegularization()));
|
assertNull(TestUtils.getL1Reg(layerConf1.getRegularization()));
|
||||||
assertNull(TestUtils.getL2Reg(layerConf1.getRegularization()));
|
assertNull(TestUtils.getL2Reg(layerConf1.getRegularization()));
|
||||||
|
@ -197,12 +197,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3);
|
assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getUpdater()).getRmsDecay(), 1e-3);
|
||||||
assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
|
assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
|
||||||
assertNull(TestUtils.getL2Reg(layerConf.getRegularization()));
|
assertNull(TestUtils.getL2Reg(layerConf.getRegularization()));
|
||||||
|
|
||||||
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3);
|
assertEquals(0.4, ((RmsProp) layerConf1.getUpdater()).getRmsDecay(), 1e-3);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -164,16 +164,4 @@ public class SameDiffConv extends SameDiffLayer {
|
||||||
return activation.asSameDiff("out", sameDiff, conv);
|
return activation.asSameDiff("out", sameDiff, conv);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
|
||||||
NeuralNetConfiguration clone = globalConfig.clone().build();
|
|
||||||
if (activation == null) {
|
|
||||||
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
|
|
||||||
}
|
|
||||||
if (convolutionMode == null) {
|
|
||||||
convolutionMode = clone.getConvolutionMode();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,14 +114,6 @@ public class SameDiffDense extends SameDiffLayer {
|
||||||
return activation.asSameDiff("out", sd, z);
|
return activation.asSameDiff("out", sd, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void applyGlobalConfigToLayer(
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
|
||||||
NeuralNetConfiguration clone = globalConfig.clone().build();
|
|
||||||
if (activation == null) {
|
|
||||||
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public char paramReshapeOrder(String param) {
|
public char paramReshapeOrder(String param) {
|
||||||
// To match DL4J for easy comparison
|
// To match DL4J for easy comparison
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
|
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer;
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer;
|
||||||
|
@ -87,8 +86,4 @@ public class SameDiffMSEOutputLayer extends SameDiffOutputLayer {
|
||||||
// To match DL4J for easy comparison
|
// To match DL4J for easy comparison
|
||||||
return 'f';
|
return 'f';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void applyGlobalConfigToLayer(
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,7 +99,7 @@ public class TransferLearningComplex extends BaseDL4JTest {
|
||||||
|
|
||||||
//Also check config:
|
//Also check config:
|
||||||
BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
|
BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
|
||||||
assertEquals(new Adam(2e-2), bl.getIUpdater());
|
assertEquals(new Adam(2e-2), bl.getUpdater());
|
||||||
assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn());
|
assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn());
|
||||||
}
|
}
|
||||||
assertTrue(cFound);
|
assertTrue(cFound);
|
||||||
|
|
|
@ -92,7 +92,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
|
|
||||||
for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) {
|
for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) {
|
||||||
BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
|
BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
|
||||||
assertEquals(new RmsProp(0.5), bl.getIUpdater());
|
assertEquals(new RmsProp(0.5), bl.getUpdater());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -504,13 +504,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
|
|
||||||
//Check original net isn't modified:
|
//Check original net isn't modified:
|
||||||
BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new Adam(1e-4), l0.getIUpdater());
|
assertEquals(new Adam(1e-4), l0.getUpdater());
|
||||||
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
||||||
|
|
||||||
BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new Adam(1e-4), l1.getIUpdater());
|
assertEquals(new Adam(1e-4), l1.getUpdater());
|
||||||
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l1.getWeightInit());
|
assertEquals(new WeightInitRelu(), l1.getWeightInit());
|
||||||
assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
|
assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
|
||||||
|
@ -519,13 +519,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
|
|
||||||
//Check new net has only the appropriate things modified (i.e., LR)
|
//Check new net has only the appropriate things modified (i.e., LR)
|
||||||
l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
|
l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new Adam(2e-2), l0.getIUpdater());
|
assertEquals(new Adam(2e-2), l0.getUpdater());
|
||||||
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
||||||
|
|
||||||
l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration();
|
l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new Adam(2e-2), l1.getIUpdater());
|
assertEquals(new Adam(2e-2), l1.getUpdater());
|
||||||
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l1.getWeightInit());
|
assertEquals(new WeightInitRelu(), l1.getWeightInit());
|
||||||
assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
|
assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
|
||||||
|
|
|
@ -100,7 +100,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -145,7 +145,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
msdx.put(key, msdxTmp);
|
msdx.put(key, msdxTmp);
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
assertEquals(rho, ((AdaDelta)layer.getTypedLayerConfiguration().getIUpdater()).getRho(), 1e-4);
|
assertEquals(rho, ((AdaDelta)layer.getTypedLayerConfiguration().getUpdater()).getRho(), 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(4, count);
|
assertEquals(4, count);
|
||||||
|
@ -166,7 +166,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -186,7 +186,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
assertEquals(lr, ((AdaGrad)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
|
assertEquals(lr, ((AdaGrad)layer.getTypedLayerConfiguration().getUpdater()).getLearningRate(), 1e-4);
|
||||||
assertEquals(2, count);
|
assertEquals(2, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -246,8 +246,8 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(beta1, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4);
|
assertEquals(beta1, ((Adam)layer.getTypedLayerConfiguration().getUpdater()).getBeta1(), 1e-4);
|
||||||
assertEquals(beta2, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4);
|
assertEquals(beta2, ((Adam)layer.getTypedLayerConfiguration().getUpdater()).getBeta2(), 1e-4);
|
||||||
assertEquals(2, count);
|
assertEquals(2, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -274,7 +274,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
|
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -363,7 +363,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -399,8 +399,8 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(beta1, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4);
|
assertEquals(beta1, ((AdaMax)layer.getTypedLayerConfiguration().getUpdater()).getBeta1(), 1e-4);
|
||||||
assertEquals(beta2, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4);
|
assertEquals(beta2, ((AdaMax)layer.getTypedLayerConfiguration().getUpdater()).getBeta2(), 1e-4);
|
||||||
assertEquals(2, count);
|
assertEquals(2, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -419,7 +419,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -444,7 +444,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(mu, ((Nesterovs)layer.getTypedLayerConfiguration().getIUpdater()).getMomentum(), 1e-4);
|
assertEquals(mu, ((Nesterovs)layer.getTypedLayerConfiguration().getUpdater()).getMomentum(), 1e-4);
|
||||||
assertEquals(2, count);
|
assertEquals(2, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -466,7 +466,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -496,7 +496,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
||||||
lastG.put(key, lastGTmp);
|
lastG.put(key, lastGTmp);
|
||||||
}
|
}
|
||||||
assertEquals(rmsDecay, ((RmsProp)layer.getTypedLayerConfiguration().getIUpdater()).getRmsDecay(), 1e-4);
|
assertEquals(rmsDecay, ((RmsProp)layer.getTypedLayerConfiguration().getUpdater()).getRmsDecay(), 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -528,7 +528,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
gradExpected = val.mul(lr);
|
gradExpected = val.mul(lr);
|
||||||
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
||||||
}
|
}
|
||||||
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
|
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getUpdater()).getLearningRate(), 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -770,7 +770,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
gradExpected = val.mul(lr);
|
gradExpected = val.mul(lr);
|
||||||
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
||||||
}
|
}
|
||||||
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
|
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getUpdater()).getLearningRate(), 1e-4);
|
||||||
|
|
||||||
|
|
||||||
//Test with pretrain == false
|
//Test with pretrain == false
|
||||||
|
@ -798,7 +798,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
updater = UpdaterCreator.getUpdater(layer);
|
updater = UpdaterCreator.getUpdater(layer);
|
||||||
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
|
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getUpdater()).getLearningRate(), 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -61,18 +61,18 @@ public class TestCustomUpdater extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
//First: Check updater config
|
//First: Check updater config
|
||||||
assertTrue(((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater() instanceof CustomIUpdater);
|
assertTrue(((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getUpdater() instanceof CustomIUpdater);
|
||||||
assertTrue(((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater() instanceof CustomIUpdater);
|
assertTrue(((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getUpdater() instanceof CustomIUpdater);
|
||||||
assertTrue(((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater() instanceof Sgd);
|
assertTrue(((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getUpdater() instanceof Sgd);
|
||||||
assertTrue(((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater() instanceof Sgd);
|
assertTrue(((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getUpdater() instanceof Sgd);
|
||||||
|
|
||||||
CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater();
|
CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getUpdater();
|
||||||
CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater();
|
CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getUpdater();
|
||||||
assertEquals(lr, u0_0.getLearningRate(), 1e-6);
|
assertEquals(lr, u0_0.getLearningRate(), 1e-6);
|
||||||
assertEquals(lr, u0_1.getLearningRate(), 1e-6);
|
assertEquals(lr, u0_1.getLearningRate(), 1e-6);
|
||||||
|
|
||||||
Sgd u1_0 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater();
|
Sgd u1_0 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getUpdater();
|
||||||
Sgd u1_1 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater();
|
Sgd u1_1 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getUpdater();
|
||||||
assertEquals(lr, u1_0.getLearningRate(), 1e-6);
|
assertEquals(lr, u1_0.getLearningRate(), 1e-6);
|
||||||
assertEquals(lr, u1_1.getLearningRate(), 1e-6);
|
assertEquals(lr, u1_1.getLearningRate(), 1e-6);
|
||||||
|
|
||||||
|
|
|
@ -73,8 +73,8 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
|
assertEquals(new Nesterovs(0.15, 0.9), l0.getUpdater());
|
||||||
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
|
OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
|
||||||
assertEquals("softmax", l1.getActivationFn().toString());
|
assertEquals("softmax", l1.getActivationFn().toString());
|
||||||
|
@ -82,9 +82,9 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater());
|
assertEquals(new Nesterovs(0.15, 0.9), l1.getUpdater());
|
||||||
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
assertEquals(0.9, ((Nesterovs)l1.getUpdater()).getMomentum(), 1e-6);
|
||||||
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l1.getUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
|
@ -107,8 +107,8 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l0.getDropOut());
|
assertEquals(new Dropout(0.6), l0.getDropOut());
|
||||||
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
||||||
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l0));
|
||||||
|
@ -119,8 +119,8 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l1.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l1.getDropOut());
|
assertEquals(new Dropout(0.6), l1.getDropOut());
|
||||||
assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
|
||||||
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));
|
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));
|
||||||
|
@ -146,8 +146,8 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(3, l0.getNOut());
|
assertEquals(3, l0.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
||||||
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
||||||
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
||||||
|
@ -166,8 +166,8 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertEquals(26 * 26 * 3, l2.getNIn());
|
assertEquals(26 * 26 * 3, l2.getNIn());
|
||||||
assertEquals(5, l2.getNOut());
|
assertEquals(5, l2.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
|
|
|
@ -75,8 +75,8 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
|
assertEquals(new Nesterovs(0.15, 0.9), l0.getUpdater());
|
||||||
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
|
OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
|
||||||
assertEquals("softmax", l1.getActivationFn().toString());
|
assertEquals("softmax", l1.getActivationFn().toString());
|
||||||
|
@ -84,9 +84,9 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater());
|
assertEquals(new Nesterovs(0.15, 0.9), l1.getUpdater());
|
||||||
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
assertEquals(0.9, ((Nesterovs)l1.getUpdater()).getMomentum(), 1e-6);
|
||||||
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l1.getUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
|
@ -109,8 +109,8 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l0.getDropOut());
|
assertEquals(new Dropout(0.6), l0.getDropOut());
|
||||||
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
||||||
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l0));
|
||||||
|
@ -123,8 +123,8 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l1.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l1.getDropOut());
|
assertEquals(new Dropout(0.6), l1.getDropOut());
|
||||||
assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
|
||||||
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l1));
|
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l1));
|
||||||
|
@ -152,8 +152,8 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(3, l0.getNOut());
|
assertEquals(3, l0.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
||||||
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
||||||
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
||||||
|
@ -172,8 +172,8 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertEquals(26 * 26 * 3, l2.getNIn());
|
assertEquals(26 * 26 * 3, l2.getNIn());
|
||||||
assertEquals(5, l2.getNOut());
|
assertEquals(5, l2.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
|
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
|
||||||
|
|
||||||
|
|
|
@ -76,8 +76,8 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater());
|
assertEquals(new Nesterovs(0.15, 0.9), l0.getUpdater());
|
||||||
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
|
OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer();
|
||||||
assertEquals("softmax", l1.getActivationFn().toString());
|
assertEquals("softmax", l1.getActivationFn().toString());
|
||||||
|
@ -85,9 +85,9 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
assertEquals(0.9, ((Nesterovs)l1.getUpdater()).getMomentum(), 1e-6);
|
||||||
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
assertEquals(0.9, ((Nesterovs)l1.getUpdater()).getMomentum(), 1e-6);
|
||||||
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l1.getUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
long numParams = (int)net.numParams();
|
long numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
|
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
|
||||||
|
@ -110,8 +110,8 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l0.getDropOut());
|
assertEquals(new Dropout(0.6), l0.getDropOut());
|
||||||
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
||||||
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l0));
|
||||||
|
@ -124,8 +124,8 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l1.getDropOut());
|
assertEquals(new Dropout(0.6), l1.getDropOut());
|
||||||
assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
|
||||||
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l1));
|
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l1));
|
||||||
|
@ -153,8 +153,8 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(3, l0.getNOut());
|
assertEquals(3, l0.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
||||||
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
||||||
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
||||||
|
@ -173,8 +173,8 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertEquals(26 * 26 * 3, l2.getNIn());
|
assertEquals(26 * 26 * 3, l2.getNIn());
|
||||||
assertEquals(5, l2.getNOut());
|
assertEquals(5, l2.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater());
|
assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getUpdater());
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
|
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
|
||||||
|
|
||||||
|
|
|
@ -75,10 +75,10 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertTrue(l0.getIUpdater() instanceof Nesterovs);
|
assertTrue(l0.getUpdater() instanceof Nesterovs);
|
||||||
Nesterovs n = (Nesterovs) l0.getIUpdater();
|
Nesterovs n = (Nesterovs) l0.getUpdater();
|
||||||
assertEquals(0.9, n.getMomentum(), 1e-6);
|
assertEquals(0.9, n.getMomentum(), 1e-6);
|
||||||
assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(0.15, n.getLearningRate(), 1e-6);
|
assertEquals(0.15, n.getLearningRate(), 1e-6);
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,9 +88,9 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertTrue(l1.getIUpdater() instanceof Nesterovs);
|
assertTrue(l1.getUpdater() instanceof Nesterovs);
|
||||||
assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6);
|
assertEquals(0.9, ((Nesterovs)l1.getUpdater()).getMomentum(), 1e-6);
|
||||||
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l1.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(0.15, n.getLearningRate(), 1e-6);
|
assertEquals(0.15, n.getLearningRate(), 1e-6);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
|
@ -114,11 +114,11 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit());
|
||||||
assertTrue(l0.getIUpdater() instanceof RmsProp);
|
assertTrue(l0.getUpdater() instanceof RmsProp);
|
||||||
RmsProp r = (RmsProp) l0.getIUpdater();
|
RmsProp r = (RmsProp) l0.getUpdater();
|
||||||
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
||||||
assertEquals(0.15, r.getLearningRate(), 1e-6);
|
assertEquals(0.15, r.getLearningRate(), 1e-6);
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l0.getDropOut());
|
assertEquals(new Dropout(0.6), l0.getDropOut());
|
||||||
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
||||||
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.2,false), TestUtils.getWeightDecayReg(l0));
|
||||||
|
@ -131,11 +131,11 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertEquals(4, l1.getNIn());
|
assertEquals(4, l1.getNIn());
|
||||||
assertEquals(5, l1.getNOut());
|
assertEquals(5, l1.getNOut());
|
||||||
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInit());
|
assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInit());
|
||||||
assertTrue(l1.getIUpdater() instanceof RmsProp);
|
assertTrue(l1.getUpdater() instanceof RmsProp);
|
||||||
r = (RmsProp) l1.getIUpdater();
|
r = (RmsProp) l1.getUpdater();
|
||||||
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
||||||
assertEquals(0.15, r.getLearningRate(), 1e-6);
|
assertEquals(0.15, r.getLearningRate(), 1e-6);
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertEquals(new Dropout(0.6), l1.getDropOut());
|
assertEquals(new Dropout(0.6), l1.getDropOut());
|
||||||
assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l1), 1e-6);
|
||||||
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));
|
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));
|
||||||
|
@ -163,11 +163,11 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertEquals(3, l0.getNIn());
|
assertEquals(3, l0.getNIn());
|
||||||
assertEquals(3, l0.getNOut());
|
assertEquals(3, l0.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
assertEquals(new WeightInitRelu(), l0.getWeightInit());
|
||||||
assertTrue(l0.getIUpdater() instanceof RmsProp);
|
assertTrue(l0.getUpdater() instanceof RmsProp);
|
||||||
RmsProp r = (RmsProp) l0.getIUpdater();
|
RmsProp r = (RmsProp) l0.getUpdater();
|
||||||
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
||||||
assertEquals(0.15, r.getLearningRate(), 1e-6);
|
assertEquals(0.15, r.getLearningRate(), 1e-6);
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getUpdater()).getLearningRate(), 1e-6);
|
||||||
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
||||||
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
||||||
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
||||||
|
@ -186,8 +186,8 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertEquals(26 * 26 * 3, l2.getNIn());
|
assertEquals(26 * 26 * 3, l2.getNIn());
|
||||||
assertEquals(5, l2.getNOut());
|
assertEquals(5, l2.getNOut());
|
||||||
assertEquals(new WeightInitRelu(), l2.getWeightInit());
|
assertEquals(new WeightInitRelu(), l2.getWeightInit());
|
||||||
assertTrue(l2.getIUpdater() instanceof RmsProp);
|
assertTrue(l2.getUpdater() instanceof RmsProp);
|
||||||
r = (RmsProp) l2.getIUpdater();
|
r = (RmsProp) l2.getUpdater();
|
||||||
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
assertEquals(0.96, r.getRmsDecay(), 1e-6);
|
||||||
assertEquals(0.15, r.getLearningRate(), 1e-6);
|
assertEquals(0.15, r.getLearningRate(), 1e-6);
|
||||||
|
|
||||||
|
|
|
@ -91,21 +91,21 @@ public class RegressionTest100a extends BaseDL4JTest {
|
||||||
assertEquals(200, l0.getNOut());
|
assertEquals(200, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new RmsProp(0.1), l0.getIUpdater());
|
assertEquals(new RmsProp(0.1), l0.getUpdater());
|
||||||
|
|
||||||
GravesLSTM l1 = (GravesLSTM) net.getLayer(1).getLayerConfiguration();
|
GravesLSTM l1 = (GravesLSTM) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
assertEquals(200, l1.getNOut());
|
assertEquals(200, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l1));
|
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l1));
|
||||||
assertEquals(new RmsProp(0.1), l1.getIUpdater());
|
assertEquals(new RmsProp(0.1), l1.getUpdater());
|
||||||
|
|
||||||
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
||||||
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
||||||
assertEquals(77, l2.getNOut());
|
assertEquals(77, l2.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new RmsProp(0.1), l0.getIUpdater());
|
assertEquals(new RmsProp(0.1), l0.getUpdater());
|
||||||
|
|
||||||
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
|
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
|
||||||
assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
|
assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
|
||||||
|
@ -141,7 +141,7 @@ public class RegressionTest100a extends BaseDL4JTest {
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new Adam(0.05), l0.getIUpdater());
|
assertEquals(new Adam(0.05), l0.getUpdater());
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
File f2 = Resources.asFile("regression_testing/100a/VaeMNISTAnomaly_Output_100a.bin");
|
File f2 = Resources.asFile("regression_testing/100a/VaeMNISTAnomaly_Output_100a.bin");
|
||||||
|
|
|
@ -75,12 +75,12 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration();
|
DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
||||||
assertEquals(new WeightDecay(0.03, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.03, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new RmsProp(0.95), l0.getIUpdater());
|
assertEquals(new RmsProp(0.95), l0.getUpdater());
|
||||||
|
|
||||||
CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration();
|
CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
|
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
|
||||||
assertEquals(new RmsProp(0.95), l1.getIUpdater());
|
assertEquals(new RmsProp(0.95), l1.getUpdater());
|
||||||
|
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
|
@ -126,21 +126,21 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
assertEquals(200, l0.getNOut());
|
assertEquals(200, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getUpdater());
|
||||||
|
|
||||||
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
assertEquals(200, l1.getNOut());
|
assertEquals(200, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l1));
|
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l1));
|
||||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
assertEquals(new Adam(0.005), l1.getUpdater());
|
||||||
|
|
||||||
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
||||||
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
||||||
assertEquals(77, l2.getNOut());
|
assertEquals(77, l2.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
||||||
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getUpdater());
|
||||||
|
|
||||||
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
|
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
|
||||||
assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
|
assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
|
||||||
|
@ -176,7 +176,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
|
assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0));
|
||||||
assertEquals(new Adam(1e-3), l0.getIUpdater());
|
assertEquals(new Adam(1e-3), l0.getUpdater());
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
File f2 = Resources.asFile("regression_testing/100b3/VaeMNISTAnomaly_Output_100b3.bin");
|
File f2 = Resources.asFile("regression_testing/100b3/VaeMNISTAnomaly_Output_100b3.bin");
|
||||||
|
|
|
@ -94,12 +94,12 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration();
|
DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
||||||
assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new RmsProp(0.95), l0.getIUpdater());
|
assertEquals(new RmsProp(0.95), l0.getUpdater());
|
||||||
|
|
||||||
CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration();
|
CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
|
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
|
||||||
assertEquals(new RmsProp(0.95), l1.getIUpdater());
|
assertEquals(new RmsProp(0.95), l1.getUpdater());
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
File f2 = Resources
|
File f2 = Resources
|
||||||
|
@ -144,21 +144,21 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
assertEquals(200, l0.getNOut());
|
assertEquals(200, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getUpdater());
|
||||||
|
|
||||||
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
assertEquals(200, l1.getNOut());
|
assertEquals(200, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
assertEquals(new Adam(0.005), l1.getUpdater());
|
||||||
|
|
||||||
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
||||||
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
||||||
assertEquals(77, l2.getNOut());
|
assertEquals(77, l2.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
||||||
assertEquals(new Adam(0.005), l2.getIUpdater());
|
assertEquals(new Adam(0.005), l2.getUpdater());
|
||||||
|
|
||||||
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
|
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
|
||||||
assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
|
assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
|
||||||
|
@ -194,7 +194,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(1e-3), l0.getIUpdater());
|
assertEquals(new Adam(1e-3), l0.getUpdater());
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
File f2 = Resources.asFile("regression_testing/100b4/VaeMNISTAnomaly_Output_100b4.bin");
|
File f2 = Resources.asFile("regression_testing/100b4/VaeMNISTAnomaly_Output_100b4.bin");
|
||||||
|
@ -262,7 +262,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
|
||||||
assertArrayEquals(new int[]{2, 1}, l0.getStride());
|
assertArrayEquals(new int[]{2, 1}, l0.getStride());
|
||||||
assertArrayEquals(new int[]{1, 1}, l0.getDilation());
|
assertArrayEquals(new int[]{1, 1}, l0.getDilation());
|
||||||
|
@ -273,7 +273,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
assertEquals(8, l1.getNOut());
|
assertEquals(8, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
assertEquals(new Adam(0.005), l1.getUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
|
||||||
assertArrayEquals(new int[]{1, 1}, l1.getStride());
|
assertArrayEquals(new int[]{1, 1}, l1.getStride());
|
||||||
assertArrayEquals(new int[]{1, 1}, l1.getDilation());
|
assertArrayEquals(new int[]{1, 1}, l1.getDilation());
|
||||||
|
@ -299,7 +299,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
assertEquals(16, l5.getNOut());
|
assertEquals(16, l5.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l5.getWeightInit());
|
assertEquals(new WeightInitXavier(), l5.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
||||||
assertEquals(new Adam(0.005), l5.getIUpdater());
|
assertEquals(new Adam(0.005), l5.getUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
|
||||||
assertArrayEquals(new int[]{1, 1}, l5.getStride());
|
assertArrayEquals(new int[]{1, 1}, l5.getStride());
|
||||||
assertArrayEquals(new int[]{1, 1}, l5.getDilation());
|
assertArrayEquals(new int[]{1, 1}, l5.getDilation());
|
||||||
|
@ -320,7 +320,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
assertEquals(4, l8.getNOut());
|
assertEquals(4, l8.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l8.getWeightInit());
|
assertEquals(new WeightInitXavier(), l8.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
|
||||||
assertEquals(new Adam(0.005), l8.getIUpdater());
|
assertEquals(new Adam(0.005), l8.getUpdater());
|
||||||
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
|
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
|
||||||
assertArrayEquals(new int[]{1, 1}, l8.getStride());
|
assertArrayEquals(new int[]{1, 1}, l8.getStride());
|
||||||
assertArrayEquals(new int[]{1, 1}, l8.getDilation());
|
assertArrayEquals(new int[]{1, 1}, l8.getDilation());
|
||||||
|
@ -329,7 +329,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
|
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
|
||||||
assertEquals(new WeightInitXavier(), l9.getWeightInit());
|
assertEquals(new WeightInitXavier(), l9.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
|
||||||
assertEquals(new Adam(0.005), l9.getIUpdater());
|
assertEquals(new Adam(0.005), l9.getUpdater());
|
||||||
assertEquals(new LossMAE(), l9.getLossFunction());
|
assertEquals(new LossMAE(), l9.getLossFunction());
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
|
|
|
@ -76,12 +76,12 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration();
|
DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
||||||
assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new RmsProp(0.95), l0.getIUpdater());
|
assertEquals(new RmsProp(0.95), l0.getUpdater());
|
||||||
|
|
||||||
CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration();
|
CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
|
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
|
||||||
assertEquals(new RmsProp(0.95), l1.getIUpdater());
|
assertEquals(new RmsProp(0.95), l1.getUpdater());
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
File f2 = Resources
|
File f2 = Resources
|
||||||
|
@ -126,21 +126,21 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertEquals(200, l0.getNOut());
|
assertEquals(200, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getUpdater());
|
||||||
|
|
||||||
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
||||||
assertEquals(200, l1.getNOut());
|
assertEquals(200, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
assertEquals(new Adam(0.005), l1.getUpdater());
|
||||||
|
|
||||||
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
||||||
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
||||||
assertEquals(77, l2.getNOut());
|
assertEquals(77, l2.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
||||||
assertEquals(new Adam(0.005), l2.getIUpdater());
|
assertEquals(new Adam(0.005), l2.getUpdater());
|
||||||
|
|
||||||
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
|
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
|
||||||
assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
|
assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
|
||||||
|
@ -176,7 +176,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(1e-3), l0.getIUpdater());
|
assertEquals(new Adam(1e-3), l0.getUpdater());
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
File f2 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Output_100b6.bin");
|
File f2 = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_Output_100b6.bin");
|
||||||
|
@ -242,7 +242,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertEquals(4, l0.getNOut());
|
assertEquals(4, l0.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
||||||
assertEquals(new Adam(0.005), l0.getIUpdater());
|
assertEquals(new Adam(0.005), l0.getUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
|
||||||
assertArrayEquals(new int[]{2, 1}, l0.getStride());
|
assertArrayEquals(new int[]{2, 1}, l0.getStride());
|
||||||
assertArrayEquals(new int[]{1, 1}, l0.getDilation());
|
assertArrayEquals(new int[]{1, 1}, l0.getDilation());
|
||||||
|
@ -253,7 +253,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertEquals(8, l1.getNOut());
|
assertEquals(8, l1.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
||||||
assertEquals(new Adam(0.005), l1.getIUpdater());
|
assertEquals(new Adam(0.005), l1.getUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
|
||||||
assertArrayEquals(new int[]{1, 1}, l1.getStride());
|
assertArrayEquals(new int[]{1, 1}, l1.getStride());
|
||||||
assertArrayEquals(new int[]{1, 1}, l1.getDilation());
|
assertArrayEquals(new int[]{1, 1}, l1.getDilation());
|
||||||
|
@ -279,7 +279,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertEquals(16, l5.getNOut());
|
assertEquals(16, l5.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l5.getWeightInit());
|
assertEquals(new WeightInitXavier(), l5.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
||||||
assertEquals(new Adam(0.005), l5.getIUpdater());
|
assertEquals(new Adam(0.005), l5.getUpdater());
|
||||||
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
|
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
|
||||||
assertArrayEquals(new int[]{1, 1}, l5.getStride());
|
assertArrayEquals(new int[]{1, 1}, l5.getStride());
|
||||||
assertArrayEquals(new int[]{1, 1}, l5.getDilation());
|
assertArrayEquals(new int[]{1, 1}, l5.getDilation());
|
||||||
|
@ -300,7 +300,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertEquals(4, l8.getNOut());
|
assertEquals(4, l8.getNOut());
|
||||||
assertEquals(new WeightInitXavier(), l8.getWeightInit());
|
assertEquals(new WeightInitXavier(), l8.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
|
||||||
assertEquals(new Adam(0.005), l8.getIUpdater());
|
assertEquals(new Adam(0.005), l8.getUpdater());
|
||||||
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
|
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
|
||||||
assertArrayEquals(new int[]{1, 1}, l8.getStride());
|
assertArrayEquals(new int[]{1, 1}, l8.getStride());
|
||||||
assertArrayEquals(new int[]{1, 1}, l8.getDilation());
|
assertArrayEquals(new int[]{1, 1}, l8.getDilation());
|
||||||
|
@ -309,7 +309,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
|
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
|
||||||
assertEquals(new WeightInitXavier(), l9.getWeightInit());
|
assertEquals(new WeightInitXavier(), l9.getWeightInit());
|
||||||
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
|
||||||
assertEquals(new Adam(0.005), l9.getIUpdater());
|
assertEquals(new Adam(0.005), l9.getUpdater());
|
||||||
assertEquals(new LossMAE(), l9.getLossFunction());
|
assertEquals(new LossMAE(), l9.getLossFunction());
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
|
|
|
@ -113,7 +113,7 @@ public class CustomLayer extends FeedForwardLayer {
|
||||||
InputType outputType = getOutputType(-1, inputType);
|
InputType outputType = getOutputType(-1, inputType);
|
||||||
|
|
||||||
val numParams = initializer().numParams(this);
|
val numParams = initializer().numParams(this);
|
||||||
int updaterStateSize = (int) getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) getUpdater().stateSize(numParams);
|
||||||
|
|
||||||
int trainSizeFixed = 0;
|
int trainSizeFixed = 0;
|
||||||
int trainSizeVariable = 0;
|
int trainSizeVariable = 0;
|
||||||
|
|
|
@ -235,7 +235,7 @@ public class GradientCheckUtil {
|
||||||
for (LayerConfiguration n : c.net.getNetConfiguration().getFlattenedLayerConfigurations()) {
|
for (LayerConfiguration n : c.net.getNetConfiguration().getFlattenedLayerConfigurations()) {
|
||||||
if (n instanceof BaseLayerConfiguration) {
|
if (n instanceof BaseLayerConfiguration) {
|
||||||
BaseLayerConfiguration bl = (BaseLayerConfiguration) n;
|
BaseLayerConfiguration bl = (BaseLayerConfiguration) n;
|
||||||
IUpdater u = bl.getIUpdater();
|
IUpdater u = bl.getUpdater();
|
||||||
if (u instanceof Sgd) {
|
if (u instanceof Sgd) {
|
||||||
// Must have LR of 1.0
|
// Must have LR of 1.0
|
||||||
double lr = ((Sgd) u).getLearningRate();
|
double lr = ((Sgd) u).getLearningRate();
|
||||||
|
@ -540,7 +540,7 @@ public class GradientCheckUtil {
|
||||||
|
|
||||||
if (lv.getLayerConfiguration() instanceof BaseLayerConfiguration) {
|
if (lv.getLayerConfiguration() instanceof BaseLayerConfiguration) {
|
||||||
BaseLayerConfiguration bl = (BaseLayerConfiguration) lv.getLayerConfiguration();
|
BaseLayerConfiguration bl = (BaseLayerConfiguration) lv.getLayerConfiguration();
|
||||||
IUpdater u = bl.getIUpdater();
|
IUpdater u = bl.getUpdater();
|
||||||
if (u instanceof Sgd) {
|
if (u instanceof Sgd) {
|
||||||
// Must have LR of 1.0
|
// Must have LR of 1.0
|
||||||
double lr = ((Sgd) u).getLearningRate();
|
double lr = ((Sgd) u).getLearningRate();
|
||||||
|
|
|
@ -322,7 +322,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
@Getter @Setter @lombok.Builder.Default @JsonIgnore private IUpdater iUpdater = new Sgd();
|
@Getter @Setter @lombok.Builder.Default @JsonIgnore private IUpdater iUpdater = new Sgd();
|
||||||
/**
|
/**
|
||||||
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
|
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
|
||||||
* set by {@link #setIUpdater(IUpdater)}<br>
|
* set by {@link #setUpdater(IUpdater)}<br>
|
||||||
* Note: values set by this method will be applied to all applicable layers in the network, unless
|
* Note: values set by this method will be applied to all applicable layers in the network, unless
|
||||||
* a different value is explicitly set on a given layer. In other words: values set via this
|
* a different value is explicitly set on a given layer. In other words: values set via this
|
||||||
* method are used as the default value, and can be overridden on a per-layer basis.
|
* method are used as the default value, and can be overridden on a per-layer basis.
|
||||||
|
|
|
@ -65,7 +65,7 @@ public class ActivationLayer extends NoParamLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public IUpdater getIUpdater() {
|
public IUpdater getUpdater() {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -88,7 +88,7 @@ public class AutoEncoder extends BasePretrainNetwork {
|
||||||
|
|
||||||
val actElementsPerEx = outputType.arrayElementsPerExample() + inputType.arrayElementsPerExample();
|
val actElementsPerEx = outputType.arrayElementsPerExample() + inputType.arrayElementsPerExample();
|
||||||
val numParams = initializer().numParams(this);
|
val numParams = initializer().numParams(this);
|
||||||
val updaterStateSize = (int) getIUpdater().stateSize(numParams);
|
val updaterStateSize = (int) getUpdater().stateSize(numParams);
|
||||||
|
|
||||||
int trainSizePerEx = 0;
|
int trainSizePerEx = 0;
|
||||||
if (getDropOut() != null) {
|
if (getDropOut() != null) {
|
||||||
|
|
|
@ -95,7 +95,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
|
||||||
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
|
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
|
||||||
* org.nd4j.linalg.learning.config.Nesterovs}
|
* org.nd4j.linalg.learning.config.Nesterovs}
|
||||||
*/
|
*/
|
||||||
@Getter @Setter
|
@Getter
|
||||||
protected IUpdater updater;
|
protected IUpdater updater;
|
||||||
/**
|
/**
|
||||||
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
|
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
|
||||||
|
@ -134,7 +134,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
|
||||||
*/
|
*/
|
||||||
public void resetLayerDefaultConfig() {
|
public void resetLayerDefaultConfig() {
|
||||||
// clear the learning related params for all layers in the origConf and set to defaults
|
// clear the learning related params for all layers in the origConf and set to defaults
|
||||||
this.setIUpdater(null);
|
this.setUpdater( (IUpdater) null);
|
||||||
this.setWeightInit(null);
|
this.setWeightInit(null);
|
||||||
this.setBiasInit(Double.NaN);
|
this.setBiasInit(Double.NaN);
|
||||||
this.setGainInit(Double.NaN);
|
this.setGainInit(Double.NaN);
|
||||||
|
@ -142,10 +142,16 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
|
||||||
this.regularizationBias = null;
|
this.regularizationBias = null;
|
||||||
this.setGradientNormalization(GradientNormalization.None);
|
this.setGradientNormalization(GradientNormalization.None);
|
||||||
this.setGradientNormalizationThreshold(1.0);
|
this.setGradientNormalizationThreshold(1.0);
|
||||||
this.updater = null;
|
|
||||||
this.biasUpdater = null;
|
this.biasUpdater = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void setUpdater(Updater updater) {
|
||||||
|
setUpdater(updater.getIUpdaterWithDefaultConfig());
|
||||||
|
}
|
||||||
|
public void setUpdater(IUpdater iUpdater) {
|
||||||
|
this.updater=iUpdater;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BaseLayerConfiguration clone() {
|
public BaseLayerConfiguration clone() {
|
||||||
BaseLayerConfiguration clone = (BaseLayerConfiguration) super.clone();
|
BaseLayerConfiguration clone = (BaseLayerConfiguration) super.clone();
|
||||||
|
@ -203,6 +209,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
|
||||||
if (this.updater == null) this.updater = conf.getUpdater();
|
if (this.updater == null) this.updater = conf.getUpdater();
|
||||||
if (this.regularizationBias == null) this.regularizationBias = conf.getRegularizationBias();
|
if (this.regularizationBias == null) this.regularizationBias = conf.getRegularizationBias();
|
||||||
if (this.regularization == null) this.regularization = conf.getRegularization();
|
if (this.regularization == null) this.regularization = conf.getRegularization();
|
||||||
|
if( this.weightInit == null) this.weightInit = conf.getWeightInit();
|
||||||
if (this.gradientNormalization == null)
|
if (this.gradientNormalization == null)
|
||||||
this.gradientNormalization = conf.getGradientNormalization();
|
this.gradientNormalization = conf.getGradientNormalization();
|
||||||
// if(this.weightInit == null) this.weightInit = conf.getWeightInit();
|
// if(this.weightInit == null) this.weightInit = conf.getWeightInit();
|
||||||
|
|
|
@ -56,7 +56,7 @@ public abstract class BaseOutputLayer extends FeedForwardLayer {
|
||||||
InputType outputType = getOutputType(-1, inputType);
|
InputType outputType = getOutputType(-1, inputType);
|
||||||
|
|
||||||
val numParams = initializer().numParams(this);
|
val numParams = initializer().numParams(this);
|
||||||
val updaterStateSize = (int) getIUpdater().stateSize(numParams);
|
val updaterStateSize = (int) getUpdater().stateSize(numParams);
|
||||||
|
|
||||||
int trainSizeFixed = 0;
|
int trainSizeFixed = 0;
|
||||||
int trainSizeVariable = 0;
|
int trainSizeVariable = 0;
|
||||||
|
|
|
@ -235,7 +235,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
@Override
|
@Override
|
||||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||||
val paramSize = initializer().numParams(this);
|
val paramSize = initializer().numParams(this);
|
||||||
val updaterStateSize = (int) getIUpdater().stateSize(paramSize);
|
val updaterStateSize = (int) getUpdater().stateSize(paramSize);
|
||||||
|
|
||||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||||
InputType.InputTypeConvolutional outputType =
|
InputType.InputTypeConvolutional outputType =
|
||||||
|
|
|
@ -60,7 +60,7 @@ public class DenseLayer extends FeedForwardLayer {
|
||||||
LayerValidation.assertNInNOutSet(
|
LayerValidation.assertNInNOutSet(
|
||||||
"DenseLayerConfiguration", getName(), layerIndex, getNIn(), getNOut());
|
"DenseLayerConfiguration", getName(), layerIndex, getNIn(), getNOut());
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
lconf.runInheritance();
|
runInheritance();
|
||||||
|
|
||||||
org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret =
|
org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret =
|
||||||
new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType);
|
new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType);
|
||||||
|
@ -84,7 +84,7 @@ public class DenseLayer extends FeedForwardLayer {
|
||||||
InputType outputType = getOutputType(-1, inputType);
|
InputType outputType = getOutputType(-1, inputType);
|
||||||
|
|
||||||
val numParams = initializer().numParams(this);
|
val numParams = initializer().numParams(this);
|
||||||
val updaterStateSize = (int) getIUpdater().stateSize(numParams);
|
val updaterStateSize = (int) getUpdater().stateSize(numParams);
|
||||||
|
|
||||||
int trainSizeFixed = 0;
|
int trainSizeFixed = 0;
|
||||||
int trainSizeVariable = 0;
|
int trainSizeVariable = 0;
|
||||||
|
|
|
@ -96,7 +96,7 @@ public class EmbeddingLayer extends FeedForwardLayer {
|
||||||
|
|
||||||
val actElementsPerEx = outputType.arrayElementsPerExample();
|
val actElementsPerEx = outputType.arrayElementsPerExample();
|
||||||
val numParams = initializer().numParams(this);
|
val numParams = initializer().numParams(this);
|
||||||
val updaterStateSize = (int) getIUpdater().stateSize(numParams);
|
val updaterStateSize = (int) getUpdater().stateSize(numParams);
|
||||||
|
|
||||||
// Embedding layer does not use caching.
|
// Embedding layer does not use caching.
|
||||||
// Inference: no working memory - just activations (pullRows)
|
// Inference: no working memory - just activations (pullRows)
|
||||||
|
|
|
@ -162,7 +162,7 @@ extends FeedForwardLayerBuilder<C, B> {
|
||||||
|
|
||||||
val actElementsPerEx = outputType.arrayElementsPerExample();
|
val actElementsPerEx = outputType.arrayElementsPerExample();
|
||||||
val numParams = initializer().numParams(this);
|
val numParams = initializer().numParams(this);
|
||||||
val updaterStateSize = (int) getIUpdater().stateSize(numParams);
|
val updaterStateSize = (int) getUpdater().stateSize(numParams);
|
||||||
|
|
||||||
return new LayerMemoryReport.Builder(name, EmbeddingSequenceLayer.class, inputType, outputType)
|
return new LayerMemoryReport.Builder(name, EmbeddingSequenceLayer.class, inputType, outputType)
|
||||||
.standardMemory(numParams, updaterStateSize).workingMemory(0, 0, 0, actElementsPerEx)
|
.standardMemory(numParams, updaterStateSize).workingMemory(0, 0, 0, actElementsPerEx)
|
||||||
|
|
|
@ -34,6 +34,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.Updater;
|
||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
||||||
import org.deeplearning4j.nn.conf.dropout.IDropout;
|
import org.deeplearning4j.nn.conf.dropout.IDropout;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -281,16 +282,6 @@ public abstract class LayerConfiguration
|
||||||
"Not supported: all layers with parameters should override this method");
|
"Not supported: all layers with parameters should override this method");
|
||||||
}
|
}
|
||||||
|
|
||||||
public IUpdater getIUpdater() {
|
|
||||||
throw new UnsupportedOperationException(
|
|
||||||
"Not supported: all layers with parameters should override this method");
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setIUpdater(IUpdater iUpdater) {
|
|
||||||
log.warn(
|
|
||||||
"Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), name);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is a report of the estimated memory consumption for the given layer
|
* This is a report of the estimated memory consumption for the given layer
|
||||||
*
|
*
|
||||||
|
@ -316,6 +307,7 @@ public abstract class LayerConfiguration
|
||||||
if (this.activation == null) this.activation = conf.getActivation();
|
if (this.activation == null) this.activation = conf.getActivation();
|
||||||
if (this.dropOut == null) this.dropOut = conf.getIdropOut();
|
if (this.dropOut == null) this.dropOut = conf.getIdropOut();
|
||||||
if (this.weightNoise == null) this.weightNoise = conf.getWeightNoise();
|
if (this.weightNoise == null) this.weightNoise = conf.getWeightNoise();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -326,6 +318,24 @@ public abstract class LayerConfiguration
|
||||||
runInheritance(getNetConfiguration());
|
runInheritance(getNetConfiguration());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This will always return no-Op updater.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public IUpdater getUpdater() {
|
||||||
|
log.warn("Calling getUpdater() in {} will always return no-Op Updater.", LayerConfiguration.class.getSimpleName());
|
||||||
|
return Updater.NONE.getIUpdaterWithDefaultConfig();
|
||||||
|
}
|
||||||
|
@Deprecated
|
||||||
|
public void setUpdater(Updater updater) {
|
||||||
|
setUpdater(updater.getIUpdaterWithDefaultConfig());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setUpdater(IUpdater iUpdater) {
|
||||||
|
throw new RuntimeException("When " + this.getName() + " wants to you an Updater, it needs to override the "
|
||||||
|
+ "Getter/ Setter for the Updater and not rely on LayerConfiguration class.");
|
||||||
|
}
|
||||||
|
|
||||||
public abstract static class LayerConfigurationBuilder<
|
public abstract static class LayerConfigurationBuilder<
|
||||||
C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
|
C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
|
||||||
|
|
||||||
|
|
|
@ -249,17 +249,6 @@ public class LocallyConnected1D extends SameDiffLayer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void applyGlobalConfigToLayer(
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
|
||||||
NeuralNetConfiguration global_conf = globalConfig.build();
|
|
||||||
if (activation == null) {
|
|
||||||
activation = SameDiffLayerUtils.fromIActivation(global_conf.getActivation());
|
|
||||||
}
|
|
||||||
if (convolutionMode == null) {
|
|
||||||
convolutionMode = global_conf.getConvolutionMode();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final class LocallyConnected1DBuilderImpl
|
private static final class LocallyConnected1DBuilderImpl
|
||||||
extends LocallyConnected1DBuilder<LocallyConnected1D, LocallyConnected1DBuilderImpl> {
|
extends LocallyConnected1DBuilder<LocallyConnected1D, LocallyConnected1DBuilderImpl> {
|
||||||
|
|
|
@ -305,17 +305,6 @@ public class LocallyConnected2D extends SameDiffLayer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void applyGlobalConfigToLayer(
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
|
||||||
NeuralNetConfiguration gconf = globalConfig.build();
|
|
||||||
if (activation == null) {
|
|
||||||
activation = SameDiffLayerUtils.fromIActivation(gconf.getActivation());
|
|
||||||
}
|
|
||||||
if (convolutionMode == null) {
|
|
||||||
convolutionMode = gconf.getConvolutionMode();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final class LocallyConnected2DBuilderImpl
|
private static final class LocallyConnected2DBuilderImpl
|
||||||
extends LocallyConnected2DBuilder<LocallyConnected2D, LocallyConnected2DBuilderImpl> {
|
extends LocallyConnected2DBuilder<LocallyConnected2D, LocallyConnected2DBuilderImpl> {
|
||||||
|
|
|
@ -56,10 +56,11 @@ public abstract class NoParamLayer extends LayerConfiguration {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* Will always return no-Op updater.
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public IUpdater getIUpdater() {
|
public IUpdater getUpdater() {
|
||||||
return Updater.NONE.getIUpdaterWithDefaultConfig();
|
return Updater.NONE.getIUpdaterWithDefaultConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,9 +40,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
@SuperBuilder(builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class OutputLayer extends BaseOutputLayer {
|
public class OutputLayer extends BaseOutputLayer {
|
||||||
|
|
||||||
{ // Set default activation function to softmax (to match default loss function MCXENT)
|
|
||||||
setActivation(Activation.SOFTMAX.getActivationFunction());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static OutputLayerBuilder<?, ?> builder() {
|
public static OutputLayerBuilder<?, ?> builder() {
|
||||||
return innerBuilder();
|
return innerBuilder();
|
||||||
|
|
|
@ -120,7 +120,7 @@ public class PReLULayer extends BaseLayerConfiguration {
|
||||||
InputType outputType = getOutputType(-1, inputType);
|
InputType outputType = getOutputType(-1, inputType);
|
||||||
|
|
||||||
val numParams = initializer().numParams(this);
|
val numParams = initializer().numParams(this);
|
||||||
val updaterStateSize = (int) getIUpdater().stateSize(numParams);
|
val updaterStateSize = (int) getUpdater().stateSize(numParams);
|
||||||
|
|
||||||
return new LayerMemoryReport.Builder(name, PReLULayer.class, inputType, outputType)
|
return new LayerMemoryReport.Builder(name, PReLULayer.class, inputType, outputType)
|
||||||
.standardMemory(numParams, updaterStateSize)
|
.standardMemory(numParams, updaterStateSize)
|
||||||
|
|
|
@ -22,6 +22,7 @@ package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
|
import org.deeplearning4j.nn.api.ITraininableLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
|
@ -35,6 +36,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -44,7 +46,9 @@ import java.util.Map;
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder
|
@SuperBuilder
|
||||||
public class RecurrentAttentionLayer extends SameDiffLayer {
|
public class RecurrentAttentionLayer extends SameDiffLayer implements ITraininableLayerConfiguration {
|
||||||
|
|
||||||
|
private DataType dataType;
|
||||||
|
|
||||||
private static final class RecurrentAttentionLayerBuilderImpl extends RecurrentAttentionLayerBuilder<RecurrentAttentionLayer, RecurrentAttentionLayerBuilderImpl> {
|
private static final class RecurrentAttentionLayerBuilderImpl extends RecurrentAttentionLayerBuilder<RecurrentAttentionLayer, RecurrentAttentionLayerBuilderImpl> {
|
||||||
public RecurrentAttentionLayer build() {
|
public RecurrentAttentionLayer build() {
|
||||||
|
@ -190,13 +194,6 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
|
||||||
if (activation == null) {
|
|
||||||
activation = SameDiffLayerUtils.fromIActivation(globalConfig.build().getActivation());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void validateInput(INDArray input) {
|
public void validateInput(INDArray input) {
|
||||||
final long inputLength = input.size(2);
|
final long inputLength = input.size(2);
|
||||||
|
|
|
@ -91,7 +91,7 @@ public class ElementWiseMultiplicationLayer extends org.deeplearning4j.nn.conf.l
|
||||||
InputType outputType = getOutputType(-1, inputType);
|
InputType outputType = getOutputType(-1, inputType);
|
||||||
|
|
||||||
val numParams = initializer().numParams(this);
|
val numParams = initializer().numParams(this);
|
||||||
val updaterStateSize = (int) getIUpdater().stateSize(numParams);
|
val updaterStateSize = (int) getUpdater().stateSize(numParams);
|
||||||
|
|
||||||
int trainSizeFixed = 0;
|
int trainSizeFixed = 0;
|
||||||
int trainSizeVariable = 0;
|
int trainSizeVariable = 0;
|
||||||
|
|
|
@ -82,6 +82,7 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayerConfiguration {
|
||||||
org.deeplearning4j.nn.api.Layer newUnderlyingLayer = underlying.instantiate(conf, trainingListeners,
|
org.deeplearning4j.nn.api.Layer newUnderlyingLayer = underlying.instantiate(conf, trainingListeners,
|
||||||
layerIndex, layerParamsView, initializeParams, networkDataType);
|
layerIndex, layerParamsView, initializeParams, networkDataType);
|
||||||
|
|
||||||
|
runInheritance();
|
||||||
newUnderlyingLayer.setLayerConfiguration(underlying); //Fix a problem, where the embedded layer gets the conf of the frozen layer, rather than its own
|
newUnderlyingLayer.setLayerConfiguration(underlying); //Fix a problem, where the embedded layer gets the conf of the frozen layer, rather than its own
|
||||||
NeuralNetConfiguration nncUnderlying = underlying.getNetConfiguration();
|
NeuralNetConfiguration nncUnderlying = underlying.getNetConfiguration();
|
||||||
|
|
||||||
|
@ -130,4 +131,6 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayerConfiguration {
|
||||||
this.constraints = constraints;
|
this.constraints = constraints;
|
||||||
this.underlying.setConstraints(constraints);
|
this.underlying.setConstraints(constraints);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,7 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.deeplearning4j.util.NetworkUtils;
|
import org.deeplearning4j.util.NetworkUtils;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -93,6 +94,21 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private SDLayerParams layerParams;
|
private SDLayerParams layerParams;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void runInheritance(@NotNull NeuralNetConfiguration conf) {
|
||||||
|
super.runInheritance(conf);
|
||||||
|
if (this.biasUpdater == null ) this.biasUpdater = conf.getBiasUpdater();
|
||||||
|
if (this.updater == null) this.updater = conf.getUpdater();
|
||||||
|
if (this.regularizationBias == null || regularizationBias.isEmpty()) this.regularizationBias = conf.getRegularizationBias();
|
||||||
|
if (this.regularization == null || regularization.isEmpty()) this.regularization = conf.getRegularization();
|
||||||
|
// if( this.weightInit == null) this.weightInit = conf.getWeightInit();
|
||||||
|
if (this.gradientNormalization == null)
|
||||||
|
this.gradientNormalization = conf.getGradientNormalization();
|
||||||
|
// if(this.weightInit == null) this.weightInit = conf.getWeightInit();
|
||||||
|
if (Double.isNaN(gradientNormalizationThreshold)) {
|
||||||
|
this.gradientNormalizationThreshold = conf.getGradientNormalizationThreshold();
|
||||||
|
}
|
||||||
|
}
|
||||||
@Override
|
@Override
|
||||||
public List<Regularization> getRegularizationByParam(String paramName) {
|
public List<Regularization> getRegularizationByParam(String paramName) {
|
||||||
if (layerParams.isWeightParam(paramName)) {
|
if (layerParams.isWeightParam(paramName)) {
|
||||||
|
@ -122,10 +138,6 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void applyGlobalConfigToLayer(
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
|
||||||
// Default implementation: no op
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Define the parameters for the network. Use {@link SDLayerParams#addWeightParam(String,
|
* Define the parameters for the network. Use {@link SDLayerParams#addWeightParam(String,
|
||||||
|
@ -195,29 +207,6 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
fanIn, fanOut, array.shape(), weightInit, null, paramReshapeOrder(null), array);
|
fanIn, fanOut, array.shape(), weightInit, null, paramReshapeOrder(null), array);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void applyGlobalConfig(NeuralNetConfiguration.NeuralNetConfigurationBuilder b) {
|
|
||||||
NeuralNetConfiguration bConf = b.build();
|
|
||||||
if (regularization == null || regularization.isEmpty()) {
|
|
||||||
regularization = bConf.getRegularization();
|
|
||||||
}
|
|
||||||
if (regularizationBias == null || regularizationBias.isEmpty()) {
|
|
||||||
regularizationBias = bConf.getRegularizationBias();
|
|
||||||
}
|
|
||||||
if (updater == null) {
|
|
||||||
updater = bConf.getUpdater();
|
|
||||||
}
|
|
||||||
if (biasUpdater == null) {
|
|
||||||
biasUpdater = bConf.getBiasUpdater();
|
|
||||||
}
|
|
||||||
if (gradientNormalization == null) {
|
|
||||||
gradientNormalization = bConf.getGradientNormalization();
|
|
||||||
}
|
|
||||||
if (Double.isNaN(gradientNormalizationThreshold)) {
|
|
||||||
gradientNormalizationThreshold = bConf.getGradientNormalizationThreshold();
|
|
||||||
}
|
|
||||||
|
|
||||||
applyGlobalConfigToLayer(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method generates an "all ones" mask array for use in the SameDiff model when none is
|
* This method generates an "all ones" mask array for use in the SameDiff model when none is
|
||||||
|
|
|
@ -137,7 +137,7 @@ public class VariationalAutoencoder extends BasePretrainNetwork {
|
||||||
|
|
||||||
val actElementsPerEx = outputType.arrayElementsPerExample();
|
val actElementsPerEx = outputType.arrayElementsPerExample();
|
||||||
val numParams = initializer().numParams(this);
|
val numParams = initializer().numParams(this);
|
||||||
int updaterStateSize = (int) getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) getUpdater().stateSize(numParams);
|
||||||
|
|
||||||
int inferenceWorkingMemSizePerEx = 0;
|
int inferenceWorkingMemSizePerEx = 0;
|
||||||
// Forward pass size through the encoder:
|
// Forward pass size through the encoder:
|
||||||
|
|
|
@ -123,17 +123,17 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
|
||||||
/**
|
/**
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
@Override
|
|
||||||
public IUpdater getIUpdater() {
|
public IUpdater getUpdater() {
|
||||||
return underlying.getIUpdater();
|
return underlying.getUpdater();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param iUpdater
|
* @param iUpdater
|
||||||
*/
|
*/
|
||||||
@Override
|
|
||||||
public void setIUpdater(IUpdater iUpdater) {
|
public void setUpdater(IUpdater iUpdater) {
|
||||||
underlying.setIUpdater(iUpdater);
|
underlying.setUpdater(iUpdater);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -68,7 +68,7 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
|
||||||
for(LayerConfiguration l : layers){
|
for(LayerConfiguration l : layers){
|
||||||
if(l instanceof BaseLayerConfiguration){
|
if(l instanceof BaseLayerConfiguration){
|
||||||
BaseLayerConfiguration bl = (BaseLayerConfiguration)l;
|
BaseLayerConfiguration bl = (BaseLayerConfiguration)l;
|
||||||
if(bl.getIUpdater() == null && bl.initializer().numParams(bl) > 0){
|
if(bl.getUpdater() == null && bl.initializer().numParams(bl) > 0){
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -200,7 +200,7 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
layer.setIUpdater(iu);
|
layer.setUpdater(iu);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,7 +116,7 @@ public class ComputationGraphConfigurationDeserializer
|
||||||
}
|
}
|
||||||
|
|
||||||
if(attemptIUpdaterFromLegacy && layers[layerIdx] instanceof BaseLayerConfiguration
|
if(attemptIUpdaterFromLegacy && layers[layerIdx] instanceof BaseLayerConfiguration
|
||||||
&& ((BaseLayerConfiguration)layers[layerIdx]).getIUpdater() == null){
|
&& ((BaseLayerConfiguration)layers[layerIdx]).getUpdater() == null){
|
||||||
handleUpdaterBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next);
|
handleUpdaterBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserialize
|
||||||
ObjectNode on = (ObjectNode) confsNode.get(i);
|
ObjectNode on = (ObjectNode) confsNode.get(i);
|
||||||
ObjectNode confNode = null;
|
ObjectNode confNode = null;
|
||||||
if(layers[i] instanceof BaseLayerConfiguration
|
if(layers[i] instanceof BaseLayerConfiguration
|
||||||
&& ((BaseLayerConfiguration)layers[i]).getIUpdater() == null){
|
&& ((BaseLayerConfiguration)layers[i]).getUpdater() == null){
|
||||||
//layer -> (first/only child) -> updater
|
//layer -> (first/only child) -> updater
|
||||||
if(on.has("layer")){
|
if(on.has("layer")){
|
||||||
confNode = on;
|
confNode = on;
|
||||||
|
|
|
@ -77,6 +77,7 @@ public class WeightNoise implements IWeightNoise {
|
||||||
(applyToBias && init.isBiasParam(layer.getLayerConfiguration(), paramKey))) {
|
(applyToBias && init.isBiasParam(layer.getLayerConfiguration(), paramKey))) {
|
||||||
|
|
||||||
org.nd4j.linalg.api.rng.distribution.Distribution dist = Distributions.createDistribution(distribution);
|
org.nd4j.linalg.api.rng.distribution.Distribution dist = Distributions.createDistribution(distribution);
|
||||||
|
|
||||||
INDArray noise = dist.sample(param.ulike());
|
INDArray noise = dist.sample(param.ulike());
|
||||||
INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering());
|
INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering());
|
||||||
|
|
||||||
|
|
|
@ -577,7 +577,8 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Map<String, INDArray> getParamTable(boolean isBackprop) {
|
public Map<String, INDArray> getParamTable(boolean isBackprop) {
|
||||||
throw new RuntimeException("Not implemented");
|
// throw new RuntimeException("Not implemented");
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -340,6 +340,6 @@ public class GlobalPoolingLayer extends AbstractLayer<org.deeplearning4j.nn.conf
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void setParamTable(Map<String, INDArray> paramTable) {
|
public void setParamTable(Map<String, INDArray> paramTable) {
|
||||||
throw new RuntimeException("Not implemented.");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -768,7 +768,7 @@ public class LSTMHelpers {
|
||||||
InputType outputType = lstmLayer.getOutputType(-1, inputType);
|
InputType outputType = lstmLayer.getOutputType(-1, inputType);
|
||||||
|
|
||||||
val numParams = lstmLayer.initializer().numParams(lstmLayer);
|
val numParams = lstmLayer.initializer().numParams(lstmLayer);
|
||||||
int updaterSize = (int) lstmLayer.getIUpdater().stateSize(numParams);
|
int updaterSize = (int) lstmLayer.getUpdater().stateSize(numParams);
|
||||||
|
|
||||||
//Memory use during forward pass:
|
//Memory use during forward pass:
|
||||||
//ifogActivations: nTimeSteps * [minibatch,4*layerSize] (not cached during inference fwd pass)
|
//ifogActivations: nTimeSteps * [minibatch,4*layerSize] (not cached during inference fwd pass)
|
||||||
|
|
|
@ -565,7 +565,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork
|
||||||
* inference/backprop are excluded from the returned list.
|
* inference/backprop are excluded from the returned list.
|
||||||
*
|
*
|
||||||
* @param backpropParamsOnly If true, return backprop params only. If false: return all params
|
* @param backpropParamsOnly If true, return backprop params only. If false: return all params
|
||||||
* @return Parameters for the network
|
* @return Parameters for the network, empty Map if no parameters present in the neural network
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Map<String, INDArray> getParamTable(boolean backpropParamsOnly) {
|
public Map<String, INDArray> getParamTable(boolean backpropParamsOnly) {
|
||||||
|
@ -573,10 +573,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork
|
||||||
Map<String, INDArray> allParams = new LinkedHashMap<>();
|
Map<String, INDArray> allParams = new LinkedHashMap<>();
|
||||||
for (int i = 0; i < layers.length; i++) {
|
for (int i = 0; i < layers.length; i++) {
|
||||||
Map<String, INDArray> paramMap = layers[i].getParamTable(backpropParamsOnly);
|
Map<String, INDArray> paramMap = layers[i].getParamTable(backpropParamsOnly);
|
||||||
|
if(paramMap!=null){
|
||||||
for (Map.Entry<String, INDArray> entry : paramMap.entrySet()) {
|
for (Map.Entry<String, INDArray> entry : paramMap.entrySet()) {
|
||||||
String newKey = i + "_" + entry.getKey();
|
String newKey = i + "_" + entry.getKey();
|
||||||
allParams.put(newKey, entry.getValue());
|
allParams.put(newKey, entry.getValue());
|
||||||
}
|
}}
|
||||||
}
|
}
|
||||||
return allParams;
|
return allParams;
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,6 +94,7 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
|
||||||
if (!(conf instanceof org.deeplearning4j.nn.conf.layers.FeedForwardLayer))
|
if (!(conf instanceof org.deeplearning4j.nn.conf.layers.FeedForwardLayer))
|
||||||
throw new IllegalArgumentException("unsupported layer type: " + conf.getClass().getName());
|
throw new IllegalArgumentException("unsupported layer type: " + conf.getClass().getName());
|
||||||
|
|
||||||
|
INDArray reshapedParamsView = paramsView.reshape(paramsView.length());
|
||||||
Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap<String, INDArray>());
|
Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap<String, INDArray>());
|
||||||
|
|
||||||
val length = numParams(conf);
|
val length = numParams(conf);
|
||||||
|
@ -107,14 +108,14 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
|
||||||
val nOut = layerConf.getNOut();
|
val nOut = layerConf.getNOut();
|
||||||
|
|
||||||
val nWeightParams = nIn * nOut;
|
val nWeightParams = nIn * nOut;
|
||||||
INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams));
|
INDArray weightView = reshapedParamsView.get(NDArrayIndex.interval(0, nWeightParams));
|
||||||
|
|
||||||
params.put(WEIGHT_KEY, createWeightMatrix(layerConf, weightView, initializeParams));
|
params.put(WEIGHT_KEY, createWeightMatrix(layerConf, weightView, initializeParams));
|
||||||
layerConf.addVariable(WEIGHT_KEY);
|
layerConf.addVariable(WEIGHT_KEY);
|
||||||
|
|
||||||
long offset = nWeightParams;
|
long offset = nWeightParams;
|
||||||
if(hasBias(layerConf)){
|
if(hasBias(layerConf)){
|
||||||
INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true),
|
INDArray biasView = reshapedParamsView.get(
|
||||||
NDArrayIndex.interval(offset, offset + nOut));
|
NDArrayIndex.interval(offset, offset + nOut));
|
||||||
params.put(BIAS_KEY, createBias(layerConf, biasView, initializeParams));
|
params.put(BIAS_KEY, createBias(layerConf, biasView, initializeParams));
|
||||||
layerConf.addVariable(BIAS_KEY);
|
layerConf.addVariable(BIAS_KEY);
|
||||||
|
@ -122,7 +123,7 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
|
||||||
}
|
}
|
||||||
|
|
||||||
if(hasLayerNorm(layerConf)){
|
if(hasLayerNorm(layerConf)){
|
||||||
INDArray gainView = paramsView.get(NDArrayIndex.interval(0,0,true),
|
INDArray gainView = reshapedParamsView.get(
|
||||||
NDArrayIndex.interval(offset, offset + nOut));
|
NDArrayIndex.interval(offset, offset + nOut));
|
||||||
params.put(GAIN_KEY, createGain(conf, gainView, initializeParams));
|
params.put(GAIN_KEY, createGain(conf, gainView, initializeParams));
|
||||||
conf.addVariable(GAIN_KEY);
|
conf.addVariable(GAIN_KEY);
|
||||||
|
@ -138,23 +139,24 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
|
||||||
val nIn = layerConf.getNIn();
|
val nIn = layerConf.getNIn();
|
||||||
val nOut = layerConf.getNOut();
|
val nOut = layerConf.getNOut();
|
||||||
val nWeightParams = nIn * nOut;
|
val nWeightParams = nIn * nOut;
|
||||||
|
INDArray gradientViewReshaped = gradientView.reshape(gradientView.length());
|
||||||
|
|
||||||
INDArray weightGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams))
|
INDArray weightGradientView = gradientViewReshaped.get(NDArrayIndex.interval(0, nWeightParams))
|
||||||
.reshape('f', nIn, nOut);
|
.reshape('f', nIn, nOut);
|
||||||
|
|
||||||
Map<String, INDArray> out = new LinkedHashMap<>();
|
Map<String, INDArray> out = new LinkedHashMap<>();
|
||||||
out.put(WEIGHT_KEY, weightGradientView);
|
out.put(WEIGHT_KEY, weightGradientView);
|
||||||
|
|
||||||
long offset = nWeightParams;
|
long offset = nWeightParams;
|
||||||
if(hasBias(layerConf)){
|
if(hasBias(layerConf)){
|
||||||
INDArray biasView = gradientView.get(NDArrayIndex.interval(0,0,true),
|
INDArray biasView = gradientViewReshaped.get(
|
||||||
NDArrayIndex.interval(offset, offset + nOut)); //Already a row vector
|
NDArrayIndex.interval(offset, offset + nOut)); //Already a row vector
|
||||||
out.put(BIAS_KEY, biasView);
|
out.put(BIAS_KEY, biasView);
|
||||||
offset += nOut;
|
offset += nOut;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(hasLayerNorm(layerConf)){
|
if(hasLayerNorm(layerConf)){
|
||||||
INDArray gainView = gradientView.get(NDArrayIndex.interval(0,0,true),
|
INDArray gainView = gradientViewReshaped.get(
|
||||||
NDArrayIndex.interval(offset, offset + nOut)); //Already a row vector
|
NDArrayIndex.interval(offset, offset + nOut)); //Already a row vector
|
||||||
out.put(GAIN_KEY, gainView);
|
out.put(GAIN_KEY, gainView);
|
||||||
}
|
}
|
||||||
|
@ -196,13 +198,7 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
|
||||||
(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf;
|
(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf;
|
||||||
|
|
||||||
if (initializeParameters) {
|
if (initializeParameters) {
|
||||||
if( layerConf.getWeightInit() == null) {
|
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInit(),
|
||||||
// set a default and set warning
|
|
||||||
layerConf.setWeightInit(new WeightInitXavier());
|
|
||||||
log.warn("Weight Initializer function was not set on layer {} of class {}, it will default to {}", conf.getName(),
|
|
||||||
conf.getClass().getSimpleName(), WeightInitXavier.class.getSimpleName());
|
|
||||||
}
|
|
||||||
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInit(),
|
|
||||||
weightParamView, true);
|
weightParamView, true);
|
||||||
} else {
|
} else {
|
||||||
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), null, weightParamView, false);
|
return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), null, weightParamView, false);
|
||||||
|
|
|
@ -200,7 +200,7 @@ public class FineTuneConfiguration {
|
||||||
bl.setGradientNormalizationThreshold(gradientNormalizationThreshold);
|
bl.setGradientNormalizationThreshold(gradientNormalizationThreshold);
|
||||||
}
|
}
|
||||||
if (updater != null) {
|
if (updater != null) {
|
||||||
bl.setIUpdater(updater);
|
bl.setUpdater(updater);
|
||||||
}
|
}
|
||||||
if (biasUpdater != null) {
|
if (biasUpdater != null) {
|
||||||
bl.setBiasUpdater(biasUpdater);
|
bl.setBiasUpdater(biasUpdater);
|
||||||
|
|
|
@ -125,7 +125,7 @@ public class NetworkUtils {
|
||||||
LayerConfiguration l = net.getLayer(layerNumber).getLayerConfiguration();
|
LayerConfiguration l = net.getLayer(layerNumber).getLayerConfiguration();
|
||||||
if (l instanceof BaseLayerConfiguration) {
|
if (l instanceof BaseLayerConfiguration) {
|
||||||
BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
|
BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
|
||||||
IUpdater u = bl.getIUpdater();
|
IUpdater u = bl.getUpdater();
|
||||||
if (u != null && u.hasLearningRate()) {
|
if (u != null && u.hasLearningRate()) {
|
||||||
if (newLrSchedule != null) {
|
if (newLrSchedule != null) {
|
||||||
u.setLrAndSchedule(Double.NaN, newLrSchedule);
|
u.setLrAndSchedule(Double.NaN, newLrSchedule);
|
||||||
|
@ -207,7 +207,7 @@ public class NetworkUtils {
|
||||||
int epoch = net.getEpochCount();
|
int epoch = net.getEpochCount();
|
||||||
if (l instanceof BaseLayerConfiguration) {
|
if (l instanceof BaseLayerConfiguration) {
|
||||||
BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
|
BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
|
||||||
IUpdater u = bl.getIUpdater();
|
IUpdater u = bl.getUpdater();
|
||||||
if (u != null && u.hasLearningRate()) {
|
if (u != null && u.hasLearningRate()) {
|
||||||
double d = u.getLearningRate(iter, epoch);
|
double d = u.getLearningRate(iter, epoch);
|
||||||
if (Double.isNaN(d)) {
|
if (Double.isNaN(d)) {
|
||||||
|
@ -247,7 +247,7 @@ public class NetworkUtils {
|
||||||
LayerConfiguration l = net.getLayer(layerName).getLayerConfiguration();
|
LayerConfiguration l = net.getLayer(layerName).getLayerConfiguration();
|
||||||
if (l instanceof BaseLayerConfiguration) {
|
if (l instanceof BaseLayerConfiguration) {
|
||||||
BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
|
BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
|
||||||
IUpdater u = bl.getIUpdater();
|
IUpdater u = bl.getUpdater();
|
||||||
if (u != null && u.hasLearningRate()) {
|
if (u != null && u.hasLearningRate()) {
|
||||||
if (newLrSchedule != null) {
|
if (newLrSchedule != null) {
|
||||||
u.setLrAndSchedule(Double.NaN, newLrSchedule);
|
u.setLrAndSchedule(Double.NaN, newLrSchedule);
|
||||||
|
@ -329,7 +329,7 @@ public class NetworkUtils {
|
||||||
int epoch = net.getComputationGraphConfiguration().getEpochCount();
|
int epoch = net.getComputationGraphConfiguration().getEpochCount();
|
||||||
if (l instanceof BaseLayerConfiguration) {
|
if (l instanceof BaseLayerConfiguration) {
|
||||||
BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
|
BaseLayerConfiguration bl = (BaseLayerConfiguration) l;
|
||||||
IUpdater u = bl.getIUpdater();
|
IUpdater u = bl.getUpdater();
|
||||||
if (u != null && u.hasLearningRate()) {
|
if (u != null && u.hasLearningRate()) {
|
||||||
double d = u.getLearningRate(iter, epoch);
|
double d = u.getLearningRate(iter, epoch);
|
||||||
if (Double.isNaN(d)) {
|
if (Double.isNaN(d)) {
|
||||||
|
|
|
@ -210,14 +210,14 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
||||||
|
|
||||||
netCopy.fit(data);
|
netCopy.fit(data);
|
||||||
IUpdater expectedUpdater = ((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getIUpdater();
|
IUpdater expectedUpdater = ((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getUpdater();
|
||||||
double expectedLR = ((Nesterovs)((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getIUpdater()).getLearningRate();
|
double expectedLR = ((Nesterovs)((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getUpdater()).getLearningRate();
|
||||||
double expectedMomentum = ((Nesterovs)((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getIUpdater()).getMomentum();
|
double expectedMomentum = ((Nesterovs)((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getUpdater()).getMomentum();
|
||||||
|
|
||||||
IUpdater actualUpdater = ((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater();
|
IUpdater actualUpdater = ((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getUpdater();
|
||||||
sparkNet.fit(sparkData);
|
sparkNet.fit(sparkData);
|
||||||
double actualLR = ((Nesterovs)((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater()).getLearningRate();
|
double actualLR = ((Nesterovs)((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getUpdater()).getLearningRate();
|
||||||
double actualMomentum = ((Nesterovs)((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater()).getMomentum();
|
double actualMomentum = ((Nesterovs)((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getUpdater()).getMomentum();
|
||||||
|
|
||||||
assertEquals(expectedUpdater, actualUpdater);
|
assertEquals(expectedUpdater, actualUpdater);
|
||||||
assertEquals(expectedLR, actualLR, 0.01);
|
assertEquals(expectedLR, actualLR, 0.01);
|
||||||
|
|
|
@ -1580,10 +1580,11 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context),
|
zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context),
|
||||||
AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context));
|
AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context));
|
||||||
}
|
}
|
||||||
|
int errorCode = nativeOps.lastErrorCode();
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (errorCode != 0) {
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage() + " error code: " + nativeOps.lastErrorCode());
|
throw new RuntimeException(
|
||||||
|
nativeOps.lastErrorMessage() + " error code: " + errorCode);
|
||||||
|
}
|
||||||
profilingConfigurableHookOut(op, oc, st);
|
profilingConfigurableHookOut(op, oc, st);
|
||||||
|
|
||||||
return z;
|
return z;
|
||||||
|
|
|
@ -1189,7 +1189,7 @@ public class TrainModule implements UIModule {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
IUpdater u = bl.getIUpdater();
|
IUpdater u = bl.getUpdater();
|
||||||
String us = (u == null ? "" : u.getClass().getSimpleName());
|
String us = (u == null ? "" : u.getClass().getSimpleName());
|
||||||
layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerUpdater"),
|
layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerUpdater"),
|
||||||
us});
|
us});
|
||||||
|
|
Loading…
Reference in New Issue