Playing with some new code 2 - clean build/test
Signed-off-by: brian <brian@brutex.de>master
parent
a5dfdcb18f
commit
0f21ed9ec5
|
@ -47,6 +47,7 @@ import org.datavec.image.transform.ShowImageTransform;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
@ -54,9 +55,11 @@ import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||||
|
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -181,6 +184,7 @@ public class App {
|
||||||
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.gradientNormalizationThreshold( 100 )
|
.gradientNormalizationThreshold( 100 )
|
||||||
//.weightInitFn( new WeightInitXavier() ) //this is internal
|
//.weightInitFn( new WeightInitXavier() ) //this is internal
|
||||||
|
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
|
||||||
.weightInit( WeightInit.XAVIER)
|
.weightInit( WeightInit.XAVIER)
|
||||||
//.activationFn( new ActivationIdentity()) //this is internal
|
//.activationFn( new ActivationIdentity()) //this is internal
|
||||||
.activation( Activation.IDENTITY )
|
.activation( Activation.IDENTITY )
|
||||||
|
@ -232,10 +236,10 @@ public class App {
|
||||||
|
|
||||||
copyParams(gen, dis, gan);
|
copyParams(gen, dis, gan);
|
||||||
|
|
||||||
//gen.setListeners(new PerformanceListener(10, true));
|
gen.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
//dis.setListeners(new PerformanceListener(10, true));
|
dis.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
//gan.setListeners(new PerformanceListener(10, true));
|
gan.addTrainingListeners(new PerformanceListener(10, true));
|
||||||
gan.setListeners(new ScoreToChartListener("gan"));
|
gan.addTrainingListeners(new ScoreToChartListener("gan"));
|
||||||
//dis.setListeners(new ScoreToChartListener("dis"));
|
//dis.setListeners(new ScoreToChartListener("dis"));
|
||||||
|
|
||||||
gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
|
gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
|
||||||
|
@ -322,23 +326,25 @@ public class App {
|
||||||
int genLayerCount = gen.getLayers().length;
|
int genLayerCount = gen.getLayers().length;
|
||||||
for (int i = 0; i < gan.getLayers().length; i++) {
|
for (int i = 0; i < gan.getLayers().length; i++) {
|
||||||
if (i < genLayerCount) {
|
if (i < genLayerCount) {
|
||||||
gen.getLayer(i).setParams(gan.getLayer(i).params());
|
if(gan.getLayer(i).getParams() != null)
|
||||||
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
} else {
|
} else {
|
||||||
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
|
if(gan.getLayer(i).getParams() != null)
|
||||||
|
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
|
||||||
for (int i = 0; i < gen.getLayers().length; i++) {
|
for (int i = 0; i < gen.getLayers().length; i++) {
|
||||||
gen.getLayer(i).setParams(gan.getLayer(i).params());
|
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
|
||||||
int genLayerCount = gen.getLayers().length;
|
int genLayerCount = gen.getLayers().length;
|
||||||
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||||
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).params());
|
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -115,15 +115,15 @@ public class GAN {
|
||||||
|
|
||||||
|
|
||||||
public void setGeneratorListeners(BaseTrainingListener[] listeners) {
|
public void setGeneratorListeners(BaseTrainingListener[] listeners) {
|
||||||
generator.setListeners(listeners);
|
generator.addTrainingListeners(listeners);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setDiscriminatorListeners(BaseTrainingListener[] listeners) {
|
public void setDiscriminatorListeners(BaseTrainingListener[] listeners) {
|
||||||
discriminator.setListeners(listeners);
|
discriminator.addTrainingListeners(listeners);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setGanListeners(BaseTrainingListener[] listeners) {
|
public void setGanListeners(BaseTrainingListener[] listeners) {
|
||||||
gan.setListeners(listeners);
|
gan.addTrainingListeners(listeners);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void fit(DataSetIterator realData, int numEpochs) {
|
public void fit(DataSetIterator realData, int numEpochs) {
|
||||||
|
@ -239,9 +239,9 @@ public class GAN {
|
||||||
int genLayerCount = generator.getLayers().length;
|
int genLayerCount = generator.getLayers().length;
|
||||||
for (int i = 0; i < gan.getLayers().length; i++) {
|
for (int i = 0; i < gan.getLayers().length; i++) {
|
||||||
if (i < genLayerCount) {
|
if (i < genLayerCount) {
|
||||||
generator.getLayer(i).setParams(gan.getLayer(i).params());
|
generator.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
} else {
|
} else {
|
||||||
discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
|
discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -252,7 +252,7 @@ public class GAN {
|
||||||
*/
|
*/
|
||||||
private void updateGeneratorFromGan() {
|
private void updateGeneratorFromGan() {
|
||||||
for (int i = 0; i < generator.getLayers().length; i++) {
|
for (int i = 0; i < generator.getLayers().length; i++) {
|
||||||
generator.getLayer(i).setParams(gan.getLayer(i).params());
|
generator.getLayer(i).setParams(gan.getLayer(i).getParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,7 +263,7 @@ public class GAN {
|
||||||
private void updateGanWithDiscriminator() {
|
private void updateGanWithDiscriminator() {
|
||||||
int genLayerCount = generator.getLayers().length;
|
int genLayerCount = generator.getLayers().length;
|
||||||
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
for (int i = genLayerCount; i < gan.getLayers().length; i++) {
|
||||||
gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).params());
|
gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).getParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -155,8 +155,8 @@ public class MnistDCGANExample {
|
||||||
.updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build())
|
.updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
gan.getGenerator().setListeners(new PerformanceListener(1, true));
|
gan.getGenerator().addTrainingListeners(new PerformanceListener(1, true));
|
||||||
gan.getDiscriminator().setListeners(new PerformanceListener(1, true));
|
gan.getDiscriminator().addTrainingListeners(new PerformanceListener(1, true));
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
|
||||||
|
|
||||||
|
|
|
@ -205,7 +205,7 @@ public class TestServer2 {
|
||||||
//PostgresStatsStorage psqlStore = new PostgresStatsStorage();
|
//PostgresStatsStorage psqlStore = new PostgresStatsStorage();
|
||||||
int listenerFrequency = 2;
|
int listenerFrequency = 2;
|
||||||
//net.setListeners(new StatsListener(psqlStore, listenerFrequency), new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
|
//net.setListeners(new StatsListener(psqlStore, listenerFrequency), new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
|
||||||
net.setListeners(new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
|
net.addTrainingListeners(new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
|
||||||
|
|
||||||
|
|
||||||
//Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
|
//Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
|
||||||
|
|
|
@ -290,7 +290,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
for (int i : layersToTrain) {
|
for (int i : layersToTrain) {
|
||||||
mln.pretrainLayer(i, dsi);
|
mln.pretrainLayer(i, dsi);
|
||||||
}
|
}
|
||||||
paramsPostTraining = mln.params();
|
paramsPostTraining = mln.getModelParams();
|
||||||
} else if (modelType == ModelType.CG) {
|
} else if (modelType == ModelType.CG) {
|
||||||
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
||||||
Preconditions.checkState(layersToTrain != null, "ILayer names must not be null");
|
Preconditions.checkState(layersToTrain != null, "ILayer names must not be null");
|
||||||
|
@ -298,7 +298,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
for (String i : layersToTrain) {
|
for (String i : layersToTrain) {
|
||||||
cg.pretrainLayer(i, iter);
|
cg.pretrainLayer(i, iter);
|
||||||
}
|
}
|
||||||
paramsPostTraining = cg.params();
|
paramsPostTraining = cg.getModelParams();
|
||||||
} else {
|
} else {
|
||||||
throw new UnsupportedOperationException("SameDiff not supported for unsupervised training tests");
|
throw new UnsupportedOperationException("SameDiff not supported for unsupervised training tests");
|
||||||
}
|
}
|
||||||
|
@ -314,7 +314,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
|
|
||||||
CollectScoresListener l = new CollectScoresListener(1);
|
CollectScoresListener l = new CollectScoresListener(1);
|
||||||
if (modelType != ModelType.SAMEDIFF)
|
if (modelType != ModelType.SAMEDIFF)
|
||||||
m.setListeners(l);
|
m.addTrainingListeners(l);
|
||||||
|
|
||||||
History h = null;
|
History h = null;
|
||||||
if (modelType == ModelType.MLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
|
@ -349,7 +349,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
||||||
IntegrationTestRunner.write(m.params(), p);
|
IntegrationTestRunner.write(m.getModelParams(), p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -191,7 +191,7 @@ public class IntegrationTestRunner {
|
||||||
|
|
||||||
MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
|
MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
|
||||||
assertEquals(loaded.getNetConfiguration(), mln.getNetConfiguration(), "Configs not equal");
|
assertEquals(loaded.getNetConfiguration(), mln.getNetConfiguration(), "Configs not equal");
|
||||||
assertEquals( loaded.params(), mln.params(), "Params not equal");
|
assertEquals( loaded.getModelParams(), mln.getModelParams(), "Params not equal");
|
||||||
assertEquals( loaded.getParamTable(), mln.getParamTable(), "Param table not equal");
|
assertEquals( loaded.getParamTable(), mln.getParamTable(), "Param table not equal");
|
||||||
} else if(config instanceof ComputationGraphConfiguration ){
|
} else if(config instanceof ComputationGraphConfiguration ){
|
||||||
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
|
||||||
|
@ -201,7 +201,7 @@ public class IntegrationTestRunner {
|
||||||
|
|
||||||
ComputationGraph loaded = ComputationGraph.load(savedModel, true);
|
ComputationGraph loaded = ComputationGraph.load(savedModel, true);
|
||||||
assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" );
|
assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" );
|
||||||
assertEquals( loaded.params(), cg.params(), "Params not equal");
|
assertEquals( loaded.getModelParams(), cg.getModelParams(), "Params not equal");
|
||||||
assertEquals(loaded.getParamTable(), cg.getParamTable(), "Param table not equal");
|
assertEquals(loaded.getParamTable(), cg.getParamTable(), "Param table not equal");
|
||||||
} else if(config instanceof SameDiff){
|
} else if(config instanceof SameDiff){
|
||||||
sd = (SameDiff)config;
|
sd = (SameDiff)config;
|
||||||
|
@ -389,7 +389,7 @@ public class IntegrationTestRunner {
|
||||||
for( int i : layersToTrain){
|
for( int i : layersToTrain){
|
||||||
mln.pretrainLayer(i, dsi);
|
mln.pretrainLayer(i, dsi);
|
||||||
}
|
}
|
||||||
paramsPostTraining = mln.params();
|
paramsPostTraining = mln.getModelParams();
|
||||||
layers = mln.getLayers();
|
layers = mln.getLayers();
|
||||||
} else if(modelType == ModelType.CG) {
|
} else if(modelType == ModelType.CG) {
|
||||||
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
||||||
|
@ -398,7 +398,7 @@ public class IntegrationTestRunner {
|
||||||
for( String i : layersToTrain){
|
for( String i : layersToTrain){
|
||||||
cg.pretrainLayer(i, iter);
|
cg.pretrainLayer(i, iter);
|
||||||
}
|
}
|
||||||
paramsPostTraining = cg.params();
|
paramsPostTraining = cg.getModelParams();
|
||||||
layers = cg.getLayers();
|
layers = cg.getLayers();
|
||||||
} else {
|
} else {
|
||||||
throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models");
|
throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models");
|
||||||
|
@ -439,7 +439,7 @@ public class IntegrationTestRunner {
|
||||||
CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength);
|
CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength);
|
||||||
CollectScoresListener l = new CollectScoresListener(1);
|
CollectScoresListener l = new CollectScoresListener(1);
|
||||||
if(modelType != ModelType.SAMEDIFF) {
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
m.setListeners(l);
|
m.addTrainingListeners(l);
|
||||||
}
|
}
|
||||||
|
|
||||||
int iterBefore;
|
int iterBefore;
|
||||||
|
@ -519,10 +519,10 @@ public class IntegrationTestRunner {
|
||||||
if(modelType != ModelType.SAMEDIFF) {
|
if(modelType != ModelType.SAMEDIFF) {
|
||||||
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
|
||||||
INDArray paramsExp = read(p);
|
INDArray paramsExp = read(p);
|
||||||
INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
|
INDArray z = exceedsRelError(m.getModelParams(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
|
||||||
int count = z.sumNumber().intValue();
|
int count = z.sumNumber().intValue();
|
||||||
if (count > 0) {
|
if (count > 0) {
|
||||||
logFailedParams(20, "Parameter", layers, z, paramsExp, m.params());
|
logFailedParams(20, "Parameter", layers, z, paramsExp, m.getModelParams());
|
||||||
}
|
}
|
||||||
assertEquals( 0, count, "Number of params exceeded max relative error");
|
assertEquals( 0, count, "Number of params exceeded max relative error");
|
||||||
} else {
|
} else {
|
||||||
|
@ -607,12 +607,12 @@ public class IntegrationTestRunner {
|
||||||
ModelSerializer.writeModel(m, f, true);
|
ModelSerializer.writeModel(m, f, true);
|
||||||
MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
|
MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
|
||||||
assertEquals(mln.getNetConfiguration(), restored.getNetConfiguration());
|
assertEquals(mln.getNetConfiguration(), restored.getNetConfiguration());
|
||||||
assertEquals(mln.params(), restored.params());
|
assertEquals(mln.getModelParams(), restored.getModelParams());
|
||||||
} else if(modelType == ModelType.CG){
|
} else if(modelType == ModelType.CG){
|
||||||
ModelSerializer.writeModel(m, f, true);
|
ModelSerializer.writeModel(m, f, true);
|
||||||
ComputationGraph restored = ComputationGraph.load(f, true);
|
ComputationGraph restored = ComputationGraph.load(f, true);
|
||||||
assertEquals(cg.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
|
assertEquals(cg.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
|
||||||
assertEquals(cg.params(), restored.params());
|
assertEquals(cg.getModelParams(), restored.getModelParams());
|
||||||
} else {
|
} else {
|
||||||
sd.save(f, true);
|
sd.save(f, true);
|
||||||
SameDiff restored = SameDiff.load(f, true);
|
SameDiff restored = SameDiff.load(f, true);
|
||||||
|
|
|
@ -49,7 +49,7 @@ public class TestUtils {
|
||||||
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
||||||
|
|
||||||
assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
|
assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
|
||||||
assertEquals(net.params(), restored.params());
|
assertEquals(net.getModelParams(), restored.getModelParams());
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
//Should never happen
|
//Should never happen
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
@ -74,7 +74,7 @@ public class TestUtils {
|
||||||
restored = ModelSerializer.restoreComputationGraph(bais, true);
|
restored = ModelSerializer.restoreComputationGraph(bais, true);
|
||||||
|
|
||||||
assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
|
assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
|
||||||
assertEquals(net.params(), restored.params());
|
assertEquals(net.getModelParams(), restored.getModelParams());
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
//Should never happen
|
//Should never happen
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/** The ActivationIdentity activation function, just returns the input as is.
|
||||||
* f(x) = x
|
* f(x) = x
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
|
|
@ -195,7 +195,7 @@ public abstract class BaseWorkspaceMgr<T extends Enum<T>> implements WorkspaceMg
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray validateArrayLocation(@NonNull T arrayType, @NonNull INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) {
|
public INDArray validateArrayLocation(T arrayType, INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) {
|
||||||
validateConfig(arrayType);
|
validateConfig(arrayType);
|
||||||
|
|
||||||
if(scopeOutOfWs.contains(arrayType)){
|
if(scopeOutOfWs.contains(arrayType)){
|
||||||
|
|
|
@ -19,6 +19,7 @@ dependencies {
|
||||||
testImplementation projects.cavisNative.cavisNativeCommon
|
testImplementation projects.cavisNative.cavisNativeCommon
|
||||||
testImplementation projects.cavisNd4j.cavisNd4jCommonTests
|
testImplementation projects.cavisNd4j.cavisNd4jCommonTests
|
||||||
testImplementation projects.cavisDnn.cavisDnnCommonTests
|
testImplementation projects.cavisDnn.cavisDnnCommonTests
|
||||||
|
testImplementation projects.cavisDnn.cavisDnnNn
|
||||||
|
|
||||||
implementation "org.apache.commons:commons-lang3"
|
implementation "org.apache.commons:commons-lang3"
|
||||||
|
|
||||||
|
|
|
@ -116,7 +116,7 @@ public class LayerHelperValidationUtil {
|
||||||
|
|
||||||
MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
|
MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
|
||||||
net2With.init();
|
net2With.init();
|
||||||
net2With.params().assign(netOrig.params());
|
net2With.getModelParams().assign(netOrig.getModelParams());
|
||||||
log.info("Removing all except for specified helpers from network copy 2: " + t.getAllowHelpersForClasses());
|
log.info("Removing all except for specified helpers from network copy 2: " + t.getAllowHelpersForClasses());
|
||||||
removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
|
removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
|
||||||
|
|
||||||
|
@ -124,7 +124,7 @@ public class LayerHelperValidationUtil {
|
||||||
Preconditions.checkNotNull(t.getFeatures(), "Features are not set (null)");
|
Preconditions.checkNotNull(t.getFeatures(), "Features are not set (null)");
|
||||||
|
|
||||||
for (boolean train : new boolean[]{false, true}) {
|
for (boolean train : new boolean[]{false, true}) {
|
||||||
assertEquals(net1NoHelper.params(), net2With.params());
|
assertEquals(net1NoHelper.getModelParams(), net2With.getModelParams());
|
||||||
String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: ");
|
String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: ");
|
||||||
List<INDArray> ff1;
|
List<INDArray> ff1;
|
||||||
try {
|
try {
|
||||||
|
@ -180,7 +180,7 @@ public class LayerHelperValidationUtil {
|
||||||
double maxRE = relError.maxNumber().doubleValue();
|
double maxRE = relError.maxNumber().doubleValue();
|
||||||
log.info(s + "Output, max relative error: " + maxRE);
|
log.info(s + "Output, max relative error: " + maxRE);
|
||||||
|
|
||||||
assertEquals(net1NoHelper.params(), net2With.params()); //Check that forward pass does not modify params
|
assertEquals(net1NoHelper.getModelParams(), net2With.getModelParams()); //Check that forward pass does not modify params
|
||||||
assertTrue(maxRE < t.getMaxRelError(), s + "Max RE: " + maxRE);
|
assertTrue(maxRE < t.getMaxRelError(), s + "Max RE: " + maxRE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -255,24 +255,24 @@ public class LayerHelperValidationUtil {
|
||||||
|
|
||||||
net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
|
net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
|
||||||
net2With.init();
|
net2With.init();
|
||||||
net2With.params().assign(netOrig.params());
|
net2With.getModelParams().assign(netOrig.getModelParams());
|
||||||
log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());
|
log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());
|
||||||
removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
|
removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
|
||||||
|
|
||||||
CollectScoresListener listener = new CollectScoresListener(1);
|
CollectScoresListener listener = new CollectScoresListener(1);
|
||||||
net2With.setListeners(listener);
|
net2With.addTrainingListeners(listener);
|
||||||
net2With.fit(t.getData());
|
net2With.fit(t.getData());
|
||||||
|
|
||||||
for( int i=0; i<2; i++ ) {
|
for( int i=0; i<2; i++ ) {
|
||||||
|
|
||||||
net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
|
net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
|
||||||
net2With.init();
|
net2With.init();
|
||||||
net2With.params().assign(netOrig.params());
|
net2With.getModelParams().assign(netOrig.getModelParams());
|
||||||
log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());
|
log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());
|
||||||
removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
|
removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
|
||||||
|
|
||||||
CollectScoresListener listener2 = new CollectScoresListener(1);
|
CollectScoresListener listener2 = new CollectScoresListener(1);
|
||||||
net2With.setListeners(listener2);
|
net2With.addTrainingListeners(listener2);
|
||||||
net2With.fit(t.getData());
|
net2With.fit(t.getData());
|
||||||
|
|
||||||
DoubleArrayList listOrig = listener.getListScore();
|
DoubleArrayList listOrig = listener.getListScore();
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
|
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
|
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
|
||||||
|
@ -67,7 +67,7 @@ public class TestUtils {
|
||||||
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
||||||
|
|
||||||
assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
|
assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
|
||||||
assertEquals(net.params(), restored.params());
|
assertEquals(net.getModelParams(), restored.getModelParams());
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
//Should never happen
|
//Should never happen
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
@ -91,7 +91,7 @@ public class TestUtils {
|
||||||
restored = ModelSerializer.restoreComputationGraph(bais, true);
|
restored = ModelSerializer.restoreComputationGraph(bais, true);
|
||||||
|
|
||||||
assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
|
assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
|
||||||
assertEquals(net.params(), restored.params());
|
assertEquals(net.getModelParams(), restored.getModelParams());
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
//Should never happen
|
//Should never happen
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
@ -205,8 +205,8 @@ public class TestUtils {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static L2Regularization getL2Reg(BaseLayer baseLayer){
|
public static L2Regularization getL2Reg(BaseLayerConfiguration baseLayerConfiguration){
|
||||||
return getL2Reg(baseLayer.getRegularization());
|
return getL2Reg(baseLayerConfiguration.getRegularization());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static L2Regularization getL2Reg(List<Regularization> l){
|
public static L2Regularization getL2Reg(List<Regularization> l){
|
||||||
|
@ -218,7 +218,7 @@ public class TestUtils {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static WeightDecay getWeightDecayReg(BaseLayer bl){
|
public static WeightDecay getWeightDecayReg(BaseLayerConfiguration bl){
|
||||||
return getWeightDecayReg(bl.getRegularization());
|
return getWeightDecayReg(bl.getRegularization());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -231,7 +231,7 @@ public class TestUtils {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double getL1(BaseLayer layer) {
|
public static double getL1(BaseLayerConfiguration layer) {
|
||||||
List<Regularization> l = layer.getRegularization();
|
List<Regularization> l = layer.getRegularization();
|
||||||
return getL1(l);
|
return getL1(l);
|
||||||
}
|
}
|
||||||
|
@ -246,7 +246,7 @@ public class TestUtils {
|
||||||
return l1Reg.getL1().valueAt(0,0);
|
return l1Reg.getL1().valueAt(0,0);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double getL2(BaseLayer layer) {
|
public static double getL2(BaseLayerConfiguration layer) {
|
||||||
List<Regularization> l = layer.getRegularization();
|
List<Regularization> l = layer.getRegularization();
|
||||||
return getL2(l);
|
return getL2(l);
|
||||||
}
|
}
|
||||||
|
@ -269,7 +269,7 @@ public class TestUtils {
|
||||||
return getL2(layer.getRegularization());
|
return getL2(layer.getRegularization());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double getWeightDecay(BaseLayer layer) {
|
public static double getWeightDecay(BaseLayerConfiguration layer) {
|
||||||
return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0);
|
return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,6 @@ import org.deeplearning4j.eval.Evaluation;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
@ -183,7 +182,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
|
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
|
||||||
model.init();
|
model.init();
|
||||||
|
|
||||||
model.setListeners(new ScoreIterationListener(listenerFreq));
|
model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
|
||||||
|
|
||||||
model.fit(lfw.next());
|
model.fit(lfw.next());
|
||||||
|
|
||||||
|
@ -247,7 +246,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
|
||||||
//model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
|
//model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
|
||||||
|
|
||||||
CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
|
CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
|
||||||
model.setListeners(listener);
|
model.addTrainingListeners(listener);
|
||||||
|
|
||||||
model.fit(cifar);
|
model.fit(cifar);
|
||||||
|
|
||||||
|
|
|
@ -226,7 +226,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
||||||
|
@ -255,7 +255,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
|
MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
|
||||||
|
@ -304,7 +304,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
||||||
|
@ -343,7 +343,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
|
@ -386,7 +386,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
|
@ -430,7 +430,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
int nSamples = 100;
|
int nSamples = 100;
|
||||||
//Generate the training data
|
//Generate the training data
|
||||||
INDArray x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1);
|
INDArray x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1);
|
||||||
|
@ -473,7 +473,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
|
MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
|
||||||
|
@ -496,9 +496,9 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(net.getnLayers(), mln.getnLayers());
|
assertEquals(net.getnLayers(), mln.getnLayers());
|
||||||
assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo());
|
assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo());
|
||||||
BaseLayer bl = (BaseLayer) net.getLayerConfiguration();
|
BaseLayerConfiguration bl = (BaseLayerConfiguration) net.getLayerConfiguration();
|
||||||
assertEquals(bl.getActivationFn().toString(), ((BaseLayer) mln.getLayerConfiguration()).getActivationFn().toString());
|
assertEquals(bl.getActivationFn().toString(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getActivationFn().toString());
|
||||||
assertEquals(bl.getIUpdater(), ((BaseLayer) mln.getLayerConfiguration()).getIUpdater());
|
assertEquals(bl.getIUpdater(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getIUpdater());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -511,7 +511,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
||||||
|
@ -792,7 +792,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
|
||||||
TestListener tl = new TestListener();
|
TestListener tl = new TestListener();
|
||||||
net.setListeners(tl);
|
net.addTrainingListeners(tl);
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
|
||||||
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
||||||
|
|
|
@ -84,7 +84,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
||||||
|
@ -128,7 +128,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
||||||
|
@ -165,7 +165,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
|
@ -207,7 +207,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
|
@ -241,7 +241,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
|
||||||
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
||||||
|
@ -538,7 +538,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
|
|
||||||
TestEarlyStopping.TestListener tl = new TestEarlyStopping.TestListener();
|
TestEarlyStopping.TestListener tl = new TestEarlyStopping.TestListener();
|
||||||
net.setListeners(tl);
|
net.addTrainingListeners(tl);
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
|
||||||
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
||||||
|
|
|
@ -84,7 +84,7 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
// Instantiate model
|
// Instantiate model
|
||||||
MultiLayerNetwork model = new MultiLayerNetwork(conf);
|
MultiLayerNetwork model = new MultiLayerNetwork(conf);
|
||||||
model.init();
|
model.init();
|
||||||
model.addListeners(new ScoreIterationListener(1));
|
model.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
// Train-test split
|
// Train-test split
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
@ -324,7 +324,7 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
net2.setParams(net1.params());
|
net2.setParams(net1.getModelParams());
|
||||||
|
|
||||||
for(boolean useMask : new boolean[]{false, true}) {
|
for(boolean useMask : new boolean[]{false, true}) {
|
||||||
|
|
||||||
|
@ -405,7 +405,7 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
ComputationGraph net2 = new ComputationGraph(conf2);
|
ComputationGraph net2 = new ComputationGraph(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
net2.setParams(net1.params());
|
net2.setParams(net1.getModelParams());
|
||||||
|
|
||||||
for (boolean useMask : new boolean[]{false, true}) {
|
for (boolean useMask : new boolean[]{false, true}) {
|
||||||
|
|
||||||
|
@ -492,7 +492,7 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
DataSetIterator iter = new IrisDataSetIterator(30, 150);
|
DataSetIterator iter = new IrisDataSetIterator(30, 150);
|
||||||
DataSetIterator iterTest = new IrisDataSetIterator(30, 150);
|
DataSetIterator iterTest = new IrisDataSetIterator(30, 150);
|
||||||
|
|
||||||
net.setListeners(new EvaluativeListener(iterTest, 3));
|
net.addTrainingListeners(new EvaluativeListener(iterTest, 3));
|
||||||
|
|
||||||
for( int i=0; i<3; i++ ){
|
for( int i=0; i<3; i++ ){
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
|
|
|
@ -26,7 +26,6 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -219,11 +218,11 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
mln.setInput(ds.getFeatures());
|
mln.setInput(ds.getFeatures());
|
||||||
mln.setLabels(ds.getLabels());
|
mln.setLabels(ds.getLabels());
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreBefore = mln.score();
|
double scoreBefore = mln.getScore();
|
||||||
for (int k = 0; k < 20; k++)
|
for (int k = 0; k < 20; k++)
|
||||||
mln.fit(ds);
|
mln.fit(ds);
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreAfter = mln.score();
|
double scoreAfter = mln.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
String msg = name
|
String msg = name
|
||||||
+ " - score did not (sufficiently) decrease during learning - activationFn="
|
+ " - score did not (sufficiently) decrease during learning - activationFn="
|
||||||
|
@ -323,11 +322,11 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
mln.setInput(ds.getFeatures());
|
mln.setInput(ds.getFeatures());
|
||||||
mln.setLabels(ds.getLabels());
|
mln.setLabels(ds.getLabels());
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreBefore = mln.score();
|
double scoreBefore = mln.getScore();
|
||||||
for (int k = 0; k < 10; k++)
|
for (int k = 0; k < 10; k++)
|
||||||
mln.fit(ds);
|
mln.fit(ds);
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreAfter = mln.score();
|
double scoreAfter = mln.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
String msg = name
|
String msg = name
|
||||||
+ " - score did not (sufficiently) decrease during learning - activationFn="
|
+ " - score did not (sufficiently) decrease during learning - activationFn="
|
||||||
|
@ -554,11 +553,11 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
net.setInput(0, ds.getFeatures());
|
net.setInput(0, ds.getFeatures());
|
||||||
net.setLabels(ds.getLabels());
|
net.setLabels(ds.getLabels());
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double scoreBefore = net.score();
|
double scoreBefore = net.getScore();
|
||||||
for (int k = 0; k < 20; k++)
|
for (int k = 0; k < 20; k++)
|
||||||
net.fit(ds);
|
net.fit(ds);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double scoreAfter = net.score();
|
double scoreAfter = net.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
String msg = name
|
String msg = name
|
||||||
+ " - score did not (sufficiently) decrease during learning - activationFn="
|
+ " - score did not (sufficiently) decrease during learning - activationFn="
|
||||||
|
|
|
@ -27,7 +27,6 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
@ -120,11 +119,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
mln.setInput(ds.getFeatures());
|
mln.setInput(ds.getFeatures());
|
||||||
mln.setLabels(ds.getLabels());
|
mln.setLabels(ds.getLabels());
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreBefore = mln.score();
|
double scoreBefore = mln.getScore();
|
||||||
for (int j = 0; j < 10; j++)
|
for (int j = 0; j < 10; j++)
|
||||||
mln.fit(ds);
|
mln.fit(ds);
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreAfter = mln.score();
|
double scoreAfter = mln.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
String msg = name + " - score did not (sufficiently) decrease during learning - activationFn="
|
String msg = name + " - score did not (sufficiently) decrease during learning - activationFn="
|
||||||
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
||||||
|
@ -212,11 +211,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
mln.setInput(ds.getFeatures());
|
mln.setInput(ds.getFeatures());
|
||||||
mln.setLabels(ds.getLabels());
|
mln.setLabels(ds.getLabels());
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreBefore = mln.score();
|
double scoreBefore = mln.getScore();
|
||||||
for (int j = 0; j < 10; j++)
|
for (int j = 0; j < 10; j++)
|
||||||
mln.fit(ds);
|
mln.fit(ds);
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreAfter = mln.score();
|
double scoreAfter = mln.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
String msg = testName
|
String msg = testName
|
||||||
+ "- score did not (sufficiently) decrease during learning - activationFn="
|
+ "- score did not (sufficiently) decrease during learning - activationFn="
|
||||||
|
|
|
@ -105,11 +105,11 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
mln.setInput(ds.getFeatures());
|
mln.setInput(ds.getFeatures());
|
||||||
mln.setLabels(ds.getLabels());
|
mln.setLabels(ds.getLabels());
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreBefore = mln.score();
|
double scoreBefore = mln.getScore();
|
||||||
for (int j = 0; j < 10; j++)
|
for (int j = 0; j < 10; j++)
|
||||||
mln.fit(ds);
|
mln.fit(ds);
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreAfter = mln.score();
|
double scoreAfter = mln.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
String msg = "testMinibatchApplication() - score did not (sufficiently) decrease during learning - activationFn="
|
String msg = "testMinibatchApplication() - score did not (sufficiently) decrease during learning - activationFn="
|
||||||
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
||||||
|
@ -184,11 +184,11 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
mln.setInput(ds.getFeatures());
|
mln.setInput(ds.getFeatures());
|
||||||
mln.setLabels(ds.getLabels());
|
mln.setLabels(ds.getLabels());
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreBefore = mln.score();
|
double scoreBefore = mln.getScore();
|
||||||
for (int j = 0; j < 10; j++)
|
for (int j = 0; j < 10; j++)
|
||||||
mln.fit(ds);
|
mln.fit(ds);
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreAfter = mln.score();
|
double scoreAfter = mln.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
|
String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
|
||||||
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
||||||
|
@ -278,11 +278,11 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
mln.setInput(ds.getFeatures());
|
mln.setInput(ds.getFeatures());
|
||||||
mln.setLabels(ds.getLabels());
|
mln.setLabels(ds.getLabels());
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreBefore = mln.score();
|
double scoreBefore = mln.getScore();
|
||||||
for (int j = 0; j < 10; j++)
|
for (int j = 0; j < 10; j++)
|
||||||
mln.fit(ds);
|
mln.fit(ds);
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreAfter = mln.score();
|
double scoreAfter = mln.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
|
String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
|
||||||
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
||||||
|
@ -452,11 +452,11 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
mln.setInput(ds.getFeatures());
|
mln.setInput(ds.getFeatures());
|
||||||
mln.setLabels(ds.getLabels());
|
mln.setLabels(ds.getLabels());
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreBefore = mln.score();
|
double scoreBefore = mln.getScore();
|
||||||
for (int j = 0; j < 10; j++)
|
for (int j = 0; j < 10; j++)
|
||||||
mln.fit(ds);
|
mln.fit(ds);
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreAfter = mln.score();
|
double scoreAfter = mln.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
|
msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
|
||||||
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
||||||
|
@ -523,13 +523,13 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
netGraph.setInputs(features);
|
netGraph.setInputs(features);
|
||||||
netGraph.setLabels(labels);
|
netGraph.setLabels(labels);
|
||||||
netGraph.computeGradientAndScore();
|
netGraph.computeGradientAndScore();
|
||||||
double scoreBefore = netGraph.score();
|
double scoreBefore = netGraph.getScore();
|
||||||
|
|
||||||
String msg;
|
String msg;
|
||||||
for (int epoch = 0; epoch < 5; epoch++)
|
for (int epoch = 0; epoch < 5; epoch++)
|
||||||
netGraph.fit(new INDArray[]{features}, new INDArray[]{labels});
|
netGraph.fit(new INDArray[]{features}, new INDArray[]{labels});
|
||||||
netGraph.computeGradientAndScore();
|
netGraph.computeGradientAndScore();
|
||||||
double scoreAfter = netGraph.score();
|
double scoreAfter = netGraph.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
msg = "elementWiseMultiplicationLayerTest() - score did not (sufficiently) decrease during learning - activationFn="
|
msg = "elementWiseMultiplicationLayerTest() - score did not (sufficiently) decrease during learning - activationFn="
|
||||||
+ "Id" + ", lossFn=" + "Cos-sim" + ", outputActivation=" + "Id"
|
+ "Id" + ", lossFn=" + "Cos-sim" + ", outputActivation=" + "Id"
|
||||||
|
@ -757,11 +757,11 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
mln.setInput(ds.getFeatures());
|
mln.setInput(ds.getFeatures());
|
||||||
mln.setLabels(ds.getLabels());
|
mln.setLabels(ds.getLabels());
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreBefore = mln.score();
|
double scoreBefore = mln.getScore();
|
||||||
for (int j = 0; j < 10; j++)
|
for (int j = 0; j < 10; j++)
|
||||||
mln.fit(ds);
|
mln.fit(ds);
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double scoreAfter = mln.score();
|
double scoreAfter = mln.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
|
String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
|
||||||
+ afn + ", lossFn=" + lf + ", layerNorm=" + layerNorm + ", outputActivation=" + outputActivation
|
+ afn + ", lossFn=" + lf + ", layerNorm=" + layerNorm + ", outputActivation=" + outputActivation
|
||||||
|
|
|
@ -666,7 +666,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
//Check params to avoid test flakiness on small or large params
|
//Check params to avoid test flakiness on small or large params
|
||||||
INDArray params = net.params();
|
INDArray params = net.getModelParams();
|
||||||
for( int x=0; x<params.length(); x++ ){
|
for( int x=0; x<params.length(); x++ ){
|
||||||
while(Math.abs(params.getDouble(x)) < 0.01 || Math.abs(params.getDouble(x)) > 1.5){
|
while(Math.abs(params.getDouble(x)) < 0.01 || Math.abs(params.getDouble(x)) > 1.5){
|
||||||
double d = Nd4j.getRandom().nextDouble();
|
double d = Nd4j.getRandom().nextDouble();
|
||||||
|
|
|
@ -37,10 +37,9 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
|
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
|
||||||
|
@ -254,8 +253,8 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
|
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
|
||||||
model2.init();
|
model2.init();
|
||||||
|
|
||||||
float[] p1 = model1.params().data().asFloat();
|
float[] p1 = model1.getModelParams().data().asFloat();
|
||||||
float[] p2 = model2.params().data().asFloat();
|
float[] p2 = model2.getModelParams().data().asFloat();
|
||||||
System.out.println(Arrays.toString(p1));
|
System.out.println(Arrays.toString(p1));
|
||||||
System.out.println(Arrays.toString(p2));
|
System.out.println(Arrays.toString(p2));
|
||||||
|
|
||||||
|
@ -266,20 +265,20 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
|
||||||
public void testTrainingListener() {
|
public void testTrainingListener() {
|
||||||
MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
|
MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
|
||||||
model1.init();
|
model1.init();
|
||||||
model1.addListeners(new ScoreIterationListener(1));
|
model1.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
|
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
|
||||||
model2.addListeners(new ScoreIterationListener(1));
|
model2.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
model2.init();
|
model2.init();
|
||||||
|
|
||||||
Layer[] l1 = model1.getLayers();
|
Layer[] l1 = model1.getLayers();
|
||||||
for (int i = 0; i < l1.length; i++) {
|
for (int i = 0; i < l1.length; i++) {
|
||||||
assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1);
|
assertTrue(l1[i].getTrainingListeners() != null && l1[i].getTrainingListeners().size() == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Layer[] l2 = model2.getLayers();
|
Layer[] l2 = model2.getLayers();
|
||||||
for (int i = 0; i < l2.length; i++) {
|
for (int i = 0; i < l2.length; i++) {
|
||||||
assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1);
|
assertTrue(l2[i].getTrainingListeners() != null && l2[i].getTrainingListeners().size() == 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -384,10 +383,10 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build())
|
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build())
|
||||||
.inputType(InputType.convolutional(28, 28, 1)).build();
|
.inputType(InputType.convolutional(28, 28, 1)).build();
|
||||||
|
|
||||||
org.deeplearning4j.nn.conf.layers.BaseLayer l0 = (BaseLayer) conf.getConf(0).getLayer();
|
BaseLayerConfiguration l0 = (BaseLayerConfiguration) conf.getConf(0).getLayer();
|
||||||
org.deeplearning4j.nn.conf.layers.BaseLayer l1 = (BaseLayer) conf.getConf(1).getLayer();
|
BaseLayerConfiguration l1 = (BaseLayerConfiguration) conf.getConf(1).getLayer();
|
||||||
org.deeplearning4j.nn.conf.layers.BaseLayer l2 = (BaseLayer) conf.getConf(2).getLayer();
|
BaseLayerConfiguration l2 = (BaseLayerConfiguration) conf.getConf(2).getLayer();
|
||||||
org.deeplearning4j.nn.conf.layers.BaseLayer l3 = (BaseLayer) conf.getConf(3).getLayer();
|
BaseLayerConfiguration l3 = (BaseLayerConfiguration) conf.getConf(3).getLayer();
|
||||||
|
|
||||||
assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6);
|
assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6);
|
||||||
assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6);
|
assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6);
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
@ -100,7 +100,7 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testClone() {
|
public void testClone() {
|
||||||
NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true);
|
NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true);
|
||||||
BaseLayer bl = (BaseLayer) conf.getFlattenedLayerConfigurations().get(0);
|
BaseLayerConfiguration bl = (BaseLayerConfiguration) conf.getFlattenedLayerConfigurations().get(0);
|
||||||
conf.setStepFunction(new DefaultStepFunction());
|
conf.setStepFunction(new DefaultStepFunction());
|
||||||
|
|
||||||
NeuralNetConfiguration conf2 = conf.clone();
|
NeuralNetConfiguration conf2 = conf.clone();
|
||||||
|
|
|
@ -158,7 +158,7 @@ public class ShiftVertexTest extends BaseDL4JTest {
|
||||||
cg.setInput(0, input);
|
cg.setInput(0, input);
|
||||||
cg.setLabel(0, target);
|
cg.setLabel(0, target);
|
||||||
cg.computeGradientAndScore();
|
cg.computeGradientAndScore();
|
||||||
double score_dl4j = cg.score();
|
double score_dl4j = cg.getScore();
|
||||||
Map<String, INDArray> weights = cg.getParamTable();
|
Map<String, INDArray> weights = cg.getParamTable();
|
||||||
Gradient g = cg.gradient();
|
Gradient g = cg.gradient();
|
||||||
Map<String, INDArray> gradients = g.gradientForVariable();
|
Map<String, INDArray> gradients = g.gradientForVariable();
|
||||||
|
|
|
@ -72,8 +72,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString());
|
assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getActivationFn().toString());
|
||||||
assertEquals("relu", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString());
|
assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getActivationFn().toString());
|
||||||
|
|
||||||
//With
|
//With
|
||||||
conf = NeuralNetConfiguration.builder().activation(Activation.RELU)
|
conf = NeuralNetConfiguration.builder().activation(Activation.RELU)
|
||||||
|
@ -83,8 +83,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString());
|
assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getActivationFn().toString());
|
||||||
assertEquals("tanh", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString());
|
assertEquals("tanh", ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getActivationFn().toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,11 +99,11 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn());
|
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn());
|
||||||
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn());
|
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn());
|
||||||
|
|
||||||
assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
|
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
|
||||||
assertEquals(1, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
|
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
|
||||||
|
|
||||||
//With:
|
//With:
|
||||||
final Distribution overriddenDistribution = new UniformDistribution(0, 1);
|
final Distribution overriddenDistribution = new UniformDistribution(0, 1);
|
||||||
|
@ -117,11 +117,11 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn());
|
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn());
|
||||||
assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn());
|
assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn());
|
||||||
|
|
||||||
assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
|
assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
|
||||||
assertEquals(0, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
|
assertEquals(0, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -137,8 +137,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(0.3, ((BaseLayer) conf.getConf(0).getLayer()).getLearningRate(), 0.0);
|
assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getLearningRate(), 0.0);
|
||||||
assertEquals(0.3, ((BaseLayer) conf.getConf(1).getLayer()).getLearningRate(), 0.0);
|
assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getLearningRate(), 0.0);
|
||||||
|
|
||||||
//With:
|
//With:
|
||||||
conf = NeuralNetConfiguration.builder().learningRate(0.3)
|
conf = NeuralNetConfiguration.builder().learningRate(0.3)
|
||||||
|
@ -148,8 +148,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(0.3, ((BaseLayer) conf.getConf(0).getLayer()).getLearningRate(), 0.0);
|
assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getLearningRate(), 0.0);
|
||||||
assertEquals(0.2, ((BaseLayer) conf.getConf(1).getLayer()).getLearningRate(), 0.0);
|
assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getLearningRate(), 0.0);
|
||||||
|
|
||||||
//L1 and L2 without layerwise override:
|
//L1 and L2 without layerwise override:
|
||||||
conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2)
|
conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2)
|
||||||
|
@ -158,10 +158,10 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(0.1, ((BaseLayer) conf.getConf(0).getLayer()).getL1(), 0.0);
|
assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL1(), 0.0);
|
||||||
assertEquals(0.1, ((BaseLayer) conf.getConf(1).getLayer()).getL1(), 0.0);
|
assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL1(), 0.0);
|
||||||
assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0);
|
assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL2(), 0.0);
|
||||||
assertEquals(0.2, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0);
|
assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL2(), 0.0);
|
||||||
|
|
||||||
//L1 and L2 with layerwise override:
|
//L1 and L2 with layerwise override:
|
||||||
conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2)
|
conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2)
|
||||||
|
@ -170,10 +170,10 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(0.9, ((BaseLayer) conf.getConf(0).getLayer()).getL1(), 0.0);
|
assertEquals(0.9, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL1(), 0.0);
|
||||||
assertEquals(0.1, ((BaseLayer) conf.getConf(1).getLayer()).getL1(), 0.0);
|
assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL1(), 0.0);
|
||||||
assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0);
|
assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL2(), 0.0);
|
||||||
assertEquals(0.8, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0);
|
assertEquals(0.8, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL2(), 0.0);
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
|
|
||||||
|
@ -213,8 +213,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
||||||
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
||||||
|
|
||||||
Map<Integer, Double> testMomentumAfter2 = new HashMap<>();
|
Map<Integer, Double> testMomentumAfter2 = new HashMap<>();
|
||||||
testMomentumAfter2.put(0, 0.2);
|
testMomentumAfter2.put(0, 0.2);
|
||||||
|
@ -227,8 +227,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
|
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
||||||
assertEquals(0.2, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
assertEquals(0.2, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -239,10 +239,10 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta);
|
assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta);
|
||||||
assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
|
assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
|
||||||
assertEquals(0.5, ((AdaDelta)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0);
|
assertEquals(0.5, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0);
|
||||||
assertEquals(0.01, ((AdaDelta)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
|
assertEquals(0.01, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
|
||||||
|
|
||||||
conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON))
|
conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON))
|
||||||
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build())
|
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build())
|
||||||
|
@ -252,10 +252,10 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp);
|
assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp);
|
||||||
assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
|
assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
|
||||||
assertEquals(1.0, ((RmsProp) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0);
|
assertEquals(1.0, ((RmsProp) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0);
|
||||||
assertEquals(0.5, ((AdaDelta) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
|
assertEquals(0.5, ((AdaDelta) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -270,10 +270,10 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0);
|
assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0);
|
||||||
assertEquals(0.6, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0);
|
assertEquals(0.6, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0);
|
||||||
assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0);
|
assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0);
|
||||||
assertEquals(0.7, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta2(), 0.0);
|
assertEquals(0.7, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta2(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -287,13 +287,11 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
|
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
BaseLayerConfiguration bconf = (BaseLayerConfiguration) conf.getConf(0).getLayer();
|
||||||
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue,
|
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization());
|
||||||
conf.getConf(0).getLayer().getGradientNormalization());
|
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization());
|
||||||
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue,
|
assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0);
|
||||||
conf.getConf(1).getLayer().getGradientNormalization());
|
assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0);
|
||||||
assertEquals(10, conf.getConf(0).getLayer().getGradientNormalizationThreshold(), 0.0);
|
|
||||||
assertEquals(10, conf.getConf(1).getLayer().getGradientNormalizationThreshold(), 0.0);
|
|
||||||
|
|
||||||
//With:
|
//With:
|
||||||
conf = NeuralNetConfiguration.builder()
|
conf = NeuralNetConfiguration.builder()
|
||||||
|
@ -308,11 +306,10 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue,
|
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization());
|
||||||
conf.getConf(0).getLayer().getGradientNormalization());
|
assertEquals(GradientNormalization.None, bconf.getGradientNormalization());
|
||||||
assertEquals(GradientNormalization.None, conf.getConf(1).getLayer().getGradientNormalization());
|
assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0);
|
||||||
assertEquals(10, conf.getConf(0).getLayer().getGradientNormalizationThreshold(), 0.0);
|
assertEquals(2.5, bconf.getGradientNormalizationThreshold(), 0.0);
|
||||||
assertEquals(2.5, conf.getConf(1).getLayer().getGradientNormalizationThreshold(), 0.0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -162,12 +162,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
BaseLayer layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration();
|
BaseLayerConfiguration layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3);
|
assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3);
|
||||||
assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
|
assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
|
||||||
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
|
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
|
||||||
|
|
||||||
BaseLayer layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration();
|
BaseLayerConfiguration layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3);
|
assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3);
|
||||||
|
|
||||||
// Adam Updater
|
// Adam Updater
|
||||||
|
@ -178,11 +178,11 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration();
|
layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3);
|
assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3);
|
||||||
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
|
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
|
||||||
|
|
||||||
layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration();
|
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
|
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
|
||||||
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
|
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
|
||||||
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn());
|
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn());
|
||||||
|
@ -196,12 +196,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration();
|
layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3);
|
assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3);
|
||||||
assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
|
assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
|
||||||
assertNull(TestUtils.getL2Reg(layerConf.getRegularization()));
|
assertNull(TestUtils.getL2Reg(layerConf.getRegularization()));
|
||||||
|
|
||||||
layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration();
|
layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3);
|
assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3);
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
@ -75,9 +75,9 @@ public class TestWeightNoise extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(wn, ((BaseLayer) net.getLayer(0).getLayerConfiguration()).getWeightNoise());
|
assertEquals(wn, ((BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration()).getWeightNoise());
|
||||||
assertEquals(new DropConnect(0.25), ((BaseLayer) net.getLayer(1).getLayerConfiguration()).getWeightNoise());
|
assertEquals(new DropConnect(0.25), ((BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration()).getWeightNoise());
|
||||||
assertEquals(wn, ((BaseLayer) net.getLayer(2).getLayerConfiguration()).getWeightNoise());
|
assertEquals(wn, ((BaseLayerConfiguration) net.getLayer(2).getLayerConfiguration()).getWeightNoise());
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
|
|
||||||
|
@ -95,9 +95,9 @@ public class TestWeightNoise extends BaseDL4JTest {
|
||||||
ComputationGraph graph = new ComputationGraph(conf2);
|
ComputationGraph graph = new ComputationGraph(conf2);
|
||||||
graph.init();
|
graph.init();
|
||||||
|
|
||||||
assertEquals(wn, ((BaseLayer) graph.getLayer(0).getLayerConfiguration()).getWeightNoise());
|
assertEquals(wn, ((BaseLayerConfiguration) graph.getLayer(0).getLayerConfiguration()).getWeightNoise());
|
||||||
assertEquals(new DropConnect(0.25), ((BaseLayer) graph.getLayer(1).getLayerConfiguration()).getWeightNoise());
|
assertEquals(new DropConnect(0.25), ((BaseLayerConfiguration) graph.getLayer(1).getLayerConfiguration()).getWeightNoise());
|
||||||
assertEquals(wn, ((BaseLayer) graph.getLayer(2).getLayerConfiguration()).getWeightNoise());
|
assertEquals(wn, ((BaseLayerConfiguration) graph.getLayer(2).getLayerConfiguration()).getWeightNoise());
|
||||||
|
|
||||||
TestUtils.testModelSerialization(graph);
|
TestUtils.testModelSerialization(graph);
|
||||||
|
|
||||||
|
|
|
@ -124,7 +124,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
|
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
|
||||||
|
@ -260,8 +260,8 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
for (NeuralNetConfiguration nnc : conf.getNetConfigurations()) {
|
for (NeuralNetConfiguration nnc : conf.getNetConfigurations()) {
|
||||||
LayerConfiguration l = nnc.getFlattenedLayerConfigurations().get(0);
|
LayerConfiguration l = nnc.getFlattenedLayerConfigurations().get(0);
|
||||||
seenLayers.add(l.getClass());
|
seenLayers.add(l.getClass());
|
||||||
if (l instanceof BaseWrapperLayer) {
|
if (l instanceof BaseWrapperLayerConfiguration) {
|
||||||
BaseWrapperLayer bwl = (BaseWrapperLayer) l;
|
BaseWrapperLayerConfiguration bwl = (BaseWrapperLayerConfiguration) l;
|
||||||
seenLayers.add(bwl.getUnderlying().getClass());
|
seenLayers.add(bwl.getUnderlying().getClass());
|
||||||
} else if (l instanceof Bidirectional) {
|
} else if (l instanceof Bidirectional) {
|
||||||
seenLayers.add(((Bidirectional) l).getFwd().getClass());
|
seenLayers.add(((Bidirectional) l).getFwd().getClass());
|
||||||
|
@ -321,17 +321,17 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
net.setInput(inD);
|
net.setInput(inD);
|
||||||
net.setLabels(lD);
|
net.setLabels(lD);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double scoreDouble = net.score();
|
double scoreDouble = net.getScore();
|
||||||
INDArray grads = net.getFlattenedGradients();
|
INDArray grads = net.getFlattenedGradients();
|
||||||
INDArray u = net.getUpdater().getStateViewArray();
|
INDArray u = net.getUpdater().getStateViewArray();
|
||||||
assertEquals(DataType.DOUBLE, net.params().dataType());
|
assertEquals(DataType.DOUBLE, net.getModelParams().dataType());
|
||||||
assertEquals(DataType.DOUBLE, grads.dataType());
|
assertEquals(DataType.DOUBLE, grads.dataType());
|
||||||
assertEquals(DataType.DOUBLE, u.dataType());
|
assertEquals(DataType.DOUBLE, u.dataType());
|
||||||
|
|
||||||
|
|
||||||
MultiLayerNetwork netFloat = net.convertDataType(DataType.FLOAT);
|
MultiLayerNetwork netFloat = net.convertDataType(DataType.FLOAT);
|
||||||
netFloat.initGradientsView();
|
netFloat.initGradientsView();
|
||||||
assertEquals(DataType.FLOAT, netFloat.params().dataType());
|
assertEquals(DataType.FLOAT, netFloat.getModelParams().dataType());
|
||||||
assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType());
|
assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType());
|
||||||
assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType());
|
assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType());
|
||||||
INDArray inF = inD.castTo(DataType.FLOAT);
|
INDArray inF = inD.castTo(DataType.FLOAT);
|
||||||
|
@ -340,7 +340,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
netFloat.setInput(inF);
|
netFloat.setInput(inF);
|
||||||
netFloat.setLabels(lF);
|
netFloat.setLabels(lF);
|
||||||
netFloat.computeGradientAndScore();
|
netFloat.computeGradientAndScore();
|
||||||
double scoreFloat = netFloat.score();
|
double scoreFloat = netFloat.getScore();
|
||||||
INDArray gradsFloat = netFloat.getFlattenedGradients();
|
INDArray gradsFloat = netFloat.getFlattenedGradients();
|
||||||
INDArray uFloat = netFloat.getUpdater().getStateViewArray();
|
INDArray uFloat = netFloat.getUpdater().getStateViewArray();
|
||||||
|
|
||||||
|
@ -352,7 +352,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork netFP16 = net.convertDataType(DataType.HALF);
|
MultiLayerNetwork netFP16 = net.convertDataType(DataType.HALF);
|
||||||
netFP16.initGradientsView();
|
netFP16.initGradientsView();
|
||||||
assertEquals(DataType.HALF, netFP16.params().dataType());
|
assertEquals(DataType.HALF, netFP16.getModelParams().dataType());
|
||||||
assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType());
|
assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType());
|
||||||
assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType());
|
assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType());
|
||||||
|
|
||||||
|
@ -362,7 +362,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
netFP16.setInput(inH);
|
netFP16.setInput(inH);
|
||||||
netFP16.setLabels(lH);
|
netFP16.setLabels(lH);
|
||||||
netFP16.computeGradientAndScore();
|
netFP16.computeGradientAndScore();
|
||||||
double scoreHalf = netFP16.score();
|
double scoreHalf = netFP16.getScore();
|
||||||
INDArray gradsHalf = netFP16.getFlattenedGradients();
|
INDArray gradsHalf = netFP16.getFlattenedGradients();
|
||||||
INDArray uHalf = netFP16.getUpdater().getStateViewArray();
|
INDArray uHalf = netFP16.getUpdater().getStateViewArray();
|
||||||
|
|
||||||
|
@ -406,17 +406,17 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
net.setInput(0, inD);
|
net.setInput(0, inD);
|
||||||
net.setLabels(lD);
|
net.setLabels(lD);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double scoreDouble = net.score();
|
double scoreDouble = net.getScore();
|
||||||
INDArray grads = net.getFlattenedGradients();
|
INDArray grads = net.getFlattenedGradients();
|
||||||
INDArray u = net.getUpdater().getStateViewArray();
|
INDArray u = net.getUpdater().getStateViewArray();
|
||||||
assertEquals(DataType.DOUBLE, net.params().dataType());
|
assertEquals(DataType.DOUBLE, net.getModelParams().dataType());
|
||||||
assertEquals(DataType.DOUBLE, grads.dataType());
|
assertEquals(DataType.DOUBLE, grads.dataType());
|
||||||
assertEquals(DataType.DOUBLE, u.dataType());
|
assertEquals(DataType.DOUBLE, u.dataType());
|
||||||
|
|
||||||
|
|
||||||
ComputationGraph netFloat = net.convertDataType(DataType.FLOAT);
|
ComputationGraph netFloat = net.convertDataType(DataType.FLOAT);
|
||||||
netFloat.initGradientsView();
|
netFloat.initGradientsView();
|
||||||
assertEquals(DataType.FLOAT, netFloat.params().dataType());
|
assertEquals(DataType.FLOAT, netFloat.getModelParams().dataType());
|
||||||
assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType());
|
assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType());
|
||||||
assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType());
|
assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType());
|
||||||
INDArray inF = inD.castTo(DataType.FLOAT);
|
INDArray inF = inD.castTo(DataType.FLOAT);
|
||||||
|
@ -425,7 +425,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
netFloat.setInput(0, inF);
|
netFloat.setInput(0, inF);
|
||||||
netFloat.setLabels(lF);
|
netFloat.setLabels(lF);
|
||||||
netFloat.computeGradientAndScore();
|
netFloat.computeGradientAndScore();
|
||||||
double scoreFloat = netFloat.score();
|
double scoreFloat = netFloat.getScore();
|
||||||
INDArray gradsFloat = netFloat.getFlattenedGradients();
|
INDArray gradsFloat = netFloat.getFlattenedGradients();
|
||||||
INDArray uFloat = netFloat.getUpdater().getStateViewArray();
|
INDArray uFloat = netFloat.getUpdater().getStateViewArray();
|
||||||
|
|
||||||
|
@ -437,7 +437,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph netFP16 = net.convertDataType(DataType.HALF);
|
ComputationGraph netFP16 = net.convertDataType(DataType.HALF);
|
||||||
netFP16.initGradientsView();
|
netFP16.initGradientsView();
|
||||||
assertEquals(DataType.HALF, netFP16.params().dataType());
|
assertEquals(DataType.HALF, netFP16.getModelParams().dataType());
|
||||||
assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType());
|
assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType());
|
||||||
assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType());
|
assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType());
|
||||||
|
|
||||||
|
@ -447,7 +447,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
netFP16.setInput(0, inH);
|
netFP16.setInput(0, inH);
|
||||||
netFP16.setLabels(lH);
|
netFP16.setLabels(lH);
|
||||||
netFP16.computeGradientAndScore();
|
netFP16.computeGradientAndScore();
|
||||||
double scoreHalf = netFP16.score();
|
double scoreHalf = netFP16.getScore();
|
||||||
INDArray gradsHalf = netFP16.getFlattenedGradients();
|
INDArray gradsHalf = netFP16.getFlattenedGradients();
|
||||||
INDArray uHalf = netFP16.getUpdater().getStateViewArray();
|
INDArray uHalf = netFP16.getUpdater().getStateViewArray();
|
||||||
|
|
||||||
|
@ -536,7 +536,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
net.initGradientsView();
|
net.initGradientsView();
|
||||||
assertEquals(networkDtype, net.params().dataType(), msg);
|
assertEquals(networkDtype, net.getModelParams().dataType(), msg);
|
||||||
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
|
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
|
||||||
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
|
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
|
||||||
|
|
||||||
|
@ -641,7 +641,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
net.initGradientsView();
|
net.initGradientsView();
|
||||||
assertEquals(networkDtype, net.params().dataType(), msg);
|
assertEquals(networkDtype, net.getModelParams().dataType(), msg);
|
||||||
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
|
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
|
||||||
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
|
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
|
||||||
|
|
||||||
|
@ -754,7 +754,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
net.initGradientsView();
|
net.initGradientsView();
|
||||||
assertEquals(networkDtype, net.params().dataType(), msg);
|
assertEquals(networkDtype, net.getModelParams().dataType(), msg);
|
||||||
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
|
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
|
||||||
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
|
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
|
||||||
|
|
||||||
|
@ -827,7 +827,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
net.initGradientsView();
|
net.initGradientsView();
|
||||||
assertEquals(networkDtype, net.params().dataType(), msg);
|
assertEquals(networkDtype, net.getModelParams().dataType(), msg);
|
||||||
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
|
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
|
||||||
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
|
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
|
||||||
|
|
||||||
|
@ -916,7 +916,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
net.initGradientsView();
|
net.initGradientsView();
|
||||||
assertEquals(networkDtype, net.params().dataType(), msg);
|
assertEquals(networkDtype, net.getModelParams().dataType(), msg);
|
||||||
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
|
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
|
||||||
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
|
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
|
||||||
|
|
||||||
|
|
|
@ -520,9 +520,9 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
|
||||||
INDArray inputLong = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
|
INDArray inputLong = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
|
||||||
INDArray labelsLong = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength);
|
INDArray labelsLong = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength);
|
||||||
|
|
||||||
INDArray initialParams = graph.params().dup();
|
INDArray initialParams = graph.getModelParams().dup();
|
||||||
graph.fit(new INDArray[] {inputLong}, new INDArray[] {labelsLong});
|
graph.fit(new INDArray[] {inputLong}, new INDArray[] {labelsLong});
|
||||||
INDArray afterParams = graph.params();
|
INDArray afterParams = graph.getModelParams();
|
||||||
|
|
||||||
assertNotEquals(initialParams, afterParams);
|
assertNotEquals(initialParams, afterParams);
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,7 +117,7 @@ public class TestCompGraphCNN extends BaseDL4JTest {
|
||||||
boolean orderOK = Arrays.equals(expOrder1, order) || Arrays.equals(expOrder2, order);
|
boolean orderOK = Arrays.equals(expOrder1, order) || Arrays.equals(expOrder2, order);
|
||||||
assertTrue(orderOK);
|
assertTrue(orderOK);
|
||||||
|
|
||||||
INDArray params = graph.params();
|
INDArray params = graph.getModelParams();
|
||||||
assertNotNull(params);
|
assertNotNull(params);
|
||||||
|
|
||||||
// confirm param shape is what is expected
|
// confirm param shape is what is expected
|
||||||
|
@ -129,7 +129,7 @@ public class TestCompGraphCNN extends BaseDL4JTest {
|
||||||
|
|
||||||
// params are set
|
// params are set
|
||||||
graph.setParams(arr);
|
graph.setParams(arr);
|
||||||
params = graph.params();
|
params = graph.getModelParams();
|
||||||
assertEquals(arr, params);
|
assertEquals(arr, params);
|
||||||
|
|
||||||
//Number of inputs and outputs:
|
//Number of inputs and outputs:
|
||||||
|
|
|
@ -108,7 +108,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int count = Nd4j.getExecutioner().exec(new MatchCondition(cg.params(), Conditions.isNan())).getInt(0);
|
int count = Nd4j.getExecutioner().exec(new MatchCondition(cg.getModelParams(), Conditions.isNan())).getInt(0);
|
||||||
assertEquals(0, count);
|
assertEquals(0, count);
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
count = Nd4j.getExecutioner().exec(new MatchCondition(cg.params(), Conditions.isNan())).getInt(0);
|
count = Nd4j.getExecutioner().exec(new MatchCondition(cg.getModelParams(), Conditions.isNan())).getInt(0);
|
||||||
assertEquals(0, count);
|
assertEquals(0, count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -176,7 +176,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
cg.pretrainLayer("0", ds);
|
cg.pretrainLayer("0", ds);
|
||||||
|
|
||||||
assertEquals(net.params(), cg.params());
|
assertEquals(net.getModelParams(), cg.getModelParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -159,7 +159,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
DataSet ds = iris.next();
|
DataSet ds = iris.next();
|
||||||
|
|
||||||
graph.setInput(0, ds.getFeatures());
|
graph.setInput(0, ds.getFeatures());
|
||||||
net.setParams(graph.params());
|
net.setParams(graph.getModelParams());
|
||||||
Map<String, INDArray> activations = graph.feedForward(false);
|
Map<String, INDArray> activations = graph.feedForward(false);
|
||||||
|
|
||||||
List<INDArray> feedForward = net.feedForward(ds.getFeatures());
|
List<INDArray> feedForward = net.feedForward(ds.getFeatures());
|
||||||
|
@ -184,7 +184,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
int[] expOrder = new int[]{0, 1, 2};
|
int[] expOrder = new int[]{0, 1, 2};
|
||||||
assertArrayEquals(expOrder, order); //Only one valid order: 0 (input) -> 1 (firstlayer) -> 2 (outputlayer)
|
assertArrayEquals(expOrder, order); //Only one valid order: 0 (input) -> 1 (firstlayer) -> 2 (outputlayer)
|
||||||
|
|
||||||
INDArray params = graph.params();
|
INDArray params = graph.getModelParams();
|
||||||
assertNotNull(params);
|
assertNotNull(params);
|
||||||
|
|
||||||
int nParams = getNumParams();
|
int nParams = getNumParams();
|
||||||
|
@ -194,7 +194,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
assertEquals(nParams, arr.length());
|
assertEquals(nParams, arr.length());
|
||||||
|
|
||||||
graph.setParams(arr);
|
graph.setParams(arr);
|
||||||
params = graph.params();
|
params = graph.getModelParams();
|
||||||
assertEquals(arr, params);
|
assertEquals(arr, params);
|
||||||
|
|
||||||
//Number of inputs and outputs:
|
//Number of inputs and outputs:
|
||||||
|
@ -315,8 +315,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
graph.fit(iris);
|
graph.fit(iris);
|
||||||
|
|
||||||
//Check that parameters are equal for both models after fitting:
|
//Check that parameters are equal for both models after fitting:
|
||||||
INDArray paramsMLN = net.params();
|
INDArray paramsMLN = net.getModelParams();
|
||||||
INDArray paramsGraph = graph.params();
|
INDArray paramsGraph = graph.getModelParams();
|
||||||
|
|
||||||
assertNotEquals(params, paramsGraph);
|
assertNotEquals(params, paramsGraph);
|
||||||
assertEquals(paramsMLN, paramsGraph);
|
assertEquals(paramsMLN, paramsGraph);
|
||||||
|
@ -636,7 +636,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.init();
|
net.init();
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(10, 150);
|
DataSetIterator iter = new IrisDataSetIterator(10, 150);
|
||||||
net.pretrain(iter);
|
net.pretrain(iter);
|
||||||
|
@ -675,7 +675,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph netNoReg = new ComputationGraph(confNoReg);
|
ComputationGraph netNoReg = new ComputationGraph(confNoReg);
|
||||||
netNoReg.init();
|
netNoReg.init();
|
||||||
netNoReg.setParams(net.params().dup());
|
netNoReg.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
//Score single example, and compare to scoreExamples:
|
//Score single example, and compare to scoreExamples:
|
||||||
INDArray input = Nd4j.rand(3, nIn);
|
INDArray input = Nd4j.rand(3, nIn);
|
||||||
|
@ -878,13 +878,13 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
net.setParam("first_b", Nd4j.ones(1, 5));
|
net.setParam("first_b", Nd4j.ones(1, 5));
|
||||||
net.setParam("output_W", Nd4j.ones(5, 3));
|
net.setParam("output_W", Nd4j.ones(5, 3));
|
||||||
net.setParam("output_b", Nd4j.ones(1, 3));
|
net.setParam("output_b", Nd4j.ones(1, 3));
|
||||||
INDArray actualParams = net.params();
|
INDArray actualParams = net.getModelParams();
|
||||||
|
|
||||||
// Confirm params
|
// Confirm params
|
||||||
assertEquals(Nd4j.ones(1, 43), actualParams);
|
assertEquals(Nd4j.ones(1, 43), actualParams);
|
||||||
|
|
||||||
net.update(expectedGradient);
|
net.update(expectedGradient);
|
||||||
actualParams = net.params();
|
actualParams = net.getModelParams();
|
||||||
assertEquals(Nd4j.ones(1, 43).addi(1), actualParams);
|
assertEquals(Nd4j.ones(1, 43).addi(1), actualParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1638,7 +1638,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
conf3.setTopologicalOrderStr(null);
|
conf3.setTopologicalOrderStr(null);
|
||||||
ComputationGraph cg3 = new ComputationGraph(conf3);
|
ComputationGraph cg3 = new ComputationGraph(conf3);
|
||||||
cg3.init();
|
cg3.init();
|
||||||
cg3.setParams(cg2.params());
|
cg3.setParams(cg2.getModelParams());
|
||||||
|
|
||||||
int[] order3 = cg3.topologicalSortOrder();
|
int[] order3 = cg3.topologicalSortOrder();
|
||||||
List<String> strOrder3 = cg.getComputationGraphConfiguration().getTopologicalOrderStr();
|
List<String> strOrder3 = cg.getComputationGraphConfiguration().getTopologicalOrderStr();
|
||||||
|
@ -1712,7 +1712,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
exp.add(ComputationGraph.class);
|
exp.add(ComputationGraph.class);
|
||||||
|
|
||||||
MultiLayerTest.CheckModelsListener listener = new MultiLayerTest.CheckModelsListener();
|
MultiLayerTest.CheckModelsListener listener = new MultiLayerTest.CheckModelsListener();
|
||||||
net.setListeners(listener);
|
net.addTrainingListeners(listener);
|
||||||
|
|
||||||
INDArray f = Nd4j.create(1,10);
|
INDArray f = Nd4j.create(1,10);
|
||||||
INDArray l = Nd4j.create(1,10);
|
INDArray l = Nd4j.create(1,10);
|
||||||
|
@ -1874,7 +1874,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph cg = new ComputationGraph(conf);
|
ComputationGraph cg = new ComputationGraph(conf);
|
||||||
cg.init();
|
cg.init();
|
||||||
cg.params().assign(Nd4j.linspace(1, 220, 220).reshape(1, -11));
|
cg.getModelParams().assign(Nd4j.linspace(1, 220, 220).reshape(1, -11));
|
||||||
|
|
||||||
INDArray p0w = cg.getParam("layer_zero_W");
|
INDArray p0w = cg.getParam("layer_zero_W");
|
||||||
assertEquals(Nd4j.linspace(1, 100, 100).reshape('f', 10, 10), p0w);
|
assertEquals(Nd4j.linspace(1, 100, 100).reshape('f', 10, 10), p0w);
|
||||||
|
|
|
@ -56,7 +56,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.init();
|
net.init();
|
||||||
INDArray params = net.params();
|
INDArray params = net.getModelParams();
|
||||||
|
|
||||||
|
|
||||||
ComputationGraph net2 = new ComputationGraph(conf);
|
ComputationGraph net2 = new ComputationGraph(conf);
|
||||||
|
@ -65,11 +65,11 @@ public class TestSetGetParameters extends BaseDL4JTest {
|
||||||
ComputationGraph net3 = new ComputationGraph(conf);
|
ComputationGraph net3 = new ComputationGraph(conf);
|
||||||
net3.init(params, false);
|
net3.init(params, false);
|
||||||
|
|
||||||
assertEquals(params, net2.params());
|
assertEquals(params, net2.getModelParams());
|
||||||
assertEquals(params, net3.params());
|
assertEquals(params, net3.getModelParams());
|
||||||
|
|
||||||
assertNotSame(params, net2.params()); //Different objects due to clone
|
assertNotSame(params, net2.getModelParams()); //Different objects due to clone
|
||||||
assertSame(params, net3.params()); //Same object due to clone
|
assertSame(params, net3.getModelParams()); //Same object due to clone
|
||||||
|
|
||||||
|
|
||||||
Map<String, INDArray> paramsMap = net.getParamTable();
|
Map<String, INDArray> paramsMap = net.getParamTable();
|
||||||
|
|
|
@ -103,14 +103,14 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
|
||||||
net.setInput(0, in1);
|
net.setInput(0, in1);
|
||||||
net.setLabel(0, labels1);
|
net.setLabel(0, labels1);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score1 = net.score();
|
double score1 = net.getScore();
|
||||||
Gradient g1 = net.gradient();
|
Gradient g1 = net.gradient();
|
||||||
|
|
||||||
net.setInput(0, in2);
|
net.setInput(0, in2);
|
||||||
net.setLabel(0, labels2);
|
net.setLabel(0, labels2);
|
||||||
net.setLayerMaskArrays(null, new INDArray[] {labelMask});
|
net.setLayerMaskArrays(null, new INDArray[] {labelMask});
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score2 = net.score();
|
double score2 = net.getScore();
|
||||||
Gradient g2 = net.gradient();
|
Gradient g2 = net.gradient();
|
||||||
|
|
||||||
//Scores and gradients should be identical for two cases (given mask array)
|
//Scores and gradients should be identical for two cases (given mask array)
|
||||||
|
@ -134,7 +134,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
net.setLabel(0, labels2);
|
net.setLabel(0, labels2);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score2a = net.score();
|
double score2a = net.getScore();
|
||||||
Gradient g2a = net.gradient();
|
Gradient g2a = net.gradient();
|
||||||
assertEquals(score2, score2a, 1e-6);
|
assertEquals(score2, score2a, 1e-6);
|
||||||
for (String s : g2map.keySet()) {
|
for (String s : g2map.keySet()) {
|
||||||
|
@ -200,7 +200,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
|
||||||
net.setInput(0, in1);
|
net.setInput(0, in1);
|
||||||
net.setLabel(0, labels1);
|
net.setLabel(0, labels1);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score1 = net.score();
|
double score1 = net.getScore();
|
||||||
Gradient g1 = net.gradient();
|
Gradient g1 = net.gradient();
|
||||||
Map<String, INDArray> map = g1.gradientForVariable();
|
Map<String, INDArray> map = g1.gradientForVariable();
|
||||||
for (String s : map.keySet()) {
|
for (String s : map.keySet()) {
|
||||||
|
@ -211,7 +211,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
|
||||||
net.setLabel(0, labels2);
|
net.setLabel(0, labels2);
|
||||||
net.setLayerMaskArrays(new INDArray[] {inputMask}, null);
|
net.setLayerMaskArrays(new INDArray[] {inputMask}, null);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score2 = net.score();
|
double score2 = net.getScore();
|
||||||
Gradient g2 = net.gradient();
|
Gradient g2 = net.gradient();
|
||||||
Map<String, INDArray> activations2 = net.feedForward();
|
Map<String, INDArray> activations2 = net.feedForward();
|
||||||
|
|
||||||
|
@ -236,7 +236,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
|
||||||
net.setInput(0, in2);
|
net.setInput(0, in2);
|
||||||
net.setLayerMaskArrays(new INDArray[]{inputMask}, null);
|
net.setLayerMaskArrays(new INDArray[]{inputMask}, null);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score2a = net.score();
|
double score2a = net.getScore();
|
||||||
Gradient g2a = net.gradient();
|
Gradient g2a = net.gradient();
|
||||||
assertEquals(score2, score2a, 1e-12);
|
assertEquals(score2, score2a, 1e-12);
|
||||||
for (String s : g2.gradientForVariable().keySet()) {
|
for (String s : g2.gradientForVariable().keySet()) {
|
||||||
|
@ -330,7 +330,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
|
||||||
net.setLabel(0, labels);
|
net.setLabel(0, labels);
|
||||||
|
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score = net.score();
|
double score = net.getScore();
|
||||||
|
|
||||||
assertEquals(expScore, score, 0.1, msg);
|
assertEquals(expScore, score, 0.1, msg);
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ import java.util.Map;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||||
|
|
||||||
public class BaseLayerTest extends BaseDL4JTest {
|
public class BaseLayerConfigurationTest extends BaseDL4JTest {
|
||||||
|
|
||||||
protected INDArray weight = Nd4j.create(new double[] {0.10, -0.20, -0.15, 0.05}, new int[] {2, 2});
|
protected INDArray weight = Nd4j.create(new double[] {0.10, -0.20, -0.15, 0.05}, new int[] {2, 2});
|
||||||
protected INDArray bias = Nd4j.create(new double[] {0.5, 0.5}, new int[] {1, 2});
|
protected INDArray bias = Nd4j.create(new double[] {0.5, 0.5}, new int[] {1, 2});
|
|
@ -56,10 +56,10 @@ public class CacheModeTest extends BaseDL4JTest {
|
||||||
INDArray out2 = net2.output(in);
|
INDArray out2 = net2.output(in);
|
||||||
assertEquals(out1, out2);
|
assertEquals(out1, out2);
|
||||||
|
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
net1.fit(in, labels);
|
net1.fit(in, labels);
|
||||||
net2.fit(in, labels);
|
net2.fit(in, labels);
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static NeuralNetConfiguration getConf(CacheMode cacheMode){
|
private static NeuralNetConfiguration getConf(CacheMode cacheMode){
|
||||||
|
@ -99,10 +99,10 @@ public class CacheModeTest extends BaseDL4JTest {
|
||||||
INDArray out2 = net2.output(in);
|
INDArray out2 = net2.output(in);
|
||||||
assertEquals(out1, out2);
|
assertEquals(out1, out2);
|
||||||
|
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
net1.fit(in, labels);
|
net1.fit(in, labels);
|
||||||
net2.fit(in, labels);
|
net2.fit(in, labels);
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,10 +145,10 @@ public class CacheModeTest extends BaseDL4JTest {
|
||||||
INDArray out2 = net2.outputSingle(in);
|
INDArray out2 = net2.outputSingle(in);
|
||||||
assertEquals(out1, out2);
|
assertEquals(out1, out2);
|
||||||
|
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
net1.fit(new DataSet(in, labels));
|
net1.fit(new DataSet(in, labels));
|
||||||
net2.fit(new DataSet(in, labels));
|
net2.fit(new DataSet(in, labels));
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){
|
private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){
|
||||||
|
|
|
@ -121,7 +121,7 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest {
|
||||||
graph.setInput(0, input);
|
graph.setInput(0, input);
|
||||||
graph.setLabel(0, labels);
|
graph.setLabel(0, labels);
|
||||||
graph.computeGradientAndScore();
|
graph.computeGradientAndScore();
|
||||||
results[i] = graph.score();
|
results[i] = graph.getScore();
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNotEquals(results[0], results[1]);
|
assertNotEquals(results[0], results[1]);
|
||||||
|
@ -137,7 +137,7 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph net = getCNNMnistConfig();
|
ComputationGraph net = getCNNMnistConfig();
|
||||||
net.init();
|
net.init();
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
for (int i = 0; i < 50; i++) {
|
for (int i = 0; i < 50; i++) {
|
||||||
net.fit(mnistTrain.next());
|
net.fit(mnistTrain.next());
|
||||||
|
|
|
@ -265,7 +265,7 @@ public class DropoutLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
|
MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
|
||||||
netSeparate.init();
|
netSeparate.init();
|
||||||
|
|
||||||
assertEquals(netIntegrated.params(), netSeparate.params());
|
assertEquals(netIntegrated.getModelParams(), netSeparate.getModelParams());
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
netIntegrated.fit(next);
|
netIntegrated.fit(next);
|
||||||
|
@ -273,7 +273,7 @@ public class DropoutLayerTest extends BaseDL4JTest {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
netSeparate.fit(next);
|
netSeparate.fit(next);
|
||||||
|
|
||||||
assertEquals(netIntegrated.params(), netSeparate.params());
|
assertEquals(netIntegrated.getModelParams(), netSeparate.getModelParams());
|
||||||
|
|
||||||
// check parameters
|
// check parameters
|
||||||
assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W"));
|
assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W"));
|
||||||
|
|
|
@ -80,7 +80,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
.setFeatureExtractor(1).build();
|
.setFeatureExtractor(1).build();
|
||||||
|
|
||||||
INDArray paramsLastTwoLayers =
|
INDArray paramsLastTwoLayers =
|
||||||
Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params());
|
Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams());
|
||||||
MultiLayerNetwork notFrozen = new MultiLayerNetwork(
|
MultiLayerNetwork notFrozen = new MultiLayerNetwork(
|
||||||
(NeuralNetConfiguration) overallConf.clone()
|
(NeuralNetConfiguration) overallConf.clone()
|
||||||
.layer(0, new Builder().nIn(2).nOut(3).build())
|
.layer(0, new Builder().nIn(2).nOut(3).build())
|
||||||
|
@ -102,9 +102,9 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
modelNow.fit(randomData);
|
modelNow.fit(randomData);
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(),
|
INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), modelToFineTune.getLayer(1).getParams(),
|
||||||
notFrozen.params());
|
notFrozen.getModelParams());
|
||||||
INDArray act = modelNow.params();
|
INDArray act = modelNow.getModelParams();
|
||||||
assertEquals(expected, act);
|
assertEquals(expected, act);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
assertEquals(modelNow.getNetConfiguration().toJson(), clonedModel.getNetConfiguration().toJson());
|
assertEquals(modelNow.getNetConfiguration().toJson(), clonedModel.getNetConfiguration().toJson());
|
||||||
|
|
||||||
//Check params
|
//Check params
|
||||||
assertEquals(modelNow.params(), clonedModel.params());
|
assertEquals(modelNow.getModelParams(), clonedModel.getModelParams());
|
||||||
|
|
||||||
MultiLayerNetwork notFrozen = new MultiLayerNetwork(
|
MultiLayerNetwork notFrozen = new MultiLayerNetwork(
|
||||||
(NeuralNetConfiguration) overallConf.layer(0, new Builder().nIn(2).nOut(3).build())
|
(NeuralNetConfiguration) overallConf.layer(0, new Builder().nIn(2).nOut(3).build())
|
||||||
|
@ -145,7 +145,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
||||||
.build())
|
.build())
|
||||||
.build(),
|
.build(),
|
||||||
Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()));
|
Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams()));
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (i < 5) {
|
while (i < 5) {
|
||||||
|
@ -155,10 +155,10 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(),
|
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(),
|
||||||
modelToFineTune.getLayer(1).params(), notFrozen.params());
|
modelToFineTune.getLayer(1).getParams(), notFrozen.getModelParams());
|
||||||
assertEquals(expectedParams, modelNow.params());
|
assertEquals(expectedParams, modelNow.getModelParams());
|
||||||
assertEquals(expectedParams, clonedModel.params());
|
assertEquals(expectedParams, clonedModel.getModelParams());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,8 +199,8 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
.setOutputs("layer1").build());
|
.setOutputs("layer1").build());
|
||||||
|
|
||||||
notFrozen.init();
|
notFrozen.init();
|
||||||
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(),
|
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").getParams(),
|
||||||
modelToFineTune.getLayer("layer3").params()));
|
modelToFineTune.getLayer("layer3").getParams()));
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (i < 5) {
|
while (i < 5) {
|
||||||
|
@ -209,8 +209,8 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").params(),
|
assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").getParams(),
|
||||||
modelToFineTune.getLayer("layer1").params(), notFrozen.params()), modelNow.params());
|
modelToFineTune.getLayer("layer1").getParams(), notFrozen.getModelParams()), modelNow.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -244,7 +244,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
assertEquals(clonedModel.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
|
assertEquals(clonedModel.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
|
||||||
|
|
||||||
//Check params
|
//Check params
|
||||||
assertEquals(modelNow.params(), clonedModel.params());
|
assertEquals(modelNow.getModelParams(), clonedModel.getModelParams());
|
||||||
|
|
||||||
ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
|
ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
|
||||||
.addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In")
|
.addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In")
|
||||||
|
@ -256,8 +256,8 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
"layer0")
|
"layer0")
|
||||||
.setOutputs("layer1").build());
|
.setOutputs("layer1").build());
|
||||||
notFrozen.init();
|
notFrozen.init();
|
||||||
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(),
|
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").getParams(),
|
||||||
modelToFineTune.getLayer("layer3").params()));
|
modelToFineTune.getLayer("layer3").getParams()));
|
||||||
|
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -268,10 +268,10 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").params(),
|
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").getParams(),
|
||||||
modelToFineTune.getLayer("layer1").params(), notFrozen.params());
|
modelToFineTune.getLayer("layer1").getParams(), notFrozen.getModelParams());
|
||||||
assertEquals(expectedParams, modelNow.params());
|
assertEquals(expectedParams, modelNow.getModelParams());
|
||||||
assertEquals(expectedParams, clonedModel.params());
|
assertEquals(expectedParams, clonedModel.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -305,7 +305,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
|
|
||||||
|
|
||||||
String json = conf2.toJson();
|
String json = conf2.toJson();
|
||||||
|
@ -362,7 +362,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
ComputationGraph net2 = new ComputationGraph(conf2);
|
ComputationGraph net2 = new ComputationGraph(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
|
|
||||||
|
|
||||||
String json = conf2.toJson();
|
String json = conf2.toJson();
|
||||||
|
|
|
@ -75,7 +75,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
|
|
||||||
|
|
||||||
String json = conf2.toJson();
|
String json = conf2.toJson();
|
||||||
|
@ -130,7 +130,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
ComputationGraph net2 = new ComputationGraph(conf2);
|
ComputationGraph net2 = new ComputationGraph(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
|
|
||||||
|
|
||||||
String json = conf2.toJson();
|
String json = conf2.toJson();
|
||||||
|
@ -170,19 +170,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork network = new MultiLayerNetwork(conf1);
|
MultiLayerNetwork network = new MultiLayerNetwork(conf1);
|
||||||
network.init();
|
network.init();
|
||||||
INDArray unfrozenLayerParams = network.getLayer(0).params().dup();
|
INDArray unfrozenLayerParams = network.getLayer(0).getParams().dup();
|
||||||
INDArray frozenLayerParams1 = network.getLayer(1).params().dup();
|
INDArray frozenLayerParams1 = network.getLayer(1).getParams().dup();
|
||||||
INDArray frozenLayerParams2 = network.getLayer(2).params().dup();
|
INDArray frozenLayerParams2 = network.getLayer(2).getParams().dup();
|
||||||
INDArray frozenOutputLayerParams = network.getLayer(3).params().dup();
|
INDArray frozenOutputLayerParams = network.getLayer(3).getParams().dup();
|
||||||
|
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
network.fit(randomData);
|
network.fit(randomData);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNotEquals(unfrozenLayerParams, network.getLayer(0).params());
|
assertNotEquals(unfrozenLayerParams, network.getLayer(0).getParams());
|
||||||
assertEquals(frozenLayerParams1, network.getLayer(1).params());
|
assertEquals(frozenLayerParams1, network.getLayer(1).getParams());
|
||||||
assertEquals(frozenLayerParams2, network.getLayer(2).params());
|
assertEquals(frozenLayerParams2, network.getLayer(2).getParams());
|
||||||
assertEquals(frozenOutputLayerParams, network.getLayer(3).params());
|
assertEquals(frozenOutputLayerParams, network.getLayer(3).getParams());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -228,19 +228,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph computationGraph = new ComputationGraph(computationGraphConf);
|
ComputationGraph computationGraph = new ComputationGraph(computationGraphConf);
|
||||||
computationGraph.init();
|
computationGraph.init();
|
||||||
INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
|
INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup();
|
||||||
INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
|
INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup();
|
||||||
INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
|
INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup();
|
||||||
INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).params().dup();
|
INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).getParams().dup();
|
||||||
|
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
computationGraph.fit(randomData);
|
computationGraph.fit(randomData);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).params());
|
assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams());
|
||||||
assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).params());
|
assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).getParams());
|
||||||
assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).params());
|
assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).getParams());
|
||||||
assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).params());
|
assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).getParams());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,17 +275,17 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
|
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
|
||||||
frozenNetwork.init();
|
frozenNetwork.init();
|
||||||
INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).params().dup();
|
INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).getParams().dup();
|
||||||
INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).params().dup();
|
INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).getParams().dup();
|
||||||
INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).params().dup();
|
INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).getParams().dup();
|
||||||
INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).params().dup();
|
INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).getParams().dup();
|
||||||
|
|
||||||
MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd);
|
MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd);
|
||||||
sgdNetwork.init();
|
sgdNetwork.init();
|
||||||
INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).params().dup();
|
INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).getParams().dup();
|
||||||
INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).params().dup();
|
INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).getParams().dup();
|
||||||
INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).params().dup();
|
INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).getParams().dup();
|
||||||
INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).params().dup();
|
INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).getParams().dup();
|
||||||
|
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
frozenNetwork.fit(randomData);
|
frozenNetwork.fit(randomData);
|
||||||
|
@ -294,10 +294,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
sgdNetwork.fit(randomData);
|
sgdNetwork.fit(randomData);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(frozenNetwork.getLayer(0).params(), sgdNetwork.getLayer(0).params());
|
assertEquals(frozenNetwork.getLayer(0).getParams(), sgdNetwork.getLayer(0).getParams());
|
||||||
assertEquals(frozenNetwork.getLayer(1).params(), sgdNetwork.getLayer(1).params());
|
assertEquals(frozenNetwork.getLayer(1).getParams(), sgdNetwork.getLayer(1).getParams());
|
||||||
assertEquals(frozenNetwork.getLayer(2).params(), sgdNetwork.getLayer(2).params());
|
assertEquals(frozenNetwork.getLayer(2).getParams(), sgdNetwork.getLayer(2).getParams());
|
||||||
assertEquals(frozenNetwork.getLayer(3).params(), sgdNetwork.getLayer(3).params());
|
assertEquals(frozenNetwork.getLayer(3).getParams(), sgdNetwork.getLayer(3).getParams());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -360,17 +360,17 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf);
|
ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf);
|
||||||
frozenComputationGraph.init();
|
frozenComputationGraph.init();
|
||||||
INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
|
INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup();
|
||||||
INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
|
INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup();
|
||||||
INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
|
INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup();
|
||||||
INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).params().dup();
|
INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).getParams().dup();
|
||||||
|
|
||||||
ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf);
|
ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf);
|
||||||
sgdComputationGraph.init();
|
sgdComputationGraph.init();
|
||||||
INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
|
INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup();
|
||||||
INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
|
INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup();
|
||||||
INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
|
INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup();
|
||||||
INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).params().dup();
|
INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).getParams().dup();
|
||||||
|
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
frozenComputationGraph.fit(randomData);
|
frozenComputationGraph.fit(randomData);
|
||||||
|
@ -379,10 +379,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
sgdComputationGraph.fit(randomData);
|
sgdComputationGraph.fit(randomData);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params());
|
assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams());
|
||||||
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params());
|
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams());
|
||||||
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params());
|
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams());
|
||||||
assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).params(), sgdComputationGraph.getLayer(frozenBranchOutput).params());
|
assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).getParams(), sgdComputationGraph.getLayer(frozenBranchOutput).getParams());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -68,9 +68,9 @@ public class OutputLayerTest extends BaseDL4JTest {
|
||||||
INDArray params = Nd4j.create(1, numParams);
|
INDArray params = Nd4j.create(1, numParams);
|
||||||
OutputLayer l = (OutputLayer) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf,
|
OutputLayer l = (OutputLayer) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf,
|
||||||
Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType());
|
Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType());
|
||||||
params = l.params();
|
params = l.getModelParams();
|
||||||
l.setParamsTable(params);
|
l.setParamsTable(params);
|
||||||
assertEquals(params, l.params());
|
assertEquals(params, l.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -217,8 +217,8 @@ public class OutputLayerTest extends BaseDL4JTest {
|
||||||
//However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping)
|
//However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping)
|
||||||
//RnnOutputLayer has miniBatch examples
|
//RnnOutputLayer has miniBatch examples
|
||||||
//Hence: expect difference in scores by factor of timeSeriesLength
|
//Hence: expect difference in scores by factor of timeSeriesLength
|
||||||
double score = mln.score() * timeSeriesLength;
|
double score = mln.getScore() * timeSeriesLength;
|
||||||
double scoreRNN = mlnRnn.score();
|
double scoreRNN = mlnRnn.getScore();
|
||||||
|
|
||||||
assertFalse(Double.isNaN(score));
|
assertFalse(Double.isNaN(score));
|
||||||
assertFalse(Double.isNaN(scoreRNN));
|
assertFalse(Double.isNaN(scoreRNN));
|
||||||
|
@ -234,7 +234,7 @@ public class OutputLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
RnnOutputLayer rnnol = (RnnOutputLayer) mlnRnn.getOutputLayer();
|
RnnOutputLayer rnnol = (RnnOutputLayer) mlnRnn.getOutputLayer();
|
||||||
//assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength});
|
//assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength});
|
||||||
//Input may be set by BaseLayer methods. Thus input may end up as reshaped 2d version instead of original 3d version.
|
//Input may be set by BaseLayerConfiguration methods. Thus input may end up as reshaped 2d version instead of original 3d version.
|
||||||
//Not ideal, but everything else works.
|
//Not ideal, but everything else works.
|
||||||
assertArrayEquals(rnnol.getLabels().shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
|
assertArrayEquals(rnnol.getLabels().shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ public class OutputLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2);
|
||||||
mln2.init();
|
mln2.init();
|
||||||
|
|
||||||
mln2.setParams(mln.params());
|
mln2.setParams(mln.getModelParams());
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
|
INDArray in = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
|
||||||
|
|
||||||
|
@ -330,7 +330,7 @@ public class OutputLayerTest extends BaseDL4JTest {
|
||||||
mln2.computeGradientAndScore();
|
mln2.computeGradientAndScore();
|
||||||
|
|
||||||
assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
|
assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
|
||||||
assertEquals(mln.score(), mln2.score(), 1e-6);
|
assertEquals(mln.getScore(), mln2.getScore(), 1e-6);
|
||||||
|
|
||||||
TestUtils.testModelSerialization(mln);
|
TestUtils.testModelSerialization(mln);
|
||||||
}
|
}
|
||||||
|
@ -386,7 +386,7 @@ public class OutputLayerTest extends BaseDL4JTest {
|
||||||
mln2.init();
|
mln2.init();
|
||||||
|
|
||||||
|
|
||||||
mln2.setParams(mln.params());
|
mln2.setParams(mln.getModelParams());
|
||||||
|
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(3, 3, 5, 5);
|
INDArray in = Nd4j.rand(3, 3, 5, 5);
|
||||||
|
@ -407,7 +407,7 @@ public class OutputLayerTest extends BaseDL4JTest {
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
mln2.computeGradientAndScore();
|
mln2.computeGradientAndScore();
|
||||||
|
|
||||||
assertEquals(mln.score(), mln2.score(), 1e-6);
|
assertEquals(mln.getScore(), mln2.getScore(), 1e-6);
|
||||||
assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
|
assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
|
||||||
|
|
||||||
//Also check computeScoreForExamples
|
//Also check computeScoreForExamples
|
||||||
|
@ -479,7 +479,7 @@ public class OutputLayerTest extends BaseDL4JTest {
|
||||||
graph2.init();
|
graph2.init();
|
||||||
|
|
||||||
|
|
||||||
graph2.setParams(graph.params());
|
graph2.setParams(graph.getModelParams());
|
||||||
|
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(3, 3, 5, 5);
|
INDArray in = Nd4j.rand(3, 3, 5, 5);
|
||||||
|
@ -500,7 +500,7 @@ public class OutputLayerTest extends BaseDL4JTest {
|
||||||
graph.computeGradientAndScore();
|
graph.computeGradientAndScore();
|
||||||
graph2.computeGradientAndScore();
|
graph2.computeGradientAndScore();
|
||||||
|
|
||||||
assertEquals(graph.score(), graph2.score(), 1e-6);
|
assertEquals(graph.getScore(), graph2.getScore(), 1e-6);
|
||||||
assertEquals(graph.gradient().gradient(), graph2.gradient().gradient());
|
assertEquals(graph.gradient().gradient(), graph2.gradient().gradient());
|
||||||
|
|
||||||
//Also check computeScoreForExamples
|
//Also check computeScoreForExamples
|
||||||
|
|
|
@ -59,13 +59,13 @@ public class SeedTest extends BaseDL4JTest {
|
||||||
layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
double score = layer.score();
|
double score = layer.getScore();
|
||||||
INDArray parameters = layer.params();
|
INDArray parameters = layer.getParams();
|
||||||
layer.setParams(parameters);
|
layer.setParams(parameters);
|
||||||
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
double score2 = layer.score();
|
double score2 = layer.getScore();
|
||||||
assertEquals(parameters, layer.params());
|
assertEquals(parameters, layer.getParams());
|
||||||
assertEquals(score, score2, 1e-4);
|
assertEquals(score, score2, 1e-4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -845,9 +845,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
|
|
||||||
public static void testHelper(TestCase tc) {
|
public static void testHelper(TestCase tc) {
|
||||||
|
|
||||||
tc.net2.params().assign(tc.net1.params());
|
tc.net2.getModelParams().assign(tc.net1.getModelParams());
|
||||||
tc.net3.params().assign(tc.net1.params());
|
tc.net3.getModelParams().assign(tc.net1.getModelParams());
|
||||||
tc.net4.params().assign(tc.net1.params());
|
tc.net4.getModelParams().assign(tc.net1.getModelParams());
|
||||||
|
|
||||||
//Test forward pass:
|
//Test forward pass:
|
||||||
INDArray inNCHW = tc.inNCHW;
|
INDArray inNCHW = tc.inNCHW;
|
||||||
|
@ -909,9 +909,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
tc.net3.fit(inNHWC, tc.labelsNHWC);
|
tc.net3.fit(inNHWC, tc.labelsNHWC);
|
||||||
tc.net4.fit(inNHWC, tc.labelsNHWC);
|
tc.net4.fit(inNHWC, tc.labelsNHWC);
|
||||||
|
|
||||||
assertEquals(tc.net1.params(), tc.net2.params(), tc.msg);
|
assertEquals(tc.net1.getModelParams(), tc.net2.getModelParams(), tc.msg);
|
||||||
assertEquals(tc.net1.params(), tc.net3.params(), tc.msg);
|
assertEquals(tc.net1.getModelParams(), tc.net3.getModelParams(), tc.msg);
|
||||||
assertEquals(tc.net1.params(), tc.net4.params(), tc.msg);
|
assertEquals(tc.net1.getModelParams(), tc.net4.getModelParams(), tc.msg);
|
||||||
|
|
||||||
//Test serialization
|
//Test serialization
|
||||||
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
|
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
|
||||||
|
|
|
@ -30,7 +30,6 @@ import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
||||||
|
@ -38,7 +37,6 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitNormal;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -450,10 +448,10 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork net = getCNNMLNConfig(true, false);
|
MultiLayerNetwork net = getCNNMLNConfig(true, false);
|
||||||
|
|
||||||
INDArray paramsOrig = net.params().dup();
|
INDArray paramsOrig = net.getModelParams().dup();
|
||||||
net.setParams(paramsOrig);
|
net.setParams(paramsOrig);
|
||||||
|
|
||||||
INDArray params2 = net.params();
|
INDArray params2 = net.getModelParams();
|
||||||
|
|
||||||
assertEquals(paramsOrig, params2);
|
assertEquals(paramsOrig, params2);
|
||||||
}
|
}
|
||||||
|
|
|
@ -154,7 +154,7 @@ public class TestCustomLayers extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
assertEquals(net2.params(), net.params());
|
assertEquals(net2.getModelParams(), net.getModelParams());
|
||||||
|
|
||||||
INDArray testFeatures = Nd4j.rand(1, 10);
|
INDArray testFeatures = Nd4j.rand(1, 10);
|
||||||
INDArray testLabels = Nd4j.zeros(1, 10);
|
INDArray testLabels = Nd4j.zeros(1, 10);
|
||||||
|
@ -207,7 +207,7 @@ public class TestCustomLayers extends BaseDL4JTest {
|
||||||
ComputationGraph net2 = new ComputationGraph(conf2);
|
ComputationGraph net2 = new ComputationGraph(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
assertEquals(net2.params(), net.params());
|
assertEquals(net2.getModelParams(), net.getModelParams());
|
||||||
|
|
||||||
INDArray testFeatures = Nd4j.rand(1, 10);
|
INDArray testFeatures = Nd4j.rand(1, 10);
|
||||||
INDArray testLabels = Nd4j.zeros(1, 10);
|
INDArray testLabels = Nd4j.zeros(1, 10);
|
||||||
|
|
|
@ -56,7 +56,7 @@ public class CustomLayer extends FeedForwardLayer {
|
||||||
boolean initializeParams, DataType networkDataType) {
|
boolean initializeParams, DataType networkDataType) {
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType);
|
CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType);
|
||||||
ret.setListeners(trainingListeners);
|
ret.addTrainingListeners(trainingListeners);
|
||||||
ret.setIndex(layerIndex);
|
ret.setIndex(layerIndex);
|
||||||
ret.setParamsViewArray(layerParamsView);
|
ret.setParamsViewArray(layerParamsView);
|
||||||
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
||||||
|
|
|
@ -54,7 +54,7 @@ public class CustomOutputLayer extends BaseOutputLayer {
|
||||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
CustomOutputLayerImpl ret = new CustomOutputLayerImpl(lconf, networkDataType);
|
CustomOutputLayerImpl ret = new CustomOutputLayerImpl(lconf, networkDataType);
|
||||||
ret.setListeners(trainingListeners);
|
ret.addTrainingListeners(trainingListeners);
|
||||||
ret.setIndex(layerIndex);
|
ret.setIndex(layerIndex);
|
||||||
ret.setParamsViewArray(layerParamsView);
|
ret.setParamsViewArray(layerParamsView);
|
||||||
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class DenseTest extends BaseDL4JTest {
|
||||||
|
|
||||||
DataSet test = iter.next();
|
DataSet test = iter.next();
|
||||||
|
|
||||||
assertEquals(model.params(), model2.params());
|
assertEquals(model.getModelParams(), model2.getModelParams());
|
||||||
|
|
||||||
Evaluation eval = new Evaluation();
|
Evaluation eval = new Evaluation();
|
||||||
INDArray output = model.output(test.getFeatures());
|
INDArray output = model.output(test.getFeatures());
|
||||||
|
@ -99,7 +99,7 @@ public class DenseTest extends BaseDL4JTest {
|
||||||
|
|
||||||
DataSet test = iter.next();
|
DataSet test = iter.next();
|
||||||
|
|
||||||
assertEquals(model.params(), model2.params());
|
assertEquals(model.getModelParams(), model2.getModelParams());
|
||||||
|
|
||||||
Evaluation eval = new Evaluation();
|
Evaluation eval = new Evaluation();
|
||||||
INDArray output = model.output(test.getFeatures());
|
INDArray output = model.output(test.getFeatures());
|
||||||
|
|
|
@ -169,7 +169,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
int batchSize = 3;
|
int batchSize = 3;
|
||||||
INDArray inEmbedding = Nd4j.create(batchSize, 1);
|
INDArray inEmbedding = Nd4j.create(batchSize, 1);
|
||||||
|
@ -216,7 +216,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
int batchSize = 3;
|
int batchSize = 3;
|
||||||
INDArray inEmbedding = Nd4j.create(batchSize, 1);
|
INDArray inEmbedding = Nd4j.create(batchSize, 1);
|
||||||
|
@ -262,7 +262,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
int batchSize = 3;
|
int batchSize = 3;
|
||||||
INDArray inEmbedding = Nd4j.create(batchSize, 1);
|
INDArray inEmbedding = Nd4j.create(batchSize, 1);
|
||||||
|
@ -287,7 +287,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
assertEquals(net2.score(), net.score(), 1e-6);
|
assertEquals(net2.getScore(), net.getScore(), 1e-6);
|
||||||
|
|
||||||
Map<String, INDArray> gradient = net.gradient().gradientForVariable();
|
Map<String, INDArray> gradient = net.gradient().gradientForVariable();
|
||||||
Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
|
Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
|
||||||
|
@ -323,7 +323,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
int batchSize = 3;
|
int batchSize = 3;
|
||||||
INDArray inEmbedding = Nd4j.create(batchSize, 1);
|
INDArray inEmbedding = Nd4j.create(batchSize, 1);
|
||||||
|
@ -349,7 +349,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
// System.out.println(net.score() + "\t" + net2.score());
|
// System.out.println(net.score() + "\t" + net2.score());
|
||||||
assertEquals(net2.score(), net.score(), 1e-6);
|
assertEquals(net2.getScore(), net.getScore(), 1e-6);
|
||||||
|
|
||||||
Map<String, INDArray> gradient = net.gradient().gradientForVariable();
|
Map<String, INDArray> gradient = net.gradient().gradientForVariable();
|
||||||
Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
|
Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
|
||||||
|
@ -395,7 +395,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
INDArray inEmbedding = Nd4j.create(batchSize, 1, timeSeriesLength);
|
INDArray inEmbedding = Nd4j.create(batchSize, 1, timeSeriesLength);
|
||||||
INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, timeSeriesLength);
|
INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, timeSeriesLength);
|
||||||
|
@ -422,7 +422,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
// System.out.println(net.score() + "\t" + net2.score());
|
// System.out.println(net.score() + "\t" + net2.score());
|
||||||
assertEquals(net2.score(), net.score(), 1e-5);
|
assertEquals(net2.getScore(), net.getScore(), 1e-5);
|
||||||
|
|
||||||
Map<String, INDArray> gradient = net.gradient().gradientForVariable();
|
Map<String, INDArray> gradient = net.gradient().gradientForVariable();
|
||||||
Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
|
Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
|
||||||
|
@ -484,7 +484,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength);
|
INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength);
|
||||||
INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength);
|
INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength);
|
||||||
|
@ -523,7 +523,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
// System.out.println(net.score() + "\t" + net2.score());
|
// System.out.println(net.score() + "\t" + net2.score());
|
||||||
assertEquals(net2.score(), net.score(), 1e-5);
|
assertEquals(net2.getScore(), net.getScore(), 1e-5);
|
||||||
|
|
||||||
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
|
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
|
||||||
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
|
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
|
||||||
|
@ -640,7 +640,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength});
|
INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength});
|
||||||
INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength);
|
INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength);
|
||||||
|
@ -678,7 +678,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
assertEquals(net2.score(), net.score(), 1e-5);
|
assertEquals(net2.getScore(), net.getScore(), 1e-5);
|
||||||
|
|
||||||
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
|
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
|
||||||
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
|
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
|
||||||
|
@ -777,9 +777,9 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
|
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
|
||||||
net3.init();
|
net3.init();
|
||||||
|
|
||||||
INDArray p1 = net.params();
|
INDArray p1 = net.getModelParams();
|
||||||
INDArray p2 = net2.params();
|
INDArray p2 = net2.getModelParams();
|
||||||
INDArray p3 = net3.params();
|
INDArray p3 = net3.getModelParams();
|
||||||
boolean eq = p1.equalsWithEps(p2, 1e-4);
|
boolean eq = p1.equalsWithEps(p2, 1e-4);
|
||||||
String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
|
String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
|
||||||
assertTrue(eq, str + " p1/p2 params not equal");
|
assertTrue(eq, str + " p1/p2 params not equal");
|
||||||
|
|
|
@ -514,7 +514,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
net.setListeners(new ScoreIterationListener(100));
|
net.addTrainingListeners(new ScoreIterationListener(100));
|
||||||
|
|
||||||
int nEpochs = 1000;
|
int nEpochs = 1000;
|
||||||
DataSet ds = iter.next();
|
DataSet ds = iter.next();
|
||||||
|
|
|
@ -79,13 +79,13 @@ public class OCNNOutputLayerTest extends BaseDL4JTest {
|
||||||
if (doLearningFirst) {
|
if (doLearningFirst) {
|
||||||
//Run a number of iterations of learning
|
//Run a number of iterations of learning
|
||||||
network.setInput(arr);
|
network.setInput(arr);
|
||||||
network.setListeners(new ScoreIterationListener(1));
|
network.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
network.computeGradientAndScore();
|
network.computeGradientAndScore();
|
||||||
double scoreBefore = network.score();
|
double scoreBefore = network.getScore();
|
||||||
for (int j = 0; j < 10; j++)
|
for (int j = 0; j < 10; j++)
|
||||||
network.fit(ds);
|
network.fit(ds);
|
||||||
network.computeGradientAndScore();
|
network.computeGradientAndScore();
|
||||||
double scoreAfter = network.score();
|
double scoreAfter = network.getScore();
|
||||||
//Can't test in 'characteristic mode of operation' if not learning
|
//Can't test in 'characteristic mode of operation' if not learning
|
||||||
String msg = "testLayer() - score did not (sufficiently) decrease during learning - activationFn="
|
String msg = "testLayer() - score did not (sufficiently) decrease during learning - activationFn="
|
||||||
+ "relu" + ", lossFn=" + "ocnn" + ", " + "sigmoid"
|
+ "relu" + ", lossFn=" + "ocnn" + ", " + "sigmoid"
|
||||||
|
@ -147,7 +147,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest {
|
||||||
tmpFile.deleteOnExit();
|
tmpFile.deleteOnExit();
|
||||||
|
|
||||||
MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(tmpFile);
|
MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(tmpFile);
|
||||||
assertEquals(network.params(),multiLayerNetwork.params());
|
assertEquals(network.getModelParams(),multiLayerNetwork.getModelParams());
|
||||||
assertEquals(network.numParams(),multiLayerNetwork.numParams());
|
assertEquals(network.numParams(),multiLayerNetwork.numParams());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -187,7 +187,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork network = new MultiLayerNetwork(configuration);
|
MultiLayerNetwork network = new MultiLayerNetwork(configuration);
|
||||||
network.init();
|
network.init();
|
||||||
network.setListeners(new ScoreIterationListener(1));
|
network.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
return network;
|
return network;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -124,7 +124,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
assertEquals(n1, n2);
|
assertEquals(n1, n2);
|
||||||
}
|
}
|
||||||
|
|
||||||
net2.setParams(net1.params()); //Assuming exact same layout here...
|
net2.setParams(net1.getModelParams()); //Assuming exact same layout here...
|
||||||
|
|
||||||
INDArray in;
|
INDArray in;
|
||||||
if (rnnDataFormat == NCW){
|
if (rnnDataFormat == NCW){
|
||||||
|
@ -154,7 +154,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
//Ensure scores are equal:
|
//Ensure scores are equal:
|
||||||
assertEquals(net1.score(), net2.score(), 1e-6);
|
assertEquals(net1.getScore(), net2.getScore(), 1e-6);
|
||||||
|
|
||||||
//Ensure gradients are equal:
|
//Ensure gradients are equal:
|
||||||
Gradient g1 = net1.gradient();
|
Gradient g1 = net1.gradient();
|
||||||
|
@ -174,8 +174,8 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
net1.fit(in, labels);
|
net1.fit(in, labels);
|
||||||
net2.fit(in, labels);
|
net2.fit(in, labels);
|
||||||
|
|
||||||
INDArray p1 = net1.params();
|
INDArray p1 = net1.getModelParams();
|
||||||
INDArray p2 = net2.params();
|
INDArray p2 = net2.getModelParams();
|
||||||
assertEquals(p1, p2);
|
assertEquals(p1, p2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -232,7 +232,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
assertEquals(n1, n2);
|
assertEquals(n1, n2);
|
||||||
}
|
}
|
||||||
|
|
||||||
net2.setParams(net1.params()); //Assuming exact same layout here...
|
net2.setParams(net1.getModelParams()); //Assuming exact same layout here...
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(3, 10, 5);
|
INDArray in = Nd4j.rand(3, 10, 5);
|
||||||
|
|
||||||
|
@ -253,7 +253,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
//Ensure scores are equal:
|
//Ensure scores are equal:
|
||||||
assertEquals(net1.score(), net2.score(), 1e-6);
|
assertEquals(net1.getScore(), net2.getScore(), 1e-6);
|
||||||
|
|
||||||
//Ensure gradients are equal:
|
//Ensure gradients are equal:
|
||||||
Gradient g1 = net1.gradient();
|
Gradient g1 = net1.gradient();
|
||||||
|
@ -273,8 +273,8 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
net1.fit(new DataSet(in, labels));
|
net1.fit(new DataSet(in, labels));
|
||||||
net2.fit(new DataSet(in, labels));
|
net2.fit(new DataSet(in, labels));
|
||||||
|
|
||||||
INDArray p1 = net1.params();
|
INDArray p1 = net1.getModelParams();
|
||||||
INDArray p2 = net2.params();
|
INDArray p2 = net2.getModelParams();
|
||||||
assertEquals(p1, p2);
|
assertEquals(p1, p2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -340,7 +340,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
net1.computeGradientAndScore();
|
net1.computeGradientAndScore();
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
assertEquals(net1.score(), net2.score(), 1e-6);
|
assertEquals(net1.getScore(), net2.getScore(), 1e-6);
|
||||||
assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
|
assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -403,7 +403,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
net1.computeGradientAndScore();
|
net1.computeGradientAndScore();
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
assertEquals(net1.score(), net2.score(), 1e-6);
|
assertEquals(net1.getScore(), net2.getScore(), 1e-6);
|
||||||
assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
|
assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -277,7 +277,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
|
|
||||||
final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
|
final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
params = bidirectionalLSTM.params();
|
params = bidirectionalLSTM.getModelParams();
|
||||||
|
|
||||||
bidirectionalLSTM.setParamsTable(params);
|
bidirectionalLSTM.setParamsTable(params);
|
||||||
|
|
||||||
|
|
|
@ -285,9 +285,9 @@ public class RnnDataFormatTests extends BaseDL4JTest {
|
||||||
|
|
||||||
public static void testHelper(TestCase tc) {
|
public static void testHelper(TestCase tc) {
|
||||||
|
|
||||||
tc.net2.params().assign(tc.net1.params());
|
tc.net2.getModelParams().assign(tc.net1.getModelParams());
|
||||||
tc.net3.params().assign(tc.net1.params());
|
tc.net3.getModelParams().assign(tc.net1.getModelParams());
|
||||||
tc.net4.params().assign(tc.net1.params());
|
tc.net4.getModelParams().assign(tc.net1.getModelParams());
|
||||||
|
|
||||||
INDArray inNCW = tc.inNCW;
|
INDArray inNCW = tc.inNCW;
|
||||||
INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup();
|
INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup();
|
||||||
|
@ -352,9 +352,9 @@ public class RnnDataFormatTests extends BaseDL4JTest {
|
||||||
tc.net3.fit(inNWC, tc.labelsNWC);
|
tc.net3.fit(inNWC, tc.labelsNWC);
|
||||||
tc.net4.fit(inNWC, tc.labelsNWC);
|
tc.net4.fit(inNWC, tc.labelsNWC);
|
||||||
|
|
||||||
assertEquals(tc.net1.params(), tc.net2.params(), tc.msg);
|
assertEquals(tc.net1.getModelParams(), tc.net2.getModelParams(), tc.msg);
|
||||||
assertEquals(tc.net1.params(), tc.net3.params(), tc.msg);
|
assertEquals(tc.net1.getModelParams(), tc.net3.getModelParams(), tc.msg);
|
||||||
assertEquals(tc.net1.params(), tc.net4.params(), tc.msg);
|
assertEquals(tc.net1.getModelParams(), tc.net4.getModelParams(), tc.msg);
|
||||||
|
|
||||||
//Test serialization
|
//Test serialization
|
||||||
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
|
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
|
||||||
|
|
|
@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.recurrent;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.dropout.TestDropout;
|
import org.deeplearning4j.nn.conf.dropout.TestDropout;
|
||||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||||
|
@ -173,8 +172,8 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
MultiLayerNetwork netD2 = new MultiLayerNetwork(confD2);
|
MultiLayerNetwork netD2 = new MultiLayerNetwork(confD2);
|
||||||
netD2.init();
|
netD2.init();
|
||||||
|
|
||||||
assertEquals(net.params(), netD.params(), s);
|
assertEquals(net.getModelParams(), netD.getModelParams(), s);
|
||||||
assertEquals(net.params(), netD2.params(), s);
|
assertEquals(net.getModelParams(), netD2.getModelParams(), s);
|
||||||
|
|
||||||
INDArray f = Nd4j.rand(DataType.FLOAT, 3, 10, 10);
|
INDArray f = Nd4j.rand(DataType.FLOAT, 3, 10, 10);
|
||||||
|
|
||||||
|
@ -193,7 +192,7 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
|
INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
|
||||||
net.fit(f.dup(), l);
|
net.fit(f.dup(), l);
|
||||||
netD.fit(f.dup(), l);
|
netD.fit(f.dup(), l);
|
||||||
assertNotEquals(net.params(), netD.params(), s);
|
assertNotEquals(net.getModelParams(), netD.getModelParams(), s);
|
||||||
|
|
||||||
netD2.fit(f.dup(), l);
|
netD2.fit(f.dup(), l);
|
||||||
netD2.fit(f.dup(), l);
|
netD2.fit(f.dup(), l);
|
||||||
|
|
|
@ -115,7 +115,7 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
||||||
net1.fit(ds);
|
net1.fit(ds);
|
||||||
net2.fit(ds);
|
net2.fit(ds);
|
||||||
|
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
|
|
||||||
MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2);
|
MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2);
|
||||||
out2 = net2.output(in);
|
out2 = net2.output(in);
|
||||||
|
|
|
@ -124,10 +124,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
net.params().assign(net2.params());
|
net.getModelParams().assign(net2.getModelParams());
|
||||||
|
|
||||||
//Check params:
|
//Check params:
|
||||||
assertEquals(net2.params(), net.params());
|
assertEquals(net2.getModelParams(), net.getModelParams());
|
||||||
Map<String, INDArray> params1 = net.getParamTable();
|
Map<String, INDArray> params1 = net.getParamTable();
|
||||||
Map<String, INDArray> params2 = net2.getParamTable();
|
Map<String, INDArray> params2 = net2.getParamTable();
|
||||||
assertEquals(params2, params1);
|
assertEquals(params2, params1);
|
||||||
|
@ -209,10 +209,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
assertEquals(net2.params(), net.params());
|
assertEquals(net2.getModelParams(), net.getModelParams());
|
||||||
|
|
||||||
//Check params:
|
//Check params:
|
||||||
assertEquals(net2.params(), net.params());
|
assertEquals(net2.getModelParams(), net.getModelParams());
|
||||||
Map<String, INDArray> params1 = net.getParamTable();
|
Map<String, INDArray> params1 = net.getParamTable();
|
||||||
Map<String, INDArray> params2 = net2.getParamTable();
|
Map<String, INDArray> params2 = net2.getParamTable();
|
||||||
assertEquals(params2, params1);
|
assertEquals(params2, params1);
|
||||||
|
@ -287,10 +287,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2);
|
||||||
netStandard.init();
|
netStandard.init();
|
||||||
|
|
||||||
netSD.params().assign(netStandard.params());
|
netSD.getModelParams().assign(netStandard.getModelParams());
|
||||||
|
|
||||||
//Check params:
|
//Check params:
|
||||||
assertEquals(netStandard.params(), netSD.params());
|
assertEquals(netStandard.getModelParams(), netSD.getModelParams());
|
||||||
assertEquals(netStandard.getParamTable(), netSD.getParamTable());
|
assertEquals(netStandard.getParamTable(), netSD.getParamTable());
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(minibatch, nIn);
|
INDArray in = Nd4j.rand(minibatch, nIn);
|
||||||
|
@ -379,10 +379,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2);
|
||||||
netStandard.init();
|
netStandard.init();
|
||||||
|
|
||||||
netSD.params().assign(netStandard.params());
|
netSD.getModelParams().assign(netStandard.getModelParams());
|
||||||
|
|
||||||
//Check params:
|
//Check params:
|
||||||
assertEquals(netStandard.params(), netSD.params());
|
assertEquals(netStandard.getModelParams(), netSD.getModelParams());
|
||||||
assertEquals(netStandard.getParamTable(), netSD.getParamTable());
|
assertEquals(netStandard.getParamTable(), netSD.getParamTable());
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
@ -398,7 +398,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
netStandard.fit(ds);
|
netStandard.fit(ds);
|
||||||
String s = String.valueOf(i);
|
String s = String.valueOf(i);
|
||||||
assertEquals( netStandard.getFlattenedGradients(), netSD.getFlattenedGradients(), s);
|
assertEquals( netStandard.getFlattenedGradients(), netSD.getFlattenedGradients(), s);
|
||||||
assertEquals( netStandard.params(), netSD.params(), s);
|
assertEquals( netStandard.getModelParams(), netSD.getModelParams(), s);
|
||||||
assertEquals( netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray(), s);
|
assertEquals( netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray(), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -100,10 +100,10 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
|
||||||
ComputationGraph netStandard = new ComputationGraph(conf2);
|
ComputationGraph netStandard = new ComputationGraph(conf2);
|
||||||
netStandard.init();
|
netStandard.init();
|
||||||
|
|
||||||
netSD.params().assign(netStandard.params());
|
netSD.getModelParams().assign(netStandard.getModelParams());
|
||||||
|
|
||||||
//Check params:
|
//Check params:
|
||||||
assertEquals(netStandard.params(), netSD.params());
|
assertEquals(netStandard.getModelParams(), netSD.getModelParams());
|
||||||
assertEquals(netStandard.getParamTable(), netSD.getParamTable());
|
assertEquals(netStandard.getParamTable(), netSD.getParamTable());
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(minibatch, nIn);
|
INDArray in = Nd4j.rand(minibatch, nIn);
|
||||||
|
@ -160,7 +160,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
|
||||||
netStandard.fit(ds);
|
netStandard.fit(ds);
|
||||||
|
|
||||||
assertEquals(netStandard.getParamTable(), netSD.getParamTable());
|
assertEquals(netStandard.getParamTable(), netSD.getParamTable());
|
||||||
assertEquals(netStandard.params(), netSD.params());
|
assertEquals(netStandard.getModelParams(), netSD.getModelParams());
|
||||||
assertEquals(netStandard.getFlattenedGradients(), netSD.getFlattenedGradients());
|
assertEquals(netStandard.getFlattenedGradients(), netSD.getFlattenedGradients());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -98,7 +98,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
ComputationGraph std = new ComputationGraph(confStd);
|
ComputationGraph std = new ComputationGraph(confStd);
|
||||||
std.init();
|
std.init();
|
||||||
|
|
||||||
lambda.setParams(std.params());
|
lambda.setParams(std.getModelParams());
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(3, 5);
|
INDArray in = Nd4j.rand(3, 5);
|
||||||
INDArray labels = TestUtils.randomOneHot(3, 5);
|
INDArray labels = TestUtils.randomOneHot(3, 5);
|
||||||
|
@ -119,7 +119,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
std.fit(ds);
|
std.fit(ds);
|
||||||
|
|
||||||
String s = String.valueOf(i);
|
String s = String.valueOf(i);
|
||||||
assertEquals(std.params(), lambda.params(), s);
|
assertEquals(std.getModelParams(), lambda.getModelParams(), s);
|
||||||
assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s);
|
assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
ComputationGraph std = new ComputationGraph(confStd);
|
ComputationGraph std = new ComputationGraph(confStd);
|
||||||
std.init();
|
std.init();
|
||||||
|
|
||||||
lambda.setParams(std.params());
|
lambda.setParams(std.getModelParams());
|
||||||
|
|
||||||
INDArray in1 = Nd4j.rand(3, 5);
|
INDArray in1 = Nd4j.rand(3, 5);
|
||||||
INDArray in2 = Nd4j.rand(3, 5);
|
INDArray in2 = Nd4j.rand(3, 5);
|
||||||
|
@ -204,7 +204,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
std.fit(mds);
|
std.fit(mds);
|
||||||
|
|
||||||
String s = String.valueOf(i);
|
String s = String.valueOf(i);
|
||||||
assertEquals(std.params(), lambda.params(), s);
|
assertEquals(std.getModelParams(), lambda.getModelParams(), s);
|
||||||
assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s);
|
assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -85,7 +85,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
|
||||||
netSD.fit(ds);
|
netSD.fit(ds);
|
||||||
netStd.fit(ds);
|
netStd.fit(ds);
|
||||||
|
|
||||||
assertEquals(netStd.params(), netSD.params());
|
assertEquals(netStd.getModelParams(), netSD.getModelParams());
|
||||||
assertEquals(netStd.getFlattenedGradients(), netSD.getFlattenedGradients());
|
assertEquals(netStd.getFlattenedGradients(), netSD.getFlattenedGradients());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
|
||||||
MultiLayerNetwork netStd = new MultiLayerNetwork(confStd);
|
MultiLayerNetwork netStd = new MultiLayerNetwork(confStd);
|
||||||
netStd.init();
|
netStd.init();
|
||||||
|
|
||||||
netSD.params().assign(netStd.params());
|
netSD.getModelParams().assign(netStd.getModelParams());
|
||||||
|
|
||||||
assertEquals(netStd.getParamTable(), netSD.getParamTable());
|
assertEquals(netStd.getParamTable(), netSD.getParamTable());
|
||||||
|
|
||||||
|
@ -165,7 +165,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
|
||||||
netSD.fit(ds);
|
netSD.fit(ds);
|
||||||
netStd.fit(ds);
|
netStd.fit(ds);
|
||||||
String s = String.valueOf(i);
|
String s = String.valueOf(i);
|
||||||
assertEquals( netStd.params(), netSD.params(), s);
|
assertEquals( netStd.getModelParams(), netSD.getModelParams(), s);
|
||||||
assertEquals( netStd.getFlattenedGradients(), netSD.getFlattenedGradients(),s );
|
assertEquals( netStd.getFlattenedGradients(), netSD.getFlattenedGradients(),s );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ public class TestVAE extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
System.out.println("Exp num params: " + expNumParams);
|
System.out.println("Exp num params: " + expNumParams);
|
||||||
assertEquals(expNumParams, net.getLayer(0).params().length());
|
assertEquals(expNumParams, net.getLayer(0).getParams().length());
|
||||||
Map<String, INDArray> paramTable = net.getLayer(0).getParamTable();
|
Map<String, INDArray> paramTable = net.getLayer(0).getParamTable();
|
||||||
int count = 0;
|
int count = 0;
|
||||||
for (INDArray arr : paramTable.values()) {
|
for (INDArray arr : paramTable.values()) {
|
||||||
|
|
|
@ -79,7 +79,7 @@ public class CloseNetworkTests extends BaseDL4JTest {
|
||||||
|
|
||||||
net.close();
|
net.close();
|
||||||
|
|
||||||
assertTrue(net.params().wasClosed());
|
assertTrue(net.getModelParams().wasClosed());
|
||||||
if(train) {
|
if(train) {
|
||||||
assertTrue(net.getGradientsViewArray().wasClosed());
|
assertTrue(net.getGradientsViewArray().wasClosed());
|
||||||
Updater u = net.getUpdater(false);
|
Updater u = net.getUpdater(false);
|
||||||
|
@ -127,7 +127,7 @@ public class CloseNetworkTests extends BaseDL4JTest {
|
||||||
|
|
||||||
net.close();
|
net.close();
|
||||||
|
|
||||||
assertTrue(net.params().wasClosed());
|
assertTrue(net.getModelParams().wasClosed());
|
||||||
if(train) {
|
if(train) {
|
||||||
assertTrue(net.getGradientsViewArray().wasClosed());
|
assertTrue(net.getGradientsViewArray().wasClosed());
|
||||||
Updater u = net.getUpdater(false);
|
Updater u = net.getUpdater(false);
|
||||||
|
|
|
@ -57,7 +57,7 @@ public class LargeNetTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
INDArray params = net.params();
|
INDArray params = net.getModelParams();
|
||||||
long paramsLength = params.length();
|
long paramsLength = params.length();
|
||||||
long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10;
|
long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10;
|
||||||
assertEquals(expParamsLength, paramsLength);
|
assertEquals(expParamsLength, paramsLength);
|
||||||
|
@ -91,7 +91,7 @@ public class LargeNetTest extends BaseDL4JTest {
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
INDArray params = net.params();
|
INDArray params = net.getModelParams();
|
||||||
long paramsLength = params.length();
|
long paramsLength = params.length();
|
||||||
long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10;
|
long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10;
|
||||||
assertEquals(expParamsLength, paramsLength);
|
assertEquals(expParamsLength, paramsLength);
|
||||||
|
|
|
@ -76,7 +76,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net2.init();
|
net2.init();
|
||||||
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
||||||
conf2.setIterationCount(conf.getIterationCount());
|
conf2.setIterationCount(conf.getIterationCount());
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
assertEquals(0.1, net.getLearningRate(0).doubleValue(), 0.0);
|
assertEquals(0.1, net.getLearningRate(0).doubleValue(), 0.0);
|
||||||
net.setLearningRate(0, 0.5); //Set LR for layer 0 to 0.5
|
net.setLearningRate(0, 0.5); //Set LR for layer 0 to 0.5
|
||||||
|
@ -96,7 +96,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net2.fit(in, l);
|
net2.fit(in, l);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(net.params(), net2.params());
|
assertEquals(net.getModelParams(), net2.getModelParams());
|
||||||
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
|
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
|
||||||
|
|
||||||
INDArray in1 = Nd4j.rand(10, 10);
|
INDArray in1 = Nd4j.rand(10, 10);
|
||||||
|
@ -110,7 +110,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net2.setLabels(l1);
|
net2.setLabels(l1);
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
assertEquals(net.score(), net2.score(), 1e-8);
|
assertEquals(net.getScore(), net2.getScore(), 1e-8);
|
||||||
|
|
||||||
|
|
||||||
//Now: Set *all* LRs to say 0.3...
|
//Now: Set *all* LRs to say 0.3...
|
||||||
|
@ -126,7 +126,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net3.init();
|
net3.init();
|
||||||
net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
||||||
conf3.setIterationCount(conf.getIterationCount());
|
conf3.setIterationCount(conf.getIterationCount());
|
||||||
net3.setParams(net.params().dup());
|
net3.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
net.setLearningRate(0.3);
|
net.setLearningRate(0.3);
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net3.fit(in, l);
|
net3.fit(in, l);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(net.params(), net3.params());
|
assertEquals(net.getModelParams(), net3.getModelParams());
|
||||||
assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
|
assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -206,7 +206,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net2.init();
|
net2.init();
|
||||||
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
||||||
conf2.setIterationCount(conf.getIterationCount());
|
conf2.setIterationCount(conf.getIterationCount());
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5
|
net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5
|
||||||
|
|
||||||
|
@ -224,7 +224,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net2.fit(in, l);
|
net2.fit(in, l);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(net.params(), net2.params());
|
assertEquals(net.getModelParams(), net2.getModelParams());
|
||||||
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
|
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -270,7 +270,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net2.init();
|
net2.init();
|
||||||
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
||||||
conf2.setIterationCount(conf.getIterationCount());
|
conf2.setIterationCount(conf.getIterationCount());
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
assertEquals(0.1, net.getLearningRate("0").doubleValue(), 0.0);
|
assertEquals(0.1, net.getLearningRate("0").doubleValue(), 0.0);
|
||||||
net.setLearningRate("0", 0.5); //Set LR for layer 0 to 0.5
|
net.setLearningRate("0", 0.5); //Set LR for layer 0 to 0.5
|
||||||
|
@ -290,7 +290,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net2.fit(new DataSet(in, l));
|
net2.fit(new DataSet(in, l));
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(net.params(), net2.params());
|
assertEquals(net.getModelParams(), net2.getModelParams());
|
||||||
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
|
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
|
||||||
|
|
||||||
INDArray in1 = Nd4j.rand(10, 10);
|
INDArray in1 = Nd4j.rand(10, 10);
|
||||||
|
@ -304,7 +304,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net2.setLabels(l1);
|
net2.setLabels(l1);
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
assertEquals(net.score(), net2.score(), 1e-8);
|
assertEquals(net.getScore(), net2.getScore(), 1e-8);
|
||||||
|
|
||||||
|
|
||||||
//Now: Set *all* LRs to say 0.3...
|
//Now: Set *all* LRs to say 0.3...
|
||||||
|
@ -320,7 +320,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net3.init();
|
net3.init();
|
||||||
net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
||||||
conf3.setIterationCount(conf.getIterationCount());
|
conf3.setIterationCount(conf.getIterationCount());
|
||||||
net3.setParams(net.params().dup());
|
net3.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
net.setLearningRate(0.3);
|
net.setLearningRate(0.3);
|
||||||
|
|
||||||
|
@ -333,7 +333,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net3.fit(new DataSet(in, l));
|
net3.fit(new DataSet(in, l));
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(net.params(), net3.params());
|
assertEquals(net.getModelParams(), net3.getModelParams());
|
||||||
assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
|
assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -375,7 +375,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net2.init();
|
net2.init();
|
||||||
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
|
||||||
conf2.setIterationCount(conf.getIterationCount());
|
conf2.setIterationCount(conf.getIterationCount());
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.getModelParams().dup());
|
||||||
|
|
||||||
net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5
|
net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5
|
||||||
|
|
||||||
|
@ -393,7 +393,7 @@ public class TestLrChanges extends BaseDL4JTest {
|
||||||
net2.fit(new DataSet(in, l));
|
net2.fit(new DataSet(in, l));
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(net.params(), net2.params());
|
assertEquals(net.getModelParams(), net2.getModelParams());
|
||||||
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
|
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -77,14 +77,14 @@ public class TestNetConversion extends BaseDL4JTest {
|
||||||
n.computeGradientAndScore();
|
n.computeGradientAndScore();
|
||||||
cg.computeGradientAndScore();
|
cg.computeGradientAndScore();
|
||||||
|
|
||||||
assertEquals(n.score(), cg.score(), 1e-6);
|
assertEquals(n.getScore(), cg.getScore(), 1e-6);
|
||||||
|
|
||||||
assertEquals(n.gradient().gradient(), cg.gradient().gradient());
|
assertEquals(n.gradient().gradient(), cg.gradient().gradient());
|
||||||
|
|
||||||
n.fit(in, labels);
|
n.fit(in, labels);
|
||||||
cg.fit(new INDArray[]{in}, new INDArray[]{labels});
|
cg.fit(new INDArray[]{in}, new INDArray[]{labels});
|
||||||
|
|
||||||
assertEquals(n.params(), cg.params());
|
assertEquals(n.getModelParams(), cg.getModelParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -476,7 +476,7 @@ public class WorkspaceTests extends BaseDL4JTest {
|
||||||
|
|
||||||
final ComputationGraph computationGraph = new ComputationGraph(config);
|
final ComputationGraph computationGraph = new ComputationGraph(config);
|
||||||
computationGraph.init();
|
computationGraph.init();
|
||||||
computationGraph.setListeners(new ScoreIterationListener(3));
|
computationGraph.addTrainingListeners(new ScoreIterationListener(3));
|
||||||
|
|
||||||
WSTestDataSetIterator iterator = new WSTestDataSetIterator();
|
WSTestDataSetIterator iterator = new WSTestDataSetIterator();
|
||||||
computationGraph.fit(iterator);
|
computationGraph.fit(iterator);
|
||||||
|
|
|
@ -54,7 +54,7 @@ public class BackPropMLPTest extends BaseDL4JTest {
|
||||||
public void testMLPTrivial() {
|
public void testMLPTrivial() {
|
||||||
//Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1.
|
//Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1.
|
||||||
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] {1}, Activation.SIGMOID));
|
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] {1}, Activation.SIGMOID));
|
||||||
network.setListeners(new ScoreIterationListener(1));
|
network.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
network.init();
|
network.init();
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(1, 10);
|
DataSetIterator iter = new IrisDataSetIterator(1, 10);
|
||||||
|
|
|
@ -64,7 +64,7 @@ import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.AutoEncoder;
|
import org.deeplearning4j.nn.conf.layers.AutoEncoder;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
@ -184,13 +184,13 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork network3 = new MultiLayerNetwork(conf);
|
MultiLayerNetwork network3 = new MultiLayerNetwork(conf);
|
||||||
network3.init();
|
network3.init();
|
||||||
|
|
||||||
INDArray params = network3.params();
|
INDArray params = network3.getModelParams();
|
||||||
INDArray weights = network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY).dup();
|
INDArray weights = network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY).dup();
|
||||||
INDArray bias = network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY).dup();
|
INDArray bias = network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY).dup();
|
||||||
network3.setParameters(params);
|
network3.setParameters(params);
|
||||||
assertEquals(weights, network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY));
|
assertEquals(weights, network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY));
|
||||||
assertEquals(bias, network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY));
|
assertEquals(bias, network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY));
|
||||||
INDArray params4 = network3.params();
|
INDArray params4 = network3.getModelParams();
|
||||||
assertEquals(params, params4);
|
assertEquals(params, params4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,7 +211,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork network = new MultiLayerNetwork(conf);
|
MultiLayerNetwork network = new MultiLayerNetwork(conf);
|
||||||
network.init();
|
network.init();
|
||||||
network.setListeners(new ScoreIterationListener(1));
|
network.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
|
@ -242,7 +242,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork network = new MultiLayerNetwork(conf);
|
MultiLayerNetwork network = new MultiLayerNetwork(conf);
|
||||||
network.init();
|
network.init();
|
||||||
network.setListeners(new ScoreIterationListener(1));
|
network.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
|
@ -330,7 +330,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork model = new MultiLayerNetwork(conf);
|
MultiLayerNetwork model = new MultiLayerNetwork(conf);
|
||||||
model.init();
|
model.init();
|
||||||
|
|
||||||
model.addListeners(new ScoreIterationListener(listenerFreq));
|
model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
|
||||||
|
|
||||||
log.info("Train model....");
|
log.info("Train model....");
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
@ -503,7 +503,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(layerNameList.get(0), net.getLayer(0).getLayerConfiguration().getLayerName());
|
assertEquals(layerNameList.get(0), net.getLayer(0).getLayerConfiguration().getLayerName());
|
||||||
assertEquals(layerNameList, net.getLayerNames());
|
assertEquals(layerNameList, net.getLayerNames());
|
||||||
BaseLayer b = (BaseLayer) net.getLayer(layerNameList.get(2)).getLayerConfiguration();
|
BaseLayerConfiguration b = (BaseLayerConfiguration) net.getLayer(layerNameList.get(2)).getLayerConfiguration();
|
||||||
assertEquals("softmax", b.getActivationFn().toString());
|
assertEquals("softmax", b.getActivationFn().toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -535,7 +535,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork netNoReg = new MultiLayerNetwork(confNoReg);
|
MultiLayerNetwork netNoReg = new MultiLayerNetwork(confNoReg);
|
||||||
netNoReg.init();
|
netNoReg.init();
|
||||||
netNoReg.setParameters(net.params().dup());
|
netNoReg.setParameters(net.getModelParams().dup());
|
||||||
|
|
||||||
//Score single example, and compare to scoreExamples:
|
//Score single example, and compare to scoreExamples:
|
||||||
INDArray input = Nd4j.rand(3, nIn);
|
INDArray input = Nd4j.rand(3, nIn);
|
||||||
|
@ -703,7 +703,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
net.fit(iter.next());
|
net.fit(iter.next());
|
||||||
// TODO validate actual layer gradientView - issue getting var out of BaseLayer w/o adding MLN getter that gets confused with local gradient vars
|
// TODO validate actual layer gradientView - issue getting var out of BaseLayerConfiguration w/o adding MLN getter that gets confused with local gradient vars
|
||||||
Gradient actualGradient = net.gradient;
|
Gradient actualGradient = net.gradient;
|
||||||
assertNotEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W"));
|
assertNotEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W"));
|
||||||
|
|
||||||
|
@ -716,13 +716,13 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
net.setParam("0_b", Nd4j.ones(1, 5));
|
net.setParam("0_b", Nd4j.ones(1, 5));
|
||||||
net.setParam("1_W", Nd4j.ones(5, 3));
|
net.setParam("1_W", Nd4j.ones(5, 3));
|
||||||
net.setParam("1_b", Nd4j.ones(1, 3));
|
net.setParam("1_b", Nd4j.ones(1, 3));
|
||||||
INDArray actualParams = net.params();
|
INDArray actualParams = net.getModelParams();
|
||||||
|
|
||||||
// Confirm params
|
// Confirm params
|
||||||
assertEquals(expectedGradient.gradient(), actualParams);
|
assertEquals(expectedGradient.gradient(), actualParams);
|
||||||
|
|
||||||
net.update(expectedGradient);
|
net.update(expectedGradient);
|
||||||
actualParams = net.params();
|
actualParams = net.getModelParams();
|
||||||
assertEquals(Nd4j.ones(1, 43).addi(1), actualParams);
|
assertEquals(Nd4j.ones(1, 43).addi(1), actualParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -762,7 +762,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork aePre = getAeModel(true, nIn, nOut);
|
MultiLayerNetwork aePre = getAeModel(true, nIn, nOut);
|
||||||
int actualNP = (int) aePre.numParams();
|
int actualNP = (int) aePre.numParams();
|
||||||
assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP);
|
assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP);
|
||||||
INDArray params = aePre.params();
|
INDArray params = aePre.getModelParams();
|
||||||
assertEquals(params.length(), actualNP); // check num params
|
assertEquals(params.length(), actualNP); // check num params
|
||||||
Map<String, INDArray> paramTable = aePre.getParamTable();
|
Map<String, INDArray> paramTable = aePre.getParamTable();
|
||||||
assertTrue(paramTable.containsKey("0_vb")); // check vb exists for pretrain layer
|
assertTrue(paramTable.containsKey("0_vb")); // check vb exists for pretrain layer
|
||||||
|
@ -774,7 +774,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork aeNoPre = getAeModel(false, nIn, nOut);
|
MultiLayerNetwork aeNoPre = getAeModel(false, nIn, nOut);
|
||||||
actualNP = (int) aeNoPre.numParams();
|
actualNP = (int) aeNoPre.numParams();
|
||||||
assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP);
|
assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP);
|
||||||
params = aeNoPre.params();
|
params = aeNoPre.getModelParams();
|
||||||
assertEquals(params.length(), actualNP);
|
assertEquals(params.length(), actualNP);
|
||||||
paramTable = aePre.getParamTable();
|
paramTable = aePre.getParamTable();
|
||||||
assertTrue(paramTable.containsKey("0_vb"));
|
assertTrue(paramTable.containsKey("0_vb"));
|
||||||
|
@ -865,14 +865,14 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
BaseLayer bl0 = (BaseLayer) net2.getLayer(0).getLayerConfiguration();
|
BaseLayerConfiguration bl0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(0.1, TestUtils.getL1(bl0.getRegularizationBias()), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(bl0.getRegularizationBias()), 1e-6);
|
||||||
assertEquals(0.2, TestUtils.getL2(bl0.getRegularizationBias()), 1e-6);
|
assertEquals(0.2, TestUtils.getL2(bl0.getRegularizationBias()), 1e-6);
|
||||||
|
|
||||||
INDArray features = Nd4j.rand(10, 10);
|
INDArray features = Nd4j.rand(10, 10);
|
||||||
INDArray labels = Nd4j.rand(10, 10);
|
INDArray labels = Nd4j.rand(10, 10);
|
||||||
|
|
||||||
net2.setParams(net1.params().dup());
|
net2.setParams(net1.getModelParams().dup());
|
||||||
|
|
||||||
net1.setInput(features);
|
net1.setInput(features);
|
||||||
net1.setLabels(labels);
|
net1.setLabels(labels);
|
||||||
|
@ -888,15 +888,15 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
r = net2.calcRegularizationScore(true);
|
r = net2.calcRegularizationScore(true);
|
||||||
assertEquals(0.0, r, 0.0);
|
assertEquals(0.0, r, 0.0);
|
||||||
|
|
||||||
double s1 = net1.score();
|
double s1 = net1.getScore();
|
||||||
double s2 = net2.score();
|
double s2 = net2.getScore();
|
||||||
assertEquals(s1, s2, 1e-6); //Biases initialized to 0 -> should initially have same score
|
assertEquals(s1, s2, 1e-6); //Biases initialized to 0 -> should initially have same score
|
||||||
|
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
net1.fit(features, labels);
|
net1.fit(features, labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
net2.setParams(net1.params().dup());
|
net2.setParams(net1.getModelParams().dup());
|
||||||
net1.computeGradientAndScore();
|
net1.computeGradientAndScore();
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
|
@ -906,8 +906,8 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
r = net2.calcRegularizationScore(true);
|
r = net2.calcRegularizationScore(true);
|
||||||
assertTrue(r > 0.0);
|
assertTrue(r > 0.0);
|
||||||
|
|
||||||
s1 = net1.score();
|
s1 = net1.getScore();
|
||||||
s2 = net2.score();
|
s2 = net2.getScore();
|
||||||
|
|
||||||
assertNotEquals(s1, s2, 1e-6); //Scores should differ due to bias l1/l2
|
assertNotEquals(s1, s2, 1e-6); //Scores should differ due to bias l1/l2
|
||||||
|
|
||||||
|
@ -1022,11 +1022,11 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
assertNotEquals(net1.params(), net2.params());
|
assertNotEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
assertNotEquals(net1.getParamTable(), net2.getParamTable());
|
assertNotEquals(net1.getParamTable(), net2.getParamTable());
|
||||||
|
|
||||||
net1.setParamTable(net2.getParamTable());
|
net1.setParamTable(net2.getParamTable());
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
assertEquals(net1.getParamTable(), net2.getParamTable());
|
assertEquals(net1.getParamTable(), net2.getParamTable());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1412,7 +1412,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
exp.add(MultiLayerNetwork.class);
|
exp.add(MultiLayerNetwork.class);
|
||||||
|
|
||||||
CheckModelsListener listener = new CheckModelsListener();
|
CheckModelsListener listener = new CheckModelsListener();
|
||||||
net.setListeners(listener);
|
net.addTrainingListeners(listener);
|
||||||
|
|
||||||
INDArray f = Nd4j.create(1, 10);
|
INDArray f = Nd4j.create(1, 10);
|
||||||
INDArray l = Nd4j.create(1, 10);
|
INDArray l = Nd4j.create(1, 10);
|
||||||
|
|
|
@ -753,9 +753,9 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
|
||||||
|
|
||||||
DataSet ds = new DataSet(features, labels, maskArrayInput, maskArrayOutput);
|
DataSet ds = new DataSet(features, labels, maskArrayInput, maskArrayOutput);
|
||||||
|
|
||||||
INDArray initialParams = mln.params().dup();
|
INDArray initialParams = mln.getModelParams().dup();
|
||||||
mln.fit(ds);
|
mln.fit(ds);
|
||||||
INDArray afterParams = mln.params();
|
INDArray afterParams = mln.getModelParams();
|
||||||
assertNotEquals(initialParams, afterParams);
|
assertNotEquals(initialParams, afterParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -172,7 +172,7 @@ public class TestMasking extends BaseDL4JTest {
|
||||||
net.setLabels(labels);
|
net.setLabels(labels);
|
||||||
|
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score1 = net.score();
|
double score1 = net.getScore();
|
||||||
INDArray grad1 = net.gradient().gradient();
|
INDArray grad1 = net.gradient().gradient();
|
||||||
|
|
||||||
//Now: change the label values for the masked steps. The
|
//Now: change the label values for the masked steps. The
|
||||||
|
@ -187,7 +187,7 @@ public class TestMasking extends BaseDL4JTest {
|
||||||
|
|
||||||
assertNotEquals(labels, newLabels);
|
assertNotEquals(labels, newLabels);
|
||||||
|
|
||||||
double score2 = net.score();
|
double score2 = net.getScore();
|
||||||
INDArray grad2 = net.gradient().gradient();
|
INDArray grad2 = net.gradient().gradient();
|
||||||
|
|
||||||
assertEquals(score1, score2, 1e-6);
|
assertEquals(score1, score2, 1e-6);
|
||||||
|
@ -214,7 +214,7 @@ public class TestMasking extends BaseDL4JTest {
|
||||||
graph.setLabels(labels);
|
graph.setLabels(labels);
|
||||||
graph.computeGradientAndScore();
|
graph.computeGradientAndScore();
|
||||||
|
|
||||||
double gScore1 = graph.score();
|
double gScore1 = graph.getScore();
|
||||||
INDArray gGrad1 = graph.gradient().gradient();
|
INDArray gGrad1 = graph.gradient().gradient();
|
||||||
|
|
||||||
graph.setLayerMaskArrays(null, new INDArray[] {labelMask});
|
graph.setLayerMaskArrays(null, new INDArray[] {labelMask});
|
||||||
|
@ -222,7 +222,7 @@ public class TestMasking extends BaseDL4JTest {
|
||||||
graph.setLabels(newLabels);
|
graph.setLabels(newLabels);
|
||||||
graph.computeGradientAndScore();
|
graph.computeGradientAndScore();
|
||||||
|
|
||||||
double gScore2 = graph.score();
|
double gScore2 = graph.getScore();
|
||||||
INDArray gGrad2 = graph.gradient().gradient();
|
INDArray gGrad2 = graph.gradient().gradient();
|
||||||
|
|
||||||
assertEquals(gScore1, gScore2, 1e-6);
|
assertEquals(gScore1, gScore2, 1e-6);
|
||||||
|
|
|
@ -53,12 +53,12 @@ public class TestSetGetParameters extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
INDArray initParams = net.params().dup();
|
INDArray initParams = net.getModelParams().dup();
|
||||||
Map<String, INDArray> initParams2 = net.getParamTable();
|
Map<String, INDArray> initParams2 = net.getParamTable();
|
||||||
|
|
||||||
net.setParams(net.params());
|
net.setParams(net.getModelParams());
|
||||||
|
|
||||||
INDArray initParamsAfter = net.params();
|
INDArray initParamsAfter = net.getModelParams();
|
||||||
Map<String, INDArray> initParams2After = net.getParamTable();
|
Map<String, INDArray> initParams2After = net.getParamTable();
|
||||||
|
|
||||||
for (String s : initParams2.keySet()) {
|
for (String s : initParams2.keySet()) {
|
||||||
|
@ -71,7 +71,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
|
||||||
INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape());
|
INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape());
|
||||||
net.setParams(randomParams.dup());
|
net.setParams(randomParams.dup());
|
||||||
|
|
||||||
assertEquals(net.params(), randomParams);
|
assertEquals(net.getModelParams(), randomParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -90,12 +90,12 @@ public class TestSetGetParameters extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
INDArray initParams = net.params().dup();
|
INDArray initParams = net.getModelParams().dup();
|
||||||
Map<String, INDArray> initParams2 = net.getParamTable();
|
Map<String, INDArray> initParams2 = net.getParamTable();
|
||||||
|
|
||||||
net.setParams(net.params());
|
net.setParams(net.getModelParams());
|
||||||
|
|
||||||
INDArray initParamsAfter = net.params();
|
INDArray initParamsAfter = net.getModelParams();
|
||||||
Map<String, INDArray> initParams2After = net.getParamTable();
|
Map<String, INDArray> initParams2After = net.getParamTable();
|
||||||
|
|
||||||
for (String s : initParams2.keySet()) {
|
for (String s : initParams2.keySet()) {
|
||||||
|
@ -108,7 +108,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
|
||||||
INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape());
|
INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape());
|
||||||
net.setParams(randomParams.dup());
|
net.setParams(randomParams.dup());
|
||||||
|
|
||||||
assertEquals(net.params(), randomParams);
|
assertEquals(net.getModelParams(), randomParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -128,7 +128,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
INDArray params = net.params();
|
INDArray params = net.getModelParams();
|
||||||
|
|
||||||
|
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf);
|
||||||
|
@ -137,11 +137,11 @@ public class TestSetGetParameters extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net3 = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net3 = new MultiLayerNetwork(conf);
|
||||||
net3.init(params, false);
|
net3.init(params, false);
|
||||||
|
|
||||||
assertEquals(params, net2.params());
|
assertEquals(params, net2.getModelParams());
|
||||||
assertEquals(params, net3.params());
|
assertEquals(params, net3.getModelParams());
|
||||||
|
|
||||||
assertNotSame(params, net2.params()); //Different objects due to clone
|
assertNotSame(params, net2.getModelParams()); //Different objects due to clone
|
||||||
assertSame(params, net3.params()); //Same object due to clone
|
assertSame(params, net3.getModelParams()); //Same object due to clone
|
||||||
|
|
||||||
|
|
||||||
Map<String, INDArray> paramsMap = net.getParamTable();
|
Map<String, INDArray> paramsMap = net.getParamTable();
|
||||||
|
|
|
@ -103,14 +103,14 @@ public class TestVariableLengthTS extends BaseDL4JTest {
|
||||||
net.setInput(in1);
|
net.setInput(in1);
|
||||||
net.setLabels(labels1);
|
net.setLabels(labels1);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score1 = net.score();
|
double score1 = net.getScore();
|
||||||
Gradient g1 = net.gradient();
|
Gradient g1 = net.gradient();
|
||||||
|
|
||||||
net.setInput(in2);
|
net.setInput(in2);
|
||||||
net.setLabels(labels2);
|
net.setLabels(labels2);
|
||||||
net.setLayerMaskArrays(null, labelMask);
|
net.setLayerMaskArrays(null, labelMask);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score2 = net.score();
|
double score2 = net.getScore();
|
||||||
Gradient g2 = net.gradient();
|
Gradient g2 = net.gradient();
|
||||||
|
|
||||||
//Scores and gradients should be identical for two cases (given mask array)
|
//Scores and gradients should be identical for two cases (given mask array)
|
||||||
|
@ -134,7 +134,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
net.setLabels(labels2);
|
net.setLabels(labels2);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score2a = net.score();
|
double score2a = net.getScore();
|
||||||
Gradient g2a = net.gradient();
|
Gradient g2a = net.gradient();
|
||||||
assertEquals(score2, score2a, 1e-6);
|
assertEquals(score2, score2a, 1e-6);
|
||||||
for (String s : g2map.keySet()) {
|
for (String s : g2map.keySet()) {
|
||||||
|
@ -196,7 +196,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
|
||||||
net.setInput(in1);
|
net.setInput(in1);
|
||||||
net.setLabels(labels1);
|
net.setLabels(labels1);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score1 = net.score();
|
double score1 = net.getScore();
|
||||||
Gradient g1 = net.gradient();
|
Gradient g1 = net.gradient();
|
||||||
Map<String, INDArray> map1 = g1.gradientForVariable();
|
Map<String, INDArray> map1 = g1.gradientForVariable();
|
||||||
for (String s : map1.keySet()) {
|
for (String s : map1.keySet()) {
|
||||||
|
@ -207,7 +207,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
|
||||||
net.setLabels(labels2);
|
net.setLabels(labels2);
|
||||||
net.setLayerMaskArrays(inputMask, null);
|
net.setLayerMaskArrays(inputMask, null);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score2 = net.score();
|
double score2 = net.getScore();
|
||||||
Gradient g2 = net.gradient();
|
Gradient g2 = net.gradient();
|
||||||
|
|
||||||
net.setInput(in2);
|
net.setInput(in2);
|
||||||
|
@ -240,7 +240,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
|
||||||
net.setInput(in2);
|
net.setInput(in2);
|
||||||
net.setLayerMaskArrays(inputMask, null);
|
net.setLayerMaskArrays(inputMask, null);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
double score2a = net.score();
|
double score2a = net.getScore();
|
||||||
Gradient g2a = net.gradient();
|
Gradient g2a = net.gradient();
|
||||||
assertEquals(score2, score2a, 1e-12);
|
assertEquals(score2, score2a, 1e-12);
|
||||||
for (String s : g2.gradientForVariable().keySet()) {
|
for (String s : g2.gradientForVariable().keySet()) {
|
||||||
|
@ -327,7 +327,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
|
||||||
mln.setLabels(labels);
|
mln.setLabels(labels);
|
||||||
|
|
||||||
mln.computeGradientAndScore();
|
mln.computeGradientAndScore();
|
||||||
double score = mln.score();
|
double score = mln.getScore();
|
||||||
|
|
||||||
assertEquals(expScore, score, 0.1, msg);
|
assertEquals(expScore, score, 0.1, msg);
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,7 +77,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2GradUpd = new MultiLayerNetwork(conf.clone());
|
MultiLayerNetwork net2GradUpd = new MultiLayerNetwork(conf.clone());
|
||||||
net2GradUpd.init();
|
net2GradUpd.init();
|
||||||
|
|
||||||
assertEquals(net1GradCalc.params(), net2GradUpd.params());
|
assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
|
||||||
|
|
||||||
INDArray f = Nd4j.rand(minibatch, nIn);
|
INDArray f = Nd4j.rand(minibatch, nIn);
|
||||||
INDArray l = Nd4j.create(minibatch, nOut);
|
INDArray l = Nd4j.create(minibatch, nOut);
|
||||||
|
@ -109,17 +109,17 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
|
||||||
|
|
||||||
//Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op
|
//Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op
|
||||||
// on the original network
|
// on the original network
|
||||||
net2GradUpd.params().subi(g.gradient());
|
net2GradUpd.getModelParams().subi(g.gradient());
|
||||||
|
|
||||||
net1GradCalc.fit(f, l);
|
net1GradCalc.fit(f, l);
|
||||||
assertEquals(net1GradCalc.params(), net2GradUpd.params());
|
assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
|
||||||
|
|
||||||
|
|
||||||
//=============================
|
//=============================
|
||||||
if (!(u instanceof Sgd)) {
|
if (!(u instanceof Sgd)) {
|
||||||
net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray());
|
net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
assertEquals(net1GradCalc.params(), net2GradUpd.params());
|
assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
|
||||||
assertEquals(net1GradCalc.getUpdater().getStateViewArray(),
|
assertEquals(net1GradCalc.getUpdater().getStateViewArray(),
|
||||||
net2GradUpd.getUpdater().getStateViewArray());
|
net2GradUpd.getUpdater().getStateViewArray());
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
net1GradCalc.fit(f, l);
|
net1GradCalc.fit(f, l);
|
||||||
net2GradUpd.fit(f, l);
|
net2GradUpd.fit(f, l);
|
||||||
assertEquals(net1GradCalc.params(), net2GradUpd.params());
|
assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -169,7 +169,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
|
||||||
ComputationGraph net2GradUpd = new ComputationGraph(conf.clone());
|
ComputationGraph net2GradUpd = new ComputationGraph(conf.clone());
|
||||||
net2GradUpd.init();
|
net2GradUpd.init();
|
||||||
|
|
||||||
assertEquals(net1GradCalc.params(), net2GradUpd.params());
|
assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
|
||||||
|
|
||||||
INDArray f = Nd4j.rand(minibatch, nIn);
|
INDArray f = Nd4j.rand(minibatch, nIn);
|
||||||
INDArray l = Nd4j.create(minibatch, nOut);
|
INDArray l = Nd4j.create(minibatch, nOut);
|
||||||
|
@ -201,16 +201,16 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
|
||||||
|
|
||||||
//Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op
|
//Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op
|
||||||
// on the original network
|
// on the original network
|
||||||
net2GradUpd.params().subi(g.gradient());
|
net2GradUpd.getModelParams().subi(g.gradient());
|
||||||
|
|
||||||
net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l});
|
net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l});
|
||||||
assertEquals(net1GradCalc.params(), net2GradUpd.params());
|
assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
|
||||||
|
|
||||||
//=============================
|
//=============================
|
||||||
if (!(u instanceof Sgd)) {
|
if (!(u instanceof Sgd)) {
|
||||||
net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray());
|
net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
assertEquals(net1GradCalc.params(), net2GradUpd.params());
|
assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
|
||||||
assertEquals(net1GradCalc.getUpdater().getStateViewArray(),
|
assertEquals(net1GradCalc.getUpdater().getStateViewArray(),
|
||||||
net2GradUpd.getUpdater().getStateViewArray());
|
net2GradUpd.getUpdater().getStateViewArray());
|
||||||
|
|
||||||
|
@ -222,7 +222,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l});
|
net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l});
|
||||||
net2GradUpd.fit(new INDArray[] {f}, new INDArray[] {l});
|
net2GradUpd.fit(new INDArray[] {f}, new INDArray[] {l});
|
||||||
assertEquals(net1GradCalc.params(), net2GradUpd.params());
|
assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,6 @@ import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
|
||||||
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
|
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
|
||||||
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
|
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
|
@ -94,7 +93,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph modelToFineTune = new ComputationGraph(expectedConf);
|
ComputationGraph modelToFineTune = new ComputationGraph(expectedConf);
|
||||||
modelToFineTune.init();
|
modelToFineTune.init();
|
||||||
modelToFineTune.setParams(expectedModel.params());
|
modelToFineTune.setParams(expectedModel.getModelParams());
|
||||||
//model after applying changes with transfer learning
|
//model after applying changes with transfer learning
|
||||||
ComputationGraph modelNow =
|
ComputationGraph modelNow =
|
||||||
new TransferLearning.GraphBuilder(modelToFineTune)
|
new TransferLearning.GraphBuilder(modelToFineTune)
|
||||||
|
@ -108,8 +107,8 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
//Check params after fit
|
//Check params after fit
|
||||||
modelNow.fit(randomData);
|
modelNow.fit(randomData);
|
||||||
expectedModel.fit(randomData);
|
expectedModel.fit(randomData);
|
||||||
assertEquals(modelNow.score(), expectedModel.score(), 1e-8);
|
assertEquals(modelNow.getScore(), expectedModel.getScore(), 1e-8);
|
||||||
assertEquals(modelNow.params(), expectedModel.params());
|
assertEquals(modelNow.getModelParams(), expectedModel.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -139,9 +138,9 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
//.setOutputs("layer3")
|
//.setOutputs("layer3")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
BaseLayer bl0 = ((BaseLayer) modelNow.getLayer("layer0").getLayerConfiguration());
|
BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getLayer("layer0").getLayerConfiguration());
|
||||||
BaseLayer bl1 = ((BaseLayer) modelNow.getLayer("layer1").getLayerConfiguration());
|
BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getLayer("layer1").getLayerConfiguration());
|
||||||
BaseLayer bl3 = ((BaseLayer) modelNow.getLayer("layer3").getLayerConfiguration());
|
BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getLayer("layer3").getLayerConfiguration());
|
||||||
assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1)));
|
assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1)));
|
||||||
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
|
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
|
||||||
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
|
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
|
||||||
|
@ -161,22 +160,22 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
modelExpectedArch.init();
|
modelExpectedArch.init();
|
||||||
|
|
||||||
//modelNow should have the same architecture as modelExpectedArch
|
//modelNow should have the same architecture as modelExpectedArch
|
||||||
assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
|
assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(),
|
assertArrayEquals(modelExpectedArch.getLayer("layer0").getParams().shape(),
|
||||||
modelNow.getLayer("layer0").params().shape());
|
modelNow.getLayer("layer0").getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(),
|
assertArrayEquals(modelExpectedArch.getLayer("layer1").getParams().shape(),
|
||||||
modelNow.getLayer("layer1").params().shape());
|
modelNow.getLayer("layer1").getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(),
|
assertArrayEquals(modelExpectedArch.getLayer("layer2").getParams().shape(),
|
||||||
modelNow.getLayer("layer2").params().shape());
|
modelNow.getLayer("layer2").getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(),
|
assertArrayEquals(modelExpectedArch.getLayer("layer3").getParams().shape(),
|
||||||
modelNow.getLayer("layer3").params().shape());
|
modelNow.getLayer("layer3").getParams().shape());
|
||||||
|
|
||||||
modelNow.setParams(modelExpectedArch.params());
|
modelNow.setParams(modelExpectedArch.getModelParams());
|
||||||
//fit should give the same results
|
//fit should give the same results
|
||||||
modelExpectedArch.fit(randomData);
|
modelExpectedArch.fit(randomData);
|
||||||
modelNow.fit(randomData);
|
modelNow.fit(randomData);
|
||||||
assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8);
|
assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 1e-8);
|
||||||
assertEquals(modelExpectedArch.params(), modelNow.params());
|
assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -227,22 +226,22 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
modelExpectedArch.init();
|
modelExpectedArch.init();
|
||||||
|
|
||||||
//modelNow should have the same architecture as modelExpectedArch
|
//modelNow should have the same architecture as modelExpectedArch
|
||||||
assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
|
assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(),
|
assertArrayEquals(modelExpectedArch.getLayer("layer0").getParams().shape(),
|
||||||
modelNow.getLayer("layer0").params().shape());
|
modelNow.getLayer("layer0").getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(),
|
assertArrayEquals(modelExpectedArch.getLayer("layer1").getParams().shape(),
|
||||||
modelNow.getLayer("layer1").params().shape());
|
modelNow.getLayer("layer1").getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(),
|
assertArrayEquals(modelExpectedArch.getLayer("layer2").getParams().shape(),
|
||||||
modelNow.getLayer("layer2").params().shape());
|
modelNow.getLayer("layer2").getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(),
|
assertArrayEquals(modelExpectedArch.getLayer("layer3").getParams().shape(),
|
||||||
modelNow.getLayer("layer3").params().shape());
|
modelNow.getLayer("layer3").getParams().shape());
|
||||||
|
|
||||||
modelNow.setParams(modelExpectedArch.params());
|
modelNow.setParams(modelExpectedArch.getModelParams());
|
||||||
//fit should give the same results
|
//fit should give the same results
|
||||||
modelExpectedArch.fit(randomData);
|
modelExpectedArch.fit(randomData);
|
||||||
modelNow.fit(randomData);
|
modelNow.fit(randomData);
|
||||||
assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8);
|
assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 1e-8);
|
||||||
assertEquals(modelExpectedArch.params(), modelNow.params());
|
assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -385,14 +384,14 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(modelExpectedArch.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
|
assertEquals(modelExpectedArch.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
|
||||||
|
|
||||||
modelNow.setParams(modelExpectedArch.params());
|
modelNow.setParams(modelExpectedArch.getModelParams());
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (i < 5) {
|
while (i < 5) {
|
||||||
modelExpectedArch.fit(randomData);
|
modelExpectedArch.fit(randomData);
|
||||||
modelNow.fit(randomData);
|
modelNow.fit(randomData);
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
assertEquals(modelExpectedArch.params(), modelNow.params());
|
assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,10 +26,9 @@ import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
|
||||||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
@ -99,7 +98,7 @@ public class TransferLearningComplex extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
//Also check config:
|
//Also check config:
|
||||||
BaseLayer bl = ((BaseLayer) l.getLayerConfiguration());
|
BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
|
||||||
assertEquals(new Adam(2e-2), bl.getIUpdater());
|
assertEquals(new Adam(2e-2), bl.getIUpdater());
|
||||||
assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn());
|
assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn());
|
||||||
}
|
}
|
||||||
|
@ -154,8 +153,8 @@ public class TransferLearningComplex extends BaseDL4JTest {
|
||||||
.setOutputs("outRight").build();
|
.setOutputs("outRight").build();
|
||||||
ComputationGraph modelOther = new ComputationGraph(otherConf);
|
ComputationGraph modelOther = new ComputationGraph(otherConf);
|
||||||
modelOther.init();
|
modelOther.init();
|
||||||
modelOther.getLayer("denseRight0").setParams(modelToTune.getLayer("denseRight0").params());
|
modelOther.getLayer("denseRight0").setParams(modelToTune.getLayer("denseRight0").getParams());
|
||||||
modelOther.getLayer("outRight").setParams(modelToTune.getLayer("outRight").params());
|
modelOther.getLayer("outRight").setParams(modelToTune.getLayer("outRight").getParams());
|
||||||
|
|
||||||
modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
|
modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
|
||||||
ComputationGraph modelNow =
|
ComputationGraph modelNow =
|
||||||
|
@ -179,11 +178,11 @@ public class TransferLearningComplex extends BaseDL4JTest {
|
||||||
assertEquals(otherRandData.getFeatures(0),
|
assertEquals(otherRandData.getFeatures(0),
|
||||||
modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
|
modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
|
||||||
|
|
||||||
assertEquals(modelOther.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params());
|
assertEquals(modelOther.getLayer("denseRight0").getParams(), modelNow.getLayer("denseRight0").getParams());
|
||||||
assertEquals(modelOther.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params());
|
assertEquals(modelOther.getLayer("denseRight0").getParams(), modelToTune.getLayer("denseRight0").getParams());
|
||||||
|
|
||||||
assertEquals(modelOther.getLayer("outRight").params(), modelNow.getLayer("outRight").params());
|
assertEquals(modelOther.getLayer("outRight").getParams(), modelNow.getLayer("outRight").getParams());
|
||||||
assertEquals(modelOther.getLayer("outRight").params(), modelToTune.getLayer("outRight").params());
|
assertEquals(modelOther.getLayer("outRight").getParams(), modelToTune.getLayer("outRight").getParams());
|
||||||
n++;
|
n++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -237,11 +236,11 @@ public class TransferLearningComplex extends BaseDL4JTest {
|
||||||
assertEquals(otherRandData.getFeatures(0),
|
assertEquals(otherRandData.getFeatures(0),
|
||||||
modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
|
modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
|
||||||
|
|
||||||
assertEquals(modelToTune.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params());
|
assertEquals(modelToTune.getLayer("denseRight0").getParams(), modelNow.getLayer("denseRight0").getParams());
|
||||||
|
|
||||||
assertEquals(modelToTune.getLayer("outRight").params(), modelNow.getLayer("outRight").params());
|
assertEquals(modelToTune.getLayer("outRight").getParams(), modelNow.getLayer("outRight").getParams());
|
||||||
|
|
||||||
assertEquals(modelToTune.getLayer("outCentre").params(), modelNow.getLayer("outCentre").params());
|
assertEquals(modelToTune.getLayer("outCentre").getParams(), modelNow.getLayer("outCentre").getParams());
|
||||||
n++;
|
n++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -178,25 +178,25 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
|
||||||
TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
|
TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
|
||||||
MultiDataSet featurizedDataSet = helper.featurize(origData);
|
MultiDataSet featurizedDataSet = helper.featurize(origData);
|
||||||
|
|
||||||
assertEquals(modelIdentical.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params());
|
assertEquals(modelIdentical.getLayer("denseRight0").getParams(), modelToTune.getLayer("denseRight0").getParams());
|
||||||
modelIdentical.fit(origData);
|
modelIdentical.fit(origData);
|
||||||
helper.fitFeaturized(featurizedDataSet);
|
helper.fitFeaturized(featurizedDataSet);
|
||||||
|
|
||||||
assertEquals(modelIdentical.getLayer("denseCentre0").params(), modelToTune.getLayer("denseCentre0").params());
|
assertEquals(modelIdentical.getLayer("denseCentre0").getParams(), modelToTune.getLayer("denseCentre0").getParams());
|
||||||
assertEquals(modelIdentical.getLayer("denseCentre1").params(), modelToTune.getLayer("denseCentre1").params());
|
assertEquals(modelIdentical.getLayer("denseCentre1").getParams(), modelToTune.getLayer("denseCentre1").getParams());
|
||||||
assertEquals(modelIdentical.getLayer("denseCentre2").params(), modelToTune.getLayer("denseCentre2").params());
|
assertEquals(modelIdentical.getLayer("denseCentre2").getParams(), modelToTune.getLayer("denseCentre2").getParams());
|
||||||
assertEquals(modelIdentical.getLayer("denseCentre3").params(), modelToTune.getLayer("denseCentre3").params());
|
assertEquals(modelIdentical.getLayer("denseCentre3").getParams(), modelToTune.getLayer("denseCentre3").getParams());
|
||||||
assertEquals(modelIdentical.getLayer("outCentre").params(), modelToTune.getLayer("outCentre").params());
|
assertEquals(modelIdentical.getLayer("outCentre").getParams(), modelToTune.getLayer("outCentre").getParams());
|
||||||
assertEquals(modelIdentical.getLayer("denseRight").getNetConfiguration().toJson(),
|
assertEquals(modelIdentical.getLayer("denseRight").getNetConfiguration().toJson(),
|
||||||
modelToTune.getLayer("denseRight").getNetConfiguration().toJson());
|
modelToTune.getLayer("denseRight").getNetConfiguration().toJson());
|
||||||
assertEquals(modelIdentical.getLayer("denseRight").params(), modelToTune.getLayer("denseRight").params());
|
assertEquals(modelIdentical.getLayer("denseRight").getParams(), modelToTune.getLayer("denseRight").getParams());
|
||||||
assertEquals(modelIdentical.getLayer("denseRight0").getNetConfiguration().toJson(),
|
assertEquals(modelIdentical.getLayer("denseRight0").getNetConfiguration().toJson(),
|
||||||
modelToTune.getLayer("denseRight0").getNetConfiguration().toJson());
|
modelToTune.getLayer("denseRight0").getNetConfiguration().toJson());
|
||||||
//assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params());
|
//assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params());
|
||||||
assertEquals(modelIdentical.getLayer("denseRight1").params(), modelToTune.getLayer("denseRight1").params());
|
assertEquals(modelIdentical.getLayer("denseRight1").getParams(), modelToTune.getLayer("denseRight1").getParams());
|
||||||
assertEquals(modelIdentical.getLayer("outRight").params(), modelToTune.getLayer("outRight").params());
|
assertEquals(modelIdentical.getLayer("outRight").getParams(), modelToTune.getLayer("outRight").getParams());
|
||||||
assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params());
|
assertEquals(modelIdentical.getLayer("denseLeft0").getParams(), modelToTune.getLayer("denseLeft0").getParams());
|
||||||
assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params());
|
assertEquals(modelIdentical.getLayer("outLeft").getParams(), modelToTune.getLayer("outLeft").getParams());
|
||||||
|
|
||||||
// log.info(modelIdentical.summary());
|
// log.info(modelIdentical.summary());
|
||||||
// log.info(helper.unfrozenGraph().summary());
|
// log.info(helper.unfrozenGraph().summary());
|
||||||
|
@ -230,7 +230,7 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
|
||||||
TransferLearningHelper helper = new TransferLearningHelper(modelToFineTune, 1);
|
TransferLearningHelper helper = new TransferLearningHelper(modelToFineTune, 1);
|
||||||
|
|
||||||
INDArray paramsLastTwoLayers =
|
INDArray paramsLastTwoLayers =
|
||||||
Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params());
|
Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams());
|
||||||
MultiLayerNetwork notFrozen = new MultiLayerNetwork(
|
MultiLayerNetwork notFrozen = new MultiLayerNetwork(
|
||||||
(NeuralNetConfiguration) overallConf.clone().list()
|
(NeuralNetConfiguration) overallConf.clone().list()
|
||||||
.layer(0, new Builder().nIn(2).nOut(3).build())
|
.layer(0, new Builder().nIn(2).nOut(3).build())
|
||||||
|
@ -248,9 +248,9 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
|
||||||
modelNow.fit(randomData);
|
modelNow.fit(randomData);
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(),
|
INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), modelToFineTune.getLayer(1).getParams(),
|
||||||
notFrozen.params());
|
notFrozen.getModelParams());
|
||||||
INDArray act = modelNow.params();
|
INDArray act = modelNow.getModelParams();
|
||||||
assertEquals(expected, act);
|
assertEquals(expected, act);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,7 +91,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) {
|
for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) {
|
||||||
BaseLayer bl = ((BaseLayer) l.getLayerConfiguration());
|
BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
|
||||||
assertEquals(new RmsProp(0.5), bl.getIUpdater());
|
assertEquals(new RmsProp(0.5), bl.getIUpdater());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,9 +107,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.build());
|
.build());
|
||||||
expectedModel.init();
|
expectedModel.init();
|
||||||
expectedModel.setParams(modelToFineTune.params().dup());
|
expectedModel.setParams(modelToFineTune.getModelParams().dup());
|
||||||
|
|
||||||
assertEquals(expectedModel.params(), modelNow.params());
|
assertEquals(expectedModel.getModelParams(), modelNow.getModelParams());
|
||||||
|
|
||||||
//Check json
|
//Check json
|
||||||
NeuralNetConfiguration expectedConf = expectedModel.getNetConfiguration();
|
NeuralNetConfiguration expectedConf = expectedModel.getNetConfiguration();
|
||||||
|
@ -119,9 +119,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
modelNow.fit(randomData);
|
modelNow.fit(randomData);
|
||||||
expectedModel.fit(randomData);
|
expectedModel.fit(randomData);
|
||||||
|
|
||||||
assertEquals(modelNow.score(), expectedModel.score(), 1e-6);
|
assertEquals(modelNow.getScore(), expectedModel.getScore(), 1e-6);
|
||||||
INDArray pExp = expectedModel.params();
|
INDArray pExp = expectedModel.getModelParams();
|
||||||
INDArray pNow = modelNow.params();
|
INDArray pNow = modelNow.getModelParams();
|
||||||
assertEquals(pExp, pNow);
|
assertEquals(pExp, pNow);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,9 +160,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
//Will fail - expected because of dist and weight init changes
|
//Will fail - expected because of dist and weight init changes
|
||||||
//assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson());
|
//assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson());
|
||||||
|
|
||||||
BaseLayer bl0 = ((BaseLayer) modelNow.getNetConfiguration().getConf(0).getLayer());
|
BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(0).getLayer());
|
||||||
BaseLayer bl1 = ((BaseLayer) modelNow.getNetConfiguration().getConf(1).getLayer());
|
BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(1).getLayer());
|
||||||
BaseLayer bl3 = ((BaseLayer) modelNow.getNetConfiguration().getConf(3).getLayer());
|
BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(3).getLayer());
|
||||||
assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class);
|
assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class);
|
||||||
try {
|
try {
|
||||||
assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()),
|
assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()),
|
||||||
|
@ -173,18 +173,18 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
assertEquals(bl3.getWeightInitFn(), new WeightInitXavier());
|
assertEquals(bl3.getWeightInitFn(), new WeightInitXavier());
|
||||||
|
|
||||||
//modelNow should have the same architecture as modelExpectedArch
|
//modelNow should have the same architecture as modelExpectedArch
|
||||||
assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
|
assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(1).getParams().shape(), modelNow.getLayer(1).getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape());
|
||||||
|
|
||||||
modelNow.setParams(modelExpectedArch.params());
|
modelNow.setParams(modelExpectedArch.getModelParams());
|
||||||
//fit should give the same results
|
//fit should give the same results
|
||||||
modelExpectedArch.fit(randomData);
|
modelExpectedArch.fit(randomData);
|
||||||
modelNow.fit(randomData);
|
modelNow.fit(randomData);
|
||||||
assertEquals(modelExpectedArch.score(), modelNow.score(), 0.000001);
|
assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 0.000001);
|
||||||
assertEquals(modelExpectedArch.params(), modelNow.params());
|
assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -227,20 +227,20 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
modelExpectedArch.init();
|
modelExpectedArch.init();
|
||||||
|
|
||||||
//modelNow should have the same architecture as modelExpectedArch
|
//modelNow should have the same architecture as modelExpectedArch
|
||||||
assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
|
assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(1).getParams().shape(), modelNow.getLayer(1).getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape());
|
||||||
|
|
||||||
modelNow.setParams(modelExpectedArch.params());
|
modelNow.setParams(modelExpectedArch.getModelParams());
|
||||||
//fit should give the same results
|
//fit should give the same results
|
||||||
modelExpectedArch.fit(randomData);
|
modelExpectedArch.fit(randomData);
|
||||||
modelNow.fit(randomData);
|
modelNow.fit(randomData);
|
||||||
double scoreExpected = modelExpectedArch.score();
|
double scoreExpected = modelExpectedArch.getScore();
|
||||||
double scoreActual = modelNow.score();
|
double scoreActual = modelNow.getScore();
|
||||||
assertEquals(scoreExpected, scoreActual, 1e-4);
|
assertEquals(scoreExpected, scoreActual, 1e-4);
|
||||||
assertEquals(modelExpectedArch.params(), modelNow.params());
|
assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -370,14 +370,14 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
assertEquals(modelExpectedArch.getNetConfiguration().getConf(5).toJson(),
|
assertEquals(modelExpectedArch.getNetConfiguration().getConf(5).toJson(),
|
||||||
modelNow.getNetConfiguration().getConf(5).toJson());
|
modelNow.getNetConfiguration().getConf(5).toJson());
|
||||||
|
|
||||||
assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
|
assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
|
||||||
//subsampling has no params
|
//subsampling has no params
|
||||||
//assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
|
//assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(4).params().shape(), modelNow.getLayer(4).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(4).getParams().shape(), modelNow.getLayer(4).getParams().shape());
|
||||||
assertArrayEquals(modelExpectedArch.getLayer(5).params().shape(), modelNow.getLayer(5).params().shape());
|
assertArrayEquals(modelExpectedArch.getLayer(5).getParams().shape(), modelNow.getLayer(5).getParams().shape());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -449,23 +449,23 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
.inputType(InputType.convolutionalFlat(12, 12, 20)).build());
|
.inputType(InputType.convolutionalFlat(12, 12, 20)).build());
|
||||||
notFrozen.init();
|
notFrozen.init();
|
||||||
|
|
||||||
assertArrayEquals(modelToFineTune.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape());
|
assertArrayEquals(modelToFineTune.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
|
||||||
//subsampling has no params
|
//subsampling has no params
|
||||||
//assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
|
//assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
|
||||||
assertArrayEquals(notFrozen.getLayer(0).params().shape(), modelNow.getLayer(2).params().shape());
|
assertArrayEquals(notFrozen.getLayer(0).getParams().shape(), modelNow.getLayer(2).getParams().shape());
|
||||||
modelNow.getLayer(2).setParams(notFrozen.getLayer(0).params());
|
modelNow.getLayer(2).setParams(notFrozen.getLayer(0).getParams());
|
||||||
//subsampling has no params
|
//subsampling has no params
|
||||||
//assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape());
|
//assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape());
|
||||||
assertArrayEquals(notFrozen.getLayer(2).params().shape(), modelNow.getLayer(4).params().shape());
|
assertArrayEquals(notFrozen.getLayer(2).getParams().shape(), modelNow.getLayer(4).getParams().shape());
|
||||||
modelNow.getLayer(4).setParams(notFrozen.getLayer(2).params());
|
modelNow.getLayer(4).setParams(notFrozen.getLayer(2).getParams());
|
||||||
assertArrayEquals(notFrozen.getLayer(3).params().shape(), modelNow.getLayer(5).params().shape());
|
assertArrayEquals(notFrozen.getLayer(3).getParams().shape(), modelNow.getLayer(5).getParams().shape());
|
||||||
modelNow.getLayer(5).setParams(notFrozen.getLayer(3).params());
|
modelNow.getLayer(5).setParams(notFrozen.getLayer(3).getParams());
|
||||||
assertArrayEquals(notFrozen.getLayer(4).params().shape(), modelNow.getLayer(6).params().shape());
|
assertArrayEquals(notFrozen.getLayer(4).getParams().shape(), modelNow.getLayer(6).getParams().shape());
|
||||||
modelNow.getLayer(6).setParams(notFrozen.getLayer(4).params());
|
modelNow.getLayer(6).setParams(notFrozen.getLayer(4).getParams());
|
||||||
assertArrayEquals(notFrozen.getLayer(5).params().shape(), modelNow.getLayer(7).params().shape());
|
assertArrayEquals(notFrozen.getLayer(5).getParams().shape(), modelNow.getLayer(7).getParams().shape());
|
||||||
modelNow.getLayer(7).setParams(notFrozen.getLayer(5).params());
|
modelNow.getLayer(7).setParams(notFrozen.getLayer(5).getParams());
|
||||||
assertArrayEquals(notFrozen.getLayer(6).params().shape(), modelNow.getLayer(8).params().shape());
|
assertArrayEquals(notFrozen.getLayer(6).getParams().shape(), modelNow.getLayer(8).getParams().shape());
|
||||||
modelNow.getLayer(8).setParams(notFrozen.getLayer(6).params());
|
modelNow.getLayer(8).setParams(notFrozen.getLayer(6).getParams());
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (i < 3) {
|
while (i < 3) {
|
||||||
|
@ -474,8 +474,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), notFrozen.params());
|
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), notFrozen.getModelParams());
|
||||||
assertEquals(expectedParams, modelNow.params());
|
assertEquals(expectedParams, modelNow.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -503,13 +503,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
||||||
//Check original net isn't modified:
|
//Check original net isn't modified:
|
||||||
BaseLayer l0 = (BaseLayer) net.getLayer(0).getLayerConfiguration();
|
BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new Adam(1e-4), l0.getIUpdater());
|
assertEquals(new Adam(1e-4), l0.getIUpdater());
|
||||||
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
||||||
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
||||||
|
|
||||||
BaseLayer l1 = (BaseLayer) net.getLayer(1).getLayerConfiguration();
|
BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new Adam(1e-4), l1.getIUpdater());
|
assertEquals(new Adam(1e-4), l1.getIUpdater());
|
||||||
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
|
||||||
|
@ -518,13 +518,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
assertEquals(BackpropType.Standard, conf.getBackpropType());
|
assertEquals(BackpropType.Standard, conf.getBackpropType());
|
||||||
|
|
||||||
//Check new net has only the appropriate things modified (i.e., LR)
|
//Check new net has only the appropriate things modified (i.e., LR)
|
||||||
l0 = (BaseLayer) net2.getLayer(0).getLayerConfiguration();
|
l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
|
||||||
assertEquals(new Adam(2e-2), l0.getIUpdater());
|
assertEquals(new Adam(2e-2), l0.getIUpdater());
|
||||||
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
|
||||||
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
|
||||||
|
|
||||||
l1 = (BaseLayer) net2.getLayer(1).getLayerConfiguration();
|
l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration();
|
||||||
assertEquals(new Adam(2e-2), l1.getIUpdater());
|
assertEquals(new Adam(2e-2), l1.getIUpdater());
|
||||||
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
|
||||||
assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
|
assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
|
||||||
|
@ -586,17 +586,17 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
.build());
|
.build());
|
||||||
notFrozen.init();
|
notFrozen.init();
|
||||||
|
|
||||||
assertArrayEquals(modelToFineTune.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape());
|
assertArrayEquals(modelToFineTune.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
|
||||||
//subsampling has no params
|
//subsampling has no params
|
||||||
//assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
|
//assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
|
||||||
assertArrayEquals(notFrozen.getLayer(0).params().shape(), modelNow.getLayer(2).params().shape());
|
assertArrayEquals(notFrozen.getLayer(0).getParams().shape(), modelNow.getLayer(2).getParams().shape());
|
||||||
modelNow.getLayer(2).setParams(notFrozen.getLayer(0).params());
|
modelNow.getLayer(2).setParams(notFrozen.getLayer(0).getParams());
|
||||||
assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape());
|
assertArrayEquals(notFrozen.getLayer(1).getParams().shape(), modelNow.getLayer(3).getParams().shape());
|
||||||
modelNow.getLayer(3).setParams(notFrozen.getLayer(1).params());
|
modelNow.getLayer(3).setParams(notFrozen.getLayer(1).getParams());
|
||||||
assertArrayEquals(notFrozen.getLayer(2).params().shape(), modelNow.getLayer(4).params().shape());
|
assertArrayEquals(notFrozen.getLayer(2).getParams().shape(), modelNow.getLayer(4).getParams().shape());
|
||||||
modelNow.getLayer(4).setParams(notFrozen.getLayer(2).params());
|
modelNow.getLayer(4).setParams(notFrozen.getLayer(2).getParams());
|
||||||
assertArrayEquals(notFrozen.getLayer(3).params().shape(), modelNow.getLayer(5).params().shape());
|
assertArrayEquals(notFrozen.getLayer(3).getParams().shape(), modelNow.getLayer(5).getParams().shape());
|
||||||
modelNow.getLayer(5).setParams(notFrozen.getLayer(3).params());
|
modelNow.getLayer(5).setParams(notFrozen.getLayer(3).getParams());
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (i < 3) {
|
while (i < 3) {
|
||||||
|
@ -605,8 +605,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), notFrozen.params());
|
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), notFrozen.getModelParams());
|
||||||
assertEquals(expectedParams, modelNow.params());
|
assertEquals(expectedParams, modelNow.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -99,7 +99,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -144,7 +144,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
msdx.put(key, msdxTmp);
|
msdx.put(key, msdxTmp);
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
assertEquals(rho, ((AdaDelta)layer.layerConf().getIUpdater()).getRho(), 1e-4);
|
assertEquals(rho, ((AdaDelta)layer.getTypedLayerConfiguration().getIUpdater()).getRho(), 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(4, count);
|
assertEquals(4, count);
|
||||||
|
@ -165,7 +165,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -185,7 +185,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
assertEquals(lr, ((AdaGrad)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4);
|
assertEquals(lr, ((AdaGrad)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
|
||||||
assertEquals(2, count);
|
assertEquals(2, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,7 +209,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -245,8 +245,8 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(beta1, ((Adam)layer.layerConf().getIUpdater()).getBeta1(), 1e-4);
|
assertEquals(beta1, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4);
|
||||||
assertEquals(beta2, ((Adam)layer.layerConf().getIUpdater()).getBeta2(), 1e-4);
|
assertEquals(beta2, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4);
|
||||||
assertEquals(2, count);
|
assertEquals(2, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -273,7 +273,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
|
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -362,7 +362,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -398,8 +398,8 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(beta1, ((AdaMax)layer.layerConf().getIUpdater()).getBeta1(), 1e-4);
|
assertEquals(beta1, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4);
|
||||||
assertEquals(beta2, ((AdaMax)layer.layerConf().getIUpdater()).getBeta2(), 1e-4);
|
assertEquals(beta2, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4);
|
||||||
assertEquals(2, count);
|
assertEquals(2, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -418,7 +418,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -443,7 +443,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(mu, ((Nesterovs)layer.layerConf().getIUpdater()).getMomentum(), 1e-4);
|
assertEquals(mu, ((Nesterovs)layer.getTypedLayerConfiguration().getIUpdater()).getMomentum(), 1e-4);
|
||||||
assertEquals(2, count);
|
assertEquals(2, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -465,7 +465,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
Updater updater = UpdaterCreator.getUpdater(layer);
|
Updater updater = UpdaterCreator.getUpdater(layer);
|
||||||
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
|
int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
|
||||||
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
INDArray updaterState = Nd4j.create(1, updaterStateSize);
|
||||||
updater.setStateViewArray(layer, updaterState, true);
|
updater.setStateViewArray(layer, updaterState, true);
|
||||||
|
|
||||||
|
@ -495,7 +495,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
||||||
lastG.put(key, lastGTmp);
|
lastG.put(key, lastGTmp);
|
||||||
}
|
}
|
||||||
assertEquals(rmsDecay, ((RmsProp)layer.layerConf().getIUpdater()).getRmsDecay(), 1e-4);
|
assertEquals(rmsDecay, ((RmsProp)layer.getTypedLayerConfiguration().getIUpdater()).getRmsDecay(), 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -527,7 +527,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
gradExpected = val.mul(lr);
|
gradExpected = val.mul(lr);
|
||||||
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
||||||
}
|
}
|
||||||
assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4);
|
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -769,7 +769,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
gradExpected = val.mul(lr);
|
gradExpected = val.mul(lr);
|
||||||
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
|
||||||
}
|
}
|
||||||
assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4);
|
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
|
||||||
|
|
||||||
|
|
||||||
//Test with pretrain == false
|
//Test with pretrain == false
|
||||||
|
@ -797,7 +797,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
|
||||||
layer.setBackpropGradientsViewArray(gradients);
|
layer.setBackpropGradientsViewArray(gradients);
|
||||||
updater = UpdaterCreator.getUpdater(layer);
|
updater = UpdaterCreator.getUpdater(layer);
|
||||||
assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4);
|
assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -858,11 +858,11 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
//Check first updater block:
|
//Check first updater block:
|
||||||
UpdaterBlock ub0 = blocks.get(0);
|
UpdaterBlock ub0 = blocks.get(0);
|
||||||
assertEquals(3, ub0.getLayersAndVariablesInBlock().size());
|
assertEquals(3, ub0.getLayersAndVariablesInBlock().size());
|
||||||
assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName());
|
assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
|
||||||
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(0).getParamName());
|
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(0).getParamName());
|
||||||
assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName());
|
assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
|
||||||
assertEquals(DefaultParamInitializer.BIAS_KEY, ub0.getLayersAndVariablesInBlock().get(1).getParamName());
|
assertEquals(DefaultParamInitializer.BIAS_KEY, ub0.getLayersAndVariablesInBlock().get(1).getParamName());
|
||||||
assertEquals("l1", ub0.getLayersAndVariablesInBlock().get(2).getLayer().getConfig().getLayerName());
|
assertEquals("l1", ub0.getLayersAndVariablesInBlock().get(2).getLayer().getTrainingConfig().getLayerName());
|
||||||
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(2).getParamName());
|
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(2).getParamName());
|
||||||
|
|
||||||
int nParams0 = 10 * 10 + 10 + 10 * 10;
|
int nParams0 = 10 * 10 + 10 + 10 * 10;
|
||||||
|
@ -875,7 +875,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
//Check second updater block:
|
//Check second updater block:
|
||||||
UpdaterBlock ub1 = blocks.get(1);
|
UpdaterBlock ub1 = blocks.get(1);
|
||||||
assertEquals(1, ub1.getLayersAndVariablesInBlock().size());
|
assertEquals(1, ub1.getLayersAndVariablesInBlock().size());
|
||||||
assertEquals("l1", ub1.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName());
|
assertEquals("l1", ub1.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
|
||||||
assertEquals(DefaultParamInitializer.BIAS_KEY, ub1.getLayersAndVariablesInBlock().get(0).getParamName());
|
assertEquals(DefaultParamInitializer.BIAS_KEY, ub1.getLayersAndVariablesInBlock().get(0).getParamName());
|
||||||
|
|
||||||
int nParams1 = 10;
|
int nParams1 = 10;
|
||||||
|
@ -888,9 +888,9 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
//Check third updater block:
|
//Check third updater block:
|
||||||
UpdaterBlock ub2 = blocks.get(2);
|
UpdaterBlock ub2 = blocks.get(2);
|
||||||
assertEquals(2, ub2.getLayersAndVariablesInBlock().size());
|
assertEquals(2, ub2.getLayersAndVariablesInBlock().size());
|
||||||
assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName());
|
assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
|
||||||
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub2.getLayersAndVariablesInBlock().get(0).getParamName());
|
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub2.getLayersAndVariablesInBlock().get(0).getParamName());
|
||||||
assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName());
|
assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
|
||||||
assertEquals(DefaultParamInitializer.BIAS_KEY, ub2.getLayersAndVariablesInBlock().get(1).getParamName());
|
assertEquals(DefaultParamInitializer.BIAS_KEY, ub2.getLayersAndVariablesInBlock().get(1).getParamName());
|
||||||
|
|
||||||
int nParams2 = 10 * 10 + 10;
|
int nParams2 = 10 * 10 + 10;
|
||||||
|
@ -903,9 +903,9 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
//Check fourth updater block:
|
//Check fourth updater block:
|
||||||
UpdaterBlock ub3 = blocks.get(3);
|
UpdaterBlock ub3 = blocks.get(3);
|
||||||
assertEquals(2, ub3.getLayersAndVariablesInBlock().size());
|
assertEquals(2, ub3.getLayersAndVariablesInBlock().size());
|
||||||
assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName());
|
assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
|
||||||
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub3.getLayersAndVariablesInBlock().get(0).getParamName());
|
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub3.getLayersAndVariablesInBlock().get(0).getParamName());
|
||||||
assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName());
|
assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
|
||||||
assertEquals(DefaultParamInitializer.BIAS_KEY, ub3.getLayersAndVariablesInBlock().get(1).getParamName());
|
assertEquals(DefaultParamInitializer.BIAS_KEY, ub3.getLayersAndVariablesInBlock().get(1).getParamName());
|
||||||
|
|
||||||
int nParams3 = 10 * 10 + 10;
|
int nParams3 = 10 * 10 + 10;
|
||||||
|
@ -918,9 +918,9 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
//Check fifth updater black
|
//Check fifth updater black
|
||||||
UpdaterBlock ub4 = blocks.get(4);
|
UpdaterBlock ub4 = blocks.get(4);
|
||||||
assertEquals(2, ub4.getLayersAndVariablesInBlock().size());
|
assertEquals(2, ub4.getLayersAndVariablesInBlock().size());
|
||||||
assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName());
|
assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
|
||||||
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub4.getLayersAndVariablesInBlock().get(0).getParamName());
|
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub4.getLayersAndVariablesInBlock().get(0).getParamName());
|
||||||
assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName());
|
assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
|
||||||
assertEquals(DefaultParamInitializer.BIAS_KEY, ub4.getLayersAndVariablesInBlock().get(1).getParamName());
|
assertEquals(DefaultParamInitializer.BIAS_KEY, ub4.getLayersAndVariablesInBlock().get(1).getParamName());
|
||||||
|
|
||||||
int nParams4 = 10 * 10 + 10;
|
int nParams4 = 10 * 10 + 10;
|
||||||
|
|
|
@ -22,7 +22,7 @@ package org.deeplearning4j.nn.updater.custom;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
@ -61,18 +61,18 @@ public class TestCustomUpdater extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
//First: Check updater config
|
//First: Check updater config
|
||||||
assertTrue(((BaseLayer) conf1.getConf(0).getLayer()).getIUpdater() instanceof CustomIUpdater);
|
assertTrue(((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater() instanceof CustomIUpdater);
|
||||||
assertTrue(((BaseLayer) conf1.getConf(1).getLayer()).getIUpdater() instanceof CustomIUpdater);
|
assertTrue(((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater() instanceof CustomIUpdater);
|
||||||
assertTrue(((BaseLayer) conf2.getConf(0).getLayer()).getIUpdater() instanceof Sgd);
|
assertTrue(((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater() instanceof Sgd);
|
||||||
assertTrue(((BaseLayer) conf2.getConf(1).getLayer()).getIUpdater() instanceof Sgd);
|
assertTrue(((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater() instanceof Sgd);
|
||||||
|
|
||||||
CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayer) conf1.getConf(0).getLayer()).getIUpdater();
|
CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater();
|
||||||
CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayer) conf1.getConf(1).getLayer()).getIUpdater();
|
CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater();
|
||||||
assertEquals(lr, u0_0.getLearningRate(), 1e-6);
|
assertEquals(lr, u0_0.getLearningRate(), 1e-6);
|
||||||
assertEquals(lr, u0_1.getLearningRate(), 1e-6);
|
assertEquals(lr, u0_1.getLearningRate(), 1e-6);
|
||||||
|
|
||||||
Sgd u1_0 = (Sgd) ((BaseLayer) conf2.getConf(0).getLayer()).getIUpdater();
|
Sgd u1_0 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater();
|
||||||
Sgd u1_1 = (Sgd) ((BaseLayer) conf2.getConf(1).getLayer()).getIUpdater();
|
Sgd u1_1 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater();
|
||||||
assertEquals(lr, u1_0.getLearningRate(), 1e-6);
|
assertEquals(lr, u1_0.getLearningRate(), 1e-6);
|
||||||
assertEquals(lr, u1_1.getLearningRate(), 1e-6);
|
assertEquals(lr, u1_1.getLearningRate(), 1e-6);
|
||||||
|
|
||||||
|
|
|
@ -81,7 +81,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
|
||||||
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, layer.getOptimizer());
|
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, layer.getOptimizer());
|
||||||
double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
||||||
|
|
||||||
assertEquals(1.0, step, 1e-3);
|
assertEquals(1.0, step, 1e-3);
|
||||||
}
|
}
|
||||||
|
@ -97,11 +97,11 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
|
||||||
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
||||||
layer.setLabels(irisData.getLabels());
|
layer.setLabels(irisData.getLabels());
|
||||||
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
score1 = layer.score();
|
score1 = layer.getScore();
|
||||||
|
|
||||||
BackTrackLineSearch lineSearch =
|
BackTrackLineSearch lineSearch =
|
||||||
new BackTrackLineSearch(layer, new NegativeDefaultStepFunction(), layer.getOptimizer());
|
new BackTrackLineSearch(layer, new NegativeDefaultStepFunction(), layer.getOptimizer());
|
||||||
double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
||||||
|
|
||||||
assertEquals(1.0, step, 1e-3);
|
assertEquals(1.0, step, 1e-3);
|
||||||
}
|
}
|
||||||
|
@ -118,18 +118,18 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
|
||||||
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
||||||
layer.setLabels(irisData.getLabels());
|
layer.setLabels(irisData.getLabels());
|
||||||
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
score1 = layer.score();
|
score1 = layer.getScore();
|
||||||
INDArray origGradient = layer.gradient().gradient().dup();
|
INDArray origGradient = layer.gradient().gradient().dup();
|
||||||
|
|
||||||
NegativeDefaultStepFunction sf = new NegativeDefaultStepFunction();
|
NegativeDefaultStepFunction sf = new NegativeDefaultStepFunction();
|
||||||
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
|
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
|
||||||
double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
||||||
INDArray currParams = layer.params();
|
INDArray currParams = layer.getModelParams();
|
||||||
sf.step(currParams, origGradient, step);
|
sf.step(currParams, origGradient, step);
|
||||||
layer.setParamsTable(currParams);
|
layer.setParamsTable(currParams);
|
||||||
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
score2 = layer.score();
|
score2 = layer.getScore();
|
||||||
|
|
||||||
assertTrue(score1 > score2, "score1=" + score1 + ", score2=" + score2);
|
assertTrue(score1 > score2, "score1=" + score1 + ", score2=" + score2);
|
||||||
|
|
||||||
|
@ -146,19 +146,19 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
|
||||||
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
|
||||||
layer.setLabels(irisData.getLabels());
|
layer.setLabels(irisData.getLabels());
|
||||||
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
score1 = layer.score();
|
score1 = layer.getScore();
|
||||||
INDArray origGradient = layer.gradient().gradient().dup();
|
INDArray origGradient = layer.gradient().gradient().dup();
|
||||||
|
|
||||||
DefaultStepFunction sf = new DefaultStepFunction();
|
DefaultStepFunction sf = new DefaultStepFunction();
|
||||||
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
|
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
|
||||||
double step = lineSearch.optimize(layer.params().dup(), layer.gradient().gradient().dup(),
|
double step = lineSearch.optimize(layer.getModelParams().dup(), layer.gradient().gradient().dup(),
|
||||||
layer.gradient().gradient().dup(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
layer.gradient().gradient().dup(), LayerWorkspaceMgr.noWorkspacesImmutable());
|
||||||
|
|
||||||
INDArray currParams = layer.params();
|
INDArray currParams = layer.getModelParams();
|
||||||
sf.step(currParams, origGradient, step);
|
sf.step(currParams, origGradient, step);
|
||||||
layer.setParamsTable(currParams);
|
layer.setParamsTable(currParams);
|
||||||
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
score2 = layer.score();
|
score2 = layer.getScore();
|
||||||
|
|
||||||
assertTrue(score1 < score2, "score1 = " + score1 + ", score2 = " + score2);
|
assertTrue(score1 < score2, "score1 = " + score1 + ", score2 = " + score2);
|
||||||
}
|
}
|
||||||
|
@ -190,12 +190,12 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer));
|
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer));
|
||||||
network.init();
|
network.init();
|
||||||
TrainingListener listener = new ScoreIterationListener(10);
|
TrainingListener listener = new ScoreIterationListener(10);
|
||||||
network.setListeners(Collections.singletonList(listener));
|
network.addTrainingListeners(Collections.singletonList(listener));
|
||||||
double oldScore = network.score(data);
|
double oldScore = network.score(data);
|
||||||
for( int i=0; i<100; i++ ) {
|
for( int i=0; i<100; i++ ) {
|
||||||
network.fit(data.getFeatures(), data.getLabels());
|
network.fit(data.getFeatures(), data.getLabels());
|
||||||
}
|
}
|
||||||
double score = network.score();
|
double score = network.getScore();
|
||||||
assertTrue(score < oldScore);
|
assertTrue(score < oldScore);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,13 +208,13 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
|
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
|
||||||
network.init();
|
network.init();
|
||||||
TrainingListener listener = new ScoreIterationListener(10);
|
TrainingListener listener = new ScoreIterationListener(10);
|
||||||
network.setListeners(Collections.singletonList(listener));
|
network.addTrainingListeners(Collections.singletonList(listener));
|
||||||
double firstScore = network.score(data);
|
double firstScore = network.score(data);
|
||||||
|
|
||||||
for( int i=0; i<5; i++ ) {
|
for( int i=0; i<5; i++ ) {
|
||||||
network.fit(data.getFeatures(), data.getLabels());
|
network.fit(data.getFeatures(), data.getLabels());
|
||||||
}
|
}
|
||||||
double score = network.score();
|
double score = network.getScore();
|
||||||
assertTrue(score < firstScore);
|
assertTrue(score < firstScore);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -227,13 +227,13 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
|
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
|
||||||
network.init();
|
network.init();
|
||||||
TrainingListener listener = new ScoreIterationListener(10);
|
TrainingListener listener = new ScoreIterationListener(10);
|
||||||
network.setListeners(Collections.singletonList(listener));
|
network.addTrainingListeners(Collections.singletonList(listener));
|
||||||
double oldScore = network.score(data);
|
double oldScore = network.score(data);
|
||||||
|
|
||||||
for( int i=0; i<5; i++ ) {
|
for( int i=0; i<5; i++ ) {
|
||||||
network.fit(data.getFeatures(), data.getLabels());
|
network.fit(data.getFeatures(), data.getLabels());
|
||||||
}
|
}
|
||||||
double score = network.score();
|
double score = network.getScore();
|
||||||
assertTrue(score < oldScore);
|
assertTrue(score < oldScore);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
import org.deeplearning4j.nn.api.*;
|
import org.deeplearning4j.nn.api.*;
|
||||||
import org.deeplearning4j.nn.conf.CacheMode;
|
import org.deeplearning4j.nn.conf.CacheMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
@ -211,38 +212,38 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
System.out.println("---------\n Alg= " + oa + ", nIter= " + numLineSearchIter + ", nDimensions= "
|
System.out.println("---------\n Alg= " + oa + ", nIter= " + numLineSearchIter + ", nDimensions= "
|
||||||
+ nDimensions);
|
+ nDimensions);
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().maxNumLineSearchIterations(numLineSearchIter)
|
LayerConfiguration conf = NeuralNetConfiguration.builder().maxNumLineSearchIterations(numLineSearchIter)
|
||||||
.updater(new Sgd(1e-2))
|
.updater(new Sgd(1e-2))
|
||||||
.layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
|
.layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build().getFlattenedLayerConfigurations().get(0);
|
||||||
conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
|
conf.addVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
|
||||||
|
|
||||||
Random rng = new DefaultRandom(12345L);
|
Random rng = new DefaultRandom(12345L);
|
||||||
org.nd4j.linalg.api.rng.distribution.Distribution dist =
|
org.nd4j.linalg.api.rng.distribution.Distribution dist =
|
||||||
new org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution(rng, -10, 10);
|
new org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution(rng, -10, 10);
|
||||||
IModel m = new SphereFunctionModel(nDimensions, dist, conf);
|
IModel m = new SphereFunctionModel(nDimensions, dist, conf);
|
||||||
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
double scoreBefore = m.score();
|
double scoreBefore = m.getScore();
|
||||||
assertTrue(!Double.isNaN(scoreBefore) && !Double.isInfinite(scoreBefore));
|
assertTrue(!Double.isNaN(scoreBefore) && !Double.isInfinite(scoreBefore));
|
||||||
if (PRINT_OPT_RESULTS) {
|
if (PRINT_OPT_RESULTS) {
|
||||||
System.out.println("Before:");
|
System.out.println("Before:");
|
||||||
System.out.println(scoreBefore);
|
System.out.println(scoreBefore);
|
||||||
System.out.println(m.params());
|
System.out.println(m.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
ConvexOptimizer opt = getOptimizer(oa, conf, m);
|
ConvexOptimizer opt = getOptimizer(oa, conf.getNetConfiguration(), m);
|
||||||
|
|
||||||
opt.setupSearchState(m.gradientAndScore());
|
opt.setupSearchState(m.gradientAndScore());
|
||||||
for( int i=0; i<100; i++ ) {
|
for( int i=0; i<100; i++ ) {
|
||||||
opt.optimize(LayerWorkspaceMgr.noWorkspaces());
|
opt.optimize(LayerWorkspaceMgr.noWorkspaces());
|
||||||
}
|
}
|
||||||
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
double scoreAfter = m.score();
|
double scoreAfter = m.getScore();
|
||||||
|
|
||||||
assertTrue(!Double.isNaN(scoreAfter) && !Double.isInfinite(scoreAfter));
|
assertTrue(!Double.isNaN(scoreAfter) && !Double.isInfinite(scoreAfter));
|
||||||
if (PRINT_OPT_RESULTS) {
|
if (PRINT_OPT_RESULTS) {
|
||||||
System.out.println("After:");
|
System.out.println("After:");
|
||||||
System.out.println(scoreAfter);
|
System.out.println(scoreAfter);
|
||||||
System.out.println(m.params());
|
System.out.println(m.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Expected behaviour after optimization:
|
//Expected behaviour after optimization:
|
||||||
|
@ -279,17 +280,17 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
.layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
|
.layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
|
||||||
conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
|
conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
|
||||||
|
|
||||||
IModel m = new SphereFunctionModel(100, dist, conf);
|
IModel m = new SphereFunctionModel(100, dist, conf.getFlattenedLayerConfigurations().get(0));
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
scores[0] = m.score(); //Before optimization
|
scores[0] = m.getScore(); //Before optimization
|
||||||
} else {
|
} else {
|
||||||
ConvexOptimizer opt = getOptimizer(oa, conf, m);
|
ConvexOptimizer opt = getOptimizer(oa, conf, m);
|
||||||
for( int j=0; j<100; j++ ) {
|
for( int j=0; j<100; j++ ) {
|
||||||
opt.optimize(LayerWorkspaceMgr.noWorkspaces());
|
opt.optimize(LayerWorkspaceMgr.noWorkspaces());
|
||||||
}
|
}
|
||||||
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
scores[i] = m.score();
|
scores[i] = m.getScore();
|
||||||
assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
|
assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -316,7 +317,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
private static final long serialVersionUID = -6963606137417355405L;
|
private static final long serialVersionUID = -6963606137417355405L;
|
||||||
|
|
||||||
private SphereFunctionModel(int nParams, org.nd4j.linalg.api.rng.distribution.Distribution distribution,
|
private SphereFunctionModel(int nParams, org.nd4j.linalg.api.rng.distribution.Distribution distribution,
|
||||||
NeuralNetConfiguration conf) {
|
LayerConfiguration conf) {
|
||||||
super(distribution.sample(new int[] {1, nParams}), conf);
|
super(distribution.sample(new int[] {1, nParams}), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -437,7 +438,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setListeners(TrainingListener... listeners) {
|
public void addTrainingListeners(TrainingListener... listeners) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -499,17 +500,17 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
.layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
|
.layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
|
||||||
conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
|
conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
|
||||||
|
|
||||||
IModel m = new RastriginFunctionModel(10, conf);
|
IModel m = new RastriginFunctionModel(10, conf.getFlattenedLayerConfigurations().get(0));
|
||||||
int nParams = (int)m.numParams();
|
int nParams = (int)m.numParams();
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
scores[0] = m.score(); //Before optimization
|
scores[0] = m.getScore(); //Before optimization
|
||||||
} else {
|
} else {
|
||||||
ConvexOptimizer opt = getOptimizer(oa, conf, m);
|
ConvexOptimizer opt = getOptimizer(oa, conf, m);
|
||||||
opt.getUpdater().setStateViewArray((Layer) m, Nd4j.create(new int[] {1, nParams}, 'c'), true);
|
opt.getUpdater().setStateViewArray((Layer) m, Nd4j.create(new int[] {1, nParams}, 'c'), true);
|
||||||
opt.optimize(LayerWorkspaceMgr.noWorkspaces());
|
opt.optimize(LayerWorkspaceMgr.noWorkspaces());
|
||||||
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
scores[i] = m.score();
|
scores[i] = m.getScore();
|
||||||
assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
|
assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -540,7 +541,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
private static class RastriginFunctionModel extends SimpleOptimizableModel {
|
private static class RastriginFunctionModel extends SimpleOptimizableModel {
|
||||||
private static final long serialVersionUID = -1772954508787487941L;
|
private static final long serialVersionUID = -1772954508787487941L;
|
||||||
|
|
||||||
private RastriginFunctionModel(int nDimensions, NeuralNetConfiguration conf) {
|
private RastriginFunctionModel(int nDimensions, LayerConfiguration conf) {
|
||||||
super(initParams(nDimensions), conf);
|
super(initParams(nDimensions), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -710,7 +711,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setListeners(TrainingListener... listeners) {
|
public void addTrainingListeners(TrainingListener... listeners) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -768,15 +769,15 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
|
conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
|
||||||
|
|
||||||
IModel m = new RosenbrockFunctionModel(100, conf);
|
IModel m = new RosenbrockFunctionModel(100, conf.getFlattenedLayerConfigurations().get(0));
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
scores[0] = m.score(); //Before optimization
|
scores[0] = m.getScore(); //Before optimization
|
||||||
} else {
|
} else {
|
||||||
ConvexOptimizer opt = getOptimizer(oa, conf, m);
|
ConvexOptimizer opt = getOptimizer(oa, conf, m);
|
||||||
opt.optimize(LayerWorkspaceMgr.noWorkspaces());
|
opt.optimize(LayerWorkspaceMgr.noWorkspaces());
|
||||||
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
scores[i] = m.score();
|
scores[i] = m.getScore();
|
||||||
assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]), "NaN or infinite score: " + scores[i]);
|
assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]), "NaN or infinite score: " + scores[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -810,7 +811,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
private static class RosenbrockFunctionModel extends SimpleOptimizableModel {
|
private static class RosenbrockFunctionModel extends SimpleOptimizableModel {
|
||||||
private static final long serialVersionUID = -5129494342531033706L;
|
private static final long serialVersionUID = -5129494342531033706L;
|
||||||
|
|
||||||
private RosenbrockFunctionModel(int nDimensions, NeuralNetConfiguration conf) {
|
private RosenbrockFunctionModel(int nDimensions, LayerConfiguration conf) {
|
||||||
super(initParams(nDimensions), conf);
|
super(initParams(nDimensions), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -995,7 +996,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setListeners(TrainingListener... listeners) {
|
public void addTrainingListeners(TrainingListener... listeners) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1029,13 +1030,31 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
private static final long serialVersionUID = 4409380971404019303L;
|
private static final long serialVersionUID = 4409380971404019303L;
|
||||||
protected INDArray parameters;
|
protected INDArray parameters;
|
||||||
protected INDArray gradientView;
|
protected INDArray gradientView;
|
||||||
protected final NeuralNetConfiguration conf;
|
protected final LayerConfiguration conf;
|
||||||
protected Gradient gradient;
|
protected Gradient gradient;
|
||||||
protected double score;
|
protected double score;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return 1d parameter vector
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public INDArray getParams() {
|
||||||
|
throw new RuntimeException("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a reference to the network this layer is part of.
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public IModel getNet() {
|
||||||
|
throw new RuntimeException("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
/**@param parameterInit Initial parameters. Also determines dimensionality of problem. Should be row vector.
|
/**@param parameterInit Initial parameters. Also determines dimensionality of problem. Should be row vector.
|
||||||
*/
|
*/
|
||||||
private SimpleOptimizableModel(INDArray parameterInit, NeuralNetConfiguration conf) {
|
private SimpleOptimizableModel(INDArray parameterInit, LayerConfiguration conf) {
|
||||||
this.parameters = parameterInit.dup();
|
this.parameters = parameterInit.dup();
|
||||||
this.gradientView = Nd4j.create(parameterInit.shape());
|
this.gradientView = Nd4j.create(parameterInit.shape());
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
|
@ -1048,17 +1067,12 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public LayerConfiguration getLayerConfiguration() {
|
public LayerConfiguration getLayerConfiguration() {
|
||||||
return this.conf.getFirstLayer();
|
return this.conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addListeners(TrainingListener... listener) {
|
public ITraininableLayerConfiguration getTrainingConfig() {
|
||||||
// no-op
|
return (BaseLayerConfiguration) conf;
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public TrainingConfig getConfig() {
|
|
||||||
return conf.getFirstLayer();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1092,7 +1106,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setListeners(TrainingListener... listeners) {
|
public void addTrainingListeners(TrainingListener... listeners) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1112,7 +1126,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double score() {
|
public double getScore() {
|
||||||
return score;
|
return score;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1132,7 +1146,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray params() {
|
public INDArray getModelParams() {
|
||||||
return parameters;
|
return parameters;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1154,7 +1168,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
@Override
|
@Override
|
||||||
public Pair<Gradient, Double> gradientAndScore() {
|
public Pair<Gradient, Double> gradientAndScore() {
|
||||||
computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
|
||||||
return new Pair<>(gradient(), score());
|
return new Pair<>(gradient(), getScore());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -1164,7 +1178,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNetConfiguration getNetConfiguration() {
|
public NeuralNetConfiguration getNetConfiguration() {
|
||||||
return conf;
|
return conf.getNetConfiguration();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -1225,12 +1239,12 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<TrainingListener> getListeners() {
|
public Collection<TrainingListener> getTrainingListeners() {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setListeners(Collection<TrainingListener> listeners) {
|
public void addTrainingListeners(Collection<TrainingListener> listeners) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1310,4 +1324,6 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
public void close(){
|
public void close(){
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,7 +76,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
.keepAll()
|
.keepAll()
|
||||||
.saveEveryNEpochs(2)
|
.saveEveryNEpochs(2)
|
||||||
.build();
|
.build();
|
||||||
net.setListeners(l);
|
net.addTrainingListeners(l);
|
||||||
|
|
||||||
for(int i=0; i<10; i++ ){
|
for(int i=0; i<10; i++ ){
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
|
@ -125,7 +125,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
.keepLast(3)
|
.keepLast(3)
|
||||||
.saveEveryNIterations(5)
|
.saveEveryNIterations(5)
|
||||||
.build();
|
.build();
|
||||||
net.setListeners(l);
|
net.addTrainingListeners(l);
|
||||||
|
|
||||||
for(int i=0; i<20; i++ ){ //40 iterations total
|
for(int i=0; i<20; i++ ){ //40 iterations total
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
|
@ -167,7 +167,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork netStatic2 = CheckpointListener.loadLastCheckpointMLN(f);
|
MultiLayerNetwork netStatic2 = CheckpointListener.loadLastCheckpointMLN(f);
|
||||||
assertEquals(35, netStatic2.getIterationCount());
|
assertEquals(35, netStatic2.getIterationCount());
|
||||||
assertEquals(netStatic.params(), netStatic2.params());
|
assertEquals(netStatic.getModelParams(), netStatic2.getModelParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -182,7 +182,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
.keepLast(3)
|
.keepLast(3)
|
||||||
.saveEvery(4900, TimeUnit.MILLISECONDS)
|
.saveEvery(4900, TimeUnit.MILLISECONDS)
|
||||||
.build();
|
.build();
|
||||||
net.setListeners(l);
|
net.addTrainingListeners(l);
|
||||||
|
|
||||||
for(int i=0; i<3; i++ ){ //10 iterations total
|
for(int i=0; i<3; i++ ){ //10 iterations total
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
|
@ -226,7 +226,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
.keepLastAndEvery(3, 3)
|
.keepLastAndEvery(3, 3)
|
||||||
.saveEveryNEpochs(2)
|
.saveEveryNEpochs(2)
|
||||||
.build();
|
.build();
|
||||||
net.setListeners(l);
|
net.addTrainingListeners(l);
|
||||||
|
|
||||||
for(int i=0; i<20; i++ ){ //40 iterations total
|
for(int i=0; i<20; i++ ){ //40 iterations total
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
|
@ -272,7 +272,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
.keepAll()
|
.keepAll()
|
||||||
.saveEveryNEpochs(1)
|
.saveEveryNEpochs(1)
|
||||||
.build();
|
.build();
|
||||||
net.setListeners(l);
|
net.addTrainingListeners(l);
|
||||||
|
|
||||||
for(int i=0; i<3; i++ ){
|
for(int i=0; i<3; i++ ){
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
|
@ -294,7 +294,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
.saveEveryNEpochs(1)
|
.saveEveryNEpochs(1)
|
||||||
.deleteExisting(true)
|
.deleteExisting(true)
|
||||||
.build();
|
.build();
|
||||||
net.setListeners(l);
|
net.addTrainingListeners(l);
|
||||||
|
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ public class TestFailureListener extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
net.setListeners(new FailureTestingListener(
|
net.addTrainingListeners(new FailureTestingListener(
|
||||||
// FailureTestingListener.FailureMode.OOM,
|
// FailureTestingListener.FailureMode.OOM,
|
||||||
FailureTestingListener.FailureMode.SYSTEM_EXIT_1,
|
FailureTestingListener.FailureMode.SYSTEM_EXIT_1,
|
||||||
new FailureTestingListener.IterationEpochTrigger(false, 10)));
|
new FailureTestingListener.IterationEpochTrigger(false, 10)));
|
||||||
|
@ -84,7 +84,7 @@ public class TestFailureListener extends BaseDL4JTest {
|
||||||
assertNotNull(username);
|
assertNotNull(username);
|
||||||
assertFalse(username.isEmpty());
|
assertFalse(username.isEmpty());
|
||||||
|
|
||||||
net.setListeners(new FailureTestingListener(
|
net.addTrainingListeners(new FailureTestingListener(
|
||||||
FailureTestingListener.FailureMode.SYSTEM_EXIT_1,
|
FailureTestingListener.FailureMode.SYSTEM_EXIT_1,
|
||||||
new FailureTestingListener.Or(
|
new FailureTestingListener.Or(
|
||||||
new FailureTestingListener.IterationEpochTrigger(false, 10000),
|
new FailureTestingListener.IterationEpochTrigger(false, 10000),
|
||||||
|
@ -112,7 +112,7 @@ public class TestFailureListener extends BaseDL4JTest {
|
||||||
assertNotNull(hostname);
|
assertNotNull(hostname);
|
||||||
assertFalse(hostname.isEmpty());
|
assertFalse(hostname.isEmpty());
|
||||||
|
|
||||||
net.setListeners(new FailureTestingListener(
|
net.addTrainingListeners(new FailureTestingListener(
|
||||||
FailureTestingListener.FailureMode.ILLEGAL_STATE,
|
FailureTestingListener.FailureMode.ILLEGAL_STATE,
|
||||||
new FailureTestingListener.And(
|
new FailureTestingListener.And(
|
||||||
new FailureTestingListener.HostNameTrigger(hostname),
|
new FailureTestingListener.HostNameTrigger(hostname),
|
||||||
|
|
|
@ -77,17 +77,17 @@ public class TestListeners extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
net.setListeners(new ScoreIterationListener(), new TestRoutingListener());
|
net.addTrainingListeners(new ScoreIterationListener(), new TestRoutingListener());
|
||||||
|
|
||||||
for (Layer l : net.getLayers()) {
|
for (Layer l : net.getLayers()) {
|
||||||
Collection<TrainingListener> layerListeners = l.getListeners();
|
Collection<TrainingListener> layerListeners = l.getTrainingListeners();
|
||||||
assertEquals(2, layerListeners.size(), l.getClass().toString());
|
assertEquals(2, layerListeners.size(), l.getClass().toString());
|
||||||
TrainingListener[] lArr = layerListeners.toArray(new TrainingListener[2]);
|
TrainingListener[] lArr = layerListeners.toArray(new TrainingListener[2]);
|
||||||
assertTrue(lArr[0] instanceof ScoreIterationListener);
|
assertTrue(lArr[0] instanceof ScoreIterationListener);
|
||||||
assertTrue(lArr[1] instanceof TestRoutingListener);
|
assertTrue(lArr[1] instanceof TestRoutingListener);
|
||||||
}
|
}
|
||||||
|
|
||||||
Collection<TrainingListener> netListeners = net.getListeners();
|
Collection<TrainingListener> netListeners = net.getTrainingListeners();
|
||||||
assertEquals(2, netListeners.size());
|
assertEquals(2, netListeners.size());
|
||||||
TrainingListener[] lArr = netListeners.toArray(new TrainingListener[2]);
|
TrainingListener[] lArr = netListeners.toArray(new TrainingListener[2]);
|
||||||
assertTrue(lArr[0] instanceof ScoreIterationListener);
|
assertTrue(lArr[0] instanceof ScoreIterationListener);
|
||||||
|
@ -101,17 +101,17 @@ public class TestListeners extends BaseDL4JTest {
|
||||||
ComputationGraph cg = new ComputationGraph(gConf);
|
ComputationGraph cg = new ComputationGraph(gConf);
|
||||||
cg.init();
|
cg.init();
|
||||||
|
|
||||||
cg.setListeners(new ScoreIterationListener(), new TestRoutingListener());
|
cg.addTrainingListeners(new ScoreIterationListener(), new TestRoutingListener());
|
||||||
|
|
||||||
for (Layer l : cg.getLayers()) {
|
for (Layer l : cg.getLayers()) {
|
||||||
Collection<TrainingListener> layerListeners = l.getListeners();
|
Collection<TrainingListener> layerListeners = l.getTrainingListeners();
|
||||||
assertEquals(2, layerListeners.size());
|
assertEquals(2, layerListeners.size());
|
||||||
lArr = layerListeners.toArray(new TrainingListener[2]);
|
lArr = layerListeners.toArray(new TrainingListener[2]);
|
||||||
assertTrue(lArr[0] instanceof ScoreIterationListener);
|
assertTrue(lArr[0] instanceof ScoreIterationListener);
|
||||||
assertTrue(lArr[1] instanceof TestRoutingListener);
|
assertTrue(lArr[1] instanceof TestRoutingListener);
|
||||||
}
|
}
|
||||||
|
|
||||||
netListeners = cg.getListeners();
|
netListeners = cg.getTrainingListeners();
|
||||||
assertEquals(2, netListeners.size());
|
assertEquals(2, netListeners.size());
|
||||||
lArr = netListeners.toArray(new TrainingListener[2]);
|
lArr = netListeners.toArray(new TrainingListener[2]);
|
||||||
assertTrue(lArr[0] instanceof ScoreIterationListener);
|
assertTrue(lArr[0] instanceof ScoreIterationListener);
|
||||||
|
@ -180,7 +180,7 @@ public class TestListeners extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
net.setListeners(listeners);
|
net.addTrainingListeners(listeners);
|
||||||
|
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
|
|
||||||
|
@ -199,7 +199,7 @@ public class TestListeners extends BaseDL4JTest {
|
||||||
listeners2.add(il2);
|
listeners2.add(il2);
|
||||||
}
|
}
|
||||||
|
|
||||||
net.setListeners(listeners2);
|
net.addTrainingListeners(listeners2);
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -216,7 +216,7 @@ public class TestListeners extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
TestListener tl = new TestListener();
|
TestListener tl = new TestListener();
|
||||||
net.setListeners(tl);
|
net.addTrainingListeners(tl);
|
||||||
|
|
||||||
DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
|
DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
|
||||||
|
|
||||||
|
@ -260,7 +260,7 @@ public class TestListeners extends BaseDL4JTest {
|
||||||
tl = new TestListener();
|
tl = new TestListener();
|
||||||
|
|
||||||
ComputationGraph cg = net.toComputationGraph();
|
ComputationGraph cg = net.toComputationGraph();
|
||||||
cg.setListeners(tl);
|
cg.addTrainingListeners(tl);
|
||||||
|
|
||||||
cg.fit(irisIter, 2);
|
cg.fit(irisIter, 2);
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ public class RandomTests extends BaseDL4JTest {
|
||||||
|
|
||||||
// at the end of day, model params has to
|
// at the end of day, model params has to
|
||||||
for (int i = 0; i < models.size(); i++) {
|
for (int i = 0; i < models.size(); i++) {
|
||||||
assertEquals(models.get(0).params(), models.get(i).params());
|
assertEquals(models.get(0).getModelParams(), models.get(i).getModelParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ public class RandomTests extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
assertEquals(net1.params(), net2.params());
|
assertEquals(net1.getModelParams(), net2.getModelParams());
|
||||||
|
|
||||||
NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json);
|
NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json);
|
||||||
|
|
||||||
|
@ -127,6 +127,6 @@ public class RandomTests extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson);
|
MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson);
|
||||||
net3.init();
|
net3.init();
|
||||||
|
|
||||||
assertEquals(net1.params(), net3.params());
|
assertEquals(net1.getModelParams(), net3.getModelParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,7 @@ public class TestSystemInfoPrintListener extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
net.setListeners(systemInfoFilePrintListener);
|
net.addTrainingListeners(systemInfoFilePrintListener);
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(10, 150);
|
DataSetIterator iter = new IrisDataSetIterator(10, 150);
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new Nesterovs().stateSize(net.numParams());
|
int updaterSize = (int) new Nesterovs().stateSize(net.numParams());
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
@ -126,7 +126,7 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));
|
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
@ -170,7 +170,7 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,7 +89,7 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new Nesterovs().stateSize(numParams);
|
int updaterSize = (int) new Nesterovs().stateSize(numParams);
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
@ -132,7 +132,7 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
|
assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
@ -178,7 +178,7 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
|
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,7 +90,7 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
|
||||||
|
|
||||||
long numParams = (int)net.numParams();
|
long numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new Nesterovs().stateSize(numParams);
|
int updaterSize = (int) new Nesterovs().stateSize(numParams);
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
@ -133,7 +133,7 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
|
assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
|
||||||
|
|
||||||
long numParams = net.numParams();
|
long numParams = net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
@ -179,7 +179,7 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
|
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
|
||||||
|
|
||||||
long numParams = net.numParams();
|
long numParams = net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,7 +94,7 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertEquals(0.15, n.getLearningRate(), 1e-6);
|
assertEquals(0.15, n.getLearningRate(), 1e-6);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new Nesterovs().stateSize(numParams);
|
int updaterSize = (int) new Nesterovs().stateSize(numParams);
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
@ -143,7 +143,7 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
|
assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
@ -194,7 +194,7 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
|
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
|
||||||
|
|
||||||
int numParams = (int)net.numParams();
|
int numParams = (int)net.numParams();
|
||||||
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
|
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
|
||||||
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
int updaterSize = (int) new RmsProp().stateSize(numParams);
|
||||||
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,7 +97,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(dt, in.dataType());
|
assertEquals(dt, in.dataType());
|
||||||
assertEquals(dt, outExp.dataType());
|
assertEquals(dt, outExp.dataType());
|
||||||
assertEquals(dt, net.params().dataType());
|
assertEquals(dt, net.getModelParams().dataType());
|
||||||
assertEquals(dt, net.getFlattenedGradients().dataType());
|
assertEquals(dt, net.getFlattenedGradients().dataType());
|
||||||
assertEquals(dt, net.getUpdater().getStateViewArray().dataType());
|
assertEquals(dt, net.getUpdater().getStateViewArray().dataType());
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
List<INDArray> activations = net.feedForward(in);
|
List<INDArray> activations = net.feedForward(in);
|
||||||
|
|
||||||
assertEquals(dt, net.getNetConfiguration().getDataType());
|
assertEquals(dt, net.getNetConfiguration().getDataType());
|
||||||
assertEquals(dt, net.params().dataType());
|
assertEquals(dt, net.getModelParams().dataType());
|
||||||
assertEquals( outExp, outAct, dtype);
|
assertEquals( outExp, outAct, dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,7 +116,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(dtype, in.dataType());
|
assertEquals(dtype, in.dataType());
|
||||||
assertEquals(dtype, outExp.dataType());
|
assertEquals(dtype, outExp.dataType());
|
||||||
assertEquals(dtype, net.params().dataType());
|
assertEquals(dtype, net.getModelParams().dataType());
|
||||||
assertEquals(dtype, net.getFlattenedGradients().dataType());
|
assertEquals(dtype, net.getFlattenedGradients().dataType());
|
||||||
assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
|
assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
assertEquals(dtype, outAct.dataType());
|
assertEquals(dtype, outAct.dataType());
|
||||||
|
|
||||||
assertEquals(dtype, net.getNetConfiguration().getDataType());
|
assertEquals(dtype, net.getNetConfiguration().getDataType());
|
||||||
assertEquals(dtype, net.params().dataType());
|
assertEquals(dtype, net.getModelParams().dataType());
|
||||||
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||||
assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct);
|
assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct);
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,7 +98,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(dtype, in.dataType());
|
assertEquals(dtype, in.dataType());
|
||||||
assertEquals(dtype, outExp.dataType());
|
assertEquals(dtype, outExp.dataType());
|
||||||
assertEquals(dtype, net.params().dataType());
|
assertEquals(dtype, net.getModelParams().dataType());
|
||||||
assertEquals(dtype, net.getFlattenedGradients().dataType());
|
assertEquals(dtype, net.getFlattenedGradients().dataType());
|
||||||
assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
|
assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertEquals(dtype, outAct.dataType());
|
assertEquals(dtype, outAct.dataType());
|
||||||
|
|
||||||
assertEquals(dtype, net.getNetConfiguration().getDataType());
|
assertEquals(dtype, net.getNetConfiguration().getDataType());
|
||||||
assertEquals(dtype, net.params().dataType());
|
assertEquals(dtype, net.getModelParams().dataType());
|
||||||
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||||
assertTrue( eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct);
|
assertTrue( eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct);
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,7 +76,7 @@ public class CustomLayer extends FeedForwardLayer {
|
||||||
//For the most part, it's the same for each type of layer
|
//For the most part, it's the same for each type of layer
|
||||||
|
|
||||||
CustomLayerImpl myCustomLayer = new CustomLayerImpl(lconf, networkDataType);
|
CustomLayerImpl myCustomLayer = new CustomLayerImpl(lconf, networkDataType);
|
||||||
myCustomLayer.setListeners(iterationListeners); //Set the iteration listeners, if any
|
myCustomLayer.addTrainingListeners(iterationListeners); //Set the iteration listeners, if any
|
||||||
myCustomLayer.setIndex(layerIndex); //Integer index of the layer
|
myCustomLayer.setIndex(layerIndex); //Integer index of the layer
|
||||||
|
|
||||||
//Parameter view array: In Deeplearning4j, the network parameters for the entire network (all layers) are
|
//Parameter view array: In Deeplearning4j, the network parameters for the entire network (all layers) are
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.regressiontest.customlayer100a;
|
package org.deeplearning4j.regressiontest.customlayer100a;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
@ -56,7 +55,7 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
|
||||||
INDArray firstHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2));
|
INDArray firstHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2));
|
||||||
INDArray secondHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
|
INDArray secondHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
|
||||||
|
|
||||||
IActivation activation1 = layerConf().getActivationFn();
|
IActivation activation1 = getTypedLayerConfiguration().getActivationFn();
|
||||||
IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
|
IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
|
||||||
|
|
||||||
//IActivation function instances modify the activation functions in-place
|
//IActivation function instances modify the activation functions in-place
|
||||||
|
@ -75,7 +74,7 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
|
||||||
@Override
|
@Override
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
||||||
/*
|
/*
|
||||||
The baockprop gradient method here is very similar to the BaseLayer backprop gradient implementation
|
The baockprop gradient method here is very similar to the BaseLayerConfiguration backprop gradient implementation
|
||||||
The only major difference is the two activation functions we have added in this example.
|
The only major difference is the two activation functions we have added in this example.
|
||||||
|
|
||||||
Note that epsilon is dL/da - i.e., the derivative of the loss function with respect to the activations.
|
Note that epsilon is dL/da - i.e., the derivative of the loss function with respect to the activations.
|
||||||
|
@ -105,14 +104,14 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
|
||||||
INDArray epsilonFirstHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2));
|
INDArray epsilonFirstHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2));
|
||||||
INDArray epsilonSecondHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
|
INDArray epsilonSecondHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
|
||||||
|
|
||||||
IActivation activation1 = layerConf().getActivationFn();
|
IActivation activation1 = getTypedLayerConfiguration().getActivationFn();
|
||||||
IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
|
IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
|
||||||
|
|
||||||
//IActivation backprop method modifies the 'firstHalf' and 'secondHalf' arrays in-place, to contain dL/dz
|
//IActivation backprop method modifies the 'firstHalf' and 'secondHalf' arrays in-place, to contain dL/dz
|
||||||
activation1.backprop(firstHalf, epsilonFirstHalf);
|
activation1.backprop(firstHalf, epsilonFirstHalf);
|
||||||
activation2.backprop(secondHalf, epsilonSecondHalf);
|
activation2.backprop(secondHalf, epsilonSecondHalf);
|
||||||
|
|
||||||
//The remaining code for this method: just copy & pasted from BaseLayer.backpropGradient
|
//The remaining code for this method: just copy & pasted from BaseLayerConfiguration.backpropGradient
|
||||||
// INDArray delta = epsilon.muli(activationDerivative);
|
// INDArray delta = epsilon.muli(activationDerivative);
|
||||||
if (maskArray != null) {
|
if (maskArray != null) {
|
||||||
activationDerivative.muliColumnVector(maskArray);
|
activationDerivative.muliColumnVector(maskArray);
|
||||||
|
@ -128,7 +127,7 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
|
||||||
ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad);
|
ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad);
|
||||||
ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad);
|
ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad);
|
||||||
|
|
||||||
INDArray epsilonNext = paramsTable.get(DefaultParamInitializer.WEIGHT_KEY).mmul(activationDerivative.transpose()).transpose();
|
INDArray epsilonNext = getParamTable().get(DefaultParamInitializer.WEIGHT_KEY).mmul(activationDerivative.transpose()).transpose();
|
||||||
|
|
||||||
return new Pair<>(ret, epsilonNext);
|
return new Pair<>(ret, epsilonNext);
|
||||||
}
|
}
|
||||||
|
|
|
@ -190,7 +190,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
||||||
//Check score
|
//Check score
|
||||||
double scoreDl4j = net.score();
|
double scoreDl4j = net.getScore();
|
||||||
double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore();
|
double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore();
|
||||||
assertEquals(scoreDl4j, scoreSd, 1e-6, testName);
|
assertEquals(scoreDl4j, scoreSd, 1e-6, testName);
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
net.addListeners(new ScoreIterationListener(1));
|
net.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
//Test net that hasn't been trained yet
|
//Test net that hasn't been trained yet
|
||||||
Exception e = new Exception();
|
Exception e = new Exception();
|
||||||
|
@ -161,7 +161,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
|
||||||
CrashReportingUtil.crashDumpOutputDirectory(dir);
|
CrashReportingUtil.crashDumpOutputDirectory(dir);
|
||||||
|
|
||||||
ComputationGraph cg = net.toComputationGraph();
|
ComputationGraph cg = net.toComputationGraph();
|
||||||
cg.setListeners(new ScoreIterationListener(1));
|
cg.addTrainingListeners(new ScoreIterationListener(1));
|
||||||
|
|
||||||
//Test net that hasn't been trained yet
|
//Test net that hasn't been trained yet
|
||||||
CrashReportingUtil.writeMemoryCrashDump(cg, e);
|
CrashReportingUtil.writeMemoryCrashDump(cg, e);
|
||||||
|
|
|
@ -156,7 +156,7 @@ public class ModelGuesserTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
|
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
|
||||||
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
|
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
|
||||||
assertEquals(net.params(), network.params());
|
assertEquals(net.getModelParams(), network.getModelParams());
|
||||||
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -173,7 +173,7 @@ public class ModelGuesserTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream);
|
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream);
|
||||||
Assertions.assertNotNull(network);
|
Assertions.assertNotNull(network);
|
||||||
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
|
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
|
||||||
assertEquals(net.params(), network.params());
|
assertEquals(net.getModelParams(), network.getModelParams());
|
||||||
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,7 +81,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile);
|
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile);
|
||||||
|
|
||||||
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
|
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
|
||||||
assertEquals(net.params(), network.params());
|
assertEquals(net.getModelParams(), network.getModelParams());
|
||||||
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis);
|
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis);
|
||||||
|
|
||||||
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
|
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
|
||||||
assertEquals(net.params(), network.params());
|
assertEquals(net.getModelParams(), network.getModelParams());
|
||||||
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,7 +151,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
|
||||||
ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile);
|
ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile);
|
||||||
|
|
||||||
assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
|
assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
|
||||||
assertEquals(cg.params(), network.params());
|
assertEquals(cg.getModelParams(), network.getModelParams());
|
||||||
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
|
||||||
ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);
|
ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);
|
||||||
|
|
||||||
assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
|
assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
|
||||||
assertEquals(cg.params(), network.params());
|
assertEquals(cg.getModelParams(), network.getModelParams());
|
||||||
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -346,7 +346,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
//Also test reading both model and normalizer from stream (correctly)
|
//Also test reading both model and normalizer from stream (correctly)
|
||||||
Pair<MultiLayerNetwork,Normalizer> pair = ModelSerializer.restoreMultiLayerNetworkAndNormalizer(new FileInputStream(tempFile), true);
|
Pair<MultiLayerNetwork,Normalizer> pair = ModelSerializer.restoreMultiLayerNetworkAndNormalizer(new FileInputStream(tempFile), true);
|
||||||
assertEquals(net.params(), pair.getFirst().params());
|
assertEquals(net.getModelParams(), pair.getFirst().getModelParams());
|
||||||
assertNotNull(pair.getSecond());
|
assertNotNull(pair.getSecond());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -395,7 +395,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
//Also test reading both model and normalizer from stream (correctly)
|
//Also test reading both model and normalizer from stream (correctly)
|
||||||
Pair<ComputationGraph,Normalizer> pair = ModelSerializer.restoreComputationGraphAndNormalizer(new FileInputStream(tempFile), true);
|
Pair<ComputationGraph,Normalizer> pair = ModelSerializer.restoreComputationGraphAndNormalizer(new FileInputStream(tempFile), true);
|
||||||
assertEquals(net.params(), pair.getFirst().params());
|
assertEquals(net.getModelParams(), pair.getFirst().getModelParams());
|
||||||
assertNotNull(pair.getSecond());
|
assertNotNull(pair.getSecond());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -496,6 +496,6 @@ public class ModelSerializerTest extends BaseDL4JTest {
|
||||||
assertTrue(entries.contains("otherData.bin"));
|
assertTrue(entries.contains("otherData.bin"));
|
||||||
|
|
||||||
ComputationGraph restoredNet = ModelSerializer.restoreComputationGraph(tempFile);
|
ComputationGraph restoredNet = ModelSerializer.restoreComputationGraph(tempFile);
|
||||||
assertEquals(net.params(), restoredNet.params());
|
assertEquals(net.getModelParams(), restoredNet.getModelParams());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
package org.deeplearning4j.nn.modelimport.keras.layers;
|
package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
|
@ -80,10 +79,6 @@ public class TFOpLayer extends LayerConfiguration {
|
||||||
public void setNIn(InputType inputType, boolean override){}
|
public void setNIn(InputType inputType, boolean override){}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public GradientNormalization getGradientNormalization(){return null;}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
|
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
|
||||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
||||||
|
@ -91,14 +86,11 @@ public class TFOpLayer extends LayerConfiguration {
|
||||||
|
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, lconf, networkDataType);
|
TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, lconf, networkDataType);
|
||||||
tfOpLayerImpl.setListeners(trainingListeners);
|
tfOpLayerImpl.addTrainingListeners(trainingListeners);
|
||||||
tfOpLayerImpl.setIndex(layerIndex);
|
tfOpLayerImpl.setIndex(layerIndex);
|
||||||
return tfOpLayerImpl;
|
return tfOpLayerImpl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getGradientNormalizationThreshold(){return 0.;}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Regularization> getRegularizationByParam(String paramName){return null;}
|
public List<Regularization> getRegularizationByParam(String paramName){return null;}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -448,8 +448,8 @@ public class KerasLSTM extends KerasLayer {
|
||||||
|
|
||||||
|
|
||||||
FeedForwardLayer ffl;
|
FeedForwardLayer ffl;
|
||||||
if(this.layer instanceof BaseWrapperLayer){
|
if(this.layer instanceof BaseWrapperLayerConfiguration){
|
||||||
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
|
BaseWrapperLayerConfiguration bwl = (BaseWrapperLayerConfiguration)this.layer;
|
||||||
ffl = (FeedForwardLayer)bwl.getUnderlying();
|
ffl = (FeedForwardLayer)bwl.getUnderlying();
|
||||||
} else {
|
} else {
|
||||||
ffl = (FeedForwardLayer) this.layer;
|
ffl = (FeedForwardLayer) this.layer;
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue