Playing with some new code 2 - clean build/test
Signed-off-by: brian <brian@brutex.de>
This commit is contained in:
		
							parent
							
								
									a5dfdcb18f
								
							
						
					
					
						commit
						0f21ed9ec5
					
				@ -47,6 +47,7 @@ import org.datavec.image.transform.ShowImageTransform;
 | 
				
			|||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
 | 
					import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
					import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.ActivationLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
@ -54,9 +55,11 @@ import org.deeplearning4j.nn.conf.layers.DropoutLayer;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
					import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
					import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
					import org.deeplearning4j.nn.weights.WeightInit;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
 | 
					import org.deeplearning4j.nn.weights.WeightInitXavier;
 | 
				
			||||||
 | 
					import org.deeplearning4j.optimize.listeners.PerformanceListener;
 | 
				
			||||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
 | 
					import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
 | 
				
			||||||
import org.junit.jupiter.api.Test;
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
import org.nd4j.linalg.activations.Activation;
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
@ -181,6 +184,7 @@ public class App {
 | 
				
			|||||||
        .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
 | 
					        .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
 | 
				
			||||||
        .gradientNormalizationThreshold( 100 )
 | 
					        .gradientNormalizationThreshold( 100 )
 | 
				
			||||||
        //.weightInitFn( new WeightInitXavier() ) //this is internal
 | 
					        //.weightInitFn( new WeightInitXavier() ) //this is internal
 | 
				
			||||||
 | 
					            .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
 | 
				
			||||||
        .weightInit( WeightInit.XAVIER)
 | 
					        .weightInit( WeightInit.XAVIER)
 | 
				
			||||||
        //.activationFn( new ActivationIdentity()) //this is internal
 | 
					        //.activationFn( new ActivationIdentity()) //this is internal
 | 
				
			||||||
        .activation( Activation.IDENTITY )
 | 
					        .activation( Activation.IDENTITY )
 | 
				
			||||||
@ -232,10 +236,10 @@ public class App {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    copyParams(gen, dis, gan);
 | 
					    copyParams(gen, dis, gan);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //gen.setListeners(new PerformanceListener(10, true));
 | 
					    gen.addTrainingListeners(new PerformanceListener(10, true));
 | 
				
			||||||
    //dis.setListeners(new PerformanceListener(10, true));
 | 
					    dis.addTrainingListeners(new PerformanceListener(10, true));
 | 
				
			||||||
    //gan.setListeners(new PerformanceListener(10, true));
 | 
					    gan.addTrainingListeners(new PerformanceListener(10, true));
 | 
				
			||||||
    gan.setListeners(new ScoreToChartListener("gan"));
 | 
					    gan.addTrainingListeners(new ScoreToChartListener("gan"));
 | 
				
			||||||
    //dis.setListeners(new ScoreToChartListener("dis"));
 | 
					    //dis.setListeners(new ScoreToChartListener("dis"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
 | 
					    gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
 | 
				
			||||||
@ -322,23 +326,25 @@ public class App {
 | 
				
			|||||||
    int genLayerCount = gen.getLayers().length;
 | 
					    int genLayerCount = gen.getLayers().length;
 | 
				
			||||||
    for (int i = 0; i < gan.getLayers().length; i++) {
 | 
					    for (int i = 0; i < gan.getLayers().length; i++) {
 | 
				
			||||||
      if (i < genLayerCount) {
 | 
					      if (i < genLayerCount) {
 | 
				
			||||||
        gen.getLayer(i).setParams(gan.getLayer(i).params());
 | 
					        if(gan.getLayer(i).getParams() != null)
 | 
				
			||||||
 | 
					         gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
 | 
					        if(gan.getLayer(i).getParams() != null)
 | 
				
			||||||
 | 
					        dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
 | 
					  private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
 | 
				
			||||||
    for (int i = 0; i < gen.getLayers().length; i++) {
 | 
					    for (int i = 0; i < gen.getLayers().length; i++) {
 | 
				
			||||||
      gen.getLayer(i).setParams(gan.getLayer(i).params());
 | 
					      gen.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
					  private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
 | 
				
			||||||
    int genLayerCount = gen.getLayers().length;
 | 
					    int genLayerCount = gen.getLayers().length;
 | 
				
			||||||
    for (int i = genLayerCount; i < gan.getLayers().length; i++) {
 | 
					    for (int i = genLayerCount; i < gan.getLayers().length; i++) {
 | 
				
			||||||
      gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).params());
 | 
					      gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -115,15 +115,15 @@ public class GAN {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public void setGeneratorListeners(BaseTrainingListener[] listeners) {
 | 
					  public void setGeneratorListeners(BaseTrainingListener[] listeners) {
 | 
				
			||||||
    generator.setListeners(listeners);
 | 
					    generator.addTrainingListeners(listeners);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public void setDiscriminatorListeners(BaseTrainingListener[] listeners) {
 | 
					  public void setDiscriminatorListeners(BaseTrainingListener[] listeners) {
 | 
				
			||||||
    discriminator.setListeners(listeners);
 | 
					    discriminator.addTrainingListeners(listeners);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public void setGanListeners(BaseTrainingListener[] listeners) {
 | 
					  public void setGanListeners(BaseTrainingListener[] listeners) {
 | 
				
			||||||
    gan.setListeners(listeners);
 | 
					    gan.addTrainingListeners(listeners);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public void fit(DataSetIterator realData, int numEpochs) {
 | 
					  public void fit(DataSetIterator realData, int numEpochs) {
 | 
				
			||||||
@ -239,9 +239,9 @@ public class GAN {
 | 
				
			|||||||
    int genLayerCount = generator.getLayers().length;
 | 
					    int genLayerCount = generator.getLayers().length;
 | 
				
			||||||
    for (int i = 0; i < gan.getLayers().length; i++) {
 | 
					    for (int i = 0; i < gan.getLayers().length; i++) {
 | 
				
			||||||
      if (i < genLayerCount) {
 | 
					      if (i < genLayerCount) {
 | 
				
			||||||
        generator.getLayer(i).setParams(gan.getLayer(i).params());
 | 
					        generator.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
 | 
					        discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@ -252,7 +252,7 @@ public class GAN {
 | 
				
			|||||||
   */
 | 
					   */
 | 
				
			||||||
  private void updateGeneratorFromGan() {
 | 
					  private void updateGeneratorFromGan() {
 | 
				
			||||||
    for (int i = 0; i < generator.getLayers().length; i++) {
 | 
					    for (int i = 0; i < generator.getLayers().length; i++) {
 | 
				
			||||||
      generator.getLayer(i).setParams(gan.getLayer(i).params());
 | 
					      generator.getLayer(i).setParams(gan.getLayer(i).getParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -263,7 +263,7 @@ public class GAN {
 | 
				
			|||||||
  private void updateGanWithDiscriminator() {
 | 
					  private void updateGanWithDiscriminator() {
 | 
				
			||||||
    int genLayerCount = generator.getLayers().length;
 | 
					    int genLayerCount = generator.getLayers().length;
 | 
				
			||||||
    for (int i = genLayerCount; i < gan.getLayers().length; i++) {
 | 
					    for (int i = genLayerCount; i < gan.getLayers().length; i++) {
 | 
				
			||||||
      gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).params());
 | 
					      gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).getParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -155,8 +155,8 @@ public class MnistDCGANExample {
 | 
				
			|||||||
        .updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build())
 | 
					        .updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build())
 | 
				
			||||||
        .build();
 | 
					        .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    gan.getGenerator().setListeners(new PerformanceListener(1, true));
 | 
					    gan.getGenerator().addTrainingListeners(new PerformanceListener(1, true));
 | 
				
			||||||
    gan.getDiscriminator().setListeners(new PerformanceListener(1, true));
 | 
					    gan.getDiscriminator().addTrainingListeners(new PerformanceListener(1, true));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
 | 
					    Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -205,7 +205,7 @@ public class TestServer2 {
 | 
				
			|||||||
        //PostgresStatsStorage psqlStore = new PostgresStatsStorage();
 | 
					        //PostgresStatsStorage psqlStore = new PostgresStatsStorage();
 | 
				
			||||||
        int listenerFrequency = 2;
 | 
					        int listenerFrequency = 2;
 | 
				
			||||||
        //net.setListeners(new StatsListener(psqlStore, listenerFrequency), new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
 | 
					        //net.setListeners(new StatsListener(psqlStore, listenerFrequency), new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
 | 
				
			||||||
        net.setListeners(new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
 | 
					        net.addTrainingListeners(new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
 | 
					        //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
 | 
				
			||||||
 | 
				
			|||||||
@ -290,7 +290,7 @@ public class IntegrationTestBaselineGenerator {
 | 
				
			|||||||
                    for (int i : layersToTrain) {
 | 
					                    for (int i : layersToTrain) {
 | 
				
			||||||
                        mln.pretrainLayer(i, dsi);
 | 
					                        mln.pretrainLayer(i, dsi);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    paramsPostTraining = mln.params();
 | 
					                    paramsPostTraining = mln.getModelParams();
 | 
				
			||||||
                } else if (modelType == ModelType.CG) {
 | 
					                } else if (modelType == ModelType.CG) {
 | 
				
			||||||
                    String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
 | 
					                    String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
 | 
				
			||||||
                    Preconditions.checkState(layersToTrain != null, "ILayer names must not be null");
 | 
					                    Preconditions.checkState(layersToTrain != null, "ILayer names must not be null");
 | 
				
			||||||
@ -298,7 +298,7 @@ public class IntegrationTestBaselineGenerator {
 | 
				
			|||||||
                    for (String i : layersToTrain) {
 | 
					                    for (String i : layersToTrain) {
 | 
				
			||||||
                        cg.pretrainLayer(i, iter);
 | 
					                        cg.pretrainLayer(i, iter);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    paramsPostTraining = cg.params();
 | 
					                    paramsPostTraining = cg.getModelParams();
 | 
				
			||||||
                } else {
 | 
					                } else {
 | 
				
			||||||
                    throw new UnsupportedOperationException("SameDiff not supported for unsupervised training tests");
 | 
					                    throw new UnsupportedOperationException("SameDiff not supported for unsupervised training tests");
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
@ -314,7 +314,7 @@ public class IntegrationTestBaselineGenerator {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                CollectScoresListener l = new CollectScoresListener(1);
 | 
					                CollectScoresListener l = new CollectScoresListener(1);
 | 
				
			||||||
                if (modelType != ModelType.SAMEDIFF)
 | 
					                if (modelType != ModelType.SAMEDIFF)
 | 
				
			||||||
                    m.setListeners(l);
 | 
					                    m.addTrainingListeners(l);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                History h = null;
 | 
					                History h = null;
 | 
				
			||||||
                if (modelType == ModelType.MLN) {
 | 
					                if (modelType == ModelType.MLN) {
 | 
				
			||||||
@ -349,7 +349,7 @@ public class IntegrationTestBaselineGenerator {
 | 
				
			|||||||
                        }
 | 
					                        }
 | 
				
			||||||
                    } else {
 | 
					                    } else {
 | 
				
			||||||
                        File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
 | 
					                        File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
 | 
				
			||||||
                        IntegrationTestRunner.write(m.params(), p);
 | 
					                        IntegrationTestRunner.write(m.getModelParams(), p);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
				
			|||||||
@ -191,7 +191,7 @@ public class IntegrationTestRunner {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
 | 
					                MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
 | 
				
			||||||
                assertEquals(loaded.getNetConfiguration(), mln.getNetConfiguration(), "Configs not equal");
 | 
					                assertEquals(loaded.getNetConfiguration(), mln.getNetConfiguration(), "Configs not equal");
 | 
				
			||||||
                assertEquals( loaded.params(), mln.params(), "Params not equal");
 | 
					                assertEquals( loaded.getModelParams(), mln.getModelParams(), "Params not equal");
 | 
				
			||||||
                assertEquals( loaded.getParamTable(), mln.getParamTable(), "Param table not equal");
 | 
					                assertEquals( loaded.getParamTable(), mln.getParamTable(), "Param table not equal");
 | 
				
			||||||
            } else if(config instanceof ComputationGraphConfiguration ){
 | 
					            } else if(config instanceof ComputationGraphConfiguration ){
 | 
				
			||||||
                ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
 | 
					                ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
 | 
				
			||||||
@ -201,7 +201,7 @@ public class IntegrationTestRunner {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                ComputationGraph loaded = ComputationGraph.load(savedModel, true);
 | 
					                ComputationGraph loaded = ComputationGraph.load(savedModel, true);
 | 
				
			||||||
                assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" );
 | 
					                assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" );
 | 
				
			||||||
                assertEquals( loaded.params(), cg.params(), "Params not equal");
 | 
					                assertEquals( loaded.getModelParams(), cg.getModelParams(), "Params not equal");
 | 
				
			||||||
                assertEquals(loaded.getParamTable(), cg.getParamTable(), "Param table not equal");
 | 
					                assertEquals(loaded.getParamTable(), cg.getParamTable(), "Param table not equal");
 | 
				
			||||||
            } else if(config instanceof SameDiff){
 | 
					            } else if(config instanceof SameDiff){
 | 
				
			||||||
                sd = (SameDiff)config;
 | 
					                sd = (SameDiff)config;
 | 
				
			||||||
@ -389,7 +389,7 @@ public class IntegrationTestRunner {
 | 
				
			|||||||
                for( int i : layersToTrain){
 | 
					                for( int i : layersToTrain){
 | 
				
			||||||
                    mln.pretrainLayer(i, dsi);
 | 
					                    mln.pretrainLayer(i, dsi);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                paramsPostTraining = mln.params();
 | 
					                paramsPostTraining = mln.getModelParams();
 | 
				
			||||||
                layers = mln.getLayers();
 | 
					                layers = mln.getLayers();
 | 
				
			||||||
            } else if(modelType == ModelType.CG) {
 | 
					            } else if(modelType == ModelType.CG) {
 | 
				
			||||||
                String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
 | 
					                String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
 | 
				
			||||||
@ -398,7 +398,7 @@ public class IntegrationTestRunner {
 | 
				
			|||||||
                for( String i : layersToTrain){
 | 
					                for( String i : layersToTrain){
 | 
				
			||||||
                    cg.pretrainLayer(i, iter);
 | 
					                    cg.pretrainLayer(i, iter);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                paramsPostTraining = cg.params();
 | 
					                paramsPostTraining = cg.getModelParams();
 | 
				
			||||||
                layers = cg.getLayers();
 | 
					                layers = cg.getLayers();
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
                throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models");
 | 
					                throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models");
 | 
				
			||||||
@ -439,7 +439,7 @@ public class IntegrationTestRunner {
 | 
				
			|||||||
            CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength);
 | 
					            CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength);
 | 
				
			||||||
            CollectScoresListener l = new CollectScoresListener(1);
 | 
					            CollectScoresListener l = new CollectScoresListener(1);
 | 
				
			||||||
            if(modelType != ModelType.SAMEDIFF) {
 | 
					            if(modelType != ModelType.SAMEDIFF) {
 | 
				
			||||||
                m.setListeners(l);
 | 
					                m.addTrainingListeners(l);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            int iterBefore;
 | 
					            int iterBefore;
 | 
				
			||||||
@ -519,10 +519,10 @@ public class IntegrationTestRunner {
 | 
				
			|||||||
                if(modelType != ModelType.SAMEDIFF) {
 | 
					                if(modelType != ModelType.SAMEDIFF) {
 | 
				
			||||||
                    File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
 | 
					                    File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
 | 
				
			||||||
                    INDArray paramsExp = read(p);
 | 
					                    INDArray paramsExp = read(p);
 | 
				
			||||||
                    INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
 | 
					                    INDArray z = exceedsRelError(m.getModelParams(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
 | 
				
			||||||
                    int count = z.sumNumber().intValue();
 | 
					                    int count = z.sumNumber().intValue();
 | 
				
			||||||
                    if (count > 0) {
 | 
					                    if (count > 0) {
 | 
				
			||||||
                        logFailedParams(20, "Parameter", layers, z, paramsExp, m.params());
 | 
					                        logFailedParams(20, "Parameter", layers, z, paramsExp, m.getModelParams());
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    assertEquals( 0, count, "Number of params exceeded max relative error");
 | 
					                    assertEquals( 0, count, "Number of params exceeded max relative error");
 | 
				
			||||||
                } else {
 | 
					                } else {
 | 
				
			||||||
@ -607,12 +607,12 @@ public class IntegrationTestRunner {
 | 
				
			|||||||
                ModelSerializer.writeModel(m, f, true);
 | 
					                ModelSerializer.writeModel(m, f, true);
 | 
				
			||||||
                MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
 | 
					                MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
 | 
				
			||||||
                assertEquals(mln.getNetConfiguration(), restored.getNetConfiguration());
 | 
					                assertEquals(mln.getNetConfiguration(), restored.getNetConfiguration());
 | 
				
			||||||
                assertEquals(mln.params(), restored.params());
 | 
					                assertEquals(mln.getModelParams(), restored.getModelParams());
 | 
				
			||||||
            } else if(modelType == ModelType.CG){
 | 
					            } else if(modelType == ModelType.CG){
 | 
				
			||||||
                ModelSerializer.writeModel(m, f, true);
 | 
					                ModelSerializer.writeModel(m, f, true);
 | 
				
			||||||
                ComputationGraph restored = ComputationGraph.load(f, true);
 | 
					                ComputationGraph restored = ComputationGraph.load(f, true);
 | 
				
			||||||
                assertEquals(cg.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
 | 
					                assertEquals(cg.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
 | 
				
			||||||
                assertEquals(cg.params(), restored.params());
 | 
					                assertEquals(cg.getModelParams(), restored.getModelParams());
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
                sd.save(f, true);
 | 
					                sd.save(f, true);
 | 
				
			||||||
                SameDiff restored = SameDiff.load(f, true);
 | 
					                SameDiff restored = SameDiff.load(f, true);
 | 
				
			||||||
 | 
				
			|||||||
@ -49,7 +49,7 @@ public class TestUtils {
 | 
				
			|||||||
            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
 | 
					            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
 | 
					            assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
 | 
				
			||||||
            assertEquals(net.params(), restored.params());
 | 
					            assertEquals(net.getModelParams(), restored.getModelParams());
 | 
				
			||||||
        } catch (IOException e){
 | 
					        } catch (IOException e){
 | 
				
			||||||
            //Should never happen
 | 
					            //Should never happen
 | 
				
			||||||
            throw new RuntimeException(e);
 | 
					            throw new RuntimeException(e);
 | 
				
			||||||
@ -74,7 +74,7 @@ public class TestUtils {
 | 
				
			|||||||
            restored = ModelSerializer.restoreComputationGraph(bais, true);
 | 
					            restored = ModelSerializer.restoreComputationGraph(bais, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
 | 
					            assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
 | 
				
			||||||
            assertEquals(net.params(), restored.params());
 | 
					            assertEquals(net.getModelParams(), restored.getModelParams());
 | 
				
			||||||
        } catch (IOException e){
 | 
					        } catch (IOException e){
 | 
				
			||||||
            //Should never happen
 | 
					            //Should never happen
 | 
				
			||||||
            throw new RuntimeException(e);
 | 
					            throw new RuntimeException(e);
 | 
				
			||||||
 | 
				
			|||||||
@ -26,7 +26,7 @@ import org.nd4j.common.primitives.Pair;
 | 
				
			|||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
 | 
					import org.nd4j.linalg.activations.BaseActivationFunction;
 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/** The ActivationIdentity activation function, just returns the input as is.
 | 
				
			||||||
 * f(x) = x
 | 
					 * f(x) = x
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
@EqualsAndHashCode(callSuper = false)
 | 
					@EqualsAndHashCode(callSuper = false)
 | 
				
			||||||
 | 
				
			|||||||
@ -195,7 +195,7 @@ public abstract class BaseWorkspaceMgr<T extends Enum<T>> implements WorkspaceMg
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public INDArray validateArrayLocation(@NonNull T arrayType, @NonNull INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) {
 | 
					    public INDArray validateArrayLocation(T arrayType, INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) {
 | 
				
			||||||
        validateConfig(arrayType);
 | 
					        validateConfig(arrayType);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if(scopeOutOfWs.contains(arrayType)){
 | 
					        if(scopeOutOfWs.contains(arrayType)){
 | 
				
			||||||
 | 
				
			|||||||
@ -19,6 +19,7 @@ dependencies {
 | 
				
			|||||||
    testImplementation projects.cavisNative.cavisNativeCommon
 | 
					    testImplementation projects.cavisNative.cavisNativeCommon
 | 
				
			||||||
    testImplementation projects.cavisNd4j.cavisNd4jCommonTests
 | 
					    testImplementation projects.cavisNd4j.cavisNd4jCommonTests
 | 
				
			||||||
    testImplementation projects.cavisDnn.cavisDnnCommonTests
 | 
					    testImplementation projects.cavisDnn.cavisDnnCommonTests
 | 
				
			||||||
 | 
					    testImplementation projects.cavisDnn.cavisDnnNn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    implementation "org.apache.commons:commons-lang3"
 | 
					    implementation "org.apache.commons:commons-lang3"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -116,7 +116,7 @@ public class LayerHelperValidationUtil {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
 | 
					        MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
 | 
				
			||||||
        net2With.init();
 | 
					        net2With.init();
 | 
				
			||||||
        net2With.params().assign(netOrig.params());
 | 
					        net2With.getModelParams().assign(netOrig.getModelParams());
 | 
				
			||||||
        log.info("Removing all except for specified helpers from network copy 2: " + t.getAllowHelpersForClasses());
 | 
					        log.info("Removing all except for specified helpers from network copy 2: " + t.getAllowHelpersForClasses());
 | 
				
			||||||
        removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
 | 
					        removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -124,7 +124,7 @@ public class LayerHelperValidationUtil {
 | 
				
			|||||||
            Preconditions.checkNotNull(t.getFeatures(), "Features are not set (null)");
 | 
					            Preconditions.checkNotNull(t.getFeatures(), "Features are not set (null)");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for (boolean train : new boolean[]{false, true}) {
 | 
					            for (boolean train : new boolean[]{false, true}) {
 | 
				
			||||||
                assertEquals(net1NoHelper.params(), net2With.params());
 | 
					                assertEquals(net1NoHelper.getModelParams(), net2With.getModelParams());
 | 
				
			||||||
                String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: ");
 | 
					                String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: ");
 | 
				
			||||||
                List<INDArray> ff1;
 | 
					                List<INDArray> ff1;
 | 
				
			||||||
                try {
 | 
					                try {
 | 
				
			||||||
@ -180,7 +180,7 @@ public class LayerHelperValidationUtil {
 | 
				
			|||||||
                double maxRE = relError.maxNumber().doubleValue();
 | 
					                double maxRE = relError.maxNumber().doubleValue();
 | 
				
			||||||
                log.info(s + "Output, max relative error: " + maxRE);
 | 
					                log.info(s + "Output, max relative error: " + maxRE);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                assertEquals(net1NoHelper.params(), net2With.params());  //Check that forward pass does not modify params
 | 
					                assertEquals(net1NoHelper.getModelParams(), net2With.getModelParams());  //Check that forward pass does not modify params
 | 
				
			||||||
                assertTrue(maxRE < t.getMaxRelError(), s + "Max RE: " + maxRE);
 | 
					                assertTrue(maxRE < t.getMaxRelError(), s + "Max RE: " + maxRE);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -255,24 +255,24 @@ public class LayerHelperValidationUtil {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
 | 
					            net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
 | 
				
			||||||
            net2With.init();
 | 
					            net2With.init();
 | 
				
			||||||
            net2With.params().assign(netOrig.params());
 | 
					            net2With.getModelParams().assign(netOrig.getModelParams());
 | 
				
			||||||
            log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());
 | 
					            log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());
 | 
				
			||||||
            removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
 | 
					            removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            CollectScoresListener listener = new CollectScoresListener(1);
 | 
					            CollectScoresListener listener = new CollectScoresListener(1);
 | 
				
			||||||
            net2With.setListeners(listener);
 | 
					            net2With.addTrainingListeners(listener);
 | 
				
			||||||
            net2With.fit(t.getData());
 | 
					            net2With.fit(t.getData());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for( int i=0; i<2; i++ ) {
 | 
					            for( int i=0; i<2; i++ ) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
 | 
					                net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
 | 
				
			||||||
                net2With.init();
 | 
					                net2With.init();
 | 
				
			||||||
                net2With.params().assign(netOrig.params());
 | 
					                net2With.getModelParams().assign(netOrig.getModelParams());
 | 
				
			||||||
                log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());
 | 
					                log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());
 | 
				
			||||||
                removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
 | 
					                removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                CollectScoresListener listener2 = new CollectScoresListener(1);
 | 
					                CollectScoresListener listener2 = new CollectScoresListener(1);
 | 
				
			||||||
                net2With.setListeners(listener2);
 | 
					                net2With.addTrainingListeners(listener2);
 | 
				
			||||||
                net2With.fit(t.getData());
 | 
					                net2With.fit(t.getData());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                DoubleArrayList listOrig = listener.getListScore();
 | 
					                DoubleArrayList listOrig = listener.getListScore();
 | 
				
			||||||
 | 
				
			|||||||
@ -25,7 +25,7 @@ import org.deeplearning4j.nn.api.Layer;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
 | 
					import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.RNNFormat;
 | 
					import org.deeplearning4j.nn.conf.RNNFormat;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
					import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
				
			||||||
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
 | 
					import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
 | 
				
			||||||
@ -67,7 +67,7 @@ public class TestUtils {
 | 
				
			|||||||
            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
 | 
					            restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
 | 
					            assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
 | 
				
			||||||
            assertEquals(net.params(), restored.params());
 | 
					            assertEquals(net.getModelParams(), restored.getModelParams());
 | 
				
			||||||
        } catch (IOException e){
 | 
					        } catch (IOException e){
 | 
				
			||||||
            //Should never happen
 | 
					            //Should never happen
 | 
				
			||||||
            throw new RuntimeException(e);
 | 
					            throw new RuntimeException(e);
 | 
				
			||||||
@ -91,7 +91,7 @@ public class TestUtils {
 | 
				
			|||||||
            restored = ModelSerializer.restoreComputationGraph(bais, true);
 | 
					            restored = ModelSerializer.restoreComputationGraph(bais, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
 | 
					            assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
 | 
				
			||||||
            assertEquals(net.params(), restored.params());
 | 
					            assertEquals(net.getModelParams(), restored.getModelParams());
 | 
				
			||||||
        } catch (IOException e){
 | 
					        } catch (IOException e){
 | 
				
			||||||
            //Should never happen
 | 
					            //Should never happen
 | 
				
			||||||
            throw new RuntimeException(e);
 | 
					            throw new RuntimeException(e);
 | 
				
			||||||
@ -205,8 +205,8 @@ public class TestUtils {
 | 
				
			|||||||
        return null;
 | 
					        return null;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static L2Regularization getL2Reg(BaseLayer baseLayer){
 | 
					    public static L2Regularization getL2Reg(BaseLayerConfiguration baseLayerConfiguration){
 | 
				
			||||||
        return getL2Reg(baseLayer.getRegularization());
 | 
					        return getL2Reg(baseLayerConfiguration.getRegularization());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static L2Regularization getL2Reg(List<Regularization> l){
 | 
					    public static L2Regularization getL2Reg(List<Regularization> l){
 | 
				
			||||||
@ -218,7 +218,7 @@ public class TestUtils {
 | 
				
			|||||||
        return null;
 | 
					        return null;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static WeightDecay getWeightDecayReg(BaseLayer bl){
 | 
					    public static WeightDecay getWeightDecayReg(BaseLayerConfiguration bl){
 | 
				
			||||||
        return getWeightDecayReg(bl.getRegularization());
 | 
					        return getWeightDecayReg(bl.getRegularization());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -231,7 +231,7 @@ public class TestUtils {
 | 
				
			|||||||
        return null;
 | 
					        return null;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static double getL1(BaseLayer layer) {
 | 
					    public static double getL1(BaseLayerConfiguration layer) {
 | 
				
			||||||
        List<Regularization> l = layer.getRegularization();
 | 
					        List<Regularization> l = layer.getRegularization();
 | 
				
			||||||
        return getL1(l);
 | 
					        return getL1(l);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -246,7 +246,7 @@ public class TestUtils {
 | 
				
			|||||||
        return l1Reg.getL1().valueAt(0,0);
 | 
					        return l1Reg.getL1().valueAt(0,0);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static double getL2(BaseLayer layer) {
 | 
					    public static double getL2(BaseLayerConfiguration layer) {
 | 
				
			||||||
        List<Regularization> l = layer.getRegularization();
 | 
					        List<Regularization> l = layer.getRegularization();
 | 
				
			||||||
        return getL2(l);
 | 
					        return getL2(l);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -269,7 +269,7 @@ public class TestUtils {
 | 
				
			|||||||
        return getL2(layer.getRegularization());
 | 
					        return getL2(layer.getRegularization());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static double getWeightDecay(BaseLayer layer) {
 | 
					    public static double getWeightDecay(BaseLayerConfiguration layer) {
 | 
				
			||||||
        return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0);
 | 
					        return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -32,7 +32,6 @@ import org.deeplearning4j.eval.Evaluation;
 | 
				
			|||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
					import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
					import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
				
			||||||
@ -183,7 +182,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
 | 
					        MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
 | 
				
			||||||
        model.init();
 | 
					        model.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        model.setListeners(new ScoreIterationListener(listenerFreq));
 | 
					        model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        model.fit(lfw.next());
 | 
					        model.fit(lfw.next());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -247,7 +246,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
 | 
				
			|||||||
        //model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
 | 
					        //model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
 | 
					        CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
 | 
				
			||||||
        model.setListeners(listener);
 | 
					        model.addTrainingListeners(listener);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        model.fit(cifar);
 | 
					        model.fit(cifar);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -226,7 +226,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
					                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
				
			||||||
                        .build();
 | 
					                        .build();
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
 | 
					        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
 | 
				
			||||||
@ -255,7 +255,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
					                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
				
			||||||
                        .build();
 | 
					                        .build();
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
        MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
 | 
					        MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
 | 
				
			||||||
@ -304,7 +304,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
					                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
				
			||||||
                        .build();
 | 
					                        .build();
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
 | 
					        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
 | 
				
			||||||
@ -343,7 +343,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
					                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
				
			||||||
                        .build();
 | 
					                        .build();
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -386,7 +386,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
					                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
				
			||||||
                        .build();
 | 
					                        .build();
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -430,7 +430,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
                                        .build())
 | 
					                                        .build())
 | 
				
			||||||
                        .build();
 | 
					                        .build();
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
        int nSamples = 100;
 | 
					        int nSamples = 100;
 | 
				
			||||||
        //Generate the training data
 | 
					        //Generate the training data
 | 
				
			||||||
        INDArray x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1);
 | 
					        INDArray x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1);
 | 
				
			||||||
@ -473,7 +473,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
					                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
				
			||||||
                        .build();
 | 
					                        .build();
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
        MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
 | 
					        MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
 | 
				
			||||||
@ -496,9 +496,9 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        assertEquals(net.getnLayers(), mln.getnLayers());
 | 
					        assertEquals(net.getnLayers(), mln.getnLayers());
 | 
				
			||||||
        assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo());
 | 
					        assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo());
 | 
				
			||||||
        BaseLayer bl = (BaseLayer) net.getLayerConfiguration();
 | 
					        BaseLayerConfiguration bl = (BaseLayerConfiguration) net.getLayerConfiguration();
 | 
				
			||||||
        assertEquals(bl.getActivationFn().toString(), ((BaseLayer) mln.getLayerConfiguration()).getActivationFn().toString());
 | 
					        assertEquals(bl.getActivationFn().toString(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getActivationFn().toString());
 | 
				
			||||||
        assertEquals(bl.getIUpdater(), ((BaseLayer) mln.getLayerConfiguration()).getIUpdater());
 | 
					        assertEquals(bl.getIUpdater(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getIUpdater());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -511,7 +511,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
					                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
 | 
				
			||||||
                        .build();
 | 
					                        .build();
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
 | 
					        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
 | 
				
			||||||
@ -792,7 +792,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        TestListener tl = new TestListener();
 | 
					        TestListener tl = new TestListener();
 | 
				
			||||||
        net.setListeners(tl);
 | 
					        net.addTrainingListeners(tl);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
 | 
				
			||||||
        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
 | 
					        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
 | 
				
			||||||
 | 
				
			|||||||
@ -84,7 +84,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
 | 
				
			|||||||
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
 | 
					                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
 | 
				
			||||||
                        .setOutputs("0").build();
 | 
					                        .setOutputs("0").build();
 | 
				
			||||||
        ComputationGraph net = new ComputationGraph(conf);
 | 
					        ComputationGraph net = new ComputationGraph(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
        EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
 | 
					        EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
 | 
				
			||||||
@ -128,7 +128,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
 | 
				
			|||||||
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
 | 
					                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
 | 
				
			||||||
                        .setOutputs("0").build();
 | 
					                        .setOutputs("0").build();
 | 
				
			||||||
        ComputationGraph net = new ComputationGraph(conf);
 | 
					        ComputationGraph net = new ComputationGraph(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
        EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
 | 
					        EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
 | 
				
			||||||
@ -165,7 +165,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
 | 
				
			|||||||
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
 | 
					                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
 | 
				
			||||||
                        .setOutputs("0").build();
 | 
					                        .setOutputs("0").build();
 | 
				
			||||||
        ComputationGraph net = new ComputationGraph(conf);
 | 
					        ComputationGraph net = new ComputationGraph(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -207,7 +207,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
 | 
				
			|||||||
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
 | 
					                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
 | 
				
			||||||
                        .setOutputs("0").build();
 | 
					                        .setOutputs("0").build();
 | 
				
			||||||
        ComputationGraph net = new ComputationGraph(conf);
 | 
					        ComputationGraph net = new ComputationGraph(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -241,7 +241,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
 | 
				
			|||||||
                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
 | 
					                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
 | 
				
			||||||
                        .setOutputs("0").build();
 | 
					                        .setOutputs("0").build();
 | 
				
			||||||
        ComputationGraph net = new ComputationGraph(conf);
 | 
					        ComputationGraph net = new ComputationGraph(conf);
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
        EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
 | 
					        EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
 | 
				
			||||||
@ -538,7 +538,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
 | 
				
			|||||||
        ComputationGraph net = new ComputationGraph(conf);
 | 
					        ComputationGraph net = new ComputationGraph(conf);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        TestEarlyStopping.TestListener tl = new TestEarlyStopping.TestListener();
 | 
					        TestEarlyStopping.TestListener tl = new TestEarlyStopping.TestListener();
 | 
				
			||||||
        net.setListeners(tl);
 | 
					        net.addTrainingListeners(tl);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
 | 
				
			||||||
        EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
 | 
					        EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
 | 
				
			||||||
 | 
				
			|||||||
@ -84,7 +84,7 @@ public class EvalTest extends BaseDL4JTest {
 | 
				
			|||||||
        // Instantiate model
 | 
					        // Instantiate model
 | 
				
			||||||
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork model = new MultiLayerNetwork(conf);
 | 
				
			||||||
        model.init();
 | 
					        model.init();
 | 
				
			||||||
        model.addListeners(new ScoreIterationListener(1));
 | 
					        model.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Train-test split
 | 
					        // Train-test split
 | 
				
			||||||
        DataSetIterator iter = new IrisDataSetIterator(150, 150);
 | 
					        DataSetIterator iter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
@ -324,7 +324,7 @@ public class EvalTest extends BaseDL4JTest {
 | 
				
			|||||||
            MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
					            MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
            net2.init();
 | 
					            net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            net2.setParams(net1.params());
 | 
					            net2.setParams(net1.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for(boolean useMask : new boolean[]{false, true}) {
 | 
					            for(boolean useMask : new boolean[]{false, true}) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -405,7 +405,7 @@ public class EvalTest extends BaseDL4JTest {
 | 
				
			|||||||
            ComputationGraph net2 = new ComputationGraph(conf2);
 | 
					            ComputationGraph net2 = new ComputationGraph(conf2);
 | 
				
			||||||
            net2.init();
 | 
					            net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            net2.setParams(net1.params());
 | 
					            net2.setParams(net1.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for (boolean useMask : new boolean[]{false, true}) {
 | 
					            for (boolean useMask : new boolean[]{false, true}) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -492,7 +492,7 @@ public class EvalTest extends BaseDL4JTest {
 | 
				
			|||||||
        DataSetIterator iter = new IrisDataSetIterator(30, 150);
 | 
					        DataSetIterator iter = new IrisDataSetIterator(30, 150);
 | 
				
			||||||
        DataSetIterator iterTest = new IrisDataSetIterator(30, 150);
 | 
					        DataSetIterator iterTest = new IrisDataSetIterator(30, 150);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setListeners(new EvaluativeListener(iterTest, 3));
 | 
					        net.addTrainingListeners(new EvaluativeListener(iterTest, 3));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for( int i=0; i<3; i++ ){
 | 
					        for( int i=0; i<3; i++ ){
 | 
				
			||||||
            net.fit(iter);
 | 
					            net.fit(iter);
 | 
				
			||||||
 | 
				
			|||||||
@ -26,7 +26,6 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
 | 
				
			|||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
					import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
 | 
					import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
					import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
 | 
					import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
@ -219,11 +218,11 @@ public class BNGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                                mln.setInput(ds.getFeatures());
 | 
					                                mln.setInput(ds.getFeatures());
 | 
				
			||||||
                                mln.setLabels(ds.getLabels());
 | 
					                                mln.setLabels(ds.getLabels());
 | 
				
			||||||
                                mln.computeGradientAndScore();
 | 
					                                mln.computeGradientAndScore();
 | 
				
			||||||
                                double scoreBefore = mln.score();
 | 
					                                double scoreBefore = mln.getScore();
 | 
				
			||||||
                                for (int k = 0; k < 20; k++)
 | 
					                                for (int k = 0; k < 20; k++)
 | 
				
			||||||
                                    mln.fit(ds);
 | 
					                                    mln.fit(ds);
 | 
				
			||||||
                                mln.computeGradientAndScore();
 | 
					                                mln.computeGradientAndScore();
 | 
				
			||||||
                                double scoreAfter = mln.score();
 | 
					                                double scoreAfter = mln.getScore();
 | 
				
			||||||
                                //Can't test in 'characteristic mode of operation' if not learning
 | 
					                                //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
                                String msg = name
 | 
					                                String msg = name
 | 
				
			||||||
                                        + " - score did not (sufficiently) decrease during learning - activationFn="
 | 
					                                        + " - score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
@ -323,11 +322,11 @@ public class BNGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                                mln.setInput(ds.getFeatures());
 | 
					                                mln.setInput(ds.getFeatures());
 | 
				
			||||||
                                mln.setLabels(ds.getLabels());
 | 
					                                mln.setLabels(ds.getLabels());
 | 
				
			||||||
                                mln.computeGradientAndScore();
 | 
					                                mln.computeGradientAndScore();
 | 
				
			||||||
                                double scoreBefore = mln.score();
 | 
					                                double scoreBefore = mln.getScore();
 | 
				
			||||||
                                for (int k = 0; k < 10; k++)
 | 
					                                for (int k = 0; k < 10; k++)
 | 
				
			||||||
                                    mln.fit(ds);
 | 
					                                    mln.fit(ds);
 | 
				
			||||||
                                mln.computeGradientAndScore();
 | 
					                                mln.computeGradientAndScore();
 | 
				
			||||||
                                double scoreAfter = mln.score();
 | 
					                                double scoreAfter = mln.getScore();
 | 
				
			||||||
                                //Can't test in 'characteristic mode of operation' if not learning
 | 
					                                //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
                                String msg = name
 | 
					                                String msg = name
 | 
				
			||||||
                                        + " - score did not (sufficiently) decrease during learning - activationFn="
 | 
					                                        + " - score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
@ -554,11 +553,11 @@ public class BNGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                            net.setInput(0, ds.getFeatures());
 | 
					                            net.setInput(0, ds.getFeatures());
 | 
				
			||||||
                            net.setLabels(ds.getLabels());
 | 
					                            net.setLabels(ds.getLabels());
 | 
				
			||||||
                            net.computeGradientAndScore();
 | 
					                            net.computeGradientAndScore();
 | 
				
			||||||
                            double scoreBefore = net.score();
 | 
					                            double scoreBefore = net.getScore();
 | 
				
			||||||
                            for (int k = 0; k < 20; k++)
 | 
					                            for (int k = 0; k < 20; k++)
 | 
				
			||||||
                                net.fit(ds);
 | 
					                                net.fit(ds);
 | 
				
			||||||
                            net.computeGradientAndScore();
 | 
					                            net.computeGradientAndScore();
 | 
				
			||||||
                            double scoreAfter = net.score();
 | 
					                            double scoreAfter = net.getScore();
 | 
				
			||||||
                            //Can't test in 'characteristic mode of operation' if not learning
 | 
					                            //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
                            String msg = name
 | 
					                            String msg = name
 | 
				
			||||||
                                    + " - score did not (sufficiently) decrease during learning - activationFn="
 | 
					                                    + " - score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
 | 
				
			|||||||
@ -27,7 +27,6 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
 | 
					import org.deeplearning4j.nn.conf.CNN2DFormat;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
					import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
					import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.*;
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
@ -120,11 +119,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                        mln.setInput(ds.getFeatures());
 | 
					                        mln.setInput(ds.getFeatures());
 | 
				
			||||||
                        mln.setLabels(ds.getLabels());
 | 
					                        mln.setLabels(ds.getLabels());
 | 
				
			||||||
                        mln.computeGradientAndScore();
 | 
					                        mln.computeGradientAndScore();
 | 
				
			||||||
                        double scoreBefore = mln.score();
 | 
					                        double scoreBefore = mln.getScore();
 | 
				
			||||||
                        for (int j = 0; j < 10; j++)
 | 
					                        for (int j = 0; j < 10; j++)
 | 
				
			||||||
                            mln.fit(ds);
 | 
					                            mln.fit(ds);
 | 
				
			||||||
                        mln.computeGradientAndScore();
 | 
					                        mln.computeGradientAndScore();
 | 
				
			||||||
                        double scoreAfter = mln.score();
 | 
					                        double scoreAfter = mln.getScore();
 | 
				
			||||||
                        //Can't test in 'characteristic mode of operation' if not learning
 | 
					                        //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
                        String msg = name + " - score did not (sufficiently) decrease during learning - activationFn="
 | 
					                        String msg = name + " - score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
                                + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
 | 
					                                + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
 | 
				
			||||||
@ -212,11 +211,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
 | 
				
			|||||||
                mln.setInput(ds.getFeatures());
 | 
					                mln.setInput(ds.getFeatures());
 | 
				
			||||||
                mln.setLabels(ds.getLabels());
 | 
					                mln.setLabels(ds.getLabels());
 | 
				
			||||||
                mln.computeGradientAndScore();
 | 
					                mln.computeGradientAndScore();
 | 
				
			||||||
                double scoreBefore = mln.score();
 | 
					                double scoreBefore = mln.getScore();
 | 
				
			||||||
                for (int j = 0; j < 10; j++)
 | 
					                for (int j = 0; j < 10; j++)
 | 
				
			||||||
                    mln.fit(ds);
 | 
					                    mln.fit(ds);
 | 
				
			||||||
                mln.computeGradientAndScore();
 | 
					                mln.computeGradientAndScore();
 | 
				
			||||||
                double scoreAfter = mln.score();
 | 
					                double scoreAfter = mln.getScore();
 | 
				
			||||||
                //Can't test in 'characteristic mode of operation' if not learning
 | 
					                //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
                String msg = testName
 | 
					                String msg = testName
 | 
				
			||||||
                        + "- score did not (sufficiently) decrease during learning - activationFn="
 | 
					                        + "- score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
 | 
				
			|||||||
@ -105,11 +105,11 @@ public class GradientCheckTests extends BaseDL4JTest {
 | 
				
			|||||||
            mln.setInput(ds.getFeatures());
 | 
					            mln.setInput(ds.getFeatures());
 | 
				
			||||||
            mln.setLabels(ds.getLabels());
 | 
					            mln.setLabels(ds.getLabels());
 | 
				
			||||||
            mln.computeGradientAndScore();
 | 
					            mln.computeGradientAndScore();
 | 
				
			||||||
            double scoreBefore = mln.score();
 | 
					            double scoreBefore = mln.getScore();
 | 
				
			||||||
            for (int j = 0; j < 10; j++)
 | 
					            for (int j = 0; j < 10; j++)
 | 
				
			||||||
                mln.fit(ds);
 | 
					                mln.fit(ds);
 | 
				
			||||||
            mln.computeGradientAndScore();
 | 
					            mln.computeGradientAndScore();
 | 
				
			||||||
            double scoreAfter = mln.score();
 | 
					            double scoreAfter = mln.getScore();
 | 
				
			||||||
            //Can't test in 'characteristic mode of operation' if not learning
 | 
					            //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
            String msg = "testMinibatchApplication() - score did not (sufficiently) decrease during learning - activationFn="
 | 
					            String msg = "testMinibatchApplication() - score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
                    + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
 | 
					                    + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
 | 
				
			||||||
@ -184,11 +184,11 @@ public class GradientCheckTests extends BaseDL4JTest {
 | 
				
			|||||||
                        mln.setInput(ds.getFeatures());
 | 
					                        mln.setInput(ds.getFeatures());
 | 
				
			||||||
                        mln.setLabels(ds.getLabels());
 | 
					                        mln.setLabels(ds.getLabels());
 | 
				
			||||||
                        mln.computeGradientAndScore();
 | 
					                        mln.computeGradientAndScore();
 | 
				
			||||||
                        double scoreBefore = mln.score();
 | 
					                        double scoreBefore = mln.getScore();
 | 
				
			||||||
                        for (int j = 0; j < 10; j++)
 | 
					                        for (int j = 0; j < 10; j++)
 | 
				
			||||||
                            mln.fit(ds);
 | 
					                            mln.fit(ds);
 | 
				
			||||||
                        mln.computeGradientAndScore();
 | 
					                        mln.computeGradientAndScore();
 | 
				
			||||||
                        double scoreAfter = mln.score();
 | 
					                        double scoreAfter = mln.getScore();
 | 
				
			||||||
                        //Can't test in 'characteristic mode of operation' if not learning
 | 
					                        //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
                        String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
 | 
					                        String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
                                        + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
 | 
					                                        + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
 | 
				
			||||||
@ -278,11 +278,11 @@ public class GradientCheckTests extends BaseDL4JTest {
 | 
				
			|||||||
                            mln.setInput(ds.getFeatures());
 | 
					                            mln.setInput(ds.getFeatures());
 | 
				
			||||||
                            mln.setLabels(ds.getLabels());
 | 
					                            mln.setLabels(ds.getLabels());
 | 
				
			||||||
                            mln.computeGradientAndScore();
 | 
					                            mln.computeGradientAndScore();
 | 
				
			||||||
                            double scoreBefore = mln.score();
 | 
					                            double scoreBefore = mln.getScore();
 | 
				
			||||||
                            for (int j = 0; j < 10; j++)
 | 
					                            for (int j = 0; j < 10; j++)
 | 
				
			||||||
                                mln.fit(ds);
 | 
					                                mln.fit(ds);
 | 
				
			||||||
                            mln.computeGradientAndScore();
 | 
					                            mln.computeGradientAndScore();
 | 
				
			||||||
                            double scoreAfter = mln.score();
 | 
					                            double scoreAfter = mln.getScore();
 | 
				
			||||||
                            //Can't test in 'characteristic mode of operation' if not learning
 | 
					                            //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
                            String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
 | 
					                            String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
                                            + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
 | 
					                                            + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
 | 
				
			||||||
@ -452,11 +452,11 @@ public class GradientCheckTests extends BaseDL4JTest {
 | 
				
			|||||||
                            mln.setInput(ds.getFeatures());
 | 
					                            mln.setInput(ds.getFeatures());
 | 
				
			||||||
                            mln.setLabels(ds.getLabels());
 | 
					                            mln.setLabels(ds.getLabels());
 | 
				
			||||||
                            mln.computeGradientAndScore();
 | 
					                            mln.computeGradientAndScore();
 | 
				
			||||||
                            double scoreBefore = mln.score();
 | 
					                            double scoreBefore = mln.getScore();
 | 
				
			||||||
                            for (int j = 0; j < 10; j++)
 | 
					                            for (int j = 0; j < 10; j++)
 | 
				
			||||||
                                mln.fit(ds);
 | 
					                                mln.fit(ds);
 | 
				
			||||||
                            mln.computeGradientAndScore();
 | 
					                            mln.computeGradientAndScore();
 | 
				
			||||||
                            double scoreAfter = mln.score();
 | 
					                            double scoreAfter = mln.getScore();
 | 
				
			||||||
                            //Can't test in 'characteristic mode of operation' if not learning
 | 
					                            //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
                            msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
 | 
					                            msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
                                            + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
 | 
					                                            + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
 | 
				
			||||||
@ -523,13 +523,13 @@ public class GradientCheckTests extends BaseDL4JTest {
 | 
				
			|||||||
            netGraph.setInputs(features);
 | 
					            netGraph.setInputs(features);
 | 
				
			||||||
            netGraph.setLabels(labels);
 | 
					            netGraph.setLabels(labels);
 | 
				
			||||||
            netGraph.computeGradientAndScore();
 | 
					            netGraph.computeGradientAndScore();
 | 
				
			||||||
            double scoreBefore = netGraph.score();
 | 
					            double scoreBefore = netGraph.getScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            String msg;
 | 
					            String msg;
 | 
				
			||||||
            for (int epoch = 0; epoch < 5; epoch++)
 | 
					            for (int epoch = 0; epoch < 5; epoch++)
 | 
				
			||||||
                netGraph.fit(new INDArray[]{features}, new INDArray[]{labels});
 | 
					                netGraph.fit(new INDArray[]{features}, new INDArray[]{labels});
 | 
				
			||||||
            netGraph.computeGradientAndScore();
 | 
					            netGraph.computeGradientAndScore();
 | 
				
			||||||
            double scoreAfter = netGraph.score();
 | 
					            double scoreAfter = netGraph.getScore();
 | 
				
			||||||
            //Can't test in 'characteristic mode of operation' if not learning
 | 
					            //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
            msg = "elementWiseMultiplicationLayerTest() - score did not (sufficiently) decrease during learning - activationFn="
 | 
					            msg = "elementWiseMultiplicationLayerTest() - score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
                    + "Id" + ", lossFn=" + "Cos-sim" + ", outputActivation=" + "Id"
 | 
					                    + "Id" + ", lossFn=" + "Cos-sim" + ", outputActivation=" + "Id"
 | 
				
			||||||
@ -757,11 +757,11 @@ public class GradientCheckTests extends BaseDL4JTest {
 | 
				
			|||||||
                            mln.setInput(ds.getFeatures());
 | 
					                            mln.setInput(ds.getFeatures());
 | 
				
			||||||
                            mln.setLabels(ds.getLabels());
 | 
					                            mln.setLabels(ds.getLabels());
 | 
				
			||||||
                            mln.computeGradientAndScore();
 | 
					                            mln.computeGradientAndScore();
 | 
				
			||||||
                            double scoreBefore = mln.score();
 | 
					                            double scoreBefore = mln.getScore();
 | 
				
			||||||
                            for (int j = 0; j < 10; j++)
 | 
					                            for (int j = 0; j < 10; j++)
 | 
				
			||||||
                                mln.fit(ds);
 | 
					                                mln.fit(ds);
 | 
				
			||||||
                            mln.computeGradientAndScore();
 | 
					                            mln.computeGradientAndScore();
 | 
				
			||||||
                            double scoreAfter = mln.score();
 | 
					                            double scoreAfter = mln.getScore();
 | 
				
			||||||
                            //Can't test in 'characteristic mode of operation' if not learning
 | 
					                            //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
                            String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
 | 
					                            String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
                                    + afn + ", lossFn=" + lf + ", layerNorm=" + layerNorm + ", outputActivation=" + outputActivation
 | 
					                                    + afn + ", lossFn=" + lf + ", layerNorm=" + layerNorm + ", outputActivation=" + outputActivation
 | 
				
			||||||
 | 
				
			|||||||
@ -666,7 +666,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
 | 
				
			|||||||
                    net.init();
 | 
					                    net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    //Check params to avoid test flakiness on small or large params
 | 
					                    //Check params to avoid test flakiness on small or large params
 | 
				
			||||||
                    INDArray params = net.params();
 | 
					                    INDArray params = net.getModelParams();
 | 
				
			||||||
                    for( int x=0; x<params.length(); x++ ){
 | 
					                    for( int x=0; x<params.length(); x++ ){
 | 
				
			||||||
                        while(Math.abs(params.getDouble(x)) < 0.01 || Math.abs(params.getDouble(x)) > 1.5){
 | 
					                        while(Math.abs(params.getDouble(x)) < 0.01 || Math.abs(params.getDouble(x)) > 1.5){
 | 
				
			||||||
                            double d = Nd4j.getRandom().nextDouble();
 | 
					                            double d = Nd4j.getRandom().nextDouble();
 | 
				
			||||||
 | 
				
			|||||||
@ -37,10 +37,9 @@ import org.deeplearning4j.BaseDL4JTest;
 | 
				
			|||||||
import org.deeplearning4j.exception.DL4JInvalidConfigException;
 | 
					import org.deeplearning4j.exception.DL4JInvalidConfigException;
 | 
				
			||||||
import org.deeplearning4j.nn.api.Layer;
 | 
					import org.deeplearning4j.nn.api.Layer;
 | 
				
			||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
					import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
					import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
 | 
				
			||||||
@ -254,8 +253,8 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
 | 
				
			|||||||
    MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
 | 
					    MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
 | 
				
			||||||
    model2.init();
 | 
					    model2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    float[] p1 = model1.params().data().asFloat();
 | 
					    float[] p1 = model1.getModelParams().data().asFloat();
 | 
				
			||||||
    float[] p2 = model2.params().data().asFloat();
 | 
					    float[] p2 = model2.getModelParams().data().asFloat();
 | 
				
			||||||
    System.out.println(Arrays.toString(p1));
 | 
					    System.out.println(Arrays.toString(p1));
 | 
				
			||||||
    System.out.println(Arrays.toString(p2));
 | 
					    System.out.println(Arrays.toString(p2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -266,20 +265,20 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
 | 
				
			|||||||
  public void testTrainingListener() {
 | 
					  public void testTrainingListener() {
 | 
				
			||||||
    MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
 | 
					    MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
 | 
				
			||||||
    model1.init();
 | 
					    model1.init();
 | 
				
			||||||
    model1.addListeners(new ScoreIterationListener(1));
 | 
					    model1.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
 | 
					    MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
 | 
				
			||||||
    model2.addListeners(new ScoreIterationListener(1));
 | 
					    model2.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
    model2.init();
 | 
					    model2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Layer[] l1 = model1.getLayers();
 | 
					    Layer[] l1 = model1.getLayers();
 | 
				
			||||||
      for (int i = 0; i < l1.length; i++) {
 | 
					      for (int i = 0; i < l1.length; i++) {
 | 
				
			||||||
          assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1);
 | 
					          assertTrue(l1[i].getTrainingListeners() != null && l1[i].getTrainingListeners().size() == 1);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Layer[] l2 = model2.getLayers();
 | 
					    Layer[] l2 = model2.getLayers();
 | 
				
			||||||
      for (int i = 0; i < l2.length; i++) {
 | 
					      for (int i = 0; i < l2.length; i++) {
 | 
				
			||||||
          assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1);
 | 
					          assertTrue(l2[i].getTrainingListeners() != null && l2[i].getTrainingListeners().size() == 1);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -384,10 +383,10 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
 | 
				
			|||||||
            .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build())
 | 
					            .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build())
 | 
				
			||||||
        .inputType(InputType.convolutional(28, 28, 1)).build();
 | 
					        .inputType(InputType.convolutional(28, 28, 1)).build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    org.deeplearning4j.nn.conf.layers.BaseLayer l0 = (BaseLayer) conf.getConf(0).getLayer();
 | 
					    BaseLayerConfiguration l0 = (BaseLayerConfiguration) conf.getConf(0).getLayer();
 | 
				
			||||||
    org.deeplearning4j.nn.conf.layers.BaseLayer l1 = (BaseLayer) conf.getConf(1).getLayer();
 | 
					    BaseLayerConfiguration l1 = (BaseLayerConfiguration) conf.getConf(1).getLayer();
 | 
				
			||||||
    org.deeplearning4j.nn.conf.layers.BaseLayer l2 = (BaseLayer) conf.getConf(2).getLayer();
 | 
					    BaseLayerConfiguration l2 = (BaseLayerConfiguration) conf.getConf(2).getLayer();
 | 
				
			||||||
    org.deeplearning4j.nn.conf.layers.BaseLayer l3 = (BaseLayer) conf.getConf(3).getLayer();
 | 
					    BaseLayerConfiguration l3 = (BaseLayerConfiguration) conf.getConf(3).getLayer();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6);
 | 
					    assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6);
 | 
				
			||||||
    assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6);
 | 
					    assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6);
 | 
				
			||||||
 | 
				
			|||||||
@ -25,7 +25,7 @@ import org.deeplearning4j.TestUtils;
 | 
				
			|||||||
import org.deeplearning4j.nn.api.Layer;
 | 
					import org.deeplearning4j.nn.api.Layer;
 | 
				
			||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
					import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
					import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
 | 
					import org.deeplearning4j.nn.conf.layers.BatchNormalization;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
				
			||||||
@ -100,7 +100,7 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
 | 
				
			|||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testClone() {
 | 
					    public void testClone() {
 | 
				
			||||||
        NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true);
 | 
					        NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true);
 | 
				
			||||||
        BaseLayer bl = (BaseLayer) conf.getFlattenedLayerConfigurations().get(0);
 | 
					        BaseLayerConfiguration bl = (BaseLayerConfiguration) conf.getFlattenedLayerConfigurations().get(0);
 | 
				
			||||||
        conf.setStepFunction(new DefaultStepFunction());
 | 
					        conf.setStepFunction(new DefaultStepFunction());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        NeuralNetConfiguration conf2 = conf.clone();
 | 
					        NeuralNetConfiguration conf2 = conf.clone();
 | 
				
			||||||
 | 
				
			|||||||
@ -158,7 +158,7 @@ public class ShiftVertexTest extends BaseDL4JTest {
 | 
				
			|||||||
        cg.setInput(0, input);
 | 
					        cg.setInput(0, input);
 | 
				
			||||||
        cg.setLabel(0, target);
 | 
					        cg.setLabel(0, target);
 | 
				
			||||||
        cg.computeGradientAndScore();
 | 
					        cg.computeGradientAndScore();
 | 
				
			||||||
        double score_dl4j = cg.score();
 | 
					        double score_dl4j = cg.getScore();
 | 
				
			||||||
        Map<String, INDArray> weights = cg.getParamTable();
 | 
					        Map<String, INDArray> weights = cg.getParamTable();
 | 
				
			||||||
        Gradient g = cg.gradient();
 | 
					        Gradient g = cg.gradient();
 | 
				
			||||||
        Map<String, INDArray> gradients = g.gradientForVariable();
 | 
					        Map<String, INDArray> gradients = g.gradientForVariable();
 | 
				
			||||||
 | 
				
			|||||||
@ -72,8 +72,8 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString());
 | 
					        assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getActivationFn().toString());
 | 
				
			||||||
        assertEquals("relu", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString());
 | 
					        assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getActivationFn().toString());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //With
 | 
					        //With
 | 
				
			||||||
        conf = NeuralNetConfiguration.builder().activation(Activation.RELU)
 | 
					        conf = NeuralNetConfiguration.builder().activation(Activation.RELU)
 | 
				
			||||||
@ -83,8 +83,8 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        net = new MultiLayerNetwork(conf);
 | 
					        net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString());
 | 
					        assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getActivationFn().toString());
 | 
				
			||||||
        assertEquals("tanh", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString());
 | 
					        assertEquals("tanh", ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getActivationFn().toString());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -99,11 +99,11 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn());
 | 
					        assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn());
 | 
				
			||||||
        assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn());
 | 
					        assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
 | 
					        assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
 | 
				
			||||||
        assertEquals(1, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
 | 
					        assertEquals(1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //With:
 | 
					        //With:
 | 
				
			||||||
        final Distribution overriddenDistribution = new UniformDistribution(0, 1);
 | 
					        final Distribution overriddenDistribution = new UniformDistribution(0, 1);
 | 
				
			||||||
@ -117,11 +117,11 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        net = new MultiLayerNetwork(conf);
 | 
					        net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn());
 | 
					        assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn());
 | 
				
			||||||
        assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn());
 | 
					        assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
 | 
					        assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
 | 
				
			||||||
        assertEquals(0, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
 | 
					        assertEquals(0, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /*
 | 
					    /*
 | 
				
			||||||
@ -137,8 +137,8 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(0.3, ((BaseLayer) conf.getConf(0).getLayer()).getLearningRate(), 0.0);
 | 
					        assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getLearningRate(), 0.0);
 | 
				
			||||||
        assertEquals(0.3, ((BaseLayer) conf.getConf(1).getLayer()).getLearningRate(), 0.0);
 | 
					        assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getLearningRate(), 0.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //With:
 | 
					        //With:
 | 
				
			||||||
        conf = NeuralNetConfiguration.builder().learningRate(0.3)
 | 
					        conf = NeuralNetConfiguration.builder().learningRate(0.3)
 | 
				
			||||||
@ -148,8 +148,8 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        net = new MultiLayerNetwork(conf);
 | 
					        net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(0.3, ((BaseLayer) conf.getConf(0).getLayer()).getLearningRate(), 0.0);
 | 
					        assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getLearningRate(), 0.0);
 | 
				
			||||||
        assertEquals(0.2, ((BaseLayer) conf.getConf(1).getLayer()).getLearningRate(), 0.0);
 | 
					        assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getLearningRate(), 0.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //L1 and L2 without layerwise override:
 | 
					        //L1 and L2 without layerwise override:
 | 
				
			||||||
        conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2)
 | 
					        conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2)
 | 
				
			||||||
@ -158,10 +158,10 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        net = new MultiLayerNetwork(conf);
 | 
					        net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(0.1, ((BaseLayer) conf.getConf(0).getLayer()).getL1(), 0.0);
 | 
					        assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL1(), 0.0);
 | 
				
			||||||
        assertEquals(0.1, ((BaseLayer) conf.getConf(1).getLayer()).getL1(), 0.0);
 | 
					        assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL1(), 0.0);
 | 
				
			||||||
        assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0);
 | 
					        assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL2(), 0.0);
 | 
				
			||||||
        assertEquals(0.2, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0);
 | 
					        assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL2(), 0.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //L1 and L2 with layerwise override:
 | 
					        //L1 and L2 with layerwise override:
 | 
				
			||||||
        conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2)
 | 
					        conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2)
 | 
				
			||||||
@ -170,10 +170,10 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        net = new MultiLayerNetwork(conf);
 | 
					        net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(0.9, ((BaseLayer) conf.getConf(0).getLayer()).getL1(), 0.0);
 | 
					        assertEquals(0.9, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL1(), 0.0);
 | 
				
			||||||
        assertEquals(0.1, ((BaseLayer) conf.getConf(1).getLayer()).getL1(), 0.0);
 | 
					        assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL1(), 0.0);
 | 
				
			||||||
        assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0);
 | 
					        assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL2(), 0.0);
 | 
				
			||||||
        assertEquals(0.8, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0);
 | 
					        assertEquals(0.8, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL2(), 0.0);
 | 
				
			||||||
    }*/
 | 
					    }*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -213,8 +213,8 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
 | 
					        assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
 | 
				
			||||||
        assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
 | 
					        assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<Integer, Double> testMomentumAfter2 = new HashMap<>();
 | 
					        Map<Integer, Double> testMomentumAfter2 = new HashMap<>();
 | 
				
			||||||
        testMomentumAfter2.put(0, 0.2);
 | 
					        testMomentumAfter2.put(0, 0.2);
 | 
				
			||||||
@ -227,8 +227,8 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        net = new MultiLayerNetwork(conf);
 | 
					        net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
 | 
					        assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
 | 
				
			||||||
        assertEquals(0.2, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
 | 
					        assertEquals(0.2, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -239,10 +239,10 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta);
 | 
					        assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta);
 | 
				
			||||||
        assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
 | 
					        assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
 | 
				
			||||||
        assertEquals(0.5, ((AdaDelta)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0);
 | 
					        assertEquals(0.5, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0);
 | 
				
			||||||
        assertEquals(0.01, ((AdaDelta)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
 | 
					        assertEquals(0.01, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON))
 | 
					        conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON))
 | 
				
			||||||
                        .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build())
 | 
					                        .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build())
 | 
				
			||||||
@ -252,10 +252,10 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        net = new MultiLayerNetwork(conf);
 | 
					        net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp);
 | 
					        assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp);
 | 
				
			||||||
        assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
 | 
					        assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
 | 
				
			||||||
        assertEquals(1.0, ((RmsProp) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0);
 | 
					        assertEquals(1.0, ((RmsProp) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0);
 | 
				
			||||||
        assertEquals(0.5, ((AdaDelta) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
 | 
					        assertEquals(0.5, ((AdaDelta) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -270,10 +270,10 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0);
 | 
					        assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0);
 | 
				
			||||||
        assertEquals(0.6, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0);
 | 
					        assertEquals(0.6, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0);
 | 
				
			||||||
        assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0);
 | 
					        assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0);
 | 
				
			||||||
        assertEquals(0.7, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta2(), 0.0);
 | 
					        assertEquals(0.7, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta2(), 0.0);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -287,13 +287,11 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
                        .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
 | 
					                        .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					        BaseLayerConfiguration bconf = (BaseLayerConfiguration)  conf.getConf(0).getLayer();
 | 
				
			||||||
        assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue,
 | 
					        assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization());
 | 
				
			||||||
                        conf.getConf(0).getLayer().getGradientNormalization());
 | 
					        assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization());
 | 
				
			||||||
        assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue,
 | 
					        assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0);
 | 
				
			||||||
                        conf.getConf(1).getLayer().getGradientNormalization());
 | 
					        assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0);
 | 
				
			||||||
        assertEquals(10, conf.getConf(0).getLayer().getGradientNormalizationThreshold(), 0.0);
 | 
					 | 
				
			||||||
        assertEquals(10, conf.getConf(1).getLayer().getGradientNormalizationThreshold(), 0.0);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //With:
 | 
					        //With:
 | 
				
			||||||
        conf = NeuralNetConfiguration.builder()
 | 
					        conf = NeuralNetConfiguration.builder()
 | 
				
			||||||
@ -308,11 +306,10 @@ public class LayerConfigTest extends BaseDL4JTest {
 | 
				
			|||||||
        net = new MultiLayerNetwork(conf);
 | 
					        net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue,
 | 
					        assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization());
 | 
				
			||||||
                        conf.getConf(0).getLayer().getGradientNormalization());
 | 
					        assertEquals(GradientNormalization.None, bconf.getGradientNormalization());
 | 
				
			||||||
        assertEquals(GradientNormalization.None, conf.getConf(1).getLayer().getGradientNormalization());
 | 
					        assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0);
 | 
				
			||||||
        assertEquals(10, conf.getConf(0).getLayer().getGradientNormalizationThreshold(), 0.0);
 | 
					        assertEquals(2.5, bconf.getGradientNormalizationThreshold(), 0.0);
 | 
				
			||||||
        assertEquals(2.5, conf.getConf(1).getLayer().getGradientNormalizationThreshold(), 0.0);
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -162,12 +162,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        BaseLayer layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration();
 | 
					        BaseLayerConfiguration layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
 | 
				
			||||||
        assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3);
 | 
					        assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3);
 | 
				
			||||||
        assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
 | 
					        assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
 | 
				
			||||||
        assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
 | 
					        assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        BaseLayer layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration();
 | 
					        BaseLayerConfiguration layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
 | 
				
			||||||
        assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3);
 | 
					        assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Adam Updater
 | 
					        // Adam Updater
 | 
				
			||||||
@ -178,11 +178,11 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
 | 
				
			|||||||
        net = new MultiLayerNetwork(conf);
 | 
					        net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration();
 | 
					        layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
 | 
				
			||||||
        assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3);
 | 
					        assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3);
 | 
				
			||||||
        assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
 | 
					        assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration();
 | 
					        layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
 | 
				
			||||||
        assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
 | 
					        assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
 | 
				
			||||||
        assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
 | 
					        assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
 | 
				
			||||||
        assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn());
 | 
					        assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn());
 | 
				
			||||||
@ -196,12 +196,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
 | 
				
			|||||||
        net = new MultiLayerNetwork(conf);
 | 
					        net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration();
 | 
					        layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
 | 
				
			||||||
        assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3);
 | 
					        assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3);
 | 
				
			||||||
        assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
 | 
					        assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
 | 
				
			||||||
        assertNull(TestUtils.getL2Reg(layerConf.getRegularization()));
 | 
					        assertNull(TestUtils.getL2Reg(layerConf.getRegularization()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration();
 | 
					        layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
 | 
				
			||||||
        assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3);
 | 
					        assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -29,7 +29,7 @@ import org.deeplearning4j.nn.api.Layer;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
 | 
					import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
					import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
					import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
				
			||||||
@ -75,9 +75,9 @@ public class TestWeightNoise extends BaseDL4JTest {
 | 
				
			|||||||
            MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					            MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
            net.init();
 | 
					            net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(wn, ((BaseLayer) net.getLayer(0).getLayerConfiguration()).getWeightNoise());
 | 
					            assertEquals(wn, ((BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration()).getWeightNoise());
 | 
				
			||||||
            assertEquals(new DropConnect(0.25), ((BaseLayer) net.getLayer(1).getLayerConfiguration()).getWeightNoise());
 | 
					            assertEquals(new DropConnect(0.25), ((BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration()).getWeightNoise());
 | 
				
			||||||
            assertEquals(wn, ((BaseLayer) net.getLayer(2).getLayerConfiguration()).getWeightNoise());
 | 
					            assertEquals(wn, ((BaseLayerConfiguration) net.getLayer(2).getLayerConfiguration()).getWeightNoise());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            TestUtils.testModelSerialization(net);
 | 
					            TestUtils.testModelSerialization(net);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -95,9 +95,9 @@ public class TestWeightNoise extends BaseDL4JTest {
 | 
				
			|||||||
            ComputationGraph graph = new ComputationGraph(conf2);
 | 
					            ComputationGraph graph = new ComputationGraph(conf2);
 | 
				
			||||||
            graph.init();
 | 
					            graph.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(wn, ((BaseLayer) graph.getLayer(0).getLayerConfiguration()).getWeightNoise());
 | 
					            assertEquals(wn, ((BaseLayerConfiguration) graph.getLayer(0).getLayerConfiguration()).getWeightNoise());
 | 
				
			||||||
            assertEquals(new DropConnect(0.25), ((BaseLayer) graph.getLayer(1).getLayerConfiguration()).getWeightNoise());
 | 
					            assertEquals(new DropConnect(0.25), ((BaseLayerConfiguration) graph.getLayer(1).getLayerConfiguration()).getWeightNoise());
 | 
				
			||||||
            assertEquals(wn, ((BaseLayer) graph.getLayer(2).getLayerConfiguration()).getWeightNoise());
 | 
					            assertEquals(wn, ((BaseLayerConfiguration) graph.getLayer(2).getLayerConfiguration()).getWeightNoise());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            TestUtils.testModelSerialization(graph);
 | 
					            TestUtils.testModelSerialization(graph);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -124,7 +124,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.layers.util.MaskLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.util.MaskLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
 | 
					import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
 | 
					import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
 | 
					import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
 | 
					import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
 | 
				
			||||||
@ -260,8 +260,8 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
        for (NeuralNetConfiguration nnc : conf.getNetConfigurations()) {
 | 
					        for (NeuralNetConfiguration nnc : conf.getNetConfigurations()) {
 | 
				
			||||||
            LayerConfiguration l = nnc.getFlattenedLayerConfigurations().get(0);
 | 
					            LayerConfiguration l = nnc.getFlattenedLayerConfigurations().get(0);
 | 
				
			||||||
            seenLayers.add(l.getClass());
 | 
					            seenLayers.add(l.getClass());
 | 
				
			||||||
            if (l instanceof BaseWrapperLayer) {
 | 
					            if (l instanceof BaseWrapperLayerConfiguration) {
 | 
				
			||||||
                BaseWrapperLayer bwl = (BaseWrapperLayer) l;
 | 
					                BaseWrapperLayerConfiguration bwl = (BaseWrapperLayerConfiguration) l;
 | 
				
			||||||
                seenLayers.add(bwl.getUnderlying().getClass());
 | 
					                seenLayers.add(bwl.getUnderlying().getClass());
 | 
				
			||||||
            } else if (l instanceof Bidirectional) {
 | 
					            } else if (l instanceof Bidirectional) {
 | 
				
			||||||
                seenLayers.add(((Bidirectional) l).getFwd().getClass());
 | 
					                seenLayers.add(((Bidirectional) l).getFwd().getClass());
 | 
				
			||||||
@ -321,17 +321,17 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
            net.setInput(inD);
 | 
					            net.setInput(inD);
 | 
				
			||||||
            net.setLabels(lD);
 | 
					            net.setLabels(lD);
 | 
				
			||||||
            net.computeGradientAndScore();
 | 
					            net.computeGradientAndScore();
 | 
				
			||||||
            double scoreDouble = net.score();
 | 
					            double scoreDouble = net.getScore();
 | 
				
			||||||
            INDArray grads = net.getFlattenedGradients();
 | 
					            INDArray grads = net.getFlattenedGradients();
 | 
				
			||||||
            INDArray u = net.getUpdater().getStateViewArray();
 | 
					            INDArray u = net.getUpdater().getStateViewArray();
 | 
				
			||||||
            assertEquals(DataType.DOUBLE, net.params().dataType());
 | 
					            assertEquals(DataType.DOUBLE, net.getModelParams().dataType());
 | 
				
			||||||
            assertEquals(DataType.DOUBLE, grads.dataType());
 | 
					            assertEquals(DataType.DOUBLE, grads.dataType());
 | 
				
			||||||
            assertEquals(DataType.DOUBLE, u.dataType());
 | 
					            assertEquals(DataType.DOUBLE, u.dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            MultiLayerNetwork netFloat = net.convertDataType(DataType.FLOAT);
 | 
					            MultiLayerNetwork netFloat = net.convertDataType(DataType.FLOAT);
 | 
				
			||||||
            netFloat.initGradientsView();
 | 
					            netFloat.initGradientsView();
 | 
				
			||||||
            assertEquals(DataType.FLOAT, netFloat.params().dataType());
 | 
					            assertEquals(DataType.FLOAT, netFloat.getModelParams().dataType());
 | 
				
			||||||
            assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType());
 | 
					            assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType());
 | 
				
			||||||
            assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType());
 | 
					            assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType());
 | 
				
			||||||
            INDArray inF = inD.castTo(DataType.FLOAT);
 | 
					            INDArray inF = inD.castTo(DataType.FLOAT);
 | 
				
			||||||
@ -340,7 +340,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
            netFloat.setInput(inF);
 | 
					            netFloat.setInput(inF);
 | 
				
			||||||
            netFloat.setLabels(lF);
 | 
					            netFloat.setLabels(lF);
 | 
				
			||||||
            netFloat.computeGradientAndScore();
 | 
					            netFloat.computeGradientAndScore();
 | 
				
			||||||
            double scoreFloat = netFloat.score();
 | 
					            double scoreFloat = netFloat.getScore();
 | 
				
			||||||
            INDArray gradsFloat = netFloat.getFlattenedGradients();
 | 
					            INDArray gradsFloat = netFloat.getFlattenedGradients();
 | 
				
			||||||
            INDArray uFloat = netFloat.getUpdater().getStateViewArray();
 | 
					            INDArray uFloat = netFloat.getUpdater().getStateViewArray();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -352,7 +352,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            MultiLayerNetwork netFP16 = net.convertDataType(DataType.HALF);
 | 
					            MultiLayerNetwork netFP16 = net.convertDataType(DataType.HALF);
 | 
				
			||||||
            netFP16.initGradientsView();
 | 
					            netFP16.initGradientsView();
 | 
				
			||||||
            assertEquals(DataType.HALF, netFP16.params().dataType());
 | 
					            assertEquals(DataType.HALF, netFP16.getModelParams().dataType());
 | 
				
			||||||
            assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType());
 | 
					            assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType());
 | 
				
			||||||
            assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType());
 | 
					            assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -362,7 +362,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
            netFP16.setInput(inH);
 | 
					            netFP16.setInput(inH);
 | 
				
			||||||
            netFP16.setLabels(lH);
 | 
					            netFP16.setLabels(lH);
 | 
				
			||||||
            netFP16.computeGradientAndScore();
 | 
					            netFP16.computeGradientAndScore();
 | 
				
			||||||
            double scoreHalf = netFP16.score();
 | 
					            double scoreHalf = netFP16.getScore();
 | 
				
			||||||
            INDArray gradsHalf = netFP16.getFlattenedGradients();
 | 
					            INDArray gradsHalf = netFP16.getFlattenedGradients();
 | 
				
			||||||
            INDArray uHalf = netFP16.getUpdater().getStateViewArray();
 | 
					            INDArray uHalf = netFP16.getUpdater().getStateViewArray();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -406,17 +406,17 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
            net.setInput(0, inD);
 | 
					            net.setInput(0, inD);
 | 
				
			||||||
            net.setLabels(lD);
 | 
					            net.setLabels(lD);
 | 
				
			||||||
            net.computeGradientAndScore();
 | 
					            net.computeGradientAndScore();
 | 
				
			||||||
            double scoreDouble = net.score();
 | 
					            double scoreDouble = net.getScore();
 | 
				
			||||||
            INDArray grads = net.getFlattenedGradients();
 | 
					            INDArray grads = net.getFlattenedGradients();
 | 
				
			||||||
            INDArray u = net.getUpdater().getStateViewArray();
 | 
					            INDArray u = net.getUpdater().getStateViewArray();
 | 
				
			||||||
            assertEquals(DataType.DOUBLE, net.params().dataType());
 | 
					            assertEquals(DataType.DOUBLE, net.getModelParams().dataType());
 | 
				
			||||||
            assertEquals(DataType.DOUBLE, grads.dataType());
 | 
					            assertEquals(DataType.DOUBLE, grads.dataType());
 | 
				
			||||||
            assertEquals(DataType.DOUBLE, u.dataType());
 | 
					            assertEquals(DataType.DOUBLE, u.dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            ComputationGraph netFloat = net.convertDataType(DataType.FLOAT);
 | 
					            ComputationGraph netFloat = net.convertDataType(DataType.FLOAT);
 | 
				
			||||||
            netFloat.initGradientsView();
 | 
					            netFloat.initGradientsView();
 | 
				
			||||||
            assertEquals(DataType.FLOAT, netFloat.params().dataType());
 | 
					            assertEquals(DataType.FLOAT, netFloat.getModelParams().dataType());
 | 
				
			||||||
            assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType());
 | 
					            assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType());
 | 
				
			||||||
            assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType());
 | 
					            assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType());
 | 
				
			||||||
            INDArray inF = inD.castTo(DataType.FLOAT);
 | 
					            INDArray inF = inD.castTo(DataType.FLOAT);
 | 
				
			||||||
@ -425,7 +425,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
            netFloat.setInput(0, inF);
 | 
					            netFloat.setInput(0, inF);
 | 
				
			||||||
            netFloat.setLabels(lF);
 | 
					            netFloat.setLabels(lF);
 | 
				
			||||||
            netFloat.computeGradientAndScore();
 | 
					            netFloat.computeGradientAndScore();
 | 
				
			||||||
            double scoreFloat = netFloat.score();
 | 
					            double scoreFloat = netFloat.getScore();
 | 
				
			||||||
            INDArray gradsFloat = netFloat.getFlattenedGradients();
 | 
					            INDArray gradsFloat = netFloat.getFlattenedGradients();
 | 
				
			||||||
            INDArray uFloat = netFloat.getUpdater().getStateViewArray();
 | 
					            INDArray uFloat = netFloat.getUpdater().getStateViewArray();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -437,7 +437,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            ComputationGraph netFP16 = net.convertDataType(DataType.HALF);
 | 
					            ComputationGraph netFP16 = net.convertDataType(DataType.HALF);
 | 
				
			||||||
            netFP16.initGradientsView();
 | 
					            netFP16.initGradientsView();
 | 
				
			||||||
            assertEquals(DataType.HALF, netFP16.params().dataType());
 | 
					            assertEquals(DataType.HALF, netFP16.getModelParams().dataType());
 | 
				
			||||||
            assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType());
 | 
					            assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType());
 | 
				
			||||||
            assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType());
 | 
					            assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -447,7 +447,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
            netFP16.setInput(0, inH);
 | 
					            netFP16.setInput(0, inH);
 | 
				
			||||||
            netFP16.setLabels(lH);
 | 
					            netFP16.setLabels(lH);
 | 
				
			||||||
            netFP16.computeGradientAndScore();
 | 
					            netFP16.computeGradientAndScore();
 | 
				
			||||||
            double scoreHalf = netFP16.score();
 | 
					            double scoreHalf = netFP16.getScore();
 | 
				
			||||||
            INDArray gradsHalf = netFP16.getFlattenedGradients();
 | 
					            INDArray gradsHalf = netFP16.getFlattenedGradients();
 | 
				
			||||||
            INDArray uHalf = netFP16.getUpdater().getStateViewArray();
 | 
					            INDArray uHalf = netFP16.getUpdater().getStateViewArray();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -536,7 +536,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
                    net.init();
 | 
					                    net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    net.initGradientsView();
 | 
					                    net.initGradientsView();
 | 
				
			||||||
                    assertEquals(networkDtype, net.params().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getModelParams().dataType(), msg);
 | 
				
			||||||
                    assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
 | 
				
			||||||
                    assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -641,7 +641,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
                    net.init();
 | 
					                    net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    net.initGradientsView();
 | 
					                    net.initGradientsView();
 | 
				
			||||||
                    assertEquals(networkDtype, net.params().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getModelParams().dataType(), msg);
 | 
				
			||||||
                    assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
 | 
				
			||||||
                    assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -754,7 +754,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
                    net.init();
 | 
					                    net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    net.initGradientsView();
 | 
					                    net.initGradientsView();
 | 
				
			||||||
                    assertEquals(networkDtype, net.params().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getModelParams().dataType(), msg);
 | 
				
			||||||
                    assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
 | 
				
			||||||
                    assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -827,7 +827,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
                net.init();
 | 
					                net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                net.initGradientsView();
 | 
					                net.initGradientsView();
 | 
				
			||||||
                assertEquals(networkDtype, net.params().dataType(), msg);
 | 
					                assertEquals(networkDtype, net.getModelParams().dataType(), msg);
 | 
				
			||||||
                assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
 | 
					                assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
 | 
				
			||||||
                assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
 | 
					                assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -916,7 +916,7 @@ public class DTypeTests extends BaseDL4JTest {
 | 
				
			|||||||
                    net.init();
 | 
					                    net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    net.initGradientsView();
 | 
					                    net.initGradientsView();
 | 
				
			||||||
                    assertEquals(networkDtype, net.params().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getModelParams().dataType(), msg);
 | 
				
			||||||
                    assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
 | 
				
			||||||
                    assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
 | 
					                    assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -520,9 +520,9 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
 | 
				
			|||||||
        INDArray inputLong = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
 | 
					        INDArray inputLong = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
 | 
				
			||||||
        INDArray labelsLong = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength);
 | 
					        INDArray labelsLong = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray initialParams = graph.params().dup();
 | 
					        INDArray initialParams = graph.getModelParams().dup();
 | 
				
			||||||
        graph.fit(new INDArray[] {inputLong}, new INDArray[] {labelsLong});
 | 
					        graph.fit(new INDArray[] {inputLong}, new INDArray[] {labelsLong});
 | 
				
			||||||
        INDArray afterParams = graph.params();
 | 
					        INDArray afterParams = graph.getModelParams();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertNotEquals(initialParams, afterParams);
 | 
					        assertNotEquals(initialParams, afterParams);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -117,7 +117,7 @@ public class TestCompGraphCNN extends BaseDL4JTest {
 | 
				
			|||||||
        boolean orderOK = Arrays.equals(expOrder1, order) || Arrays.equals(expOrder2, order);
 | 
					        boolean orderOK = Arrays.equals(expOrder1, order) || Arrays.equals(expOrder2, order);
 | 
				
			||||||
        assertTrue(orderOK);
 | 
					        assertTrue(orderOK);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray params = graph.params();
 | 
					        INDArray params = graph.getModelParams();
 | 
				
			||||||
        assertNotNull(params);
 | 
					        assertNotNull(params);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // confirm param shape is what is expected
 | 
					        // confirm param shape is what is expected
 | 
				
			||||||
@ -129,7 +129,7 @@ public class TestCompGraphCNN extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        // params are set
 | 
					        // params are set
 | 
				
			||||||
        graph.setParams(arr);
 | 
					        graph.setParams(arr);
 | 
				
			||||||
        params = graph.params();
 | 
					        params = graph.getModelParams();
 | 
				
			||||||
        assertEquals(arr, params);
 | 
					        assertEquals(arr, params);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Number of inputs and outputs:
 | 
					        //Number of inputs and outputs:
 | 
				
			||||||
 | 
				
			|||||||
@ -108,7 +108,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            int count = Nd4j.getExecutioner().exec(new MatchCondition(cg.params(), Conditions.isNan())).getInt(0);
 | 
					            int count = Nd4j.getExecutioner().exec(new MatchCondition(cg.getModelParams(), Conditions.isNan())).getInt(0);
 | 
				
			||||||
            assertEquals(0, count);
 | 
					            assertEquals(0, count);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -125,7 +125,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            count = Nd4j.getExecutioner().exec(new MatchCondition(cg.params(), Conditions.isNan())).getInt(0);
 | 
					            count = Nd4j.getExecutioner().exec(new MatchCondition(cg.getModelParams(), Conditions.isNan())).getInt(0);
 | 
				
			||||||
            assertEquals(0, count);
 | 
					            assertEquals(0, count);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -176,7 +176,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
 | 
				
			|||||||
            Nd4j.getRandom().setSeed(12345);
 | 
					            Nd4j.getRandom().setSeed(12345);
 | 
				
			||||||
            cg.pretrainLayer("0", ds);
 | 
					            cg.pretrainLayer("0", ds);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(net.params(), cg.params());
 | 
					            assertEquals(net.getModelParams(), cg.getModelParams());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -159,7 +159,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
 | 
				
			|||||||
        DataSet ds = iris.next();
 | 
					        DataSet ds = iris.next();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        graph.setInput(0, ds.getFeatures());
 | 
					        graph.setInput(0, ds.getFeatures());
 | 
				
			||||||
        net.setParams(graph.params());
 | 
					        net.setParams(graph.getModelParams());
 | 
				
			||||||
        Map<String, INDArray> activations = graph.feedForward(false);
 | 
					        Map<String, INDArray> activations = graph.feedForward(false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        List<INDArray> feedForward = net.feedForward(ds.getFeatures());
 | 
					        List<INDArray> feedForward = net.feedForward(ds.getFeatures());
 | 
				
			||||||
@ -184,7 +184,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
 | 
				
			|||||||
        int[] expOrder = new int[]{0, 1, 2};
 | 
					        int[] expOrder = new int[]{0, 1, 2};
 | 
				
			||||||
        assertArrayEquals(expOrder, order); //Only one valid order: 0 (input) -> 1 (firstlayer) -> 2 (outputlayer)
 | 
					        assertArrayEquals(expOrder, order); //Only one valid order: 0 (input) -> 1 (firstlayer) -> 2 (outputlayer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray params = graph.params();
 | 
					        INDArray params = graph.getModelParams();
 | 
				
			||||||
        assertNotNull(params);
 | 
					        assertNotNull(params);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int nParams = getNumParams();
 | 
					        int nParams = getNumParams();
 | 
				
			||||||
@ -194,7 +194,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(nParams, arr.length());
 | 
					        assertEquals(nParams, arr.length());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        graph.setParams(arr);
 | 
					        graph.setParams(arr);
 | 
				
			||||||
        params = graph.params();
 | 
					        params = graph.getModelParams();
 | 
				
			||||||
        assertEquals(arr, params);
 | 
					        assertEquals(arr, params);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Number of inputs and outputs:
 | 
					        //Number of inputs and outputs:
 | 
				
			||||||
@ -315,8 +315,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
 | 
				
			|||||||
        graph.fit(iris);
 | 
					        graph.fit(iris);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Check that parameters are equal for both models after fitting:
 | 
					        //Check that parameters are equal for both models after fitting:
 | 
				
			||||||
        INDArray paramsMLN = net.params();
 | 
					        INDArray paramsMLN = net.getModelParams();
 | 
				
			||||||
        INDArray paramsGraph = graph.params();
 | 
					        INDArray paramsGraph = graph.getModelParams();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertNotEquals(params, paramsGraph);
 | 
					        assertNotEquals(params, paramsGraph);
 | 
				
			||||||
        assertEquals(paramsMLN, paramsGraph);
 | 
					        assertEquals(paramsMLN, paramsGraph);
 | 
				
			||||||
@ -636,7 +636,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        ComputationGraph net = new ComputationGraph(conf);
 | 
					        ComputationGraph net = new ComputationGraph(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator iter = new IrisDataSetIterator(10, 150);
 | 
					        DataSetIterator iter = new IrisDataSetIterator(10, 150);
 | 
				
			||||||
        net.pretrain(iter);
 | 
					        net.pretrain(iter);
 | 
				
			||||||
@ -675,7 +675,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        ComputationGraph netNoReg = new ComputationGraph(confNoReg);
 | 
					        ComputationGraph netNoReg = new ComputationGraph(confNoReg);
 | 
				
			||||||
        netNoReg.init();
 | 
					        netNoReg.init();
 | 
				
			||||||
        netNoReg.setParams(net.params().dup());
 | 
					        netNoReg.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Score single example, and compare to scoreExamples:
 | 
					        //Score single example, and compare to scoreExamples:
 | 
				
			||||||
        INDArray input = Nd4j.rand(3, nIn);
 | 
					        INDArray input = Nd4j.rand(3, nIn);
 | 
				
			||||||
@ -878,13 +878,13 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
 | 
				
			|||||||
        net.setParam("first_b", Nd4j.ones(1, 5));
 | 
					        net.setParam("first_b", Nd4j.ones(1, 5));
 | 
				
			||||||
        net.setParam("output_W", Nd4j.ones(5, 3));
 | 
					        net.setParam("output_W", Nd4j.ones(5, 3));
 | 
				
			||||||
        net.setParam("output_b", Nd4j.ones(1, 3));
 | 
					        net.setParam("output_b", Nd4j.ones(1, 3));
 | 
				
			||||||
        INDArray actualParams = net.params();
 | 
					        INDArray actualParams = net.getModelParams();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Confirm params
 | 
					        // Confirm params
 | 
				
			||||||
        assertEquals(Nd4j.ones(1, 43), actualParams);
 | 
					        assertEquals(Nd4j.ones(1, 43), actualParams);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.update(expectedGradient);
 | 
					        net.update(expectedGradient);
 | 
				
			||||||
        actualParams = net.params();
 | 
					        actualParams = net.getModelParams();
 | 
				
			||||||
        assertEquals(Nd4j.ones(1, 43).addi(1), actualParams);
 | 
					        assertEquals(Nd4j.ones(1, 43).addi(1), actualParams);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1638,7 +1638,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
 | 
				
			|||||||
        conf3.setTopologicalOrderStr(null);
 | 
					        conf3.setTopologicalOrderStr(null);
 | 
				
			||||||
        ComputationGraph cg3 = new ComputationGraph(conf3);
 | 
					        ComputationGraph cg3 = new ComputationGraph(conf3);
 | 
				
			||||||
        cg3.init();
 | 
					        cg3.init();
 | 
				
			||||||
        cg3.setParams(cg2.params());
 | 
					        cg3.setParams(cg2.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int[] order3 = cg3.topologicalSortOrder();
 | 
					        int[] order3 = cg3.topologicalSortOrder();
 | 
				
			||||||
        List<String> strOrder3 = cg.getComputationGraphConfiguration().getTopologicalOrderStr();
 | 
					        List<String> strOrder3 = cg.getComputationGraphConfiguration().getTopologicalOrderStr();
 | 
				
			||||||
@ -1712,7 +1712,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
 | 
				
			|||||||
        exp.add(ComputationGraph.class);
 | 
					        exp.add(ComputationGraph.class);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        MultiLayerTest.CheckModelsListener listener = new MultiLayerTest.CheckModelsListener();
 | 
					        MultiLayerTest.CheckModelsListener listener = new MultiLayerTest.CheckModelsListener();
 | 
				
			||||||
        net.setListeners(listener);
 | 
					        net.addTrainingListeners(listener);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray f = Nd4j.create(1,10);
 | 
					        INDArray f = Nd4j.create(1,10);
 | 
				
			||||||
        INDArray l = Nd4j.create(1,10);
 | 
					        INDArray l = Nd4j.create(1,10);
 | 
				
			||||||
@ -1874,7 +1874,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        ComputationGraph cg = new ComputationGraph(conf);
 | 
					        ComputationGraph cg = new ComputationGraph(conf);
 | 
				
			||||||
        cg.init();
 | 
					        cg.init();
 | 
				
			||||||
        cg.params().assign(Nd4j.linspace(1, 220, 220).reshape(1, -11));
 | 
					        cg.getModelParams().assign(Nd4j.linspace(1, 220, 220).reshape(1, -11));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray p0w = cg.getParam("layer_zero_W");
 | 
					        INDArray p0w = cg.getParam("layer_zero_W");
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, 100, 100).reshape('f', 10, 10), p0w);
 | 
					        assertEquals(Nd4j.linspace(1, 100, 100).reshape('f', 10, 10), p0w);
 | 
				
			||||||
 | 
				
			|||||||
@ -56,7 +56,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        ComputationGraph net = new ComputationGraph(conf);
 | 
					        ComputationGraph net = new ComputationGraph(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        INDArray params = net.params();
 | 
					        INDArray params = net.getModelParams();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ComputationGraph net2 = new ComputationGraph(conf);
 | 
					        ComputationGraph net2 = new ComputationGraph(conf);
 | 
				
			||||||
@ -65,11 +65,11 @@ public class TestSetGetParameters extends BaseDL4JTest {
 | 
				
			|||||||
        ComputationGraph net3 = new ComputationGraph(conf);
 | 
					        ComputationGraph net3 = new ComputationGraph(conf);
 | 
				
			||||||
        net3.init(params, false);
 | 
					        net3.init(params, false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(params, net2.params());
 | 
					        assertEquals(params, net2.getModelParams());
 | 
				
			||||||
        assertEquals(params, net3.params());
 | 
					        assertEquals(params, net3.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertNotSame(params, net2.params()); //Different objects due to clone
 | 
					        assertNotSame(params, net2.getModelParams()); //Different objects due to clone
 | 
				
			||||||
        assertSame(params, net3.params()); //Same object due to clone
 | 
					        assertSame(params, net3.getModelParams()); //Same object due to clone
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String, INDArray> paramsMap = net.getParamTable();
 | 
					        Map<String, INDArray> paramsMap = net.getParamTable();
 | 
				
			||||||
 | 
				
			|||||||
@ -103,14 +103,14 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
 | 
				
			|||||||
            net.setInput(0, in1);
 | 
					            net.setInput(0, in1);
 | 
				
			||||||
            net.setLabel(0, labels1);
 | 
					            net.setLabel(0, labels1);
 | 
				
			||||||
            net.computeGradientAndScore();
 | 
					            net.computeGradientAndScore();
 | 
				
			||||||
            double score1 = net.score();
 | 
					            double score1 = net.getScore();
 | 
				
			||||||
            Gradient g1 = net.gradient();
 | 
					            Gradient g1 = net.gradient();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            net.setInput(0, in2);
 | 
					            net.setInput(0, in2);
 | 
				
			||||||
            net.setLabel(0, labels2);
 | 
					            net.setLabel(0, labels2);
 | 
				
			||||||
            net.setLayerMaskArrays(null, new INDArray[] {labelMask});
 | 
					            net.setLayerMaskArrays(null, new INDArray[] {labelMask});
 | 
				
			||||||
            net.computeGradientAndScore();
 | 
					            net.computeGradientAndScore();
 | 
				
			||||||
            double score2 = net.score();
 | 
					            double score2 = net.getScore();
 | 
				
			||||||
            Gradient g2 = net.gradient();
 | 
					            Gradient g2 = net.gradient();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //Scores and gradients should be identical for two cases (given mask array)
 | 
					            //Scores and gradients should be identical for two cases (given mask array)
 | 
				
			||||||
@ -134,7 +134,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
                net.setLabel(0, labels2);
 | 
					                net.setLabel(0, labels2);
 | 
				
			||||||
                net.computeGradientAndScore();
 | 
					                net.computeGradientAndScore();
 | 
				
			||||||
                double score2a = net.score();
 | 
					                double score2a = net.getScore();
 | 
				
			||||||
                Gradient g2a = net.gradient();
 | 
					                Gradient g2a = net.gradient();
 | 
				
			||||||
                assertEquals(score2, score2a, 1e-6);
 | 
					                assertEquals(score2, score2a, 1e-6);
 | 
				
			||||||
                for (String s : g2map.keySet()) {
 | 
					                for (String s : g2map.keySet()) {
 | 
				
			||||||
@ -200,7 +200,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
 | 
				
			|||||||
            net.setInput(0, in1);
 | 
					            net.setInput(0, in1);
 | 
				
			||||||
            net.setLabel(0, labels1);
 | 
					            net.setLabel(0, labels1);
 | 
				
			||||||
            net.computeGradientAndScore();
 | 
					            net.computeGradientAndScore();
 | 
				
			||||||
            double score1 = net.score();
 | 
					            double score1 = net.getScore();
 | 
				
			||||||
            Gradient g1 = net.gradient();
 | 
					            Gradient g1 = net.gradient();
 | 
				
			||||||
            Map<String, INDArray> map = g1.gradientForVariable();
 | 
					            Map<String, INDArray> map = g1.gradientForVariable();
 | 
				
			||||||
            for (String s : map.keySet()) {
 | 
					            for (String s : map.keySet()) {
 | 
				
			||||||
@ -211,7 +211,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
 | 
				
			|||||||
            net.setLabel(0, labels2);
 | 
					            net.setLabel(0, labels2);
 | 
				
			||||||
            net.setLayerMaskArrays(new INDArray[] {inputMask}, null);
 | 
					            net.setLayerMaskArrays(new INDArray[] {inputMask}, null);
 | 
				
			||||||
            net.computeGradientAndScore();
 | 
					            net.computeGradientAndScore();
 | 
				
			||||||
            double score2 = net.score();
 | 
					            double score2 = net.getScore();
 | 
				
			||||||
            Gradient g2 = net.gradient();
 | 
					            Gradient g2 = net.gradient();
 | 
				
			||||||
            Map<String, INDArray> activations2 = net.feedForward();
 | 
					            Map<String, INDArray> activations2 = net.feedForward();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -236,7 +236,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
 | 
				
			|||||||
                net.setInput(0, in2);
 | 
					                net.setInput(0, in2);
 | 
				
			||||||
                net.setLayerMaskArrays(new INDArray[]{inputMask}, null);
 | 
					                net.setLayerMaskArrays(new INDArray[]{inputMask}, null);
 | 
				
			||||||
                net.computeGradientAndScore();
 | 
					                net.computeGradientAndScore();
 | 
				
			||||||
                double score2a = net.score();
 | 
					                double score2a = net.getScore();
 | 
				
			||||||
                Gradient g2a = net.gradient();
 | 
					                Gradient g2a = net.gradient();
 | 
				
			||||||
                assertEquals(score2, score2a, 1e-12);
 | 
					                assertEquals(score2, score2a, 1e-12);
 | 
				
			||||||
                for (String s : g2.gradientForVariable().keySet()) {
 | 
					                for (String s : g2.gradientForVariable().keySet()) {
 | 
				
			||||||
@ -330,7 +330,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
 | 
				
			|||||||
                        net.setLabel(0, labels);
 | 
					                        net.setLabel(0, labels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        net.computeGradientAndScore();
 | 
					                        net.computeGradientAndScore();
 | 
				
			||||||
                        double score = net.score();
 | 
					                        double score = net.getScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        assertEquals(expScore, score, 0.1, msg);
 | 
					                        assertEquals(expScore, score, 0.1, msg);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
 | 
				
			|||||||
@ -40,7 +40,7 @@ import java.util.Map;
 | 
				
			|||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
 | 
					import static org.junit.jupiter.api.Assertions.assertEquals;
 | 
				
			||||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
 | 
					import static org.junit.jupiter.api.Assertions.assertNotEquals;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public class BaseLayerTest extends BaseDL4JTest {
 | 
					public class BaseLayerConfigurationTest extends BaseDL4JTest {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    protected INDArray weight = Nd4j.create(new double[] {0.10, -0.20, -0.15, 0.05}, new int[] {2, 2});
 | 
					    protected INDArray weight = Nd4j.create(new double[] {0.10, -0.20, -0.15, 0.05}, new int[] {2, 2});
 | 
				
			||||||
    protected INDArray bias = Nd4j.create(new double[] {0.5, 0.5}, new int[] {1, 2});
 | 
					    protected INDArray bias = Nd4j.create(new double[] {0.5, 0.5}, new int[] {1, 2});
 | 
				
			||||||
@ -56,10 +56,10 @@ public class CacheModeTest extends BaseDL4JTest {
 | 
				
			|||||||
        INDArray out2 = net2.output(in);
 | 
					        INDArray out2 = net2.output(in);
 | 
				
			||||||
        assertEquals(out1, out2);
 | 
					        assertEquals(out1, out2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net1.params(), net2.params());
 | 
					        assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
        net1.fit(in, labels);
 | 
					        net1.fit(in, labels);
 | 
				
			||||||
        net2.fit(in, labels);
 | 
					        net2.fit(in, labels);
 | 
				
			||||||
        assertEquals(net1.params(), net2.params());
 | 
					        assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static NeuralNetConfiguration getConf(CacheMode cacheMode){
 | 
					    private static NeuralNetConfiguration getConf(CacheMode cacheMode){
 | 
				
			||||||
@ -99,10 +99,10 @@ public class CacheModeTest extends BaseDL4JTest {
 | 
				
			|||||||
            INDArray out2 = net2.output(in);
 | 
					            INDArray out2 = net2.output(in);
 | 
				
			||||||
            assertEquals(out1, out2);
 | 
					            assertEquals(out1, out2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(net1.params(), net2.params());
 | 
					            assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
            net1.fit(in, labels);
 | 
					            net1.fit(in, labels);
 | 
				
			||||||
            net2.fit(in, labels);
 | 
					            net2.fit(in, labels);
 | 
				
			||||||
            assertEquals(net1.params(), net2.params());
 | 
					            assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -145,10 +145,10 @@ public class CacheModeTest extends BaseDL4JTest {
 | 
				
			|||||||
        INDArray out2 = net2.outputSingle(in);
 | 
					        INDArray out2 = net2.outputSingle(in);
 | 
				
			||||||
        assertEquals(out1, out2);
 | 
					        assertEquals(out1, out2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net1.params(), net2.params());
 | 
					        assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
        net1.fit(new DataSet(in, labels));
 | 
					        net1.fit(new DataSet(in, labels));
 | 
				
			||||||
        net2.fit(new DataSet(in, labels));
 | 
					        net2.fit(new DataSet(in, labels));
 | 
				
			||||||
        assertEquals(net1.params(), net2.params());
 | 
					        assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){
 | 
					    private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){
 | 
				
			||||||
 | 
				
			|||||||
@ -121,7 +121,7 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
            graph.setInput(0, input);
 | 
					            graph.setInput(0, input);
 | 
				
			||||||
            graph.setLabel(0, labels);
 | 
					            graph.setLabel(0, labels);
 | 
				
			||||||
            graph.computeGradientAndScore();
 | 
					            graph.computeGradientAndScore();
 | 
				
			||||||
            results[i] = graph.score();
 | 
					            results[i] = graph.getScore();
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertNotEquals(results[0], results[1]);
 | 
					        assertNotEquals(results[0], results[1]);
 | 
				
			||||||
@ -137,7 +137,7 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        ComputationGraph net = getCNNMnistConfig();
 | 
					        ComputationGraph net = getCNNMnistConfig();
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int i = 0; i < 50; i++) {
 | 
					        for (int i = 0; i < 50; i++) {
 | 
				
			||||||
            net.fit(mnistTrain.next());
 | 
					            net.fit(mnistTrain.next());
 | 
				
			||||||
 | 
				
			|||||||
@ -265,7 +265,7 @@ public class DropoutLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
 | 
					        MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
 | 
				
			||||||
        netSeparate.init();
 | 
					        netSeparate.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(netIntegrated.params(), netSeparate.params());
 | 
					        assertEquals(netIntegrated.getModelParams(), netSeparate.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Nd4j.getRandom().setSeed(12345);
 | 
					        Nd4j.getRandom().setSeed(12345);
 | 
				
			||||||
        netIntegrated.fit(next);
 | 
					        netIntegrated.fit(next);
 | 
				
			||||||
@ -273,7 +273,7 @@ public class DropoutLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        Nd4j.getRandom().setSeed(12345);
 | 
					        Nd4j.getRandom().setSeed(12345);
 | 
				
			||||||
        netSeparate.fit(next);
 | 
					        netSeparate.fit(next);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(netIntegrated.params(), netSeparate.params());
 | 
					        assertEquals(netIntegrated.getModelParams(), netSeparate.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // check parameters
 | 
					        // check parameters
 | 
				
			||||||
        assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W"));
 | 
					        assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W"));
 | 
				
			||||||
 | 
				
			|||||||
@ -80,7 +80,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                        .setFeatureExtractor(1).build();
 | 
					                        .setFeatureExtractor(1).build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray paramsLastTwoLayers =
 | 
					        INDArray paramsLastTwoLayers =
 | 
				
			||||||
                        Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params());
 | 
					                        Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams());
 | 
				
			||||||
        MultiLayerNetwork notFrozen = new MultiLayerNetwork(
 | 
					        MultiLayerNetwork notFrozen = new MultiLayerNetwork(
 | 
				
			||||||
            (NeuralNetConfiguration) overallConf.clone()
 | 
					            (NeuralNetConfiguration) overallConf.clone()
 | 
				
			||||||
                            .layer(0, new Builder().nIn(2).nOut(3).build())
 | 
					                            .layer(0, new Builder().nIn(2).nOut(3).build())
 | 
				
			||||||
@ -102,9 +102,9 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
            modelNow.fit(randomData);
 | 
					            modelNow.fit(randomData);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(),
 | 
					        INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), modelToFineTune.getLayer(1).getParams(),
 | 
				
			||||||
                        notFrozen.params());
 | 
					                        notFrozen.getModelParams());
 | 
				
			||||||
        INDArray act = modelNow.params();
 | 
					        INDArray act = modelNow.getModelParams();
 | 
				
			||||||
        assertEquals(expected, act);
 | 
					        assertEquals(expected, act);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -136,7 +136,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(modelNow.getNetConfiguration().toJson(), clonedModel.getNetConfiguration().toJson());
 | 
					        assertEquals(modelNow.getNetConfiguration().toJson(), clonedModel.getNetConfiguration().toJson());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Check params
 | 
					        //Check params
 | 
				
			||||||
        assertEquals(modelNow.params(), clonedModel.params());
 | 
					        assertEquals(modelNow.getModelParams(), clonedModel.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork notFrozen = new MultiLayerNetwork(
 | 
					        MultiLayerNetwork notFrozen = new MultiLayerNetwork(
 | 
				
			||||||
            (NeuralNetConfiguration) overallConf.layer(0, new Builder().nIn(2).nOut(3).build())
 | 
					            (NeuralNetConfiguration) overallConf.layer(0, new Builder().nIn(2).nOut(3).build())
 | 
				
			||||||
@ -145,7 +145,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                                                            .activation(Activation.SOFTMAX).nIn(3).nOut(3)
 | 
					                                                            .activation(Activation.SOFTMAX).nIn(3).nOut(3)
 | 
				
			||||||
                                                            .build())
 | 
					                                                            .build())
 | 
				
			||||||
                            .build(),
 | 
					                            .build(),
 | 
				
			||||||
                        Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()));
 | 
					                        Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int i = 0;
 | 
					        int i = 0;
 | 
				
			||||||
        while (i < 5) {
 | 
					        while (i < 5) {
 | 
				
			||||||
@ -155,10 +155,10 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
            i++;
 | 
					            i++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(),
 | 
					        INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(),
 | 
				
			||||||
                        modelToFineTune.getLayer(1).params(), notFrozen.params());
 | 
					                        modelToFineTune.getLayer(1).getParams(), notFrozen.getModelParams());
 | 
				
			||||||
        assertEquals(expectedParams, modelNow.params());
 | 
					        assertEquals(expectedParams, modelNow.getModelParams());
 | 
				
			||||||
        assertEquals(expectedParams, clonedModel.params());
 | 
					        assertEquals(expectedParams, clonedModel.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -199,8 +199,8 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                        .setOutputs("layer1").build());
 | 
					                        .setOutputs("layer1").build());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        notFrozen.init();
 | 
					        notFrozen.init();
 | 
				
			||||||
        notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(),
 | 
					        notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").getParams(),
 | 
				
			||||||
                        modelToFineTune.getLayer("layer3").params()));
 | 
					                        modelToFineTune.getLayer("layer3").getParams()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int i = 0;
 | 
					        int i = 0;
 | 
				
			||||||
        while (i < 5) {
 | 
					        while (i < 5) {
 | 
				
			||||||
@ -209,8 +209,8 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
            i++;
 | 
					            i++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").params(),
 | 
					        assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").getParams(),
 | 
				
			||||||
                        modelToFineTune.getLayer("layer1").params(), notFrozen.params()), modelNow.params());
 | 
					                        modelToFineTune.getLayer("layer1").getParams(), notFrozen.getModelParams()), modelNow.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -244,7 +244,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(clonedModel.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
 | 
					        assertEquals(clonedModel.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Check params
 | 
					        //Check params
 | 
				
			||||||
        assertEquals(modelNow.params(), clonedModel.params());
 | 
					        assertEquals(modelNow.getModelParams(), clonedModel.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
 | 
					        ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
 | 
				
			||||||
                        .addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In")
 | 
					                        .addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In")
 | 
				
			||||||
@ -256,8 +256,8 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                                        "layer0")
 | 
					                                        "layer0")
 | 
				
			||||||
                        .setOutputs("layer1").build());
 | 
					                        .setOutputs("layer1").build());
 | 
				
			||||||
        notFrozen.init();
 | 
					        notFrozen.init();
 | 
				
			||||||
        notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(),
 | 
					        notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").getParams(),
 | 
				
			||||||
                        modelToFineTune.getLayer("layer3").params()));
 | 
					                        modelToFineTune.getLayer("layer3").getParams()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int i = 0;
 | 
					        int i = 0;
 | 
				
			||||||
@ -268,10 +268,10 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
            i++;
 | 
					            i++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").params(),
 | 
					        INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").getParams(),
 | 
				
			||||||
                        modelToFineTune.getLayer("layer1").params(), notFrozen.params());
 | 
					                        modelToFineTune.getLayer("layer1").getParams(), notFrozen.getModelParams());
 | 
				
			||||||
        assertEquals(expectedParams, modelNow.params());
 | 
					        assertEquals(expectedParams, modelNow.getModelParams());
 | 
				
			||||||
        assertEquals(expectedParams, clonedModel.params());
 | 
					        assertEquals(expectedParams, clonedModel.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -305,7 +305,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
					        MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net1.params(), net2.params());
 | 
					        assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String json = conf2.toJson();
 | 
					        String json = conf2.toJson();
 | 
				
			||||||
@ -362,7 +362,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        ComputationGraph net2 = new ComputationGraph(conf2);
 | 
					        ComputationGraph net2 = new ComputationGraph(conf2);
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net1.params(), net2.params());
 | 
					        assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String json = conf2.toJson();
 | 
					        String json = conf2.toJson();
 | 
				
			||||||
 | 
				
			|||||||
@ -75,7 +75,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
					        MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net1.params(), net2.params());
 | 
					        assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String json = conf2.toJson();
 | 
					        String json = conf2.toJson();
 | 
				
			||||||
@ -130,7 +130,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
 | 
				
			|||||||
        ComputationGraph net2 = new ComputationGraph(conf2);
 | 
					        ComputationGraph net2 = new ComputationGraph(conf2);
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net1.params(), net2.params());
 | 
					        assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String json = conf2.toJson();
 | 
					        String json = conf2.toJson();
 | 
				
			||||||
@ -170,19 +170,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork network = new MultiLayerNetwork(conf1);
 | 
					        MultiLayerNetwork network = new MultiLayerNetwork(conf1);
 | 
				
			||||||
        network.init();
 | 
					        network.init();
 | 
				
			||||||
        INDArray unfrozenLayerParams = network.getLayer(0).params().dup();
 | 
					        INDArray unfrozenLayerParams = network.getLayer(0).getParams().dup();
 | 
				
			||||||
        INDArray frozenLayerParams1 = network.getLayer(1).params().dup();
 | 
					        INDArray frozenLayerParams1 = network.getLayer(1).getParams().dup();
 | 
				
			||||||
        INDArray frozenLayerParams2 = network.getLayer(2).params().dup();
 | 
					        INDArray frozenLayerParams2 = network.getLayer(2).getParams().dup();
 | 
				
			||||||
        INDArray frozenOutputLayerParams = network.getLayer(3).params().dup();
 | 
					        INDArray frozenOutputLayerParams = network.getLayer(3).getParams().dup();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int i = 0; i < 100; i++) {
 | 
					        for (int i = 0; i < 100; i++) {
 | 
				
			||||||
            network.fit(randomData);
 | 
					            network.fit(randomData);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertNotEquals(unfrozenLayerParams, network.getLayer(0).params());
 | 
					        assertNotEquals(unfrozenLayerParams, network.getLayer(0).getParams());
 | 
				
			||||||
        assertEquals(frozenLayerParams1, network.getLayer(1).params());
 | 
					        assertEquals(frozenLayerParams1, network.getLayer(1).getParams());
 | 
				
			||||||
        assertEquals(frozenLayerParams2, network.getLayer(2).params());
 | 
					        assertEquals(frozenLayerParams2, network.getLayer(2).getParams());
 | 
				
			||||||
        assertEquals(frozenOutputLayerParams, network.getLayer(3).params());
 | 
					        assertEquals(frozenOutputLayerParams, network.getLayer(3).getParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -228,19 +228,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        ComputationGraph computationGraph = new ComputationGraph(computationGraphConf);
 | 
					        ComputationGraph computationGraph = new ComputationGraph(computationGraphConf);
 | 
				
			||||||
        computationGraph.init();
 | 
					        computationGraph.init();
 | 
				
			||||||
        INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
 | 
					        INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup();
 | 
				
			||||||
        INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
 | 
					        INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup();
 | 
				
			||||||
        INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
 | 
					        INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup();
 | 
				
			||||||
        INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).params().dup();
 | 
					        INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).getParams().dup();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int i = 0; i < 100; i++) {
 | 
					        for (int i = 0; i < 100; i++) {
 | 
				
			||||||
            computationGraph.fit(randomData);
 | 
					            computationGraph.fit(randomData);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).params());
 | 
					        assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams());
 | 
				
			||||||
        assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).params());
 | 
					        assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).getParams());
 | 
				
			||||||
        assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).params());
 | 
					        assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).getParams());
 | 
				
			||||||
        assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).params());
 | 
					        assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).getParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -275,17 +275,17 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
 | 
				
			|||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
        MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
 | 
					        MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
 | 
				
			||||||
        frozenNetwork.init();
 | 
					        frozenNetwork.init();
 | 
				
			||||||
        INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).params().dup();
 | 
					        INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).getParams().dup();
 | 
				
			||||||
        INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).params().dup();
 | 
					        INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).getParams().dup();
 | 
				
			||||||
        INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).params().dup();
 | 
					        INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).getParams().dup();
 | 
				
			||||||
        INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).params().dup();
 | 
					        INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).getParams().dup();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd);
 | 
					        MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd);
 | 
				
			||||||
        sgdNetwork.init();
 | 
					        sgdNetwork.init();
 | 
				
			||||||
        INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).params().dup();
 | 
					        INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).getParams().dup();
 | 
				
			||||||
        INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).params().dup();
 | 
					        INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).getParams().dup();
 | 
				
			||||||
        INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).params().dup();
 | 
					        INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).getParams().dup();
 | 
				
			||||||
        INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).params().dup();
 | 
					        INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).getParams().dup();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int i = 0; i < 100; i++) {
 | 
					        for (int i = 0; i < 100; i++) {
 | 
				
			||||||
            frozenNetwork.fit(randomData);
 | 
					            frozenNetwork.fit(randomData);
 | 
				
			||||||
@ -294,10 +294,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
 | 
				
			|||||||
            sgdNetwork.fit(randomData);
 | 
					            sgdNetwork.fit(randomData);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(frozenNetwork.getLayer(0).params(), sgdNetwork.getLayer(0).params());
 | 
					        assertEquals(frozenNetwork.getLayer(0).getParams(), sgdNetwork.getLayer(0).getParams());
 | 
				
			||||||
        assertEquals(frozenNetwork.getLayer(1).params(), sgdNetwork.getLayer(1).params());
 | 
					        assertEquals(frozenNetwork.getLayer(1).getParams(), sgdNetwork.getLayer(1).getParams());
 | 
				
			||||||
        assertEquals(frozenNetwork.getLayer(2).params(), sgdNetwork.getLayer(2).params());
 | 
					        assertEquals(frozenNetwork.getLayer(2).getParams(), sgdNetwork.getLayer(2).getParams());
 | 
				
			||||||
        assertEquals(frozenNetwork.getLayer(3).params(), sgdNetwork.getLayer(3).params());
 | 
					        assertEquals(frozenNetwork.getLayer(3).getParams(), sgdNetwork.getLayer(3).getParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -360,17 +360,17 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf);
 | 
					        ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf);
 | 
				
			||||||
        frozenComputationGraph.init();
 | 
					        frozenComputationGraph.init();
 | 
				
			||||||
        INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
 | 
					        INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup();
 | 
				
			||||||
        INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
 | 
					        INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup();
 | 
				
			||||||
        INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
 | 
					        INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup();
 | 
				
			||||||
        INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).params().dup();
 | 
					        INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).getParams().dup();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf);
 | 
					        ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf);
 | 
				
			||||||
        sgdComputationGraph.init();
 | 
					        sgdComputationGraph.init();
 | 
				
			||||||
        INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
 | 
					        INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup();
 | 
				
			||||||
        INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
 | 
					        INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup();
 | 
				
			||||||
        INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
 | 
					        INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup();
 | 
				
			||||||
        INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).params().dup();
 | 
					        INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).getParams().dup();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int i = 0; i < 100; i++) {
 | 
					        for (int i = 0; i < 100; i++) {
 | 
				
			||||||
            frozenComputationGraph.fit(randomData);
 | 
					            frozenComputationGraph.fit(randomData);
 | 
				
			||||||
@ -379,10 +379,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
 | 
				
			|||||||
            sgdComputationGraph.fit(randomData);
 | 
					            sgdComputationGraph.fit(randomData);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params());
 | 
					        assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams());
 | 
				
			||||||
        assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params());
 | 
					        assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams());
 | 
				
			||||||
        assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params());
 | 
					        assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams());
 | 
				
			||||||
        assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).params(), sgdComputationGraph.getLayer(frozenBranchOutput).params());
 | 
					        assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).getParams(), sgdComputationGraph.getLayer(frozenBranchOutput).getParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -68,9 +68,9 @@ public class OutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        INDArray params = Nd4j.create(1, numParams);
 | 
					        INDArray params = Nd4j.create(1, numParams);
 | 
				
			||||||
        OutputLayer l = (OutputLayer) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf,
 | 
					        OutputLayer l = (OutputLayer) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf,
 | 
				
			||||||
                        Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType());
 | 
					                        Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType());
 | 
				
			||||||
        params = l.params();
 | 
					        params = l.getModelParams();
 | 
				
			||||||
        l.setParamsTable(params);
 | 
					        l.setParamsTable(params);
 | 
				
			||||||
        assertEquals(params, l.params());
 | 
					        assertEquals(params, l.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -217,8 +217,8 @@ public class OutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
            //However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping)
 | 
					            //However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping)
 | 
				
			||||||
            //RnnOutputLayer has miniBatch examples
 | 
					            //RnnOutputLayer has miniBatch examples
 | 
				
			||||||
            //Hence: expect difference in scores by factor of timeSeriesLength
 | 
					            //Hence: expect difference in scores by factor of timeSeriesLength
 | 
				
			||||||
            double score = mln.score() * timeSeriesLength;
 | 
					            double score = mln.getScore() * timeSeriesLength;
 | 
				
			||||||
            double scoreRNN = mlnRnn.score();
 | 
					            double scoreRNN = mlnRnn.getScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertFalse(Double.isNaN(score));
 | 
					            assertFalse(Double.isNaN(score));
 | 
				
			||||||
            assertFalse(Double.isNaN(scoreRNN));
 | 
					            assertFalse(Double.isNaN(scoreRNN));
 | 
				
			||||||
@ -234,7 +234,7 @@ public class OutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            RnnOutputLayer rnnol = (RnnOutputLayer) mlnRnn.getOutputLayer();
 | 
					            RnnOutputLayer rnnol = (RnnOutputLayer) mlnRnn.getOutputLayer();
 | 
				
			||||||
            //assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength});
 | 
					            //assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength});
 | 
				
			||||||
            //Input may be set by BaseLayer methods. Thus input may end up as reshaped 2d version instead of original 3d version.
 | 
					            //Input may be set by BaseLayerConfiguration methods. Thus input may end up as reshaped 2d version instead of original 3d version.
 | 
				
			||||||
            //Not ideal, but everything else works.
 | 
					            //Not ideal, but everything else works.
 | 
				
			||||||
            assertArrayEquals(rnnol.getLabels().shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
 | 
					            assertArrayEquals(rnnol.getLabels().shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -303,7 +303,7 @@ public class OutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2);
 | 
					        MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
        mln2.init();
 | 
					        mln2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        mln2.setParams(mln.params());
 | 
					        mln2.setParams(mln.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray in = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
 | 
					        INDArray in = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -330,7 +330,7 @@ public class OutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        mln2.computeGradientAndScore();
 | 
					        mln2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
 | 
					        assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
 | 
				
			||||||
        assertEquals(mln.score(), mln2.score(), 1e-6);
 | 
					        assertEquals(mln.getScore(), mln2.getScore(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        TestUtils.testModelSerialization(mln);
 | 
					        TestUtils.testModelSerialization(mln);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -386,7 +386,7 @@ public class OutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                mln2.init();
 | 
					                mln2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                mln2.setParams(mln.params());
 | 
					                mln2.setParams(mln.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                INDArray in = Nd4j.rand(3, 3, 5, 5);
 | 
					                INDArray in = Nd4j.rand(3, 3, 5, 5);
 | 
				
			||||||
@ -407,7 +407,7 @@ public class OutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                mln.computeGradientAndScore();
 | 
					                mln.computeGradientAndScore();
 | 
				
			||||||
                mln2.computeGradientAndScore();
 | 
					                mln2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                assertEquals(mln.score(), mln2.score(), 1e-6);
 | 
					                assertEquals(mln.getScore(), mln2.getScore(), 1e-6);
 | 
				
			||||||
                assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
 | 
					                assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                //Also check computeScoreForExamples
 | 
					                //Also check computeScoreForExamples
 | 
				
			||||||
@ -479,7 +479,7 @@ public class OutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                graph2.init();
 | 
					                graph2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                graph2.setParams(graph.params());
 | 
					                graph2.setParams(graph.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                INDArray in = Nd4j.rand(3, 3, 5, 5);
 | 
					                INDArray in = Nd4j.rand(3, 3, 5, 5);
 | 
				
			||||||
@ -500,7 +500,7 @@ public class OutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                graph.computeGradientAndScore();
 | 
					                graph.computeGradientAndScore();
 | 
				
			||||||
                graph2.computeGradientAndScore();
 | 
					                graph2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                assertEquals(graph.score(), graph2.score(), 1e-6);
 | 
					                assertEquals(graph.getScore(), graph2.getScore(), 1e-6);
 | 
				
			||||||
                assertEquals(graph.gradient().gradient(), graph2.gradient().gradient());
 | 
					                assertEquals(graph.gradient().gradient(), graph2.gradient().gradient());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                //Also check computeScoreForExamples
 | 
					                //Also check computeScoreForExamples
 | 
				
			||||||
 | 
				
			|||||||
@ -59,13 +59,13 @@ public class SeedTest extends BaseDL4JTest {
 | 
				
			|||||||
        layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
        double score = layer.score();
 | 
					        double score = layer.getScore();
 | 
				
			||||||
        INDArray parameters = layer.params();
 | 
					        INDArray parameters = layer.getParams();
 | 
				
			||||||
        layer.setParams(parameters);
 | 
					        layer.setParams(parameters);
 | 
				
			||||||
        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        double score2 = layer.score();
 | 
					        double score2 = layer.getScore();
 | 
				
			||||||
        assertEquals(parameters, layer.params());
 | 
					        assertEquals(parameters, layer.getParams());
 | 
				
			||||||
        assertEquals(score, score2, 1e-4);
 | 
					        assertEquals(score, score2, 1e-4);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -845,9 +845,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    public static void testHelper(TestCase tc) {
 | 
					    public static void testHelper(TestCase tc) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        tc.net2.params().assign(tc.net1.params());
 | 
					        tc.net2.getModelParams().assign(tc.net1.getModelParams());
 | 
				
			||||||
        tc.net3.params().assign(tc.net1.params());
 | 
					        tc.net3.getModelParams().assign(tc.net1.getModelParams());
 | 
				
			||||||
        tc.net4.params().assign(tc.net1.params());
 | 
					        tc.net4.getModelParams().assign(tc.net1.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Test forward pass:
 | 
					        //Test forward pass:
 | 
				
			||||||
        INDArray inNCHW = tc.inNCHW;
 | 
					        INDArray inNCHW = tc.inNCHW;
 | 
				
			||||||
@ -909,9 +909,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
 | 
				
			|||||||
        tc.net3.fit(inNHWC, tc.labelsNHWC);
 | 
					        tc.net3.fit(inNHWC, tc.labelsNHWC);
 | 
				
			||||||
        tc.net4.fit(inNHWC, tc.labelsNHWC);
 | 
					        tc.net4.fit(inNHWC, tc.labelsNHWC);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(tc.net1.params(), tc.net2.params(), tc.msg);
 | 
					        assertEquals(tc.net1.getModelParams(), tc.net2.getModelParams(), tc.msg);
 | 
				
			||||||
        assertEquals(tc.net1.params(), tc.net3.params(), tc.msg);
 | 
					        assertEquals(tc.net1.getModelParams(), tc.net3.getModelParams(), tc.msg);
 | 
				
			||||||
        assertEquals(tc.net1.params(), tc.net4.params(), tc.msg);
 | 
					        assertEquals(tc.net1.getModelParams(), tc.net4.getModelParams(), tc.msg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Test serialization
 | 
					        //Test serialization
 | 
				
			||||||
        MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
 | 
					        MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
 | 
				
			||||||
 | 
				
			|||||||
@ -30,7 +30,6 @@ import org.deeplearning4j.nn.api.Layer;
 | 
				
			|||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
					import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
					import org.deeplearning4j.nn.conf.ConvolutionMode;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.RNNFormat;
 | 
					import org.deeplearning4j.nn.conf.RNNFormat;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
 | 
				
			||||||
@ -38,7 +37,6 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.layers.*;
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInit;
 | 
					import org.deeplearning4j.nn.weights.WeightInit;
 | 
				
			||||||
import org.deeplearning4j.nn.weights.WeightInitNormal;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
					import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
 | 
				
			||||||
import org.junit.jupiter.api.Test;
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
import org.nd4j.linalg.activations.Activation;
 | 
					import org.nd4j.linalg.activations.Activation;
 | 
				
			||||||
@ -450,10 +448,10 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork net = getCNNMLNConfig(true, false);
 | 
					        MultiLayerNetwork net = getCNNMLNConfig(true, false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray paramsOrig = net.params().dup();
 | 
					        INDArray paramsOrig = net.getModelParams().dup();
 | 
				
			||||||
        net.setParams(paramsOrig);
 | 
					        net.setParams(paramsOrig);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray params2 = net.params();
 | 
					        INDArray params2 = net.getModelParams();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(paramsOrig, params2);
 | 
					        assertEquals(paramsOrig, params2);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -154,7 +154,7 @@ public class TestCustomLayers extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
					        MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net2.params(), net.params());
 | 
					        assertEquals(net2.getModelParams(), net.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray testFeatures = Nd4j.rand(1, 10);
 | 
					        INDArray testFeatures = Nd4j.rand(1, 10);
 | 
				
			||||||
        INDArray testLabels = Nd4j.zeros(1, 10);
 | 
					        INDArray testLabels = Nd4j.zeros(1, 10);
 | 
				
			||||||
@ -207,7 +207,7 @@ public class TestCustomLayers extends BaseDL4JTest {
 | 
				
			|||||||
        ComputationGraph net2 = new ComputationGraph(conf2);
 | 
					        ComputationGraph net2 = new ComputationGraph(conf2);
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net2.params(), net.params());
 | 
					        assertEquals(net2.getModelParams(), net.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray testFeatures = Nd4j.rand(1, 10);
 | 
					        INDArray testFeatures = Nd4j.rand(1, 10);
 | 
				
			||||||
        INDArray testLabels = Nd4j.zeros(1, 10);
 | 
					        INDArray testLabels = Nd4j.zeros(1, 10);
 | 
				
			||||||
 | 
				
			|||||||
@ -56,7 +56,7 @@ public class CustomLayer extends FeedForwardLayer {
 | 
				
			|||||||
                                                       boolean initializeParams, DataType networkDataType) {
 | 
					                                                       boolean initializeParams, DataType networkDataType) {
 | 
				
			||||||
        LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
 | 
					        LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
 | 
				
			||||||
        CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType);
 | 
					        CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType);
 | 
				
			||||||
        ret.setListeners(trainingListeners);
 | 
					        ret.addTrainingListeners(trainingListeners);
 | 
				
			||||||
        ret.setIndex(layerIndex);
 | 
					        ret.setIndex(layerIndex);
 | 
				
			||||||
        ret.setParamsViewArray(layerParamsView);
 | 
					        ret.setParamsViewArray(layerParamsView);
 | 
				
			||||||
        Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
 | 
					        Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
 | 
				
			||||||
 | 
				
			|||||||
@ -54,7 +54,7 @@ public class CustomOutputLayer extends BaseOutputLayer {
 | 
				
			|||||||
                             int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
 | 
					                             int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
 | 
				
			||||||
        LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
 | 
					        LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
 | 
				
			||||||
        CustomOutputLayerImpl ret = new CustomOutputLayerImpl(lconf, networkDataType);
 | 
					        CustomOutputLayerImpl ret = new CustomOutputLayerImpl(lconf, networkDataType);
 | 
				
			||||||
        ret.setListeners(trainingListeners);
 | 
					        ret.addTrainingListeners(trainingListeners);
 | 
				
			||||||
        ret.setIndex(layerIndex);
 | 
					        ret.setIndex(layerIndex);
 | 
				
			||||||
        ret.setParamsViewArray(layerParamsView);
 | 
					        ret.setParamsViewArray(layerParamsView);
 | 
				
			||||||
        Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
 | 
					        Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
 | 
				
			||||||
 | 
				
			|||||||
@ -72,7 +72,7 @@ public class DenseTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        DataSet test = iter.next();
 | 
					        DataSet test = iter.next();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(model.params(), model2.params());
 | 
					        assertEquals(model.getModelParams(), model2.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Evaluation eval = new Evaluation();
 | 
					        Evaluation eval = new Evaluation();
 | 
				
			||||||
        INDArray output = model.output(test.getFeatures());
 | 
					        INDArray output = model.output(test.getFeatures());
 | 
				
			||||||
@ -99,7 +99,7 @@ public class DenseTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        DataSet test = iter.next();
 | 
					        DataSet test = iter.next();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(model.params(), model2.params());
 | 
					        assertEquals(model.getModelParams(), model2.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Evaluation eval = new Evaluation();
 | 
					        Evaluation eval = new Evaluation();
 | 
				
			||||||
        INDArray output = model.output(test.getFeatures());
 | 
					        INDArray output = model.output(test.getFeatures());
 | 
				
			||||||
 | 
				
			|||||||
@ -169,7 +169,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net2.setParams(net.params().dup());
 | 
					        net2.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int batchSize = 3;
 | 
					        int batchSize = 3;
 | 
				
			||||||
        INDArray inEmbedding = Nd4j.create(batchSize, 1);
 | 
					        INDArray inEmbedding = Nd4j.create(batchSize, 1);
 | 
				
			||||||
@ -216,7 +216,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net2.setParams(net.params().dup());
 | 
					        net2.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int batchSize = 3;
 | 
					        int batchSize = 3;
 | 
				
			||||||
        INDArray inEmbedding = Nd4j.create(batchSize, 1);
 | 
					        INDArray inEmbedding = Nd4j.create(batchSize, 1);
 | 
				
			||||||
@ -262,7 +262,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net2.setParams(net.params().dup());
 | 
					        net2.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int batchSize = 3;
 | 
					        int batchSize = 3;
 | 
				
			||||||
        INDArray inEmbedding = Nd4j.create(batchSize, 1);
 | 
					        INDArray inEmbedding = Nd4j.create(batchSize, 1);
 | 
				
			||||||
@ -287,7 +287,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        net.computeGradientAndScore();
 | 
					        net.computeGradientAndScore();
 | 
				
			||||||
        net2.computeGradientAndScore();
 | 
					        net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net2.score(), net.score(), 1e-6);
 | 
					        assertEquals(net2.getScore(), net.getScore(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String, INDArray> gradient = net.gradient().gradientForVariable();
 | 
					        Map<String, INDArray> gradient = net.gradient().gradientForVariable();
 | 
				
			||||||
        Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
 | 
					        Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
 | 
				
			||||||
@ -323,7 +323,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net2.setParams(net.params().dup());
 | 
					        net2.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int batchSize = 3;
 | 
					        int batchSize = 3;
 | 
				
			||||||
        INDArray inEmbedding = Nd4j.create(batchSize, 1);
 | 
					        INDArray inEmbedding = Nd4j.create(batchSize, 1);
 | 
				
			||||||
@ -349,7 +349,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        net2.computeGradientAndScore();
 | 
					        net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//        System.out.println(net.score() + "\t" + net2.score());
 | 
					//        System.out.println(net.score() + "\t" + net2.score());
 | 
				
			||||||
        assertEquals(net2.score(), net.score(), 1e-6);
 | 
					        assertEquals(net2.getScore(), net.getScore(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String, INDArray> gradient = net.gradient().gradientForVariable();
 | 
					        Map<String, INDArray> gradient = net.gradient().gradientForVariable();
 | 
				
			||||||
        Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
 | 
					        Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
 | 
				
			||||||
@ -395,7 +395,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net2.setParams(net.params().dup());
 | 
					        net2.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray inEmbedding = Nd4j.create(batchSize, 1, timeSeriesLength);
 | 
					        INDArray inEmbedding = Nd4j.create(batchSize, 1, timeSeriesLength);
 | 
				
			||||||
        INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, timeSeriesLength);
 | 
					        INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, timeSeriesLength);
 | 
				
			||||||
@ -422,7 +422,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        net2.computeGradientAndScore();
 | 
					        net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//        System.out.println(net.score() + "\t" + net2.score());
 | 
					//        System.out.println(net.score() + "\t" + net2.score());
 | 
				
			||||||
        assertEquals(net2.score(), net.score(), 1e-5);
 | 
					        assertEquals(net2.getScore(), net.getScore(), 1e-5);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String, INDArray> gradient = net.gradient().gradientForVariable();
 | 
					        Map<String, INDArray> gradient = net.gradient().gradientForVariable();
 | 
				
			||||||
        Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
 | 
					        Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
 | 
				
			||||||
@ -484,7 +484,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
					                MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
                net2.init();
 | 
					                net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                net2.setParams(net.params().dup());
 | 
					                net2.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength);
 | 
					                INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength);
 | 
				
			||||||
                INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength);
 | 
					                INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength);
 | 
				
			||||||
@ -523,7 +523,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                net2.computeGradientAndScore();
 | 
					                net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//                System.out.println(net.score() + "\t" + net2.score());
 | 
					//                System.out.println(net.score() + "\t" + net2.score());
 | 
				
			||||||
                assertEquals(net2.score(), net.score(), 1e-5);
 | 
					                assertEquals(net2.getScore(), net.getScore(), 1e-5);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                Map<String, INDArray> gradients = net.gradient().gradientForVariable();
 | 
					                Map<String, INDArray> gradients = net.gradient().gradientForVariable();
 | 
				
			||||||
                Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
 | 
					                Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
 | 
				
			||||||
@ -640,7 +640,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                        MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
					                        MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
                        net2.init();
 | 
					                        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        net2.setParams(net.params().dup());
 | 
					                        net2.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength});
 | 
					                        INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength});
 | 
				
			||||||
                        INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength);
 | 
					                        INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength);
 | 
				
			||||||
@ -678,7 +678,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                        net.computeGradientAndScore();
 | 
					                        net.computeGradientAndScore();
 | 
				
			||||||
                        net2.computeGradientAndScore();
 | 
					                        net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        assertEquals(net2.score(), net.score(), 1e-5);
 | 
					                        assertEquals(net2.getScore(), net.getScore(), 1e-5);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        Map<String, INDArray> gradients = net.gradient().gradientForVariable();
 | 
					                        Map<String, INDArray> gradients = net.gradient().gradientForVariable();
 | 
				
			||||||
                        Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
 | 
					                        Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
 | 
				
			||||||
@ -777,9 +777,9 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
 | 
					                MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
 | 
				
			||||||
                net3.init();
 | 
					                net3.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                INDArray p1 = net.params();
 | 
					                INDArray p1 = net.getModelParams();
 | 
				
			||||||
                INDArray p2 = net2.params();
 | 
					                INDArray p2 = net2.getModelParams();
 | 
				
			||||||
                INDArray p3 = net3.params();
 | 
					                INDArray p3 = net3.getModelParams();
 | 
				
			||||||
                boolean eq = p1.equalsWithEps(p2, 1e-4);
 | 
					                boolean eq = p1.equalsWithEps(p2, 1e-4);
 | 
				
			||||||
                String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
 | 
					                String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
 | 
				
			||||||
                assertTrue(eq, str + " p1/p2 params not equal");
 | 
					                assertTrue(eq, str + " p1/p2 params not equal");
 | 
				
			||||||
 | 
				
			|||||||
@ -514,7 +514,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(100));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(100));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int nEpochs = 1000;
 | 
					        int nEpochs = 1000;
 | 
				
			||||||
        DataSet ds = iter.next();
 | 
					        DataSet ds = iter.next();
 | 
				
			||||||
 | 
				
			|||||||
@ -79,13 +79,13 @@ public class OCNNOutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        if (doLearningFirst) {
 | 
					        if (doLearningFirst) {
 | 
				
			||||||
            //Run a number of iterations of learning
 | 
					            //Run a number of iterations of learning
 | 
				
			||||||
            network.setInput(arr);
 | 
					            network.setInput(arr);
 | 
				
			||||||
            network.setListeners(new ScoreIterationListener(1));
 | 
					            network.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
            network.computeGradientAndScore();
 | 
					            network.computeGradientAndScore();
 | 
				
			||||||
            double scoreBefore = network.score();
 | 
					            double scoreBefore = network.getScore();
 | 
				
			||||||
            for (int j = 0; j < 10; j++)
 | 
					            for (int j = 0; j < 10; j++)
 | 
				
			||||||
                network.fit(ds);
 | 
					                network.fit(ds);
 | 
				
			||||||
            network.computeGradientAndScore();
 | 
					            network.computeGradientAndScore();
 | 
				
			||||||
            double scoreAfter = network.score();
 | 
					            double scoreAfter = network.getScore();
 | 
				
			||||||
            //Can't test in 'characteristic mode of operation' if not learning
 | 
					            //Can't test in 'characteristic mode of operation' if not learning
 | 
				
			||||||
            String msg = "testLayer() - score did not (sufficiently) decrease during learning - activationFn="
 | 
					            String msg = "testLayer() - score did not (sufficiently) decrease during learning - activationFn="
 | 
				
			||||||
                    + "relu" + ", lossFn=" + "ocnn" + ", "  + "sigmoid"
 | 
					                    + "relu" + ", lossFn=" + "ocnn" + ", "  + "sigmoid"
 | 
				
			||||||
@ -147,7 +147,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
        tmpFile.deleteOnExit();
 | 
					        tmpFile.deleteOnExit();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(tmpFile);
 | 
					        MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(tmpFile);
 | 
				
			||||||
        assertEquals(network.params(),multiLayerNetwork.params());
 | 
					        assertEquals(network.getModelParams(),multiLayerNetwork.getModelParams());
 | 
				
			||||||
        assertEquals(network.numParams(),multiLayerNetwork.numParams());
 | 
					        assertEquals(network.numParams(),multiLayerNetwork.numParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -187,7 +187,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
        MultiLayerNetwork network = new MultiLayerNetwork(configuration);
 | 
					        MultiLayerNetwork network = new MultiLayerNetwork(configuration);
 | 
				
			||||||
        network.init();
 | 
					        network.init();
 | 
				
			||||||
        network.setListeners(new ScoreIterationListener(1));
 | 
					        network.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
        return network;
 | 
					        return network;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -124,7 +124,7 @@ public class BidirectionalTest extends BaseDL4JTest {
 | 
				
			|||||||
                assertEquals(n1, n2);
 | 
					                assertEquals(n1, n2);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            net2.setParams(net1.params());  //Assuming exact same layout here...
 | 
					            net2.setParams(net1.getModelParams());  //Assuming exact same layout here...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            INDArray in;
 | 
					            INDArray in;
 | 
				
			||||||
            if (rnnDataFormat == NCW){
 | 
					            if (rnnDataFormat == NCW){
 | 
				
			||||||
@ -154,7 +154,7 @@ public class BidirectionalTest extends BaseDL4JTest {
 | 
				
			|||||||
            net2.computeGradientAndScore();
 | 
					            net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //Ensure scores are equal:
 | 
					            //Ensure scores are equal:
 | 
				
			||||||
            assertEquals(net1.score(), net2.score(), 1e-6);
 | 
					            assertEquals(net1.getScore(), net2.getScore(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //Ensure gradients are equal:
 | 
					            //Ensure gradients are equal:
 | 
				
			||||||
            Gradient g1 = net1.gradient();
 | 
					            Gradient g1 = net1.gradient();
 | 
				
			||||||
@ -174,8 +174,8 @@ public class BidirectionalTest extends BaseDL4JTest {
 | 
				
			|||||||
            net1.fit(in, labels);
 | 
					            net1.fit(in, labels);
 | 
				
			||||||
            net2.fit(in, labels);
 | 
					            net2.fit(in, labels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            INDArray p1 = net1.params();
 | 
					            INDArray p1 = net1.getModelParams();
 | 
				
			||||||
            INDArray p2 = net2.params();
 | 
					            INDArray p2 = net2.getModelParams();
 | 
				
			||||||
            assertEquals(p1, p2);
 | 
					            assertEquals(p1, p2);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -232,7 +232,7 @@ public class BidirectionalTest extends BaseDL4JTest {
 | 
				
			|||||||
                assertEquals(n1, n2);
 | 
					                assertEquals(n1, n2);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            net2.setParams(net1.params());  //Assuming exact same layout here...
 | 
					            net2.setParams(net1.getModelParams());  //Assuming exact same layout here...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            INDArray in = Nd4j.rand(3, 10, 5);
 | 
					            INDArray in = Nd4j.rand(3, 10, 5);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -253,7 +253,7 @@ public class BidirectionalTest extends BaseDL4JTest {
 | 
				
			|||||||
            net2.computeGradientAndScore();
 | 
					            net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //Ensure scores are equal:
 | 
					            //Ensure scores are equal:
 | 
				
			||||||
            assertEquals(net1.score(), net2.score(), 1e-6);
 | 
					            assertEquals(net1.getScore(), net2.getScore(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //Ensure gradients are equal:
 | 
					            //Ensure gradients are equal:
 | 
				
			||||||
            Gradient g1 = net1.gradient();
 | 
					            Gradient g1 = net1.gradient();
 | 
				
			||||||
@ -273,8 +273,8 @@ public class BidirectionalTest extends BaseDL4JTest {
 | 
				
			|||||||
            net1.fit(new DataSet(in, labels));
 | 
					            net1.fit(new DataSet(in, labels));
 | 
				
			||||||
            net2.fit(new DataSet(in, labels));
 | 
					            net2.fit(new DataSet(in, labels));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            INDArray p1 = net1.params();
 | 
					            INDArray p1 = net1.getModelParams();
 | 
				
			||||||
            INDArray p2 = net2.params();
 | 
					            INDArray p2 = net2.getModelParams();
 | 
				
			||||||
            assertEquals(p1, p2);
 | 
					            assertEquals(p1, p2);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -340,7 +340,7 @@ public class BidirectionalTest extends BaseDL4JTest {
 | 
				
			|||||||
            net1.computeGradientAndScore();
 | 
					            net1.computeGradientAndScore();
 | 
				
			||||||
            net2.computeGradientAndScore();
 | 
					            net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(net1.score(), net2.score(), 1e-6);
 | 
					            assertEquals(net1.getScore(), net2.getScore(), 1e-6);
 | 
				
			||||||
            assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
 | 
					            assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -403,7 +403,7 @@ public class BidirectionalTest extends BaseDL4JTest {
 | 
				
			|||||||
            net1.computeGradientAndScore();
 | 
					            net1.computeGradientAndScore();
 | 
				
			||||||
            net2.computeGradientAndScore();
 | 
					            net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(net1.score(), net2.score(), 1e-6);
 | 
					            assertEquals(net1.getScore(), net2.getScore(), 1e-6);
 | 
				
			||||||
            assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
 | 
					            assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -277,7 +277,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
 | 
					        final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        params = bidirectionalLSTM.params();
 | 
					        params = bidirectionalLSTM.getModelParams();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        bidirectionalLSTM.setParamsTable(params);
 | 
					        bidirectionalLSTM.setParamsTable(params);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -285,9 +285,9 @@ public class RnnDataFormatTests extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        public static void testHelper(TestCase tc) {
 | 
					        public static void testHelper(TestCase tc) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            tc.net2.params().assign(tc.net1.params());
 | 
					            tc.net2.getModelParams().assign(tc.net1.getModelParams());
 | 
				
			||||||
            tc.net3.params().assign(tc.net1.params());
 | 
					            tc.net3.getModelParams().assign(tc.net1.getModelParams());
 | 
				
			||||||
            tc.net4.params().assign(tc.net1.params());
 | 
					            tc.net4.getModelParams().assign(tc.net1.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            INDArray inNCW = tc.inNCW;
 | 
					            INDArray inNCW = tc.inNCW;
 | 
				
			||||||
            INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup();
 | 
					            INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup();
 | 
				
			||||||
@ -352,9 +352,9 @@ public class RnnDataFormatTests extends BaseDL4JTest {
 | 
				
			|||||||
            tc.net3.fit(inNWC, tc.labelsNWC);
 | 
					            tc.net3.fit(inNWC, tc.labelsNWC);
 | 
				
			||||||
            tc.net4.fit(inNWC, tc.labelsNWC);
 | 
					            tc.net4.fit(inNWC, tc.labelsNWC);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(tc.net1.params(), tc.net2.params(), tc.msg);
 | 
					            assertEquals(tc.net1.getModelParams(), tc.net2.getModelParams(), tc.msg);
 | 
				
			||||||
            assertEquals(tc.net1.params(), tc.net3.params(), tc.msg);
 | 
					            assertEquals(tc.net1.getModelParams(), tc.net3.getModelParams(), tc.msg);
 | 
				
			||||||
            assertEquals(tc.net1.params(), tc.net4.params(), tc.msg);
 | 
					            assertEquals(tc.net1.getModelParams(), tc.net4.getModelParams(), tc.msg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //Test serialization
 | 
					            //Test serialization
 | 
				
			||||||
            MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
 | 
					            MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
 | 
				
			||||||
 | 
				
			|||||||
@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.recurrent;
 | 
				
			|||||||
import org.deeplearning4j.BaseDL4JTest;
 | 
					import org.deeplearning4j.BaseDL4JTest;
 | 
				
			||||||
import org.deeplearning4j.TestUtils;
 | 
					import org.deeplearning4j.TestUtils;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.RNNFormat;
 | 
					import org.deeplearning4j.nn.conf.RNNFormat;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.dropout.TestDropout;
 | 
					import org.deeplearning4j.nn.conf.dropout.TestDropout;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
 | 
					import org.deeplearning4j.nn.conf.layers.GravesLSTM;
 | 
				
			||||||
@ -173,8 +172,8 @@ public class TestRnnLayers extends BaseDL4JTest {
 | 
				
			|||||||
            MultiLayerNetwork netD2 = new MultiLayerNetwork(confD2);
 | 
					            MultiLayerNetwork netD2 = new MultiLayerNetwork(confD2);
 | 
				
			||||||
            netD2.init();
 | 
					            netD2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(net.params(), netD.params(), s);
 | 
					            assertEquals(net.getModelParams(), netD.getModelParams(), s);
 | 
				
			||||||
            assertEquals(net.params(), netD2.params(), s);
 | 
					            assertEquals(net.getModelParams(), netD2.getModelParams(), s);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            INDArray f = Nd4j.rand(DataType.FLOAT, 3, 10, 10);
 | 
					            INDArray f = Nd4j.rand(DataType.FLOAT, 3, 10, 10);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -193,7 +192,7 @@ public class TestRnnLayers extends BaseDL4JTest {
 | 
				
			|||||||
            INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
 | 
					            INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
 | 
				
			||||||
            net.fit(f.dup(), l);
 | 
					            net.fit(f.dup(), l);
 | 
				
			||||||
            netD.fit(f.dup(), l);
 | 
					            netD.fit(f.dup(), l);
 | 
				
			||||||
            assertNotEquals(net.params(), netD.params(), s);
 | 
					            assertNotEquals(net.getModelParams(), netD.getModelParams(), s);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            netD2.fit(f.dup(), l);
 | 
					            netD2.fit(f.dup(), l);
 | 
				
			||||||
            netD2.fit(f.dup(), l);
 | 
					            netD2.fit(f.dup(), l);
 | 
				
			||||||
 | 
				
			|||||||
@ -115,7 +115,7 @@ public class TestTimeDistributed extends BaseDL4JTest {
 | 
				
			|||||||
                    net1.fit(ds);
 | 
					                    net1.fit(ds);
 | 
				
			||||||
                    net2.fit(ds);
 | 
					                    net2.fit(ds);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    assertEquals(net1.params(), net2.params());
 | 
					                    assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2);
 | 
					                    MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2);
 | 
				
			||||||
                    out2 = net2.output(in);
 | 
					                    out2 = net2.output(in);
 | 
				
			||||||
 | 
				
			|||||||
@ -124,10 +124,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
 | 
				
			|||||||
                    MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
					                    MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
                    net2.init();
 | 
					                    net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    net.params().assign(net2.params());
 | 
					                    net.getModelParams().assign(net2.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    //Check params:
 | 
					                    //Check params:
 | 
				
			||||||
                    assertEquals(net2.params(), net.params());
 | 
					                    assertEquals(net2.getModelParams(), net.getModelParams());
 | 
				
			||||||
                    Map<String, INDArray> params1 = net.getParamTable();
 | 
					                    Map<String, INDArray> params1 = net.getParamTable();
 | 
				
			||||||
                    Map<String, INDArray> params2 = net2.getParamTable();
 | 
					                    Map<String, INDArray> params2 = net2.getParamTable();
 | 
				
			||||||
                    assertEquals(params2, params1);
 | 
					                    assertEquals(params2, params1);
 | 
				
			||||||
@ -209,10 +209,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
 | 
				
			|||||||
                    MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
					                    MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
                    net2.init();
 | 
					                    net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    assertEquals(net2.params(), net.params());
 | 
					                    assertEquals(net2.getModelParams(), net.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    //Check params:
 | 
					                    //Check params:
 | 
				
			||||||
                    assertEquals(net2.params(), net.params());
 | 
					                    assertEquals(net2.getModelParams(), net.getModelParams());
 | 
				
			||||||
                    Map<String, INDArray> params1 = net.getParamTable();
 | 
					                    Map<String, INDArray> params1 = net.getParamTable();
 | 
				
			||||||
                    Map<String, INDArray> params2 = net2.getParamTable();
 | 
					                    Map<String, INDArray> params2 = net2.getParamTable();
 | 
				
			||||||
                    assertEquals(params2, params1);
 | 
					                    assertEquals(params2, params1);
 | 
				
			||||||
@ -287,10 +287,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
 | 
				
			|||||||
                    MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2);
 | 
					                    MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2);
 | 
				
			||||||
                    netStandard.init();
 | 
					                    netStandard.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    netSD.params().assign(netStandard.params());
 | 
					                    netSD.getModelParams().assign(netStandard.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    //Check params:
 | 
					                    //Check params:
 | 
				
			||||||
                    assertEquals(netStandard.params(), netSD.params());
 | 
					                    assertEquals(netStandard.getModelParams(), netSD.getModelParams());
 | 
				
			||||||
                    assertEquals(netStandard.getParamTable(), netSD.getParamTable());
 | 
					                    assertEquals(netStandard.getParamTable(), netSD.getParamTable());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    INDArray in = Nd4j.rand(minibatch, nIn);
 | 
					                    INDArray in = Nd4j.rand(minibatch, nIn);
 | 
				
			||||||
@ -379,10 +379,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
 | 
				
			|||||||
            MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2);
 | 
					            MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2);
 | 
				
			||||||
            netStandard.init();
 | 
					            netStandard.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            netSD.params().assign(netStandard.params());
 | 
					            netSD.getModelParams().assign(netStandard.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //Check params:
 | 
					            //Check params:
 | 
				
			||||||
            assertEquals(netStandard.params(), netSD.params());
 | 
					            assertEquals(netStandard.getModelParams(), netSD.getModelParams());
 | 
				
			||||||
            assertEquals(netStandard.getParamTable(), netSD.getParamTable());
 | 
					            assertEquals(netStandard.getParamTable(), netSD.getParamTable());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            DataSetIterator iter = new IrisDataSetIterator(150, 150);
 | 
					            DataSetIterator iter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
@ -398,7 +398,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
 | 
				
			|||||||
                netStandard.fit(ds);
 | 
					                netStandard.fit(ds);
 | 
				
			||||||
                String s = String.valueOf(i);
 | 
					                String s = String.valueOf(i);
 | 
				
			||||||
                assertEquals( netStandard.getFlattenedGradients(), netSD.getFlattenedGradients(), s);
 | 
					                assertEquals( netStandard.getFlattenedGradients(), netSD.getFlattenedGradients(), s);
 | 
				
			||||||
                assertEquals( netStandard.params(), netSD.params(), s);
 | 
					                assertEquals( netStandard.getModelParams(), netSD.getModelParams(), s);
 | 
				
			||||||
                assertEquals( netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray(), s);
 | 
					                assertEquals( netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray(), s);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -100,10 +100,10 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
 | 
				
			|||||||
                    ComputationGraph netStandard = new ComputationGraph(conf2);
 | 
					                    ComputationGraph netStandard = new ComputationGraph(conf2);
 | 
				
			||||||
                    netStandard.init();
 | 
					                    netStandard.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    netSD.params().assign(netStandard.params());
 | 
					                    netSD.getModelParams().assign(netStandard.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    //Check params:
 | 
					                    //Check params:
 | 
				
			||||||
                    assertEquals(netStandard.params(), netSD.params());
 | 
					                    assertEquals(netStandard.getModelParams(), netSD.getModelParams());
 | 
				
			||||||
                    assertEquals(netStandard.getParamTable(), netSD.getParamTable());
 | 
					                    assertEquals(netStandard.getParamTable(), netSD.getParamTable());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    INDArray in = Nd4j.rand(minibatch, nIn);
 | 
					                    INDArray in = Nd4j.rand(minibatch, nIn);
 | 
				
			||||||
@ -160,7 +160,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
 | 
				
			|||||||
                        netStandard.fit(ds);
 | 
					                        netStandard.fit(ds);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        assertEquals(netStandard.getParamTable(), netSD.getParamTable());
 | 
					                        assertEquals(netStandard.getParamTable(), netSD.getParamTable());
 | 
				
			||||||
                        assertEquals(netStandard.params(), netSD.params());
 | 
					                        assertEquals(netStandard.getModelParams(), netSD.getModelParams());
 | 
				
			||||||
                        assertEquals(netStandard.getFlattenedGradients(), netSD.getFlattenedGradients());
 | 
					                        assertEquals(netStandard.getFlattenedGradients(), netSD.getFlattenedGradients());
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -98,7 +98,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
 | 
				
			|||||||
            ComputationGraph std = new ComputationGraph(confStd);
 | 
					            ComputationGraph std = new ComputationGraph(confStd);
 | 
				
			||||||
            std.init();
 | 
					            std.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            lambda.setParams(std.params());
 | 
					            lambda.setParams(std.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            INDArray in = Nd4j.rand(3, 5);
 | 
					            INDArray in = Nd4j.rand(3, 5);
 | 
				
			||||||
            INDArray labels = TestUtils.randomOneHot(3, 5);
 | 
					            INDArray labels = TestUtils.randomOneHot(3, 5);
 | 
				
			||||||
@ -119,7 +119,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
 | 
				
			|||||||
                std.fit(ds);
 | 
					                std.fit(ds);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                String s = String.valueOf(i);
 | 
					                String s = String.valueOf(i);
 | 
				
			||||||
                assertEquals(std.params(), lambda.params(), s);
 | 
					                assertEquals(std.getModelParams(), lambda.getModelParams(), s);
 | 
				
			||||||
                assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s);
 | 
					                assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -182,7 +182,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
 | 
				
			|||||||
            ComputationGraph std = new ComputationGraph(confStd);
 | 
					            ComputationGraph std = new ComputationGraph(confStd);
 | 
				
			||||||
            std.init();
 | 
					            std.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            lambda.setParams(std.params());
 | 
					            lambda.setParams(std.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            INDArray in1 = Nd4j.rand(3, 5);
 | 
					            INDArray in1 = Nd4j.rand(3, 5);
 | 
				
			||||||
            INDArray in2 = Nd4j.rand(3, 5);
 | 
					            INDArray in2 = Nd4j.rand(3, 5);
 | 
				
			||||||
@ -204,7 +204,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
 | 
				
			|||||||
                std.fit(mds);
 | 
					                std.fit(mds);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                String s = String.valueOf(i);
 | 
					                String s = String.valueOf(i);
 | 
				
			||||||
                assertEquals(std.params(), lambda.params(), s);
 | 
					                assertEquals(std.getModelParams(), lambda.getModelParams(), s);
 | 
				
			||||||
                assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s);
 | 
					                assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -85,7 +85,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
 | 
				
			|||||||
            netSD.fit(ds);
 | 
					            netSD.fit(ds);
 | 
				
			||||||
            netStd.fit(ds);
 | 
					            netStd.fit(ds);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(netStd.params(), netSD.params());
 | 
					            assertEquals(netStd.getModelParams(), netSD.getModelParams());
 | 
				
			||||||
            assertEquals(netStd.getFlattenedGradients(), netSD.getFlattenedGradients());
 | 
					            assertEquals(netStd.getFlattenedGradients(), netSD.getFlattenedGradients());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -131,7 +131,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
 | 
				
			|||||||
            MultiLayerNetwork netStd = new MultiLayerNetwork(confStd);
 | 
					            MultiLayerNetwork netStd = new MultiLayerNetwork(confStd);
 | 
				
			||||||
            netStd.init();
 | 
					            netStd.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            netSD.params().assign(netStd.params());
 | 
					            netSD.getModelParams().assign(netStd.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(netStd.getParamTable(), netSD.getParamTable());
 | 
					            assertEquals(netStd.getParamTable(), netSD.getParamTable());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -165,7 +165,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
 | 
				
			|||||||
                netSD.fit(ds);
 | 
					                netSD.fit(ds);
 | 
				
			||||||
                netStd.fit(ds);
 | 
					                netStd.fit(ds);
 | 
				
			||||||
                String s = String.valueOf(i);
 | 
					                String s = String.valueOf(i);
 | 
				
			||||||
                assertEquals( netStd.params(), netSD.params(), s);
 | 
					                assertEquals( netStd.getModelParams(), netSD.getModelParams(), s);
 | 
				
			||||||
                assertEquals( netStd.getFlattenedGradients(), netSD.getFlattenedGradients(),s );
 | 
					                assertEquals( netStd.getFlattenedGradients(), netSD.getFlattenedGradients(),s );
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -77,7 +77,7 @@ public class TestVAE extends BaseDL4JTest {
 | 
				
			|||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        System.out.println("Exp num params: " + expNumParams);
 | 
					        System.out.println("Exp num params: " + expNumParams);
 | 
				
			||||||
        assertEquals(expNumParams, net.getLayer(0).params().length());
 | 
					        assertEquals(expNumParams, net.getLayer(0).getParams().length());
 | 
				
			||||||
        Map<String, INDArray> paramTable = net.getLayer(0).getParamTable();
 | 
					        Map<String, INDArray> paramTable = net.getLayer(0).getParamTable();
 | 
				
			||||||
        int count = 0;
 | 
					        int count = 0;
 | 
				
			||||||
        for (INDArray arr : paramTable.values()) {
 | 
					        for (INDArray arr : paramTable.values()) {
 | 
				
			||||||
 | 
				
			|||||||
@ -79,7 +79,7 @@ public class CloseNetworkTests extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                net.close();
 | 
					                net.close();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                assertTrue(net.params().wasClosed());
 | 
					                assertTrue(net.getModelParams().wasClosed());
 | 
				
			||||||
                if(train) {
 | 
					                if(train) {
 | 
				
			||||||
                    assertTrue(net.getGradientsViewArray().wasClosed());
 | 
					                    assertTrue(net.getGradientsViewArray().wasClosed());
 | 
				
			||||||
                    Updater u = net.getUpdater(false);
 | 
					                    Updater u = net.getUpdater(false);
 | 
				
			||||||
@ -127,7 +127,7 @@ public class CloseNetworkTests extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                net.close();
 | 
					                net.close();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                assertTrue(net.params().wasClosed());
 | 
					                assertTrue(net.getModelParams().wasClosed());
 | 
				
			||||||
                if(train) {
 | 
					                if(train) {
 | 
				
			||||||
                    assertTrue(net.getGradientsViewArray().wasClosed());
 | 
					                    assertTrue(net.getGradientsViewArray().wasClosed());
 | 
				
			||||||
                    Updater u = net.getUpdater(false);
 | 
					                    Updater u = net.getUpdater(false);
 | 
				
			||||||
 | 
				
			|||||||
@ -57,7 +57,7 @@ public class LargeNetTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray params = net.params();
 | 
					        INDArray params = net.getModelParams();
 | 
				
			||||||
        long paramsLength = params.length();
 | 
					        long paramsLength = params.length();
 | 
				
			||||||
        long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10;
 | 
					        long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10;
 | 
				
			||||||
        assertEquals(expParamsLength, paramsLength);
 | 
					        assertEquals(expParamsLength, paramsLength);
 | 
				
			||||||
@ -91,7 +91,7 @@ public class LargeNetTest extends BaseDL4JTest {
 | 
				
			|||||||
        ComputationGraph net = new ComputationGraph(conf);
 | 
					        ComputationGraph net = new ComputationGraph(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray params = net.params();
 | 
					        INDArray params = net.getModelParams();
 | 
				
			||||||
        long paramsLength = params.length();
 | 
					        long paramsLength = params.length();
 | 
				
			||||||
        long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10;
 | 
					        long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10;
 | 
				
			||||||
        assertEquals(expParamsLength, paramsLength);
 | 
					        assertEquals(expParamsLength, paramsLength);
 | 
				
			||||||
 | 
				
			|||||||
@ -76,7 +76,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
        net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
					        net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
				
			||||||
        conf2.setIterationCount(conf.getIterationCount());
 | 
					        conf2.setIterationCount(conf.getIterationCount());
 | 
				
			||||||
        net2.setParams(net.params().dup());
 | 
					        net2.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(0.1, net.getLearningRate(0).doubleValue(), 0.0);
 | 
					        assertEquals(0.1, net.getLearningRate(0).doubleValue(), 0.0);
 | 
				
			||||||
        net.setLearningRate(0, 0.5);  //Set LR for layer 0 to 0.5
 | 
					        net.setLearningRate(0, 0.5);  //Set LR for layer 0 to 0.5
 | 
				
			||||||
@ -96,7 +96,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
            net2.fit(in, l);
 | 
					            net2.fit(in, l);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net.params(), net2.params());
 | 
					        assertEquals(net.getModelParams(), net2.getModelParams());
 | 
				
			||||||
        assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
 | 
					        assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray in1 = Nd4j.rand(10, 10);
 | 
					        INDArray in1 = Nd4j.rand(10, 10);
 | 
				
			||||||
@ -110,7 +110,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
        net2.setLabels(l1);
 | 
					        net2.setLabels(l1);
 | 
				
			||||||
        net2.computeGradientAndScore();
 | 
					        net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net.score(), net2.score(), 1e-8);
 | 
					        assertEquals(net.getScore(), net2.getScore(), 1e-8);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Now: Set *all* LRs to say 0.3...
 | 
					        //Now: Set *all* LRs to say 0.3...
 | 
				
			||||||
@ -126,7 +126,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
        net3.init();
 | 
					        net3.init();
 | 
				
			||||||
        net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
					        net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
				
			||||||
        conf3.setIterationCount(conf.getIterationCount());
 | 
					        conf3.setIterationCount(conf.getIterationCount());
 | 
				
			||||||
        net3.setParams(net.params().dup());
 | 
					        net3.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setLearningRate(0.3);
 | 
					        net.setLearningRate(0.3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -139,7 +139,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
            net3.fit(in, l);
 | 
					            net3.fit(in, l);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net.params(), net3.params());
 | 
					        assertEquals(net.getModelParams(), net3.getModelParams());
 | 
				
			||||||
        assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
 | 
					        assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -206,7 +206,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
        net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
					        net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
				
			||||||
        conf2.setIterationCount(conf.getIterationCount());
 | 
					        conf2.setIterationCount(conf.getIterationCount());
 | 
				
			||||||
        net2.setParams(net.params().dup());
 | 
					        net2.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 ));  //Set LR for layer 0 to 0.5
 | 
					        net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 ));  //Set LR for layer 0 to 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -224,7 +224,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
            net2.fit(in, l);
 | 
					            net2.fit(in, l);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net.params(), net2.params());
 | 
					        assertEquals(net.getModelParams(), net2.getModelParams());
 | 
				
			||||||
        assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
 | 
					        assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -270,7 +270,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
        net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
					        net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
				
			||||||
        conf2.setIterationCount(conf.getIterationCount());
 | 
					        conf2.setIterationCount(conf.getIterationCount());
 | 
				
			||||||
        net2.setParams(net.params().dup());
 | 
					        net2.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(0.1, net.getLearningRate("0").doubleValue(), 0.0);
 | 
					        assertEquals(0.1, net.getLearningRate("0").doubleValue(), 0.0);
 | 
				
			||||||
        net.setLearningRate("0", 0.5);  //Set LR for layer 0 to 0.5
 | 
					        net.setLearningRate("0", 0.5);  //Set LR for layer 0 to 0.5
 | 
				
			||||||
@ -290,7 +290,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
            net2.fit(new DataSet(in, l));
 | 
					            net2.fit(new DataSet(in, l));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net.params(), net2.params());
 | 
					        assertEquals(net.getModelParams(), net2.getModelParams());
 | 
				
			||||||
        assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
 | 
					        assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray in1 = Nd4j.rand(10, 10);
 | 
					        INDArray in1 = Nd4j.rand(10, 10);
 | 
				
			||||||
@ -304,7 +304,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
        net2.setLabels(l1);
 | 
					        net2.setLabels(l1);
 | 
				
			||||||
        net2.computeGradientAndScore();
 | 
					        net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net.score(), net2.score(), 1e-8);
 | 
					        assertEquals(net.getScore(), net2.getScore(), 1e-8);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Now: Set *all* LRs to say 0.3...
 | 
					        //Now: Set *all* LRs to say 0.3...
 | 
				
			||||||
@ -320,7 +320,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
        net3.init();
 | 
					        net3.init();
 | 
				
			||||||
        net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
					        net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
				
			||||||
        conf3.setIterationCount(conf.getIterationCount());
 | 
					        conf3.setIterationCount(conf.getIterationCount());
 | 
				
			||||||
        net3.setParams(net.params().dup());
 | 
					        net3.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setLearningRate(0.3);
 | 
					        net.setLearningRate(0.3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -333,7 +333,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
            net3.fit(new DataSet(in, l));
 | 
					            net3.fit(new DataSet(in, l));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net.params(), net3.params());
 | 
					        assertEquals(net.getModelParams(), net3.getModelParams());
 | 
				
			||||||
        assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
 | 
					        assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -375,7 +375,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
        net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
					        net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
 | 
				
			||||||
        conf2.setIterationCount(conf.getIterationCount());
 | 
					        conf2.setIterationCount(conf.getIterationCount());
 | 
				
			||||||
        net2.setParams(net.params().dup());
 | 
					        net2.setParams(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 ));  //Set LR for layer 0 to 0.5
 | 
					        net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 ));  //Set LR for layer 0 to 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -393,7 +393,7 @@ public class TestLrChanges extends BaseDL4JTest {
 | 
				
			|||||||
            net2.fit(new DataSet(in, l));
 | 
					            net2.fit(new DataSet(in, l));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net.params(), net2.params());
 | 
					        assertEquals(net.getModelParams(), net2.getModelParams());
 | 
				
			||||||
        assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
 | 
					        assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -77,14 +77,14 @@ public class TestNetConversion extends BaseDL4JTest {
 | 
				
			|||||||
            n.computeGradientAndScore();
 | 
					            n.computeGradientAndScore();
 | 
				
			||||||
            cg.computeGradientAndScore();
 | 
					            cg.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(n.score(), cg.score(), 1e-6);
 | 
					            assertEquals(n.getScore(), cg.getScore(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(n.gradient().gradient(), cg.gradient().gradient());
 | 
					            assertEquals(n.gradient().gradient(), cg.gradient().gradient());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            n.fit(in, labels);
 | 
					            n.fit(in, labels);
 | 
				
			||||||
            cg.fit(new INDArray[]{in}, new INDArray[]{labels});
 | 
					            cg.fit(new INDArray[]{in}, new INDArray[]{labels});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(n.params(), cg.params());
 | 
					            assertEquals(n.getModelParams(), cg.getModelParams());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -476,7 +476,7 @@ public class WorkspaceTests extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            final ComputationGraph computationGraph = new ComputationGraph(config);
 | 
					            final ComputationGraph computationGraph = new ComputationGraph(config);
 | 
				
			||||||
            computationGraph.init();
 | 
					            computationGraph.init();
 | 
				
			||||||
            computationGraph.setListeners(new ScoreIterationListener(3));
 | 
					            computationGraph.addTrainingListeners(new ScoreIterationListener(3));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            WSTestDataSetIterator iterator = new WSTestDataSetIterator();
 | 
					            WSTestDataSetIterator iterator = new WSTestDataSetIterator();
 | 
				
			||||||
            computationGraph.fit(iterator);
 | 
					            computationGraph.fit(iterator);
 | 
				
			||||||
 | 
				
			|||||||
@ -54,7 +54,7 @@ public class BackPropMLPTest extends BaseDL4JTest {
 | 
				
			|||||||
    public void testMLPTrivial() {
 | 
					    public void testMLPTrivial() {
 | 
				
			||||||
        //Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1.
 | 
					        //Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1.
 | 
				
			||||||
        MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] {1}, Activation.SIGMOID));
 | 
					        MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] {1}, Activation.SIGMOID));
 | 
				
			||||||
        network.setListeners(new ScoreIterationListener(1));
 | 
					        network.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
        network.init();
 | 
					        network.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator iter = new IrisDataSetIterator(1, 10);
 | 
					        DataSetIterator iter = new IrisDataSetIterator(1, 10);
 | 
				
			||||||
 | 
				
			|||||||
@ -64,7 +64,7 @@ import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.ActivationLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.AutoEncoder;
 | 
					import org.deeplearning4j.nn.conf.layers.AutoEncoder;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
 | 
					import org.deeplearning4j.nn.conf.layers.BatchNormalization;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
@ -184,13 +184,13 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
    MultiLayerNetwork network3 = new MultiLayerNetwork(conf);
 | 
					    MultiLayerNetwork network3 = new MultiLayerNetwork(conf);
 | 
				
			||||||
    network3.init();
 | 
					    network3.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    INDArray params = network3.params();
 | 
					    INDArray params = network3.getModelParams();
 | 
				
			||||||
    INDArray weights = network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY).dup();
 | 
					    INDArray weights = network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY).dup();
 | 
				
			||||||
    INDArray bias = network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY).dup();
 | 
					    INDArray bias = network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY).dup();
 | 
				
			||||||
    network3.setParameters(params);
 | 
					    network3.setParameters(params);
 | 
				
			||||||
    assertEquals(weights, network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY));
 | 
					    assertEquals(weights, network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY));
 | 
				
			||||||
    assertEquals(bias, network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY));
 | 
					    assertEquals(bias, network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY));
 | 
				
			||||||
    INDArray params4 = network3.params();
 | 
					    INDArray params4 = network3.getModelParams();
 | 
				
			||||||
    assertEquals(params, params4);
 | 
					    assertEquals(params, params4);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -211,7 +211,7 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    MultiLayerNetwork network = new MultiLayerNetwork(conf);
 | 
					    MultiLayerNetwork network = new MultiLayerNetwork(conf);
 | 
				
			||||||
    network.init();
 | 
					    network.init();
 | 
				
			||||||
    network.setListeners(new ScoreIterationListener(1));
 | 
					    network.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
 | 
					    DataSetIterator iter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -242,7 +242,7 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    MultiLayerNetwork network = new MultiLayerNetwork(conf);
 | 
					    MultiLayerNetwork network = new MultiLayerNetwork(conf);
 | 
				
			||||||
    network.init();
 | 
					    network.init();
 | 
				
			||||||
    network.setListeners(new ScoreIterationListener(1));
 | 
					    network.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
 | 
					    DataSetIterator iter = new IrisDataSetIterator(150, 150);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -330,7 +330,7 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
 | 
					    MultiLayerNetwork model = new MultiLayerNetwork(conf);
 | 
				
			||||||
    model.init();
 | 
					    model.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model.addListeners(new ScoreIterationListener(listenerFreq));
 | 
					    model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    log.info("Train model....");
 | 
					    log.info("Train model....");
 | 
				
			||||||
    int cnt = 0;
 | 
					    int cnt = 0;
 | 
				
			||||||
@ -503,7 +503,7 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    assertEquals(layerNameList.get(0), net.getLayer(0).getLayerConfiguration().getLayerName());
 | 
					    assertEquals(layerNameList.get(0), net.getLayer(0).getLayerConfiguration().getLayerName());
 | 
				
			||||||
    assertEquals(layerNameList, net.getLayerNames());
 | 
					    assertEquals(layerNameList, net.getLayerNames());
 | 
				
			||||||
    BaseLayer b = (BaseLayer) net.getLayer(layerNameList.get(2)).getLayerConfiguration();
 | 
					    BaseLayerConfiguration b = (BaseLayerConfiguration) net.getLayer(layerNameList.get(2)).getLayerConfiguration();
 | 
				
			||||||
    assertEquals("softmax", b.getActivationFn().toString());
 | 
					    assertEquals("softmax", b.getActivationFn().toString());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -535,7 +535,7 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    MultiLayerNetwork netNoReg = new MultiLayerNetwork(confNoReg);
 | 
					    MultiLayerNetwork netNoReg = new MultiLayerNetwork(confNoReg);
 | 
				
			||||||
    netNoReg.init();
 | 
					    netNoReg.init();
 | 
				
			||||||
    netNoReg.setParameters(net.params().dup());
 | 
					    netNoReg.setParameters(net.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //Score single example, and compare to scoreExamples:
 | 
					    //Score single example, and compare to scoreExamples:
 | 
				
			||||||
    INDArray input = Nd4j.rand(3, nIn);
 | 
					    INDArray input = Nd4j.rand(3, nIn);
 | 
				
			||||||
@ -703,7 +703,7 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					    MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
    net.init();
 | 
					    net.init();
 | 
				
			||||||
    net.fit(iter.next());
 | 
					    net.fit(iter.next());
 | 
				
			||||||
    // TODO validate actual layer gradientView - issue getting var out of BaseLayer w/o adding MLN getter that gets confused with local gradient vars
 | 
					    // TODO validate actual layer gradientView - issue getting var out of BaseLayerConfiguration w/o adding MLN getter that gets confused with local gradient vars
 | 
				
			||||||
    Gradient actualGradient = net.gradient;
 | 
					    Gradient actualGradient = net.gradient;
 | 
				
			||||||
    assertNotEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W"));
 | 
					    assertNotEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -716,13 +716,13 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
    net.setParam("0_b", Nd4j.ones(1, 5));
 | 
					    net.setParam("0_b", Nd4j.ones(1, 5));
 | 
				
			||||||
    net.setParam("1_W", Nd4j.ones(5, 3));
 | 
					    net.setParam("1_W", Nd4j.ones(5, 3));
 | 
				
			||||||
    net.setParam("1_b", Nd4j.ones(1, 3));
 | 
					    net.setParam("1_b", Nd4j.ones(1, 3));
 | 
				
			||||||
    INDArray actualParams = net.params();
 | 
					    INDArray actualParams = net.getModelParams();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Confirm params
 | 
					    // Confirm params
 | 
				
			||||||
    assertEquals(expectedGradient.gradient(), actualParams);
 | 
					    assertEquals(expectedGradient.gradient(), actualParams);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    net.update(expectedGradient);
 | 
					    net.update(expectedGradient);
 | 
				
			||||||
    actualParams = net.params();
 | 
					    actualParams = net.getModelParams();
 | 
				
			||||||
    assertEquals(Nd4j.ones(1, 43).addi(1), actualParams);
 | 
					    assertEquals(Nd4j.ones(1, 43).addi(1), actualParams);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -762,7 +762,7 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
    MultiLayerNetwork aePre = getAeModel(true, nIn, nOut);
 | 
					    MultiLayerNetwork aePre = getAeModel(true, nIn, nOut);
 | 
				
			||||||
    int actualNP = (int) aePre.numParams();
 | 
					    int actualNP = (int) aePre.numParams();
 | 
				
			||||||
    assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP);
 | 
					    assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP);
 | 
				
			||||||
    INDArray params = aePre.params();
 | 
					    INDArray params = aePre.getModelParams();
 | 
				
			||||||
    assertEquals(params.length(), actualNP); // check num params
 | 
					    assertEquals(params.length(), actualNP); // check num params
 | 
				
			||||||
    Map<String, INDArray> paramTable = aePre.getParamTable();
 | 
					    Map<String, INDArray> paramTable = aePre.getParamTable();
 | 
				
			||||||
    assertTrue(paramTable.containsKey("0_vb")); // check vb exists for pretrain layer
 | 
					    assertTrue(paramTable.containsKey("0_vb")); // check vb exists for pretrain layer
 | 
				
			||||||
@ -774,7 +774,7 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
    MultiLayerNetwork aeNoPre = getAeModel(false, nIn, nOut);
 | 
					    MultiLayerNetwork aeNoPre = getAeModel(false, nIn, nOut);
 | 
				
			||||||
    actualNP = (int) aeNoPre.numParams();
 | 
					    actualNP = (int) aeNoPre.numParams();
 | 
				
			||||||
    assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP);
 | 
					    assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP);
 | 
				
			||||||
    params = aeNoPre.params();
 | 
					    params = aeNoPre.getModelParams();
 | 
				
			||||||
    assertEquals(params.length(), actualNP);
 | 
					    assertEquals(params.length(), actualNP);
 | 
				
			||||||
    paramTable = aePre.getParamTable();
 | 
					    paramTable = aePre.getParamTable();
 | 
				
			||||||
    assertTrue(paramTable.containsKey("0_vb"));
 | 
					    assertTrue(paramTable.containsKey("0_vb"));
 | 
				
			||||||
@ -865,14 +865,14 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
    MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
					    MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
    net2.init();
 | 
					    net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    BaseLayer bl0 = (BaseLayer) net2.getLayer(0).getLayerConfiguration();
 | 
					    BaseLayerConfiguration bl0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
 | 
				
			||||||
    assertEquals(0.1, TestUtils.getL1(bl0.getRegularizationBias()), 1e-6);
 | 
					    assertEquals(0.1, TestUtils.getL1(bl0.getRegularizationBias()), 1e-6);
 | 
				
			||||||
    assertEquals(0.2, TestUtils.getL2(bl0.getRegularizationBias()), 1e-6);
 | 
					    assertEquals(0.2, TestUtils.getL2(bl0.getRegularizationBias()), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    INDArray features = Nd4j.rand(10, 10);
 | 
					    INDArray features = Nd4j.rand(10, 10);
 | 
				
			||||||
    INDArray labels = Nd4j.rand(10, 10);
 | 
					    INDArray labels = Nd4j.rand(10, 10);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    net2.setParams(net1.params().dup());
 | 
					    net2.setParams(net1.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    net1.setInput(features);
 | 
					    net1.setInput(features);
 | 
				
			||||||
    net1.setLabels(labels);
 | 
					    net1.setLabels(labels);
 | 
				
			||||||
@ -888,15 +888,15 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
    r = net2.calcRegularizationScore(true);
 | 
					    r = net2.calcRegularizationScore(true);
 | 
				
			||||||
    assertEquals(0.0, r, 0.0);
 | 
					    assertEquals(0.0, r, 0.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    double s1 = net1.score();
 | 
					    double s1 = net1.getScore();
 | 
				
			||||||
    double s2 = net2.score();
 | 
					    double s2 = net2.getScore();
 | 
				
			||||||
    assertEquals(s1, s2, 1e-6); //Biases initialized to 0 -> should initially have same score
 | 
					    assertEquals(s1, s2, 1e-6); //Biases initialized to 0 -> should initially have same score
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (int i = 0; i < 10; i++) {
 | 
					    for (int i = 0; i < 10; i++) {
 | 
				
			||||||
      net1.fit(features, labels);
 | 
					      net1.fit(features, labels);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    net2.setParams(net1.params().dup());
 | 
					    net2.setParams(net1.getModelParams().dup());
 | 
				
			||||||
    net1.computeGradientAndScore();
 | 
					    net1.computeGradientAndScore();
 | 
				
			||||||
    net2.computeGradientAndScore();
 | 
					    net2.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -906,8 +906,8 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
    r = net2.calcRegularizationScore(true);
 | 
					    r = net2.calcRegularizationScore(true);
 | 
				
			||||||
    assertTrue(r > 0.0);
 | 
					    assertTrue(r > 0.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    s1 = net1.score();
 | 
					    s1 = net1.getScore();
 | 
				
			||||||
    s2 = net2.score();
 | 
					    s2 = net2.getScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assertNotEquals(s1, s2, 1e-6); //Scores should differ due to bias l1/l2
 | 
					    assertNotEquals(s1, s2, 1e-6); //Scores should differ due to bias l1/l2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1022,11 +1022,11 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
    MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
					    MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
 | 
				
			||||||
    net2.init();
 | 
					    net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assertNotEquals(net1.params(), net2.params());
 | 
					    assertNotEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
    assertNotEquals(net1.getParamTable(), net2.getParamTable());
 | 
					    assertNotEquals(net1.getParamTable(), net2.getParamTable());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    net1.setParamTable(net2.getParamTable());
 | 
					    net1.setParamTable(net2.getParamTable());
 | 
				
			||||||
    assertEquals(net1.params(), net2.params());
 | 
					    assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
    assertEquals(net1.getParamTable(), net2.getParamTable());
 | 
					    assertEquals(net1.getParamTable(), net2.getParamTable());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1412,7 +1412,7 @@ public class MultiLayerTest extends BaseDL4JTest {
 | 
				
			|||||||
    exp.add(MultiLayerNetwork.class);
 | 
					    exp.add(MultiLayerNetwork.class);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    CheckModelsListener listener = new CheckModelsListener();
 | 
					    CheckModelsListener listener = new CheckModelsListener();
 | 
				
			||||||
    net.setListeners(listener);
 | 
					    net.addTrainingListeners(listener);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    INDArray f = Nd4j.create(1, 10);
 | 
					    INDArray f = Nd4j.create(1, 10);
 | 
				
			||||||
    INDArray l = Nd4j.create(1, 10);
 | 
					    INDArray l = Nd4j.create(1, 10);
 | 
				
			||||||
 | 
				
			|||||||
@ -753,9 +753,9 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        DataSet ds = new DataSet(features, labels, maskArrayInput, maskArrayOutput);
 | 
					        DataSet ds = new DataSet(features, labels, maskArrayInput, maskArrayOutput);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray initialParams = mln.params().dup();
 | 
					        INDArray initialParams = mln.getModelParams().dup();
 | 
				
			||||||
        mln.fit(ds);
 | 
					        mln.fit(ds);
 | 
				
			||||||
        INDArray afterParams = mln.params();
 | 
					        INDArray afterParams = mln.getModelParams();
 | 
				
			||||||
        assertNotEquals(initialParams, afterParams);
 | 
					        assertNotEquals(initialParams, afterParams);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -172,7 +172,7 @@ public class TestMasking extends BaseDL4JTest {
 | 
				
			|||||||
                net.setLabels(labels);
 | 
					                net.setLabels(labels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                net.computeGradientAndScore();
 | 
					                net.computeGradientAndScore();
 | 
				
			||||||
                double score1 = net.score();
 | 
					                double score1 = net.getScore();
 | 
				
			||||||
                INDArray grad1 = net.gradient().gradient();
 | 
					                INDArray grad1 = net.gradient().gradient();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                //Now: change the label values for the masked steps. The
 | 
					                //Now: change the label values for the masked steps. The
 | 
				
			||||||
@ -187,7 +187,7 @@ public class TestMasking extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                assertNotEquals(labels, newLabels);
 | 
					                assertNotEquals(labels, newLabels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                double score2 = net.score();
 | 
					                double score2 = net.getScore();
 | 
				
			||||||
                INDArray grad2 = net.gradient().gradient();
 | 
					                INDArray grad2 = net.gradient().gradient();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                assertEquals(score1, score2, 1e-6);
 | 
					                assertEquals(score1, score2, 1e-6);
 | 
				
			||||||
@ -214,7 +214,7 @@ public class TestMasking extends BaseDL4JTest {
 | 
				
			|||||||
                graph.setLabels(labels);
 | 
					                graph.setLabels(labels);
 | 
				
			||||||
                graph.computeGradientAndScore();
 | 
					                graph.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                double gScore1 = graph.score();
 | 
					                double gScore1 = graph.getScore();
 | 
				
			||||||
                INDArray gGrad1 = graph.gradient().gradient();
 | 
					                INDArray gGrad1 = graph.gradient().gradient();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                graph.setLayerMaskArrays(null, new INDArray[] {labelMask});
 | 
					                graph.setLayerMaskArrays(null, new INDArray[] {labelMask});
 | 
				
			||||||
@ -222,7 +222,7 @@ public class TestMasking extends BaseDL4JTest {
 | 
				
			|||||||
                graph.setLabels(newLabels);
 | 
					                graph.setLabels(newLabels);
 | 
				
			||||||
                graph.computeGradientAndScore();
 | 
					                graph.computeGradientAndScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                double gScore2 = graph.score();
 | 
					                double gScore2 = graph.getScore();
 | 
				
			||||||
                INDArray gGrad2 = graph.gradient().gradient();
 | 
					                INDArray gGrad2 = graph.gradient().gradient();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                assertEquals(gScore1, gScore2, 1e-6);
 | 
					                assertEquals(gScore1, gScore2, 1e-6);
 | 
				
			||||||
 | 
				
			|||||||
@ -53,12 +53,12 @@ public class TestSetGetParameters extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray initParams = net.params().dup();
 | 
					        INDArray initParams = net.getModelParams().dup();
 | 
				
			||||||
        Map<String, INDArray> initParams2 = net.getParamTable();
 | 
					        Map<String, INDArray> initParams2 = net.getParamTable();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setParams(net.params());
 | 
					        net.setParams(net.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray initParamsAfter = net.params();
 | 
					        INDArray initParamsAfter = net.getModelParams();
 | 
				
			||||||
        Map<String, INDArray> initParams2After = net.getParamTable();
 | 
					        Map<String, INDArray> initParams2After = net.getParamTable();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (String s : initParams2.keySet()) {
 | 
					        for (String s : initParams2.keySet()) {
 | 
				
			||||||
@ -71,7 +71,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
 | 
				
			|||||||
        INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape());
 | 
					        INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape());
 | 
				
			||||||
        net.setParams(randomParams.dup());
 | 
					        net.setParams(randomParams.dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net.params(), randomParams);
 | 
					        assertEquals(net.getModelParams(), randomParams);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -90,12 +90,12 @@ public class TestSetGetParameters extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray initParams = net.params().dup();
 | 
					        INDArray initParams = net.getModelParams().dup();
 | 
				
			||||||
        Map<String, INDArray> initParams2 = net.getParamTable();
 | 
					        Map<String, INDArray> initParams2 = net.getParamTable();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setParams(net.params());
 | 
					        net.setParams(net.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray initParamsAfter = net.params();
 | 
					        INDArray initParamsAfter = net.getModelParams();
 | 
				
			||||||
        Map<String, INDArray> initParams2After = net.getParamTable();
 | 
					        Map<String, INDArray> initParams2After = net.getParamTable();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (String s : initParams2.keySet()) {
 | 
					        for (String s : initParams2.keySet()) {
 | 
				
			||||||
@ -108,7 +108,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
 | 
				
			|||||||
        INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape());
 | 
					        INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape());
 | 
				
			||||||
        net.setParams(randomParams.dup());
 | 
					        net.setParams(randomParams.dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net.params(), randomParams);
 | 
					        assertEquals(net.getModelParams(), randomParams);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -128,7 +128,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        INDArray params = net.params();
 | 
					        INDArray params = net.getModelParams();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork net2 = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net2 = new MultiLayerNetwork(conf);
 | 
				
			||||||
@ -137,11 +137,11 @@ public class TestSetGetParameters extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net3 = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net3 = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net3.init(params, false);
 | 
					        net3.init(params, false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(params, net2.params());
 | 
					        assertEquals(params, net2.getModelParams());
 | 
				
			||||||
        assertEquals(params, net3.params());
 | 
					        assertEquals(params, net3.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertNotSame(params, net2.params()); //Different objects due to clone
 | 
					        assertNotSame(params, net2.getModelParams()); //Different objects due to clone
 | 
				
			||||||
        assertSame(params, net3.params()); //Same object due to clone
 | 
					        assertSame(params, net3.getModelParams()); //Same object due to clone
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String, INDArray> paramsMap = net.getParamTable();
 | 
					        Map<String, INDArray> paramsMap = net.getParamTable();
 | 
				
			||||||
 | 
				
			|||||||
@ -103,14 +103,14 @@ public class TestVariableLengthTS extends BaseDL4JTest {
 | 
				
			|||||||
            net.setInput(in1);
 | 
					            net.setInput(in1);
 | 
				
			||||||
            net.setLabels(labels1);
 | 
					            net.setLabels(labels1);
 | 
				
			||||||
            net.computeGradientAndScore();
 | 
					            net.computeGradientAndScore();
 | 
				
			||||||
            double score1 = net.score();
 | 
					            double score1 = net.getScore();
 | 
				
			||||||
            Gradient g1 = net.gradient();
 | 
					            Gradient g1 = net.gradient();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            net.setInput(in2);
 | 
					            net.setInput(in2);
 | 
				
			||||||
            net.setLabels(labels2);
 | 
					            net.setLabels(labels2);
 | 
				
			||||||
            net.setLayerMaskArrays(null, labelMask);
 | 
					            net.setLayerMaskArrays(null, labelMask);
 | 
				
			||||||
            net.computeGradientAndScore();
 | 
					            net.computeGradientAndScore();
 | 
				
			||||||
            double score2 = net.score();
 | 
					            double score2 = net.getScore();
 | 
				
			||||||
            Gradient g2 = net.gradient();
 | 
					            Gradient g2 = net.gradient();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //Scores and gradients should be identical for two cases (given mask array)
 | 
					            //Scores and gradients should be identical for two cases (given mask array)
 | 
				
			||||||
@ -134,7 +134,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
                net.setLabels(labels2);
 | 
					                net.setLabels(labels2);
 | 
				
			||||||
                net.computeGradientAndScore();
 | 
					                net.computeGradientAndScore();
 | 
				
			||||||
                double score2a = net.score();
 | 
					                double score2a = net.getScore();
 | 
				
			||||||
                Gradient g2a = net.gradient();
 | 
					                Gradient g2a = net.gradient();
 | 
				
			||||||
                assertEquals(score2, score2a, 1e-6);
 | 
					                assertEquals(score2, score2a, 1e-6);
 | 
				
			||||||
                for (String s : g2map.keySet()) {
 | 
					                for (String s : g2map.keySet()) {
 | 
				
			||||||
@ -196,7 +196,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
 | 
				
			|||||||
            net.setInput(in1);
 | 
					            net.setInput(in1);
 | 
				
			||||||
            net.setLabels(labels1);
 | 
					            net.setLabels(labels1);
 | 
				
			||||||
            net.computeGradientAndScore();
 | 
					            net.computeGradientAndScore();
 | 
				
			||||||
            double score1 = net.score();
 | 
					            double score1 = net.getScore();
 | 
				
			||||||
            Gradient g1 = net.gradient();
 | 
					            Gradient g1 = net.gradient();
 | 
				
			||||||
            Map<String, INDArray> map1 = g1.gradientForVariable();
 | 
					            Map<String, INDArray> map1 = g1.gradientForVariable();
 | 
				
			||||||
            for (String s : map1.keySet()) {
 | 
					            for (String s : map1.keySet()) {
 | 
				
			||||||
@ -207,7 +207,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
 | 
				
			|||||||
            net.setLabels(labels2);
 | 
					            net.setLabels(labels2);
 | 
				
			||||||
            net.setLayerMaskArrays(inputMask, null);
 | 
					            net.setLayerMaskArrays(inputMask, null);
 | 
				
			||||||
            net.computeGradientAndScore();
 | 
					            net.computeGradientAndScore();
 | 
				
			||||||
            double score2 = net.score();
 | 
					            double score2 = net.getScore();
 | 
				
			||||||
            Gradient g2 = net.gradient();
 | 
					            Gradient g2 = net.gradient();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            net.setInput(in2);
 | 
					            net.setInput(in2);
 | 
				
			||||||
@ -240,7 +240,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
 | 
				
			|||||||
                net.setInput(in2);
 | 
					                net.setInput(in2);
 | 
				
			||||||
                net.setLayerMaskArrays(inputMask, null);
 | 
					                net.setLayerMaskArrays(inputMask, null);
 | 
				
			||||||
                net.computeGradientAndScore();
 | 
					                net.computeGradientAndScore();
 | 
				
			||||||
                double score2a = net.score();
 | 
					                double score2a = net.getScore();
 | 
				
			||||||
                Gradient g2a = net.gradient();
 | 
					                Gradient g2a = net.gradient();
 | 
				
			||||||
                assertEquals(score2, score2a, 1e-12);
 | 
					                assertEquals(score2, score2a, 1e-12);
 | 
				
			||||||
                for (String s : g2.gradientForVariable().keySet()) {
 | 
					                for (String s : g2.gradientForVariable().keySet()) {
 | 
				
			||||||
@ -327,7 +327,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
 | 
				
			|||||||
                        mln.setLabels(labels);
 | 
					                        mln.setLabels(labels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        mln.computeGradientAndScore();
 | 
					                        mln.computeGradientAndScore();
 | 
				
			||||||
                        double score = mln.score();
 | 
					                        double score = mln.getScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        assertEquals(expScore, score, 0.1, msg);
 | 
					                        assertEquals(expScore, score, 0.1, msg);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
 | 
				
			|||||||
@ -77,7 +77,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
 | 
				
			|||||||
                MultiLayerNetwork net2GradUpd = new MultiLayerNetwork(conf.clone());
 | 
					                MultiLayerNetwork net2GradUpd = new MultiLayerNetwork(conf.clone());
 | 
				
			||||||
                net2GradUpd.init();
 | 
					                net2GradUpd.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                assertEquals(net1GradCalc.params(), net2GradUpd.params());
 | 
					                assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                INDArray f = Nd4j.rand(minibatch, nIn);
 | 
					                INDArray f = Nd4j.rand(minibatch, nIn);
 | 
				
			||||||
                INDArray l = Nd4j.create(minibatch, nOut);
 | 
					                INDArray l = Nd4j.create(minibatch, nOut);
 | 
				
			||||||
@ -109,17 +109,17 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                //Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op
 | 
					                //Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op
 | 
				
			||||||
                // on the original network
 | 
					                // on the original network
 | 
				
			||||||
                net2GradUpd.params().subi(g.gradient());
 | 
					                net2GradUpd.getModelParams().subi(g.gradient());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                net1GradCalc.fit(f, l);
 | 
					                net1GradCalc.fit(f, l);
 | 
				
			||||||
                assertEquals(net1GradCalc.params(), net2GradUpd.params());
 | 
					                assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                //=============================
 | 
					                //=============================
 | 
				
			||||||
                if (!(u instanceof Sgd)) {
 | 
					                if (!(u instanceof Sgd)) {
 | 
				
			||||||
                    net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray());
 | 
					                    net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray());
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                assertEquals(net1GradCalc.params(), net2GradUpd.params());
 | 
					                assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
 | 
				
			||||||
                assertEquals(net1GradCalc.getUpdater().getStateViewArray(),
 | 
					                assertEquals(net1GradCalc.getUpdater().getStateViewArray(),
 | 
				
			||||||
                                net2GradUpd.getUpdater().getStateViewArray());
 | 
					                                net2GradUpd.getUpdater().getStateViewArray());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -130,7 +130,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
 | 
				
			|||||||
                for (int i = 0; i < 100; i++) {
 | 
					                for (int i = 0; i < 100; i++) {
 | 
				
			||||||
                    net1GradCalc.fit(f, l);
 | 
					                    net1GradCalc.fit(f, l);
 | 
				
			||||||
                    net2GradUpd.fit(f, l);
 | 
					                    net2GradUpd.fit(f, l);
 | 
				
			||||||
                    assertEquals(net1GradCalc.params(), net2GradUpd.params());
 | 
					                    assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -169,7 +169,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
 | 
				
			|||||||
                ComputationGraph net2GradUpd = new ComputationGraph(conf.clone());
 | 
					                ComputationGraph net2GradUpd = new ComputationGraph(conf.clone());
 | 
				
			||||||
                net2GradUpd.init();
 | 
					                net2GradUpd.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                assertEquals(net1GradCalc.params(), net2GradUpd.params());
 | 
					                assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                INDArray f = Nd4j.rand(minibatch, nIn);
 | 
					                INDArray f = Nd4j.rand(minibatch, nIn);
 | 
				
			||||||
                INDArray l = Nd4j.create(minibatch, nOut);
 | 
					                INDArray l = Nd4j.create(minibatch, nOut);
 | 
				
			||||||
@ -201,16 +201,16 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                //Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op
 | 
					                //Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op
 | 
				
			||||||
                // on the original network
 | 
					                // on the original network
 | 
				
			||||||
                net2GradUpd.params().subi(g.gradient());
 | 
					                net2GradUpd.getModelParams().subi(g.gradient());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l});
 | 
					                net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l});
 | 
				
			||||||
                assertEquals(net1GradCalc.params(), net2GradUpd.params());
 | 
					                assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                //=============================
 | 
					                //=============================
 | 
				
			||||||
                if (!(u instanceof Sgd)) {
 | 
					                if (!(u instanceof Sgd)) {
 | 
				
			||||||
                    net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray());
 | 
					                    net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray());
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                assertEquals(net1GradCalc.params(), net2GradUpd.params());
 | 
					                assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
 | 
				
			||||||
                assertEquals(net1GradCalc.getUpdater().getStateViewArray(),
 | 
					                assertEquals(net1GradCalc.getUpdater().getStateViewArray(),
 | 
				
			||||||
                                net2GradUpd.getUpdater().getStateViewArray());
 | 
					                                net2GradUpd.getUpdater().getStateViewArray());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -222,7 +222,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
 | 
				
			|||||||
                for (int i = 0; i < 100; i++) {
 | 
					                for (int i = 0; i < 100; i++) {
 | 
				
			||||||
                    net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l});
 | 
					                    net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l});
 | 
				
			||||||
                    net2GradUpd.fit(new INDArray[] {f}, new INDArray[] {l});
 | 
					                    net2GradUpd.fit(new INDArray[] {f}, new INDArray[] {l});
 | 
				
			||||||
                    assertEquals(net1GradCalc.params(), net2GradUpd.params());
 | 
					                    assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
				
			|||||||
@ -25,7 +25,6 @@ import org.deeplearning4j.TestUtils;
 | 
				
			|||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
					import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
 | 
					import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
 | 
					import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
 | 
					import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
					import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
 | 
				
			||||||
@ -94,7 +93,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        ComputationGraph modelToFineTune = new ComputationGraph(expectedConf);
 | 
					        ComputationGraph modelToFineTune = new ComputationGraph(expectedConf);
 | 
				
			||||||
        modelToFineTune.init();
 | 
					        modelToFineTune.init();
 | 
				
			||||||
        modelToFineTune.setParams(expectedModel.params());
 | 
					        modelToFineTune.setParams(expectedModel.getModelParams());
 | 
				
			||||||
        //model after applying changes with transfer learning
 | 
					        //model after applying changes with transfer learning
 | 
				
			||||||
        ComputationGraph modelNow =
 | 
					        ComputationGraph modelNow =
 | 
				
			||||||
                        new TransferLearning.GraphBuilder(modelToFineTune)
 | 
					                        new TransferLearning.GraphBuilder(modelToFineTune)
 | 
				
			||||||
@ -108,8 +107,8 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
 | 
				
			|||||||
        //Check params after fit
 | 
					        //Check params after fit
 | 
				
			||||||
        modelNow.fit(randomData);
 | 
					        modelNow.fit(randomData);
 | 
				
			||||||
        expectedModel.fit(randomData);
 | 
					        expectedModel.fit(randomData);
 | 
				
			||||||
        assertEquals(modelNow.score(), expectedModel.score(), 1e-8);
 | 
					        assertEquals(modelNow.getScore(), expectedModel.getScore(), 1e-8);
 | 
				
			||||||
        assertEquals(modelNow.params(), expectedModel.params());
 | 
					        assertEquals(modelNow.getModelParams(), expectedModel.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -139,9 +138,9 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
 | 
				
			|||||||
                        //.setOutputs("layer3")
 | 
					                        //.setOutputs("layer3")
 | 
				
			||||||
                        .build();
 | 
					                        .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        BaseLayer bl0 = ((BaseLayer) modelNow.getLayer("layer0").getLayerConfiguration());
 | 
					        BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getLayer("layer0").getLayerConfiguration());
 | 
				
			||||||
        BaseLayer bl1 = ((BaseLayer) modelNow.getLayer("layer1").getLayerConfiguration());
 | 
					        BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getLayer("layer1").getLayerConfiguration());
 | 
				
			||||||
        BaseLayer bl3 = ((BaseLayer) modelNow.getLayer("layer3").getLayerConfiguration());
 | 
					        BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getLayer("layer3").getLayerConfiguration());
 | 
				
			||||||
        assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1)));
 | 
					        assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1)));
 | 
				
			||||||
        assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
 | 
					        assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
 | 
				
			||||||
        assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
 | 
					        assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
 | 
				
			||||||
@ -161,22 +160,22 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
 | 
				
			|||||||
        modelExpectedArch.init();
 | 
					        modelExpectedArch.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //modelNow should have the same architecture as modelExpectedArch
 | 
					        //modelNow should have the same architecture as modelExpectedArch
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(),
 | 
					        assertArrayEquals(modelExpectedArch.getLayer("layer0").getParams().shape(),
 | 
				
			||||||
                        modelNow.getLayer("layer0").params().shape());
 | 
					                        modelNow.getLayer("layer0").getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(),
 | 
					        assertArrayEquals(modelExpectedArch.getLayer("layer1").getParams().shape(),
 | 
				
			||||||
                        modelNow.getLayer("layer1").params().shape());
 | 
					                        modelNow.getLayer("layer1").getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(),
 | 
					        assertArrayEquals(modelExpectedArch.getLayer("layer2").getParams().shape(),
 | 
				
			||||||
                        modelNow.getLayer("layer2").params().shape());
 | 
					                        modelNow.getLayer("layer2").getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(),
 | 
					        assertArrayEquals(modelExpectedArch.getLayer("layer3").getParams().shape(),
 | 
				
			||||||
                        modelNow.getLayer("layer3").params().shape());
 | 
					                        modelNow.getLayer("layer3").getParams().shape());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        modelNow.setParams(modelExpectedArch.params());
 | 
					        modelNow.setParams(modelExpectedArch.getModelParams());
 | 
				
			||||||
        //fit should give the same results
 | 
					        //fit should give the same results
 | 
				
			||||||
        modelExpectedArch.fit(randomData);
 | 
					        modelExpectedArch.fit(randomData);
 | 
				
			||||||
        modelNow.fit(randomData);
 | 
					        modelNow.fit(randomData);
 | 
				
			||||||
        assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8);
 | 
					        assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 1e-8);
 | 
				
			||||||
        assertEquals(modelExpectedArch.params(), modelNow.params());
 | 
					        assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -227,22 +226,22 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
 | 
				
			|||||||
        modelExpectedArch.init();
 | 
					        modelExpectedArch.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //modelNow should have the same architecture as modelExpectedArch
 | 
					        //modelNow should have the same architecture as modelExpectedArch
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(),
 | 
					        assertArrayEquals(modelExpectedArch.getLayer("layer0").getParams().shape(),
 | 
				
			||||||
                        modelNow.getLayer("layer0").params().shape());
 | 
					                        modelNow.getLayer("layer0").getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(),
 | 
					        assertArrayEquals(modelExpectedArch.getLayer("layer1").getParams().shape(),
 | 
				
			||||||
                        modelNow.getLayer("layer1").params().shape());
 | 
					                        modelNow.getLayer("layer1").getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(),
 | 
					        assertArrayEquals(modelExpectedArch.getLayer("layer2").getParams().shape(),
 | 
				
			||||||
                        modelNow.getLayer("layer2").params().shape());
 | 
					                        modelNow.getLayer("layer2").getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(),
 | 
					        assertArrayEquals(modelExpectedArch.getLayer("layer3").getParams().shape(),
 | 
				
			||||||
                        modelNow.getLayer("layer3").params().shape());
 | 
					                        modelNow.getLayer("layer3").getParams().shape());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        modelNow.setParams(modelExpectedArch.params());
 | 
					        modelNow.setParams(modelExpectedArch.getModelParams());
 | 
				
			||||||
        //fit should give the same results
 | 
					        //fit should give the same results
 | 
				
			||||||
        modelExpectedArch.fit(randomData);
 | 
					        modelExpectedArch.fit(randomData);
 | 
				
			||||||
        modelNow.fit(randomData);
 | 
					        modelNow.fit(randomData);
 | 
				
			||||||
        assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8);
 | 
					        assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 1e-8);
 | 
				
			||||||
        assertEquals(modelExpectedArch.params(), modelNow.params());
 | 
					        assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -385,14 +384,14 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        assertEquals(modelExpectedArch.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
 | 
					        assertEquals(modelExpectedArch.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        modelNow.setParams(modelExpectedArch.params());
 | 
					        modelNow.setParams(modelExpectedArch.getModelParams());
 | 
				
			||||||
        int i = 0;
 | 
					        int i = 0;
 | 
				
			||||||
        while (i < 5) {
 | 
					        while (i < 5) {
 | 
				
			||||||
            modelExpectedArch.fit(randomData);
 | 
					            modelExpectedArch.fit(randomData);
 | 
				
			||||||
            modelNow.fit(randomData);
 | 
					            modelNow.fit(randomData);
 | 
				
			||||||
            i++;
 | 
					            i++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        assertEquals(modelExpectedArch.params(), modelNow.params());
 | 
					        assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -26,10 +26,9 @@ import org.deeplearning4j.nn.api.Layer;
 | 
				
			|||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
					import org.deeplearning4j.nn.api.OptimizationAlgorithm;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
 | 
					import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
 | 
					import org.deeplearning4j.nn.conf.graph.MergeVertex;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
					import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
					import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
				
			||||||
@ -99,7 +98,7 @@ public class TransferLearningComplex extends BaseDL4JTest {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //Also check config:
 | 
					            //Also check config:
 | 
				
			||||||
            BaseLayer bl = ((BaseLayer) l.getLayerConfiguration());
 | 
					            BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
 | 
				
			||||||
            assertEquals(new Adam(2e-2), bl.getIUpdater());
 | 
					            assertEquals(new Adam(2e-2), bl.getIUpdater());
 | 
				
			||||||
            assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn());
 | 
					            assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -154,8 +153,8 @@ public class TransferLearningComplex extends BaseDL4JTest {
 | 
				
			|||||||
                                        .setOutputs("outRight").build();
 | 
					                                        .setOutputs("outRight").build();
 | 
				
			||||||
        ComputationGraph modelOther = new ComputationGraph(otherConf);
 | 
					        ComputationGraph modelOther = new ComputationGraph(otherConf);
 | 
				
			||||||
        modelOther.init();
 | 
					        modelOther.init();
 | 
				
			||||||
        modelOther.getLayer("denseRight0").setParams(modelToTune.getLayer("denseRight0").params());
 | 
					        modelOther.getLayer("denseRight0").setParams(modelToTune.getLayer("denseRight0").getParams());
 | 
				
			||||||
        modelOther.getLayer("outRight").setParams(modelToTune.getLayer("outRight").params());
 | 
					        modelOther.getLayer("outRight").setParams(modelToTune.getLayer("outRight").getParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
 | 
					        modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
 | 
				
			||||||
        ComputationGraph modelNow =
 | 
					        ComputationGraph modelNow =
 | 
				
			||||||
@ -179,11 +178,11 @@ public class TransferLearningComplex extends BaseDL4JTest {
 | 
				
			|||||||
            assertEquals(otherRandData.getFeatures(0),
 | 
					            assertEquals(otherRandData.getFeatures(0),
 | 
				
			||||||
                            modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
 | 
					                            modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(modelOther.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params());
 | 
					            assertEquals(modelOther.getLayer("denseRight0").getParams(), modelNow.getLayer("denseRight0").getParams());
 | 
				
			||||||
            assertEquals(modelOther.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params());
 | 
					            assertEquals(modelOther.getLayer("denseRight0").getParams(), modelToTune.getLayer("denseRight0").getParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(modelOther.getLayer("outRight").params(), modelNow.getLayer("outRight").params());
 | 
					            assertEquals(modelOther.getLayer("outRight").getParams(), modelNow.getLayer("outRight").getParams());
 | 
				
			||||||
            assertEquals(modelOther.getLayer("outRight").params(), modelToTune.getLayer("outRight").params());
 | 
					            assertEquals(modelOther.getLayer("outRight").getParams(), modelToTune.getLayer("outRight").getParams());
 | 
				
			||||||
            n++;
 | 
					            n++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -237,11 +236,11 @@ public class TransferLearningComplex extends BaseDL4JTest {
 | 
				
			|||||||
            assertEquals(otherRandData.getFeatures(0),
 | 
					            assertEquals(otherRandData.getFeatures(0),
 | 
				
			||||||
                            modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
 | 
					                            modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(modelToTune.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params());
 | 
					            assertEquals(modelToTune.getLayer("denseRight0").getParams(), modelNow.getLayer("denseRight0").getParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(modelToTune.getLayer("outRight").params(), modelNow.getLayer("outRight").params());
 | 
					            assertEquals(modelToTune.getLayer("outRight").getParams(), modelNow.getLayer("outRight").getParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(modelToTune.getLayer("outCentre").params(), modelNow.getLayer("outCentre").params());
 | 
					            assertEquals(modelToTune.getLayer("outCentre").getParams(), modelNow.getLayer("outCentre").getParams());
 | 
				
			||||||
            n++;
 | 
					            n++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -178,25 +178,25 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
 | 
				
			|||||||
        TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
 | 
					        TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
 | 
				
			||||||
        MultiDataSet featurizedDataSet = helper.featurize(origData);
 | 
					        MultiDataSet featurizedDataSet = helper.featurize(origData);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params());
 | 
					        assertEquals(modelIdentical.getLayer("denseRight0").getParams(), modelToTune.getLayer("denseRight0").getParams());
 | 
				
			||||||
        modelIdentical.fit(origData);
 | 
					        modelIdentical.fit(origData);
 | 
				
			||||||
        helper.fitFeaturized(featurizedDataSet);
 | 
					        helper.fitFeaturized(featurizedDataSet);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("denseCentre0").params(), modelToTune.getLayer("denseCentre0").params());
 | 
					        assertEquals(modelIdentical.getLayer("denseCentre0").getParams(), modelToTune.getLayer("denseCentre0").getParams());
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("denseCentre1").params(), modelToTune.getLayer("denseCentre1").params());
 | 
					        assertEquals(modelIdentical.getLayer("denseCentre1").getParams(), modelToTune.getLayer("denseCentre1").getParams());
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("denseCentre2").params(), modelToTune.getLayer("denseCentre2").params());
 | 
					        assertEquals(modelIdentical.getLayer("denseCentre2").getParams(), modelToTune.getLayer("denseCentre2").getParams());
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("denseCentre3").params(), modelToTune.getLayer("denseCentre3").params());
 | 
					        assertEquals(modelIdentical.getLayer("denseCentre3").getParams(), modelToTune.getLayer("denseCentre3").getParams());
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("outCentre").params(), modelToTune.getLayer("outCentre").params());
 | 
					        assertEquals(modelIdentical.getLayer("outCentre").getParams(), modelToTune.getLayer("outCentre").getParams());
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("denseRight").getNetConfiguration().toJson(),
 | 
					        assertEquals(modelIdentical.getLayer("denseRight").getNetConfiguration().toJson(),
 | 
				
			||||||
                        modelToTune.getLayer("denseRight").getNetConfiguration().toJson());
 | 
					                        modelToTune.getLayer("denseRight").getNetConfiguration().toJson());
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("denseRight").params(), modelToTune.getLayer("denseRight").params());
 | 
					        assertEquals(modelIdentical.getLayer("denseRight").getParams(), modelToTune.getLayer("denseRight").getParams());
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("denseRight0").getNetConfiguration().toJson(),
 | 
					        assertEquals(modelIdentical.getLayer("denseRight0").getNetConfiguration().toJson(),
 | 
				
			||||||
                        modelToTune.getLayer("denseRight0").getNetConfiguration().toJson());
 | 
					                        modelToTune.getLayer("denseRight0").getNetConfiguration().toJson());
 | 
				
			||||||
        //assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params());
 | 
					        //assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params());
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("denseRight1").params(), modelToTune.getLayer("denseRight1").params());
 | 
					        assertEquals(modelIdentical.getLayer("denseRight1").getParams(), modelToTune.getLayer("denseRight1").getParams());
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("outRight").params(), modelToTune.getLayer("outRight").params());
 | 
					        assertEquals(modelIdentical.getLayer("outRight").getParams(), modelToTune.getLayer("outRight").getParams());
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params());
 | 
					        assertEquals(modelIdentical.getLayer("denseLeft0").getParams(), modelToTune.getLayer("denseLeft0").getParams());
 | 
				
			||||||
        assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params());
 | 
					        assertEquals(modelIdentical.getLayer("outLeft").getParams(), modelToTune.getLayer("outLeft").getParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//        log.info(modelIdentical.summary());
 | 
					//        log.info(modelIdentical.summary());
 | 
				
			||||||
//        log.info(helper.unfrozenGraph().summary());
 | 
					//        log.info(helper.unfrozenGraph().summary());
 | 
				
			||||||
@ -230,7 +230,7 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
 | 
				
			|||||||
        TransferLearningHelper helper = new TransferLearningHelper(modelToFineTune, 1);
 | 
					        TransferLearningHelper helper = new TransferLearningHelper(modelToFineTune, 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray paramsLastTwoLayers =
 | 
					        INDArray paramsLastTwoLayers =
 | 
				
			||||||
                        Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params());
 | 
					                        Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams());
 | 
				
			||||||
        MultiLayerNetwork notFrozen = new MultiLayerNetwork(
 | 
					        MultiLayerNetwork notFrozen = new MultiLayerNetwork(
 | 
				
			||||||
            (NeuralNetConfiguration) overallConf.clone().list()
 | 
					            (NeuralNetConfiguration) overallConf.clone().list()
 | 
				
			||||||
                            .layer(0, new Builder().nIn(2).nOut(3).build())
 | 
					                            .layer(0, new Builder().nIn(2).nOut(3).build())
 | 
				
			||||||
@ -248,9 +248,9 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
 | 
				
			|||||||
            modelNow.fit(randomData);
 | 
					            modelNow.fit(randomData);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(),
 | 
					        INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), modelToFineTune.getLayer(1).getParams(),
 | 
				
			||||||
                        notFrozen.params());
 | 
					                        notFrozen.getModelParams());
 | 
				
			||||||
        INDArray act = modelNow.params();
 | 
					        INDArray act = modelNow.getModelParams();
 | 
				
			||||||
        assertEquals(expected, act);
 | 
					        assertEquals(expected, act);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -91,7 +91,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
                        .build();
 | 
					                        .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) {
 | 
					        for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) {
 | 
				
			||||||
            BaseLayer bl = ((BaseLayer) l.getLayerConfiguration());
 | 
					            BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
 | 
				
			||||||
            assertEquals(new RmsProp(0.5), bl.getIUpdater());
 | 
					            assertEquals(new RmsProp(0.5), bl.getIUpdater());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -107,9 +107,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
                                                        .build())
 | 
					                                                        .build())
 | 
				
			||||||
                        .build());
 | 
					                        .build());
 | 
				
			||||||
        expectedModel.init();
 | 
					        expectedModel.init();
 | 
				
			||||||
        expectedModel.setParams(modelToFineTune.params().dup());
 | 
					        expectedModel.setParams(modelToFineTune.getModelParams().dup());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(expectedModel.params(), modelNow.params());
 | 
					        assertEquals(expectedModel.getModelParams(), modelNow.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Check json
 | 
					        //Check json
 | 
				
			||||||
        NeuralNetConfiguration expectedConf = expectedModel.getNetConfiguration();
 | 
					        NeuralNetConfiguration expectedConf = expectedModel.getNetConfiguration();
 | 
				
			||||||
@ -119,9 +119,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
        modelNow.fit(randomData);
 | 
					        modelNow.fit(randomData);
 | 
				
			||||||
        expectedModel.fit(randomData);
 | 
					        expectedModel.fit(randomData);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(modelNow.score(), expectedModel.score(), 1e-6);
 | 
					        assertEquals(modelNow.getScore(), expectedModel.getScore(), 1e-6);
 | 
				
			||||||
        INDArray pExp = expectedModel.params();
 | 
					        INDArray pExp = expectedModel.getModelParams();
 | 
				
			||||||
        INDArray pNow = modelNow.params();
 | 
					        INDArray pNow = modelNow.getModelParams();
 | 
				
			||||||
        assertEquals(pExp, pNow);
 | 
					        assertEquals(pExp, pNow);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -160,9 +160,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
        //Will fail - expected because of dist and weight init changes
 | 
					        //Will fail - expected because of dist and weight init changes
 | 
				
			||||||
        //assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson());
 | 
					        //assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        BaseLayer bl0 = ((BaseLayer) modelNow.getNetConfiguration().getConf(0).getLayer());
 | 
					        BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(0).getLayer());
 | 
				
			||||||
        BaseLayer bl1 = ((BaseLayer) modelNow.getNetConfiguration().getConf(1).getLayer());
 | 
					        BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(1).getLayer());
 | 
				
			||||||
        BaseLayer bl3 = ((BaseLayer) modelNow.getNetConfiguration().getConf(3).getLayer());
 | 
					        BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(3).getLayer());
 | 
				
			||||||
        assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class);
 | 
					        assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class);
 | 
				
			||||||
        try {
 | 
					        try {
 | 
				
			||||||
            assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()),
 | 
					            assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()),
 | 
				
			||||||
@ -173,18 +173,18 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(bl3.getWeightInitFn(), new WeightInitXavier());
 | 
					        assertEquals(bl3.getWeightInitFn(), new WeightInitXavier());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //modelNow should have the same architecture as modelExpectedArch
 | 
					        //modelNow should have the same architecture as modelExpectedArch
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(1).getParams().shape(), modelNow.getLayer(1).getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        modelNow.setParams(modelExpectedArch.params());
 | 
					        modelNow.setParams(modelExpectedArch.getModelParams());
 | 
				
			||||||
        //fit should give the same results
 | 
					        //fit should give the same results
 | 
				
			||||||
        modelExpectedArch.fit(randomData);
 | 
					        modelExpectedArch.fit(randomData);
 | 
				
			||||||
        modelNow.fit(randomData);
 | 
					        modelNow.fit(randomData);
 | 
				
			||||||
        assertEquals(modelExpectedArch.score(), modelNow.score(), 0.000001);
 | 
					        assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 0.000001);
 | 
				
			||||||
        assertEquals(modelExpectedArch.params(), modelNow.params());
 | 
					        assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -227,20 +227,20 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
        modelExpectedArch.init();
 | 
					        modelExpectedArch.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //modelNow should have the same architecture as modelExpectedArch
 | 
					        //modelNow should have the same architecture as modelExpectedArch
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(1).getParams().shape(), modelNow.getLayer(1).getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        modelNow.setParams(modelExpectedArch.params());
 | 
					        modelNow.setParams(modelExpectedArch.getModelParams());
 | 
				
			||||||
        //fit should give the same results
 | 
					        //fit should give the same results
 | 
				
			||||||
        modelExpectedArch.fit(randomData);
 | 
					        modelExpectedArch.fit(randomData);
 | 
				
			||||||
        modelNow.fit(randomData);
 | 
					        modelNow.fit(randomData);
 | 
				
			||||||
        double scoreExpected = modelExpectedArch.score();
 | 
					        double scoreExpected = modelExpectedArch.getScore();
 | 
				
			||||||
        double scoreActual = modelNow.score();
 | 
					        double scoreActual = modelNow.getScore();
 | 
				
			||||||
        assertEquals(scoreExpected, scoreActual, 1e-4);
 | 
					        assertEquals(scoreExpected, scoreActual, 1e-4);
 | 
				
			||||||
        assertEquals(modelExpectedArch.params(), modelNow.params());
 | 
					        assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -370,14 +370,14 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(modelExpectedArch.getNetConfiguration().getConf(5).toJson(),
 | 
					        assertEquals(modelExpectedArch.getNetConfiguration().getConf(5).toJson(),
 | 
				
			||||||
                        modelNow.getNetConfiguration().getConf(5).toJson());
 | 
					                        modelNow.getNetConfiguration().getConf(5).toJson());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
 | 
				
			||||||
        //subsampling has no params
 | 
					        //subsampling has no params
 | 
				
			||||||
        //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
 | 
					        //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(4).params().shape(), modelNow.getLayer(4).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(4).getParams().shape(), modelNow.getLayer(4).getParams().shape());
 | 
				
			||||||
        assertArrayEquals(modelExpectedArch.getLayer(5).params().shape(), modelNow.getLayer(5).params().shape());
 | 
					        assertArrayEquals(modelExpectedArch.getLayer(5).getParams().shape(), modelNow.getLayer(5).getParams().shape());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -449,23 +449,23 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
                        .inputType(InputType.convolutionalFlat(12, 12, 20)).build());
 | 
					                        .inputType(InputType.convolutionalFlat(12, 12, 20)).build());
 | 
				
			||||||
        notFrozen.init();
 | 
					        notFrozen.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertArrayEquals(modelToFineTune.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape());
 | 
					        assertArrayEquals(modelToFineTune.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
 | 
				
			||||||
        //subsampling has no params
 | 
					        //subsampling has no params
 | 
				
			||||||
        //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
 | 
					        //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
 | 
				
			||||||
        assertArrayEquals(notFrozen.getLayer(0).params().shape(), modelNow.getLayer(2).params().shape());
 | 
					        assertArrayEquals(notFrozen.getLayer(0).getParams().shape(), modelNow.getLayer(2).getParams().shape());
 | 
				
			||||||
        modelNow.getLayer(2).setParams(notFrozen.getLayer(0).params());
 | 
					        modelNow.getLayer(2).setParams(notFrozen.getLayer(0).getParams());
 | 
				
			||||||
        //subsampling has no params
 | 
					        //subsampling has no params
 | 
				
			||||||
        //assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape());
 | 
					        //assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape());
 | 
				
			||||||
        assertArrayEquals(notFrozen.getLayer(2).params().shape(), modelNow.getLayer(4).params().shape());
 | 
					        assertArrayEquals(notFrozen.getLayer(2).getParams().shape(), modelNow.getLayer(4).getParams().shape());
 | 
				
			||||||
        modelNow.getLayer(4).setParams(notFrozen.getLayer(2).params());
 | 
					        modelNow.getLayer(4).setParams(notFrozen.getLayer(2).getParams());
 | 
				
			||||||
        assertArrayEquals(notFrozen.getLayer(3).params().shape(), modelNow.getLayer(5).params().shape());
 | 
					        assertArrayEquals(notFrozen.getLayer(3).getParams().shape(), modelNow.getLayer(5).getParams().shape());
 | 
				
			||||||
        modelNow.getLayer(5).setParams(notFrozen.getLayer(3).params());
 | 
					        modelNow.getLayer(5).setParams(notFrozen.getLayer(3).getParams());
 | 
				
			||||||
        assertArrayEquals(notFrozen.getLayer(4).params().shape(), modelNow.getLayer(6).params().shape());
 | 
					        assertArrayEquals(notFrozen.getLayer(4).getParams().shape(), modelNow.getLayer(6).getParams().shape());
 | 
				
			||||||
        modelNow.getLayer(6).setParams(notFrozen.getLayer(4).params());
 | 
					        modelNow.getLayer(6).setParams(notFrozen.getLayer(4).getParams());
 | 
				
			||||||
        assertArrayEquals(notFrozen.getLayer(5).params().shape(), modelNow.getLayer(7).params().shape());
 | 
					        assertArrayEquals(notFrozen.getLayer(5).getParams().shape(), modelNow.getLayer(7).getParams().shape());
 | 
				
			||||||
        modelNow.getLayer(7).setParams(notFrozen.getLayer(5).params());
 | 
					        modelNow.getLayer(7).setParams(notFrozen.getLayer(5).getParams());
 | 
				
			||||||
        assertArrayEquals(notFrozen.getLayer(6).params().shape(), modelNow.getLayer(8).params().shape());
 | 
					        assertArrayEquals(notFrozen.getLayer(6).getParams().shape(), modelNow.getLayer(8).getParams().shape());
 | 
				
			||||||
        modelNow.getLayer(8).setParams(notFrozen.getLayer(6).params());
 | 
					        modelNow.getLayer(8).setParams(notFrozen.getLayer(6).getParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int i = 0;
 | 
					        int i = 0;
 | 
				
			||||||
        while (i < 3) {
 | 
					        while (i < 3) {
 | 
				
			||||||
@ -474,8 +474,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
            i++;
 | 
					            i++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), notFrozen.params());
 | 
					        INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), notFrozen.getModelParams());
 | 
				
			||||||
        assertEquals(expectedParams, modelNow.params());
 | 
					        assertEquals(expectedParams, modelNow.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -503,13 +503,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Check original net isn't modified:
 | 
					        //Check original net isn't modified:
 | 
				
			||||||
        BaseLayer l0 = (BaseLayer) net.getLayer(0).getLayerConfiguration();
 | 
					        BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
 | 
				
			||||||
        assertEquals(new Adam(1e-4), l0.getIUpdater());
 | 
					        assertEquals(new Adam(1e-4), l0.getIUpdater());
 | 
				
			||||||
        assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
 | 
					        assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
 | 
				
			||||||
        assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
 | 
					        assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
 | 
				
			||||||
        assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
 | 
					        assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        BaseLayer l1 = (BaseLayer) net.getLayer(1).getLayerConfiguration();
 | 
					        BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
 | 
				
			||||||
        assertEquals(new Adam(1e-4), l1.getIUpdater());
 | 
					        assertEquals(new Adam(1e-4), l1.getIUpdater());
 | 
				
			||||||
        assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
 | 
					        assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
 | 
				
			||||||
        assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
 | 
					        assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
 | 
				
			||||||
@ -518,13 +518,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(BackpropType.Standard, conf.getBackpropType());
 | 
					        assertEquals(BackpropType.Standard, conf.getBackpropType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Check new net has only the appropriate things modified (i.e., LR)
 | 
					        //Check new net has only the appropriate things modified (i.e., LR)
 | 
				
			||||||
        l0 = (BaseLayer) net2.getLayer(0).getLayerConfiguration();
 | 
					        l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
 | 
				
			||||||
        assertEquals(new Adam(2e-2), l0.getIUpdater());
 | 
					        assertEquals(new Adam(2e-2), l0.getIUpdater());
 | 
				
			||||||
        assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
 | 
					        assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
 | 
				
			||||||
        assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
 | 
					        assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
 | 
				
			||||||
        assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
 | 
					        assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        l1 = (BaseLayer) net2.getLayer(1).getLayerConfiguration();
 | 
					        l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration();
 | 
				
			||||||
        assertEquals(new Adam(2e-2), l1.getIUpdater());
 | 
					        assertEquals(new Adam(2e-2), l1.getIUpdater());
 | 
				
			||||||
        assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
 | 
					        assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
 | 
				
			||||||
        assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
 | 
					        assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
 | 
				
			||||||
@ -586,17 +586,17 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
                        .build());
 | 
					                        .build());
 | 
				
			||||||
        notFrozen.init();
 | 
					        notFrozen.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertArrayEquals(modelToFineTune.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape());
 | 
					        assertArrayEquals(modelToFineTune.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
 | 
				
			||||||
        //subsampling has no params
 | 
					        //subsampling has no params
 | 
				
			||||||
        //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
 | 
					        //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
 | 
				
			||||||
        assertArrayEquals(notFrozen.getLayer(0).params().shape(), modelNow.getLayer(2).params().shape());
 | 
					        assertArrayEquals(notFrozen.getLayer(0).getParams().shape(), modelNow.getLayer(2).getParams().shape());
 | 
				
			||||||
        modelNow.getLayer(2).setParams(notFrozen.getLayer(0).params());
 | 
					        modelNow.getLayer(2).setParams(notFrozen.getLayer(0).getParams());
 | 
				
			||||||
        assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape());
 | 
					        assertArrayEquals(notFrozen.getLayer(1).getParams().shape(), modelNow.getLayer(3).getParams().shape());
 | 
				
			||||||
        modelNow.getLayer(3).setParams(notFrozen.getLayer(1).params());
 | 
					        modelNow.getLayer(3).setParams(notFrozen.getLayer(1).getParams());
 | 
				
			||||||
        assertArrayEquals(notFrozen.getLayer(2).params().shape(), modelNow.getLayer(4).params().shape());
 | 
					        assertArrayEquals(notFrozen.getLayer(2).getParams().shape(), modelNow.getLayer(4).getParams().shape());
 | 
				
			||||||
        modelNow.getLayer(4).setParams(notFrozen.getLayer(2).params());
 | 
					        modelNow.getLayer(4).setParams(notFrozen.getLayer(2).getParams());
 | 
				
			||||||
        assertArrayEquals(notFrozen.getLayer(3).params().shape(), modelNow.getLayer(5).params().shape());
 | 
					        assertArrayEquals(notFrozen.getLayer(3).getParams().shape(), modelNow.getLayer(5).getParams().shape());
 | 
				
			||||||
        modelNow.getLayer(5).setParams(notFrozen.getLayer(3).params());
 | 
					        modelNow.getLayer(5).setParams(notFrozen.getLayer(3).getParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int i = 0;
 | 
					        int i = 0;
 | 
				
			||||||
        while (i < 3) {
 | 
					        while (i < 3) {
 | 
				
			||||||
@ -605,8 +605,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
 | 
				
			|||||||
            i++;
 | 
					            i++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), notFrozen.params());
 | 
					        INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), notFrozen.getModelParams());
 | 
				
			||||||
        assertEquals(expectedParams, modelNow.params());
 | 
					        assertEquals(expectedParams, modelNow.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
 | 
				
			|||||||
@ -99,7 +99,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
					        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
				
			||||||
        layer.setBackpropGradientsViewArray(gradients);
 | 
					        layer.setBackpropGradientsViewArray(gradients);
 | 
				
			||||||
        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
					        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
				
			||||||
        int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
 | 
					        int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
 | 
				
			||||||
        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
					        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
				
			||||||
        updater.setStateViewArray(layer, updaterState, true);
 | 
					        updater.setStateViewArray(layer, updaterState, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -144,7 +144,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
                msdx.put(key, msdxTmp);
 | 
					                msdx.put(key, msdxTmp);
 | 
				
			||||||
                count++;
 | 
					                count++;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            assertEquals(rho, ((AdaDelta)layer.layerConf().getIUpdater()).getRho(), 1e-4);
 | 
					            assertEquals(rho, ((AdaDelta)layer.getTypedLayerConfiguration().getIUpdater()).getRho(), 1e-4);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(4, count);
 | 
					        assertEquals(4, count);
 | 
				
			||||||
@ -165,7 +165,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
					        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
				
			||||||
        layer.setBackpropGradientsViewArray(gradients);
 | 
					        layer.setBackpropGradientsViewArray(gradients);
 | 
				
			||||||
        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
					        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
				
			||||||
        int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
 | 
					        int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
 | 
				
			||||||
        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
					        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
				
			||||||
        updater.setStateViewArray(layer, updaterState, true);
 | 
					        updater.setStateViewArray(layer, updaterState, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -185,7 +185,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
 | 
					            assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
 | 
				
			||||||
            count++;
 | 
					            count++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        assertEquals(lr, ((AdaGrad)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4);
 | 
					        assertEquals(lr, ((AdaGrad)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
 | 
				
			||||||
        assertEquals(2, count);
 | 
					        assertEquals(2, count);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -209,7 +209,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
					        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
				
			||||||
        layer.setBackpropGradientsViewArray(gradients);
 | 
					        layer.setBackpropGradientsViewArray(gradients);
 | 
				
			||||||
        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
					        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
				
			||||||
        int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
 | 
					        int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
 | 
				
			||||||
        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
					        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
				
			||||||
        updater.setStateViewArray(layer, updaterState, true);
 | 
					        updater.setStateViewArray(layer, updaterState, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -245,8 +245,8 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            count++;
 | 
					            count++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(beta1, ((Adam)layer.layerConf().getIUpdater()).getBeta1(), 1e-4);
 | 
					        assertEquals(beta1, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4);
 | 
				
			||||||
        assertEquals(beta2, ((Adam)layer.layerConf().getIUpdater()).getBeta2(), 1e-4);
 | 
					        assertEquals(beta2, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4);
 | 
				
			||||||
        assertEquals(2, count);
 | 
					        assertEquals(2, count);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -273,7 +273,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
        layer.setBackpropGradientsViewArray(gradients);
 | 
					        layer.setBackpropGradientsViewArray(gradients);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
					        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
				
			||||||
        int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
 | 
					        int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
 | 
				
			||||||
        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
					        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
				
			||||||
        updater.setStateViewArray(layer, updaterState, true);
 | 
					        updater.setStateViewArray(layer, updaterState, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -362,7 +362,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
					        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
				
			||||||
        layer.setBackpropGradientsViewArray(gradients);
 | 
					        layer.setBackpropGradientsViewArray(gradients);
 | 
				
			||||||
        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
					        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
				
			||||||
        int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
 | 
					        int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
 | 
				
			||||||
        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
					        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
				
			||||||
        updater.setStateViewArray(layer, updaterState, true);
 | 
					        updater.setStateViewArray(layer, updaterState, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -398,8 +398,8 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            count++;
 | 
					            count++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(beta1, ((AdaMax)layer.layerConf().getIUpdater()).getBeta1(), 1e-4);
 | 
					        assertEquals(beta1, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4);
 | 
				
			||||||
        assertEquals(beta2, ((AdaMax)layer.layerConf().getIUpdater()).getBeta2(), 1e-4);
 | 
					        assertEquals(beta2, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4);
 | 
				
			||||||
        assertEquals(2, count);
 | 
					        assertEquals(2, count);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -418,7 +418,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
					        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
				
			||||||
        layer.setBackpropGradientsViewArray(gradients);
 | 
					        layer.setBackpropGradientsViewArray(gradients);
 | 
				
			||||||
        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
					        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
				
			||||||
        int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
 | 
					        int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
 | 
				
			||||||
        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
					        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
				
			||||||
        updater.setStateViewArray(layer, updaterState, true);
 | 
					        updater.setStateViewArray(layer, updaterState, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -443,7 +443,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            count++;
 | 
					            count++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(mu, ((Nesterovs)layer.layerConf().getIUpdater()).getMomentum(), 1e-4);
 | 
					        assertEquals(mu, ((Nesterovs)layer.getTypedLayerConfiguration().getIUpdater()).getMomentum(), 1e-4);
 | 
				
			||||||
        assertEquals(2, count);
 | 
					        assertEquals(2, count);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -465,7 +465,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
					        BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
				
			||||||
        layer.setBackpropGradientsViewArray(gradients);
 | 
					        layer.setBackpropGradientsViewArray(gradients);
 | 
				
			||||||
        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
					        Updater updater = UpdaterCreator.getUpdater(layer);
 | 
				
			||||||
        int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams);
 | 
					        int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
 | 
				
			||||||
        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
					        INDArray updaterState = Nd4j.create(1, updaterStateSize);
 | 
				
			||||||
        updater.setStateViewArray(layer, updaterState, true);
 | 
					        updater.setStateViewArray(layer, updaterState, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -495,7 +495,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
 | 
					            assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
 | 
				
			||||||
            lastG.put(key, lastGTmp);
 | 
					            lastG.put(key, lastGTmp);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        assertEquals(rmsDecay, ((RmsProp)layer.layerConf().getIUpdater()).getRmsDecay(), 1e-4);
 | 
					        assertEquals(rmsDecay, ((RmsProp)layer.getTypedLayerConfiguration().getIUpdater()).getRmsDecay(), 1e-4);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -527,7 +527,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            gradExpected = val.mul(lr);
 | 
					            gradExpected = val.mul(lr);
 | 
				
			||||||
            assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
 | 
					            assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4);
 | 
					        assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -769,7 +769,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            gradExpected = val.mul(lr);
 | 
					            gradExpected = val.mul(lr);
 | 
				
			||||||
            assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
 | 
					            assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4);
 | 
					        assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Test with pretrain == false
 | 
					        //Test with pretrain == false
 | 
				
			||||||
@ -797,7 +797,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
        layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
					        layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
 | 
				
			||||||
        layer.setBackpropGradientsViewArray(gradients);
 | 
					        layer.setBackpropGradientsViewArray(gradients);
 | 
				
			||||||
        updater = UpdaterCreator.getUpdater(layer);
 | 
					        updater = UpdaterCreator.getUpdater(layer);
 | 
				
			||||||
        assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4);
 | 
					        assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -858,11 +858,11 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            //Check first updater block:
 | 
					            //Check first updater block:
 | 
				
			||||||
            UpdaterBlock ub0 = blocks.get(0);
 | 
					            UpdaterBlock ub0 = blocks.get(0);
 | 
				
			||||||
            assertEquals(3, ub0.getLayersAndVariablesInBlock().size());
 | 
					            assertEquals(3, ub0.getLayersAndVariablesInBlock().size());
 | 
				
			||||||
            assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName());
 | 
					            assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
 | 
				
			||||||
            assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(0).getParamName());
 | 
					            assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(0).getParamName());
 | 
				
			||||||
            assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName());
 | 
					            assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
 | 
				
			||||||
            assertEquals(DefaultParamInitializer.BIAS_KEY, ub0.getLayersAndVariablesInBlock().get(1).getParamName());
 | 
					            assertEquals(DefaultParamInitializer.BIAS_KEY, ub0.getLayersAndVariablesInBlock().get(1).getParamName());
 | 
				
			||||||
            assertEquals("l1", ub0.getLayersAndVariablesInBlock().get(2).getLayer().getConfig().getLayerName());
 | 
					            assertEquals("l1", ub0.getLayersAndVariablesInBlock().get(2).getLayer().getTrainingConfig().getLayerName());
 | 
				
			||||||
            assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(2).getParamName());
 | 
					            assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(2).getParamName());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            int nParams0 = 10 * 10 + 10 + 10 * 10;
 | 
					            int nParams0 = 10 * 10 + 10 + 10 * 10;
 | 
				
			||||||
@ -875,7 +875,7 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            //Check second updater block:
 | 
					            //Check second updater block:
 | 
				
			||||||
            UpdaterBlock ub1 = blocks.get(1);
 | 
					            UpdaterBlock ub1 = blocks.get(1);
 | 
				
			||||||
            assertEquals(1, ub1.getLayersAndVariablesInBlock().size());
 | 
					            assertEquals(1, ub1.getLayersAndVariablesInBlock().size());
 | 
				
			||||||
            assertEquals("l1", ub1.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName());
 | 
					            assertEquals("l1", ub1.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
 | 
				
			||||||
            assertEquals(DefaultParamInitializer.BIAS_KEY, ub1.getLayersAndVariablesInBlock().get(0).getParamName());
 | 
					            assertEquals(DefaultParamInitializer.BIAS_KEY, ub1.getLayersAndVariablesInBlock().get(0).getParamName());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            int nParams1 = 10;
 | 
					            int nParams1 = 10;
 | 
				
			||||||
@ -888,9 +888,9 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            //Check third updater block:
 | 
					            //Check third updater block:
 | 
				
			||||||
            UpdaterBlock ub2 = blocks.get(2);
 | 
					            UpdaterBlock ub2 = blocks.get(2);
 | 
				
			||||||
            assertEquals(2, ub2.getLayersAndVariablesInBlock().size());
 | 
					            assertEquals(2, ub2.getLayersAndVariablesInBlock().size());
 | 
				
			||||||
            assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName());
 | 
					            assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
 | 
				
			||||||
            assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub2.getLayersAndVariablesInBlock().get(0).getParamName());
 | 
					            assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub2.getLayersAndVariablesInBlock().get(0).getParamName());
 | 
				
			||||||
            assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName());
 | 
					            assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
 | 
				
			||||||
            assertEquals(DefaultParamInitializer.BIAS_KEY, ub2.getLayersAndVariablesInBlock().get(1).getParamName());
 | 
					            assertEquals(DefaultParamInitializer.BIAS_KEY, ub2.getLayersAndVariablesInBlock().get(1).getParamName());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            int nParams2 = 10 * 10 + 10;
 | 
					            int nParams2 = 10 * 10 + 10;
 | 
				
			||||||
@ -903,9 +903,9 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            //Check fourth updater block:
 | 
					            //Check fourth updater block:
 | 
				
			||||||
            UpdaterBlock ub3 = blocks.get(3);
 | 
					            UpdaterBlock ub3 = blocks.get(3);
 | 
				
			||||||
            assertEquals(2, ub3.getLayersAndVariablesInBlock().size());
 | 
					            assertEquals(2, ub3.getLayersAndVariablesInBlock().size());
 | 
				
			||||||
            assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName());
 | 
					            assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
 | 
				
			||||||
            assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub3.getLayersAndVariablesInBlock().get(0).getParamName());
 | 
					            assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub3.getLayersAndVariablesInBlock().get(0).getParamName());
 | 
				
			||||||
            assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName());
 | 
					            assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
 | 
				
			||||||
            assertEquals(DefaultParamInitializer.BIAS_KEY, ub3.getLayersAndVariablesInBlock().get(1).getParamName());
 | 
					            assertEquals(DefaultParamInitializer.BIAS_KEY, ub3.getLayersAndVariablesInBlock().get(1).getParamName());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            int nParams3 = 10 * 10 + 10;
 | 
					            int nParams3 = 10 * 10 + 10;
 | 
				
			||||||
@ -918,9 +918,9 @@ public class TestUpdaters extends BaseDL4JTest {
 | 
				
			|||||||
            //Check fifth updater black
 | 
					            //Check fifth updater black
 | 
				
			||||||
            UpdaterBlock ub4 = blocks.get(4);
 | 
					            UpdaterBlock ub4 = blocks.get(4);
 | 
				
			||||||
            assertEquals(2, ub4.getLayersAndVariablesInBlock().size());
 | 
					            assertEquals(2, ub4.getLayersAndVariablesInBlock().size());
 | 
				
			||||||
            assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName());
 | 
					            assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
 | 
				
			||||||
            assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub4.getLayersAndVariablesInBlock().get(0).getParamName());
 | 
					            assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub4.getLayersAndVariablesInBlock().get(0).getParamName());
 | 
				
			||||||
            assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName());
 | 
					            assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
 | 
				
			||||||
            assertEquals(DefaultParamInitializer.BIAS_KEY, ub4.getLayersAndVariablesInBlock().get(1).getParamName());
 | 
					            assertEquals(DefaultParamInitializer.BIAS_KEY, ub4.getLayersAndVariablesInBlock().get(1).getParamName());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            int nParams4 = 10 * 10 + 10;
 | 
					            int nParams4 = 10 * 10 + 10;
 | 
				
			||||||
 | 
				
			|||||||
@ -22,7 +22,7 @@ package org.deeplearning4j.nn.updater.custom;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.BaseDL4JTest;
 | 
					import org.deeplearning4j.BaseDL4JTest;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
					import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
				
			||||||
@ -61,18 +61,18 @@ public class TestCustomUpdater extends BaseDL4JTest {
 | 
				
			|||||||
                        .build();
 | 
					                        .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //First: Check updater config
 | 
					        //First: Check updater config
 | 
				
			||||||
        assertTrue(((BaseLayer) conf1.getConf(0).getLayer()).getIUpdater() instanceof CustomIUpdater);
 | 
					        assertTrue(((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater() instanceof CustomIUpdater);
 | 
				
			||||||
        assertTrue(((BaseLayer) conf1.getConf(1).getLayer()).getIUpdater() instanceof CustomIUpdater);
 | 
					        assertTrue(((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater() instanceof CustomIUpdater);
 | 
				
			||||||
        assertTrue(((BaseLayer) conf2.getConf(0).getLayer()).getIUpdater() instanceof Sgd);
 | 
					        assertTrue(((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater() instanceof Sgd);
 | 
				
			||||||
        assertTrue(((BaseLayer) conf2.getConf(1).getLayer()).getIUpdater() instanceof Sgd);
 | 
					        assertTrue(((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater() instanceof Sgd);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayer) conf1.getConf(0).getLayer()).getIUpdater();
 | 
					        CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater();
 | 
				
			||||||
        CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayer) conf1.getConf(1).getLayer()).getIUpdater();
 | 
					        CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater();
 | 
				
			||||||
        assertEquals(lr, u0_0.getLearningRate(), 1e-6);
 | 
					        assertEquals(lr, u0_0.getLearningRate(), 1e-6);
 | 
				
			||||||
        assertEquals(lr, u0_1.getLearningRate(), 1e-6);
 | 
					        assertEquals(lr, u0_1.getLearningRate(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Sgd u1_0 = (Sgd) ((BaseLayer) conf2.getConf(0).getLayer()).getIUpdater();
 | 
					        Sgd u1_0 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater();
 | 
				
			||||||
        Sgd u1_1 = (Sgd) ((BaseLayer) conf2.getConf(1).getLayer()).getIUpdater();
 | 
					        Sgd u1_1 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater();
 | 
				
			||||||
        assertEquals(lr, u1_0.getLearningRate(), 1e-6);
 | 
					        assertEquals(lr, u1_0.getLearningRate(), 1e-6);
 | 
				
			||||||
        assertEquals(lr, u1_1.getLearningRate(), 1e-6);
 | 
					        assertEquals(lr, u1_1.getLearningRate(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -81,7 +81,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
 | 
				
			|||||||
        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, layer.getOptimizer());
 | 
					        BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, layer.getOptimizer());
 | 
				
			||||||
        double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
 | 
					        double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(1.0, step, 1e-3);
 | 
					        assertEquals(1.0, step, 1e-3);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -97,11 +97,11 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
 | 
				
			|||||||
        layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
        layer.setLabels(irisData.getLabels());
 | 
					        layer.setLabels(irisData.getLabels());
 | 
				
			||||||
        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
        score1 = layer.score();
 | 
					        score1 = layer.getScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        BackTrackLineSearch lineSearch =
 | 
					        BackTrackLineSearch lineSearch =
 | 
				
			||||||
                        new BackTrackLineSearch(layer, new NegativeDefaultStepFunction(), layer.getOptimizer());
 | 
					                        new BackTrackLineSearch(layer, new NegativeDefaultStepFunction(), layer.getOptimizer());
 | 
				
			||||||
        double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
 | 
					        double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(1.0, step, 1e-3);
 | 
					        assertEquals(1.0, step, 1e-3);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -118,18 +118,18 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
 | 
				
			|||||||
        layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
        layer.setLabels(irisData.getLabels());
 | 
					        layer.setLabels(irisData.getLabels());
 | 
				
			||||||
        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
        score1 = layer.score();
 | 
					        score1 = layer.getScore();
 | 
				
			||||||
        INDArray origGradient = layer.gradient().gradient().dup();
 | 
					        INDArray origGradient = layer.gradient().gradient().dup();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        NegativeDefaultStepFunction sf = new NegativeDefaultStepFunction();
 | 
					        NegativeDefaultStepFunction sf = new NegativeDefaultStepFunction();
 | 
				
			||||||
        BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
 | 
					        BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
 | 
				
			||||||
        double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
 | 
					        double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
 | 
				
			||||||
        INDArray currParams = layer.params();
 | 
					        INDArray currParams = layer.getModelParams();
 | 
				
			||||||
        sf.step(currParams, origGradient, step);
 | 
					        sf.step(currParams, origGradient, step);
 | 
				
			||||||
        layer.setParamsTable(currParams);
 | 
					        layer.setParamsTable(currParams);
 | 
				
			||||||
        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        score2 = layer.score();
 | 
					        score2 = layer.getScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertTrue(score1 > score2, "score1=" + score1 + ", score2=" + score2);
 | 
					        assertTrue(score1 > score2, "score1=" + score1 + ", score2=" + score2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -146,19 +146,19 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
 | 
				
			|||||||
        layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
        layer.setLabels(irisData.getLabels());
 | 
					        layer.setLabels(irisData.getLabels());
 | 
				
			||||||
        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
        score1 = layer.score();
 | 
					        score1 = layer.getScore();
 | 
				
			||||||
        INDArray origGradient = layer.gradient().gradient().dup();
 | 
					        INDArray origGradient = layer.gradient().gradient().dup();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DefaultStepFunction sf = new DefaultStepFunction();
 | 
					        DefaultStepFunction sf = new DefaultStepFunction();
 | 
				
			||||||
        BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
 | 
					        BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
 | 
				
			||||||
        double step = lineSearch.optimize(layer.params().dup(), layer.gradient().gradient().dup(),
 | 
					        double step = lineSearch.optimize(layer.getModelParams().dup(), layer.gradient().gradient().dup(),
 | 
				
			||||||
                        layer.gradient().gradient().dup(), LayerWorkspaceMgr.noWorkspacesImmutable());
 | 
					                        layer.gradient().gradient().dup(), LayerWorkspaceMgr.noWorkspacesImmutable());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray currParams = layer.params();
 | 
					        INDArray currParams = layer.getModelParams();
 | 
				
			||||||
        sf.step(currParams, origGradient, step);
 | 
					        sf.step(currParams, origGradient, step);
 | 
				
			||||||
        layer.setParamsTable(currParams);
 | 
					        layer.setParamsTable(currParams);
 | 
				
			||||||
        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					        layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
        score2 = layer.score();
 | 
					        score2 = layer.getScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertTrue(score1 < score2, "score1 = " + score1 + ", score2 = " + score2);
 | 
					        assertTrue(score1 < score2, "score1 = " + score1 + ", score2 = " + score2);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -190,12 +190,12 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer));
 | 
					        MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer));
 | 
				
			||||||
        network.init();
 | 
					        network.init();
 | 
				
			||||||
        TrainingListener listener = new ScoreIterationListener(10);
 | 
					        TrainingListener listener = new ScoreIterationListener(10);
 | 
				
			||||||
        network.setListeners(Collections.singletonList(listener));
 | 
					        network.addTrainingListeners(Collections.singletonList(listener));
 | 
				
			||||||
        double oldScore = network.score(data);
 | 
					        double oldScore = network.score(data);
 | 
				
			||||||
        for( int i=0; i<100; i++ ) {
 | 
					        for( int i=0; i<100; i++ ) {
 | 
				
			||||||
            network.fit(data.getFeatures(), data.getLabels());
 | 
					            network.fit(data.getFeatures(), data.getLabels());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        double score = network.score();
 | 
					        double score = network.getScore();
 | 
				
			||||||
        assertTrue(score < oldScore);
 | 
					        assertTrue(score < oldScore);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -208,13 +208,13 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
 | 
					        MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
 | 
				
			||||||
        network.init();
 | 
					        network.init();
 | 
				
			||||||
        TrainingListener listener = new ScoreIterationListener(10);
 | 
					        TrainingListener listener = new ScoreIterationListener(10);
 | 
				
			||||||
        network.setListeners(Collections.singletonList(listener));
 | 
					        network.addTrainingListeners(Collections.singletonList(listener));
 | 
				
			||||||
        double firstScore = network.score(data);
 | 
					        double firstScore = network.score(data);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for( int i=0; i<5; i++ ) {
 | 
					        for( int i=0; i<5; i++ ) {
 | 
				
			||||||
            network.fit(data.getFeatures(), data.getLabels());
 | 
					            network.fit(data.getFeatures(), data.getLabels());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        double score = network.score();
 | 
					        double score = network.getScore();
 | 
				
			||||||
        assertTrue(score < firstScore);
 | 
					        assertTrue(score < firstScore);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -227,13 +227,13 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
 | 
					        MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
 | 
				
			||||||
        network.init();
 | 
					        network.init();
 | 
				
			||||||
        TrainingListener listener = new ScoreIterationListener(10);
 | 
					        TrainingListener listener = new ScoreIterationListener(10);
 | 
				
			||||||
        network.setListeners(Collections.singletonList(listener));
 | 
					        network.addTrainingListeners(Collections.singletonList(listener));
 | 
				
			||||||
        double oldScore = network.score(data);
 | 
					        double oldScore = network.score(data);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for( int i=0; i<5; i++ ) {
 | 
					        for( int i=0; i<5; i++ ) {
 | 
				
			||||||
            network.fit(data.getFeatures(), data.getLabels());
 | 
					            network.fit(data.getFeatures(), data.getLabels());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        double score = network.score();
 | 
					        double score = network.getScore();
 | 
				
			||||||
        assertTrue(score < oldScore);
 | 
					        assertTrue(score < oldScore);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -28,6 +28,7 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
 | 
				
			|||||||
import org.deeplearning4j.nn.api.*;
 | 
					import org.deeplearning4j.nn.api.*;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.CacheMode;
 | 
					import org.deeplearning4j.nn.conf.CacheMode;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
 | 
					import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
					import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
				
			||||||
@ -211,38 +212,38 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
            System.out.println("---------\n Alg= " + oa + ", nIter= " + numLineSearchIter + ", nDimensions= "
 | 
					            System.out.println("---------\n Alg= " + oa + ", nIter= " + numLineSearchIter + ", nDimensions= "
 | 
				
			||||||
                            + nDimensions);
 | 
					                            + nDimensions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        NeuralNetConfiguration conf = NeuralNetConfiguration.builder().maxNumLineSearchIterations(numLineSearchIter)
 | 
					        LayerConfiguration conf = NeuralNetConfiguration.builder().maxNumLineSearchIterations(numLineSearchIter)
 | 
				
			||||||
                        .updater(new Sgd(1e-2))
 | 
					                        .updater(new Sgd(1e-2))
 | 
				
			||||||
                        .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
 | 
					                        .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build().getFlattenedLayerConfigurations().get(0);
 | 
				
			||||||
        conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
 | 
					        conf.addVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Random rng = new DefaultRandom(12345L);
 | 
					        Random rng = new DefaultRandom(12345L);
 | 
				
			||||||
        org.nd4j.linalg.api.rng.distribution.Distribution dist =
 | 
					        org.nd4j.linalg.api.rng.distribution.Distribution dist =
 | 
				
			||||||
                        new org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution(rng, -10, 10);
 | 
					                        new org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution(rng, -10, 10);
 | 
				
			||||||
        IModel m = new SphereFunctionModel(nDimensions, dist, conf);
 | 
					        IModel m = new SphereFunctionModel(nDimensions, dist, conf);
 | 
				
			||||||
        m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					        m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
        double scoreBefore = m.score();
 | 
					        double scoreBefore = m.getScore();
 | 
				
			||||||
        assertTrue(!Double.isNaN(scoreBefore) && !Double.isInfinite(scoreBefore));
 | 
					        assertTrue(!Double.isNaN(scoreBefore) && !Double.isInfinite(scoreBefore));
 | 
				
			||||||
        if (PRINT_OPT_RESULTS) {
 | 
					        if (PRINT_OPT_RESULTS) {
 | 
				
			||||||
            System.out.println("Before:");
 | 
					            System.out.println("Before:");
 | 
				
			||||||
            System.out.println(scoreBefore);
 | 
					            System.out.println(scoreBefore);
 | 
				
			||||||
            System.out.println(m.params());
 | 
					            System.out.println(m.getModelParams());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ConvexOptimizer opt = getOptimizer(oa, conf, m);
 | 
					        ConvexOptimizer opt = getOptimizer(oa, conf.getNetConfiguration(), m);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        opt.setupSearchState(m.gradientAndScore());
 | 
					        opt.setupSearchState(m.gradientAndScore());
 | 
				
			||||||
        for( int i=0; i<100; i++ ) {
 | 
					        for( int i=0; i<100; i++ ) {
 | 
				
			||||||
            opt.optimize(LayerWorkspaceMgr.noWorkspaces());
 | 
					            opt.optimize(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					        m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
        double scoreAfter = m.score();
 | 
					        double scoreAfter = m.getScore();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertTrue(!Double.isNaN(scoreAfter) && !Double.isInfinite(scoreAfter));
 | 
					        assertTrue(!Double.isNaN(scoreAfter) && !Double.isInfinite(scoreAfter));
 | 
				
			||||||
        if (PRINT_OPT_RESULTS) {
 | 
					        if (PRINT_OPT_RESULTS) {
 | 
				
			||||||
            System.out.println("After:");
 | 
					            System.out.println("After:");
 | 
				
			||||||
            System.out.println(scoreAfter);
 | 
					            System.out.println(scoreAfter);
 | 
				
			||||||
            System.out.println(m.params());
 | 
					            System.out.println(m.getModelParams());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Expected behaviour after optimization:
 | 
					        //Expected behaviour after optimization:
 | 
				
			||||||
@ -279,17 +280,17 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
                            .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
 | 
					                            .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
 | 
				
			||||||
            conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
 | 
					            conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            IModel m = new SphereFunctionModel(100, dist, conf);
 | 
					            IModel m = new SphereFunctionModel(100, dist, conf.getFlattenedLayerConfigurations().get(0));
 | 
				
			||||||
            if (i == 0) {
 | 
					            if (i == 0) {
 | 
				
			||||||
                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
                scores[0] = m.score(); //Before optimization
 | 
					                scores[0] = m.getScore(); //Before optimization
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
                ConvexOptimizer opt = getOptimizer(oa, conf, m);
 | 
					                ConvexOptimizer opt = getOptimizer(oa, conf, m);
 | 
				
			||||||
                for( int j=0; j<100; j++ ) {
 | 
					                for( int j=0; j<100; j++ ) {
 | 
				
			||||||
                    opt.optimize(LayerWorkspaceMgr.noWorkspaces());
 | 
					                    opt.optimize(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
                scores[i] = m.score();
 | 
					                scores[i] = m.getScore();
 | 
				
			||||||
                assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
 | 
					                assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -316,7 +317,7 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
        private static final long serialVersionUID = -6963606137417355405L;
 | 
					        private static final long serialVersionUID = -6963606137417355405L;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        private SphereFunctionModel(int nParams, org.nd4j.linalg.api.rng.distribution.Distribution distribution,
 | 
					        private SphereFunctionModel(int nParams, org.nd4j.linalg.api.rng.distribution.Distribution distribution,
 | 
				
			||||||
                        NeuralNetConfiguration conf) {
 | 
					                        LayerConfiguration conf) {
 | 
				
			||||||
            super(distribution.sample(new int[] {1, nParams}), conf);
 | 
					            super(distribution.sample(new int[] {1, nParams}), conf);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -437,7 +438,7 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public void setListeners(TrainingListener... listeners) {
 | 
					        public void addTrainingListeners(TrainingListener... listeners) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -499,17 +500,17 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
                            .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
 | 
					                            .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
 | 
				
			||||||
            conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
 | 
					            conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            IModel m = new RastriginFunctionModel(10, conf);
 | 
					            IModel m = new RastriginFunctionModel(10, conf.getFlattenedLayerConfigurations().get(0));
 | 
				
			||||||
            int nParams = (int)m.numParams();
 | 
					            int nParams = (int)m.numParams();
 | 
				
			||||||
            if (i == 0) {
 | 
					            if (i == 0) {
 | 
				
			||||||
                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
                scores[0] = m.score(); //Before optimization
 | 
					                scores[0] = m.getScore(); //Before optimization
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
                ConvexOptimizer opt = getOptimizer(oa, conf, m);
 | 
					                ConvexOptimizer opt = getOptimizer(oa, conf, m);
 | 
				
			||||||
                opt.getUpdater().setStateViewArray((Layer) m, Nd4j.create(new int[] {1, nParams}, 'c'), true);
 | 
					                opt.getUpdater().setStateViewArray((Layer) m, Nd4j.create(new int[] {1, nParams}, 'c'), true);
 | 
				
			||||||
                opt.optimize(LayerWorkspaceMgr.noWorkspaces());
 | 
					                opt.optimize(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
                scores[i] = m.score();
 | 
					                scores[i] = m.getScore();
 | 
				
			||||||
                assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
 | 
					                assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -540,7 +541,7 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
    private static class RastriginFunctionModel extends SimpleOptimizableModel {
 | 
					    private static class RastriginFunctionModel extends SimpleOptimizableModel {
 | 
				
			||||||
        private static final long serialVersionUID = -1772954508787487941L;
 | 
					        private static final long serialVersionUID = -1772954508787487941L;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        private RastriginFunctionModel(int nDimensions, NeuralNetConfiguration conf) {
 | 
					        private RastriginFunctionModel(int nDimensions, LayerConfiguration conf) {
 | 
				
			||||||
            super(initParams(nDimensions), conf);
 | 
					            super(initParams(nDimensions), conf);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -710,7 +711,7 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public void setListeners(TrainingListener... listeners) {
 | 
					        public void addTrainingListeners(TrainingListener... listeners) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -768,15 +769,15 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
                            .build();
 | 
					                            .build();
 | 
				
			||||||
            conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
 | 
					            conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            IModel m = new RosenbrockFunctionModel(100, conf);
 | 
					            IModel m = new RosenbrockFunctionModel(100, conf.getFlattenedLayerConfigurations().get(0));
 | 
				
			||||||
            if (i == 0) {
 | 
					            if (i == 0) {
 | 
				
			||||||
                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
                scores[0] = m.score(); //Before optimization
 | 
					                scores[0] = m.getScore(); //Before optimization
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
                ConvexOptimizer opt = getOptimizer(oa, conf, m);
 | 
					                ConvexOptimizer opt = getOptimizer(oa, conf, m);
 | 
				
			||||||
                opt.optimize(LayerWorkspaceMgr.noWorkspaces());
 | 
					                opt.optimize(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					                m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
                scores[i] = m.score();
 | 
					                scores[i] = m.getScore();
 | 
				
			||||||
                assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]), "NaN or infinite score: " + scores[i]);
 | 
					                assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]), "NaN or infinite score: " + scores[i]);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -810,7 +811,7 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
    private static class RosenbrockFunctionModel extends SimpleOptimizableModel {
 | 
					    private static class RosenbrockFunctionModel extends SimpleOptimizableModel {
 | 
				
			||||||
        private static final long serialVersionUID = -5129494342531033706L;
 | 
					        private static final long serialVersionUID = -5129494342531033706L;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        private RosenbrockFunctionModel(int nDimensions, NeuralNetConfiguration conf) {
 | 
					        private RosenbrockFunctionModel(int nDimensions, LayerConfiguration conf) {
 | 
				
			||||||
            super(initParams(nDimensions), conf);
 | 
					            super(initParams(nDimensions), conf);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -995,7 +996,7 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public void setListeners(TrainingListener... listeners) {
 | 
					        public void addTrainingListeners(TrainingListener... listeners) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1029,13 +1030,31 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
        private static final long serialVersionUID = 4409380971404019303L;
 | 
					        private static final long serialVersionUID = 4409380971404019303L;
 | 
				
			||||||
        protected INDArray parameters;
 | 
					        protected INDArray parameters;
 | 
				
			||||||
        protected INDArray gradientView;
 | 
					        protected INDArray gradientView;
 | 
				
			||||||
        protected final NeuralNetConfiguration conf;
 | 
					        protected final LayerConfiguration conf;
 | 
				
			||||||
        protected Gradient gradient;
 | 
					        protected Gradient gradient;
 | 
				
			||||||
        protected double score;
 | 
					        protected double score;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /**
 | 
				
			||||||
 | 
					         * @return 1d parameter vector
 | 
				
			||||||
 | 
					         */
 | 
				
			||||||
 | 
					        @Override
 | 
				
			||||||
 | 
					        public INDArray getParams() {
 | 
				
			||||||
 | 
					            throw new RuntimeException("Not implemented");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /**
 | 
				
			||||||
 | 
					         * Get a reference to the network this layer is part of.
 | 
				
			||||||
 | 
					         *
 | 
				
			||||||
 | 
					         * @return
 | 
				
			||||||
 | 
					         */
 | 
				
			||||||
 | 
					        @Override
 | 
				
			||||||
 | 
					        public IModel getNet() {
 | 
				
			||||||
 | 
					           throw new RuntimeException("Not implemented");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        /**@param parameterInit Initial parameters. Also determines dimensionality of problem. Should be row vector.
 | 
					        /**@param parameterInit Initial parameters. Also determines dimensionality of problem. Should be row vector.
 | 
				
			||||||
         */
 | 
					         */
 | 
				
			||||||
        private SimpleOptimizableModel(INDArray parameterInit, NeuralNetConfiguration conf) {
 | 
					        private SimpleOptimizableModel(INDArray parameterInit, LayerConfiguration conf) {
 | 
				
			||||||
            this.parameters = parameterInit.dup();
 | 
					            this.parameters = parameterInit.dup();
 | 
				
			||||||
            this.gradientView = Nd4j.create(parameterInit.shape());
 | 
					            this.gradientView = Nd4j.create(parameterInit.shape());
 | 
				
			||||||
            this.conf = conf;
 | 
					            this.conf = conf;
 | 
				
			||||||
@ -1048,17 +1067,12 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
         */
 | 
					         */
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public LayerConfiguration getLayerConfiguration() {
 | 
					        public LayerConfiguration getLayerConfiguration() {
 | 
				
			||||||
            return this.conf.getFirstLayer();
 | 
					            return this.conf;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public void addListeners(TrainingListener... listener) {
 | 
					        public ITraininableLayerConfiguration getTrainingConfig() {
 | 
				
			||||||
            // no-op
 | 
					            return (BaseLayerConfiguration) conf;
 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        @Override
 | 
					 | 
				
			||||||
        public TrainingConfig getConfig() {
 | 
					 | 
				
			||||||
            return conf.getFirstLayer();
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        /**
 | 
					        /**
 | 
				
			||||||
@ -1092,7 +1106,7 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public void setListeners(TrainingListener... listeners) {
 | 
					        public void addTrainingListeners(TrainingListener... listeners) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1112,7 +1126,7 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public double score() {
 | 
					        public double getScore() {
 | 
				
			||||||
            return score;
 | 
					            return score;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1132,7 +1146,7 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public INDArray params() {
 | 
					        public INDArray getModelParams() {
 | 
				
			||||||
            return parameters;
 | 
					            return parameters;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1154,7 +1168,7 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public Pair<Gradient, Double> gradientAndScore() {
 | 
					        public Pair<Gradient, Double> gradientAndScore() {
 | 
				
			||||||
            computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
					            computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
 | 
				
			||||||
            return new Pair<>(gradient(), score());
 | 
					            return new Pair<>(gradient(), getScore());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
@ -1164,7 +1178,7 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public NeuralNetConfiguration getNetConfiguration() {
 | 
					        public NeuralNetConfiguration getNetConfiguration() {
 | 
				
			||||||
            return conf;
 | 
					            return conf.getNetConfiguration();
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
@ -1225,12 +1239,12 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public Collection<TrainingListener> getListeners() {
 | 
					        public Collection<TrainingListener> getTrainingListeners() {
 | 
				
			||||||
            return null;
 | 
					            return null;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Override
 | 
					        @Override
 | 
				
			||||||
        public void setListeners(Collection<TrainingListener> listeners) {
 | 
					        public void addTrainingListeners(Collection<TrainingListener> listeners) {
 | 
				
			||||||
            throw new UnsupportedOperationException();
 | 
					            throw new UnsupportedOperationException();
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1310,4 +1324,6 @@ public class TestOptimizers extends BaseDL4JTest {
 | 
				
			|||||||
        public void close(){
 | 
					        public void close(){
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -76,7 +76,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
 | 
				
			|||||||
                .keepAll()
 | 
					                .keepAll()
 | 
				
			||||||
                .saveEveryNEpochs(2)
 | 
					                .saveEveryNEpochs(2)
 | 
				
			||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
        net.setListeners(l);
 | 
					        net.addTrainingListeners(l);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for(int i=0; i<10; i++ ){
 | 
					        for(int i=0; i<10; i++ ){
 | 
				
			||||||
            net.fit(iter);
 | 
					            net.fit(iter);
 | 
				
			||||||
@ -125,7 +125,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
 | 
				
			|||||||
                .keepLast(3)
 | 
					                .keepLast(3)
 | 
				
			||||||
                .saveEveryNIterations(5)
 | 
					                .saveEveryNIterations(5)
 | 
				
			||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
        net.setListeners(l);
 | 
					        net.addTrainingListeners(l);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for(int i=0; i<20; i++ ){   //40 iterations total
 | 
					        for(int i=0; i<20; i++ ){   //40 iterations total
 | 
				
			||||||
            net.fit(iter);
 | 
					            net.fit(iter);
 | 
				
			||||||
@ -167,7 +167,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork netStatic2 = CheckpointListener.loadLastCheckpointMLN(f);
 | 
					        MultiLayerNetwork netStatic2 = CheckpointListener.loadLastCheckpointMLN(f);
 | 
				
			||||||
        assertEquals(35, netStatic2.getIterationCount());
 | 
					        assertEquals(35, netStatic2.getIterationCount());
 | 
				
			||||||
        assertEquals(netStatic.params(), netStatic2.params());
 | 
					        assertEquals(netStatic.getModelParams(), netStatic2.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@ -182,7 +182,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
 | 
				
			|||||||
                .keepLast(3)
 | 
					                .keepLast(3)
 | 
				
			||||||
                .saveEvery(4900, TimeUnit.MILLISECONDS)
 | 
					                .saveEvery(4900, TimeUnit.MILLISECONDS)
 | 
				
			||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
        net.setListeners(l);
 | 
					        net.addTrainingListeners(l);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for(int i=0; i<3; i++ ){   //10 iterations total
 | 
					        for(int i=0; i<3; i++ ){   //10 iterations total
 | 
				
			||||||
            net.fit(iter);
 | 
					            net.fit(iter);
 | 
				
			||||||
@ -226,7 +226,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
 | 
				
			|||||||
                .keepLastAndEvery(3, 3)
 | 
					                .keepLastAndEvery(3, 3)
 | 
				
			||||||
                .saveEveryNEpochs(2)
 | 
					                .saveEveryNEpochs(2)
 | 
				
			||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
        net.setListeners(l);
 | 
					        net.addTrainingListeners(l);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for(int i=0; i<20; i++ ){   //40 iterations total
 | 
					        for(int i=0; i<20; i++ ){   //40 iterations total
 | 
				
			||||||
            net.fit(iter);
 | 
					            net.fit(iter);
 | 
				
			||||||
@ -272,7 +272,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
 | 
				
			|||||||
                .keepAll()
 | 
					                .keepAll()
 | 
				
			||||||
                .saveEveryNEpochs(1)
 | 
					                .saveEveryNEpochs(1)
 | 
				
			||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
        net.setListeners(l);
 | 
					        net.addTrainingListeners(l);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for(int i=0; i<3; i++ ){
 | 
					        for(int i=0; i<3; i++ ){
 | 
				
			||||||
            net.fit(iter);
 | 
					            net.fit(iter);
 | 
				
			||||||
@ -294,7 +294,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
 | 
				
			|||||||
                .saveEveryNEpochs(1)
 | 
					                .saveEveryNEpochs(1)
 | 
				
			||||||
                .deleteExisting(true)
 | 
					                .deleteExisting(true)
 | 
				
			||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
        net.setListeners(l);
 | 
					        net.addTrainingListeners(l);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.fit(iter);
 | 
					        net.fit(iter);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -58,7 +58,7 @@ public class TestFailureListener extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setListeners(new FailureTestingListener(
 | 
					        net.addTrainingListeners(new FailureTestingListener(
 | 
				
			||||||
//                FailureTestingListener.FailureMode.OOM,
 | 
					//                FailureTestingListener.FailureMode.OOM,
 | 
				
			||||||
                FailureTestingListener.FailureMode.SYSTEM_EXIT_1,
 | 
					                FailureTestingListener.FailureMode.SYSTEM_EXIT_1,
 | 
				
			||||||
                new FailureTestingListener.IterationEpochTrigger(false, 10)));
 | 
					                new FailureTestingListener.IterationEpochTrigger(false, 10)));
 | 
				
			||||||
@ -84,7 +84,7 @@ public class TestFailureListener extends BaseDL4JTest {
 | 
				
			|||||||
        assertNotNull(username);
 | 
					        assertNotNull(username);
 | 
				
			||||||
        assertFalse(username.isEmpty());
 | 
					        assertFalse(username.isEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setListeners(new FailureTestingListener(
 | 
					        net.addTrainingListeners(new FailureTestingListener(
 | 
				
			||||||
                FailureTestingListener.FailureMode.SYSTEM_EXIT_1,
 | 
					                FailureTestingListener.FailureMode.SYSTEM_EXIT_1,
 | 
				
			||||||
                new FailureTestingListener.Or(
 | 
					                new FailureTestingListener.Or(
 | 
				
			||||||
                        new FailureTestingListener.IterationEpochTrigger(false, 10000),
 | 
					                        new FailureTestingListener.IterationEpochTrigger(false, 10000),
 | 
				
			||||||
@ -112,7 +112,7 @@ public class TestFailureListener extends BaseDL4JTest {
 | 
				
			|||||||
        assertNotNull(hostname);
 | 
					        assertNotNull(hostname);
 | 
				
			||||||
        assertFalse(hostname.isEmpty());
 | 
					        assertFalse(hostname.isEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setListeners(new FailureTestingListener(
 | 
					        net.addTrainingListeners(new FailureTestingListener(
 | 
				
			||||||
                FailureTestingListener.FailureMode.ILLEGAL_STATE,
 | 
					                FailureTestingListener.FailureMode.ILLEGAL_STATE,
 | 
				
			||||||
                new FailureTestingListener.And(
 | 
					                new FailureTestingListener.And(
 | 
				
			||||||
                        new FailureTestingListener.HostNameTrigger(hostname),
 | 
					                        new FailureTestingListener.HostNameTrigger(hostname),
 | 
				
			||||||
 | 
				
			|||||||
@ -77,17 +77,17 @@ public class TestListeners extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setListeners(new ScoreIterationListener(), new TestRoutingListener());
 | 
					        net.addTrainingListeners(new ScoreIterationListener(), new TestRoutingListener());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (Layer l : net.getLayers()) {
 | 
					        for (Layer l : net.getLayers()) {
 | 
				
			||||||
            Collection<TrainingListener> layerListeners = l.getListeners();
 | 
					            Collection<TrainingListener> layerListeners = l.getTrainingListeners();
 | 
				
			||||||
            assertEquals(2, layerListeners.size(), l.getClass().toString());
 | 
					            assertEquals(2, layerListeners.size(), l.getClass().toString());
 | 
				
			||||||
            TrainingListener[] lArr = layerListeners.toArray(new TrainingListener[2]);
 | 
					            TrainingListener[] lArr = layerListeners.toArray(new TrainingListener[2]);
 | 
				
			||||||
            assertTrue(lArr[0] instanceof ScoreIterationListener);
 | 
					            assertTrue(lArr[0] instanceof ScoreIterationListener);
 | 
				
			||||||
            assertTrue(lArr[1] instanceof TestRoutingListener);
 | 
					            assertTrue(lArr[1] instanceof TestRoutingListener);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Collection<TrainingListener> netListeners = net.getListeners();
 | 
					        Collection<TrainingListener> netListeners = net.getTrainingListeners();
 | 
				
			||||||
        assertEquals(2, netListeners.size());
 | 
					        assertEquals(2, netListeners.size());
 | 
				
			||||||
        TrainingListener[] lArr = netListeners.toArray(new TrainingListener[2]);
 | 
					        TrainingListener[] lArr = netListeners.toArray(new TrainingListener[2]);
 | 
				
			||||||
        assertTrue(lArr[0] instanceof ScoreIterationListener);
 | 
					        assertTrue(lArr[0] instanceof ScoreIterationListener);
 | 
				
			||||||
@ -101,17 +101,17 @@ public class TestListeners extends BaseDL4JTest {
 | 
				
			|||||||
        ComputationGraph cg = new ComputationGraph(gConf);
 | 
					        ComputationGraph cg = new ComputationGraph(gConf);
 | 
				
			||||||
        cg.init();
 | 
					        cg.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cg.setListeners(new ScoreIterationListener(), new TestRoutingListener());
 | 
					        cg.addTrainingListeners(new ScoreIterationListener(), new TestRoutingListener());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (Layer l : cg.getLayers()) {
 | 
					        for (Layer l : cg.getLayers()) {
 | 
				
			||||||
            Collection<TrainingListener> layerListeners = l.getListeners();
 | 
					            Collection<TrainingListener> layerListeners = l.getTrainingListeners();
 | 
				
			||||||
            assertEquals(2, layerListeners.size());
 | 
					            assertEquals(2, layerListeners.size());
 | 
				
			||||||
            lArr = layerListeners.toArray(new TrainingListener[2]);
 | 
					            lArr = layerListeners.toArray(new TrainingListener[2]);
 | 
				
			||||||
            assertTrue(lArr[0] instanceof ScoreIterationListener);
 | 
					            assertTrue(lArr[0] instanceof ScoreIterationListener);
 | 
				
			||||||
            assertTrue(lArr[1] instanceof TestRoutingListener);
 | 
					            assertTrue(lArr[1] instanceof TestRoutingListener);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        netListeners = cg.getListeners();
 | 
					        netListeners = cg.getTrainingListeners();
 | 
				
			||||||
        assertEquals(2, netListeners.size());
 | 
					        assertEquals(2, netListeners.size());
 | 
				
			||||||
        lArr = netListeners.toArray(new TrainingListener[2]);
 | 
					        lArr = netListeners.toArray(new TrainingListener[2]);
 | 
				
			||||||
        assertTrue(lArr[0] instanceof ScoreIterationListener);
 | 
					        assertTrue(lArr[0] instanceof ScoreIterationListener);
 | 
				
			||||||
@ -180,7 +180,7 @@ public class TestListeners extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        net.setListeners(listeners);
 | 
					        net.addTrainingListeners(listeners);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.fit(iter);
 | 
					        net.fit(iter);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -199,7 +199,7 @@ public class TestListeners extends BaseDL4JTest {
 | 
				
			|||||||
            listeners2.add(il2);
 | 
					            listeners2.add(il2);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        net.setListeners(listeners2);
 | 
					        net.addTrainingListeners(listeners2);
 | 
				
			||||||
        net.fit(iter);
 | 
					        net.fit(iter);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -216,7 +216,7 @@ public class TestListeners extends BaseDL4JTest {
 | 
				
			|||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        TestListener tl = new TestListener();
 | 
					        TestListener tl = new TestListener();
 | 
				
			||||||
        net.setListeners(tl);
 | 
					        net.addTrainingListeners(tl);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
 | 
					        DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -260,7 +260,7 @@ public class TestListeners extends BaseDL4JTest {
 | 
				
			|||||||
        tl = new TestListener();
 | 
					        tl = new TestListener();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ComputationGraph cg = net.toComputationGraph();
 | 
					        ComputationGraph cg = net.toComputationGraph();
 | 
				
			||||||
        cg.setListeners(tl);
 | 
					        cg.addTrainingListeners(tl);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cg.fit(irisIter, 2);
 | 
					        cg.fit(irisIter, 2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -94,7 +94,7 @@ public class RandomTests extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        // at the end of day, model params has to
 | 
					        // at the end of day, model params has to
 | 
				
			||||||
        for (int i = 0; i < models.size(); i++) {
 | 
					        for (int i = 0; i < models.size(); i++) {
 | 
				
			||||||
            assertEquals(models.get(0).params(), models.get(i).params());
 | 
					            assertEquals(models.get(0).getModelParams(), models.get(i).getModelParams());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -119,7 +119,7 @@ public class RandomTests extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net2 = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net2 = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net2.init();
 | 
					        net2.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net1.params(), net2.params());
 | 
					        assertEquals(net1.getModelParams(), net2.getModelParams());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json);
 | 
					        NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -127,6 +127,6 @@ public class RandomTests extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson);
 | 
					        MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson);
 | 
				
			||||||
        net3.init();
 | 
					        net3.init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(net1.params(), net3.params());
 | 
					        assertEquals(net1.getModelParams(), net3.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -63,7 +63,7 @@ public class TestSystemInfoPrintListener extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        net.setListeners(systemInfoFilePrintListener);
 | 
					        net.addTrainingListeners(systemInfoFilePrintListener);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DataSetIterator iter = new IrisDataSetIterator(10, 150);
 | 
					        DataSetIterator iter = new IrisDataSetIterator(10, 150);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -87,7 +87,7 @@ public class RegressionTest050 extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
 | 
					        assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int numParams = (int)net.numParams();
 | 
					        int numParams = (int)net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new Nesterovs().stateSize(net.numParams());
 | 
					        int updaterSize = (int) new Nesterovs().stateSize(net.numParams());
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -126,7 +126,7 @@ public class RegressionTest050 extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));
 | 
					        assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int numParams = (int)net.numParams();
 | 
					        int numParams = (int)net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
					        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -170,7 +170,7 @@ public class RegressionTest050 extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
 | 
					        assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int numParams = (int)net.numParams();
 | 
					        int numParams = (int)net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
					        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -89,7 +89,7 @@ public class RegressionTest060 extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
 | 
					        assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int numParams = (int)net.numParams();
 | 
					        int numParams = (int)net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new Nesterovs().stateSize(numParams);
 | 
					        int updaterSize = (int) new Nesterovs().stateSize(numParams);
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -132,7 +132,7 @@ public class RegressionTest060 extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
 | 
					        assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int numParams = (int)net.numParams();
 | 
					        int numParams = (int)net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
					        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -178,7 +178,7 @@ public class RegressionTest060 extends BaseDL4JTest {
 | 
				
			|||||||
        assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
 | 
					        assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int numParams = (int)net.numParams();
 | 
					        int numParams = (int)net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
					        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -90,7 +90,7 @@ public class RegressionTest071 extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
 | 
					        assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        long numParams = (int)net.numParams();
 | 
					        long numParams = (int)net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new Nesterovs().stateSize(numParams);
 | 
					        int updaterSize = (int) new Nesterovs().stateSize(numParams);
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -133,7 +133,7 @@ public class RegressionTest071 extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
 | 
					        assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        long numParams = net.numParams();
 | 
					        long numParams = net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
					        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -179,7 +179,7 @@ public class RegressionTest071 extends BaseDL4JTest {
 | 
				
			|||||||
        assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
 | 
					        assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        long numParams = net.numParams();
 | 
					        long numParams = net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
					        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -94,7 +94,7 @@ public class RegressionTest080 extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(0.15, n.getLearningRate(), 1e-6);
 | 
					        assertEquals(0.15, n.getLearningRate(), 1e-6);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int numParams = (int)net.numParams();
 | 
					        int numParams = (int)net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new Nesterovs().stateSize(numParams);
 | 
					        int updaterSize = (int) new Nesterovs().stateSize(numParams);
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -143,7 +143,7 @@ public class RegressionTest080 extends BaseDL4JTest {
 | 
				
			|||||||
        assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
 | 
					        assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int numParams = (int)net.numParams();
 | 
					        int numParams = (int)net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
					        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -194,7 +194,7 @@ public class RegressionTest080 extends BaseDL4JTest {
 | 
				
			|||||||
        assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
 | 
					        assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int numParams = (int)net.numParams();
 | 
					        int numParams = (int)net.numParams();
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params());
 | 
					        assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
 | 
				
			||||||
        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
					        int updaterSize = (int) new RmsProp().stateSize(numParams);
 | 
				
			||||||
        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
					        assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -97,7 +97,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            assertEquals(dt, in.dataType());
 | 
					            assertEquals(dt, in.dataType());
 | 
				
			||||||
            assertEquals(dt, outExp.dataType());
 | 
					            assertEquals(dt, outExp.dataType());
 | 
				
			||||||
            assertEquals(dt, net.params().dataType());
 | 
					            assertEquals(dt, net.getModelParams().dataType());
 | 
				
			||||||
            assertEquals(dt, net.getFlattenedGradients().dataType());
 | 
					            assertEquals(dt, net.getFlattenedGradients().dataType());
 | 
				
			||||||
            assertEquals(dt, net.getUpdater().getStateViewArray().dataType());
 | 
					            assertEquals(dt, net.getUpdater().getStateViewArray().dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -109,7 +109,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
 | 
				
			|||||||
            List<INDArray> activations = net.feedForward(in);
 | 
					            List<INDArray> activations = net.feedForward(in);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(dt, net.getNetConfiguration().getDataType());
 | 
					            assertEquals(dt, net.getNetConfiguration().getDataType());
 | 
				
			||||||
            assertEquals(dt, net.params().dataType());
 | 
					            assertEquals(dt, net.getModelParams().dataType());
 | 
				
			||||||
            assertEquals( outExp, outAct, dtype);
 | 
					            assertEquals( outExp, outAct, dtype);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -116,7 +116,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            assertEquals(dtype, in.dataType());
 | 
					            assertEquals(dtype, in.dataType());
 | 
				
			||||||
            assertEquals(dtype, outExp.dataType());
 | 
					            assertEquals(dtype, outExp.dataType());
 | 
				
			||||||
            assertEquals(dtype, net.params().dataType());
 | 
					            assertEquals(dtype, net.getModelParams().dataType());
 | 
				
			||||||
            assertEquals(dtype, net.getFlattenedGradients().dataType());
 | 
					            assertEquals(dtype, net.getFlattenedGradients().dataType());
 | 
				
			||||||
            assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
 | 
					            assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -126,7 +126,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
 | 
				
			|||||||
            assertEquals(dtype, outAct.dataType());
 | 
					            assertEquals(dtype, outAct.dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(dtype, net.getNetConfiguration().getDataType());
 | 
					            assertEquals(dtype, net.getNetConfiguration().getDataType());
 | 
				
			||||||
            assertEquals(dtype, net.params().dataType());
 | 
					            assertEquals(dtype, net.getModelParams().dataType());
 | 
				
			||||||
            boolean eq = outExp.equalsWithEps(outAct, 0.01);
 | 
					            boolean eq = outExp.equalsWithEps(outAct, 0.01);
 | 
				
			||||||
            assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct);
 | 
					            assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
				
			|||||||
@ -98,7 +98,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            assertEquals(dtype, in.dataType());
 | 
					            assertEquals(dtype, in.dataType());
 | 
				
			||||||
            assertEquals(dtype, outExp.dataType());
 | 
					            assertEquals(dtype, outExp.dataType());
 | 
				
			||||||
            assertEquals(dtype, net.params().dataType());
 | 
					            assertEquals(dtype, net.getModelParams().dataType());
 | 
				
			||||||
            assertEquals(dtype, net.getFlattenedGradients().dataType());
 | 
					            assertEquals(dtype, net.getFlattenedGradients().dataType());
 | 
				
			||||||
            assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
 | 
					            assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -108,7 +108,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
 | 
				
			|||||||
            assertEquals(dtype, outAct.dataType());
 | 
					            assertEquals(dtype, outAct.dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            assertEquals(dtype, net.getNetConfiguration().getDataType());
 | 
					            assertEquals(dtype, net.getNetConfiguration().getDataType());
 | 
				
			||||||
            assertEquals(dtype, net.params().dataType());
 | 
					            assertEquals(dtype, net.getModelParams().dataType());
 | 
				
			||||||
            boolean eq = outExp.equalsWithEps(outAct, 0.01);
 | 
					            boolean eq = outExp.equalsWithEps(outAct, 0.01);
 | 
				
			||||||
            assertTrue( eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct);
 | 
					            assertTrue( eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
				
			|||||||
@ -76,7 +76,7 @@ public class CustomLayer extends FeedForwardLayer {
 | 
				
			|||||||
        //For the most part, it's the same for each type of layer
 | 
					        //For the most part, it's the same for each type of layer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        CustomLayerImpl myCustomLayer = new CustomLayerImpl(lconf, networkDataType);
 | 
					        CustomLayerImpl myCustomLayer = new CustomLayerImpl(lconf, networkDataType);
 | 
				
			||||||
        myCustomLayer.setListeners(iterationListeners);             //Set the iteration listeners, if any
 | 
					        myCustomLayer.addTrainingListeners(iterationListeners);             //Set the iteration listeners, if any
 | 
				
			||||||
        myCustomLayer.setIndex(layerIndex);                         //Integer index of the layer
 | 
					        myCustomLayer.setIndex(layerIndex);                         //Integer index of the layer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Parameter view array: In Deeplearning4j, the network parameters for the entire network (all layers) are
 | 
					        //Parameter view array: In Deeplearning4j, the network parameters for the entire network (all layers) are
 | 
				
			||||||
 | 
				
			|||||||
@ -20,7 +20,6 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.deeplearning4j.regressiontest.customlayer100a;
 | 
					package org.deeplearning4j.regressiontest.customlayer100a;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
					import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
 | 
					import org.deeplearning4j.nn.gradient.DefaultGradient;
 | 
				
			||||||
import org.deeplearning4j.nn.gradient.Gradient;
 | 
					import org.deeplearning4j.nn.gradient.Gradient;
 | 
				
			||||||
@ -56,7 +55,7 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
 | 
				
			|||||||
        INDArray firstHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2));
 | 
					        INDArray firstHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2));
 | 
				
			||||||
        INDArray secondHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
 | 
					        INDArray secondHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        IActivation activation1 = layerConf().getActivationFn();
 | 
					        IActivation activation1 = getTypedLayerConfiguration().getActivationFn();
 | 
				
			||||||
        IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
 | 
					        IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //IActivation function instances modify the activation functions in-place
 | 
					        //IActivation function instances modify the activation functions in-place
 | 
				
			||||||
@ -75,7 +74,7 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
 | 
				
			|||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
 | 
					    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
 | 
				
			||||||
        /*
 | 
					        /*
 | 
				
			||||||
        The baockprop gradient method here is very similar to the BaseLayer backprop gradient implementation
 | 
					        The baockprop gradient method here is very similar to the BaseLayerConfiguration backprop gradient implementation
 | 
				
			||||||
        The only major difference is the two activation functions we have added in this example.
 | 
					        The only major difference is the two activation functions we have added in this example.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Note that epsilon is dL/da - i.e., the derivative of the loss function with respect to the activations.
 | 
					        Note that epsilon is dL/da - i.e., the derivative of the loss function with respect to the activations.
 | 
				
			||||||
@ -105,14 +104,14 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
 | 
				
			|||||||
        INDArray epsilonFirstHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2));
 | 
					        INDArray epsilonFirstHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2));
 | 
				
			||||||
        INDArray epsilonSecondHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
 | 
					        INDArray epsilonSecondHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        IActivation activation1 = layerConf().getActivationFn();
 | 
					        IActivation activation1 = getTypedLayerConfiguration().getActivationFn();
 | 
				
			||||||
        IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
 | 
					        IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //IActivation backprop method modifies the 'firstHalf' and 'secondHalf' arrays in-place, to contain dL/dz
 | 
					        //IActivation backprop method modifies the 'firstHalf' and 'secondHalf' arrays in-place, to contain dL/dz
 | 
				
			||||||
        activation1.backprop(firstHalf, epsilonFirstHalf);
 | 
					        activation1.backprop(firstHalf, epsilonFirstHalf);
 | 
				
			||||||
        activation2.backprop(secondHalf, epsilonSecondHalf);
 | 
					        activation2.backprop(secondHalf, epsilonSecondHalf);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //The remaining code for this method: just copy & pasted from BaseLayer.backpropGradient
 | 
					        //The remaining code for this method: just copy & pasted from BaseLayerConfiguration.backpropGradient
 | 
				
			||||||
//        INDArray delta = epsilon.muli(activationDerivative);
 | 
					//        INDArray delta = epsilon.muli(activationDerivative);
 | 
				
			||||||
        if (maskArray != null) {
 | 
					        if (maskArray != null) {
 | 
				
			||||||
            activationDerivative.muliColumnVector(maskArray);
 | 
					            activationDerivative.muliColumnVector(maskArray);
 | 
				
			||||||
@ -128,7 +127,7 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
 | 
				
			|||||||
        ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad);
 | 
					        ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad);
 | 
				
			||||||
        ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad);
 | 
					        ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray epsilonNext = paramsTable.get(DefaultParamInitializer.WEIGHT_KEY).mmul(activationDerivative.transpose()).transpose();
 | 
					        INDArray epsilonNext = getParamTable().get(DefaultParamInitializer.WEIGHT_KEY).mmul(activationDerivative.transpose()).transpose();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return new Pair<>(ret, epsilonNext);
 | 
					        return new Pair<>(ret, epsilonNext);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -190,7 +190,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                //Check score
 | 
					                //Check score
 | 
				
			||||||
                double scoreDl4j = net.score();
 | 
					                double scoreDl4j = net.getScore();
 | 
				
			||||||
                double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore();
 | 
					                double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore();
 | 
				
			||||||
                assertEquals(scoreDl4j, scoreSd, 1e-6, testName);
 | 
					                assertEquals(scoreDl4j, scoreSd, 1e-6, testName);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -104,7 +104,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
					        MultiLayerNetwork net = new MultiLayerNetwork(conf);
 | 
				
			||||||
        net.init();
 | 
					        net.init();
 | 
				
			||||||
        net.addListeners(new ScoreIterationListener(1));
 | 
					        net.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Test net that hasn't been trained yet
 | 
					        //Test net that hasn't been trained yet
 | 
				
			||||||
        Exception e = new Exception();
 | 
					        Exception e = new Exception();
 | 
				
			||||||
@ -161,7 +161,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
 | 
				
			|||||||
        CrashReportingUtil.crashDumpOutputDirectory(dir);
 | 
					        CrashReportingUtil.crashDumpOutputDirectory(dir);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ComputationGraph cg = net.toComputationGraph();
 | 
					        ComputationGraph cg = net.toComputationGraph();
 | 
				
			||||||
        cg.setListeners(new ScoreIterationListener(1));
 | 
					        cg.addTrainingListeners(new ScoreIterationListener(1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        //Test net that hasn't been trained yet
 | 
					        //Test net that hasn't been trained yet
 | 
				
			||||||
        CrashReportingUtil.writeMemoryCrashDump(cg, e);
 | 
					        CrashReportingUtil.writeMemoryCrashDump(cg, e);
 | 
				
			||||||
 | 
				
			|||||||
@ -156,7 +156,7 @@ public class ModelGuesserTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
 | 
					        MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
 | 
				
			||||||
        assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
 | 
					        assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
 | 
				
			||||||
        assertEquals(net.params(), network.params());
 | 
					        assertEquals(net.getModelParams(), network.getModelParams());
 | 
				
			||||||
        assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
					        assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -173,7 +173,7 @@ public class ModelGuesserTest extends BaseDL4JTest {
 | 
				
			|||||||
            MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream);
 | 
					            MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream);
 | 
				
			||||||
            Assertions.assertNotNull(network);
 | 
					            Assertions.assertNotNull(network);
 | 
				
			||||||
            assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
 | 
					            assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
 | 
				
			||||||
            assertEquals(net.params(), network.params());
 | 
					            assertEquals(net.getModelParams(), network.getModelParams());
 | 
				
			||||||
            assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
					            assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -81,7 +81,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile);
 | 
					        MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
 | 
					        assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
 | 
				
			||||||
        assertEquals(net.params(), network.params());
 | 
					        assertEquals(net.getModelParams(), network.getModelParams());
 | 
				
			||||||
        assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
					        assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -125,7 +125,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
 | 
				
			|||||||
        MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis);
 | 
					        MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
 | 
					        assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
 | 
				
			||||||
        assertEquals(net.params(), network.params());
 | 
					        assertEquals(net.getModelParams(), network.getModelParams());
 | 
				
			||||||
        assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
					        assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -151,7 +151,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
 | 
				
			|||||||
        ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile);
 | 
					        ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
 | 
					        assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
 | 
				
			||||||
        assertEquals(cg.params(), network.params());
 | 
					        assertEquals(cg.getModelParams(), network.getModelParams());
 | 
				
			||||||
        assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
					        assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -177,7 +177,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
 | 
				
			|||||||
        ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);
 | 
					        ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
 | 
					        assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
 | 
				
			||||||
        assertEquals(cg.params(), network.params());
 | 
					        assertEquals(cg.getModelParams(), network.getModelParams());
 | 
				
			||||||
        assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
					        assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -346,7 +346,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        //Also test reading  both model and normalizer from stream (correctly)
 | 
					        //Also test reading  both model and normalizer from stream (correctly)
 | 
				
			||||||
        Pair<MultiLayerNetwork,Normalizer> pair = ModelSerializer.restoreMultiLayerNetworkAndNormalizer(new FileInputStream(tempFile), true);
 | 
					        Pair<MultiLayerNetwork,Normalizer> pair = ModelSerializer.restoreMultiLayerNetworkAndNormalizer(new FileInputStream(tempFile), true);
 | 
				
			||||||
        assertEquals(net.params(), pair.getFirst().params());
 | 
					        assertEquals(net.getModelParams(), pair.getFirst().getModelParams());
 | 
				
			||||||
        assertNotNull(pair.getSecond());
 | 
					        assertNotNull(pair.getSecond());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -395,7 +395,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        //Also test reading  both model and normalizer from stream (correctly)
 | 
					        //Also test reading  both model and normalizer from stream (correctly)
 | 
				
			||||||
        Pair<ComputationGraph,Normalizer> pair = ModelSerializer.restoreComputationGraphAndNormalizer(new FileInputStream(tempFile), true);
 | 
					        Pair<ComputationGraph,Normalizer> pair = ModelSerializer.restoreComputationGraphAndNormalizer(new FileInputStream(tempFile), true);
 | 
				
			||||||
        assertEquals(net.params(), pair.getFirst().params());
 | 
					        assertEquals(net.getModelParams(), pair.getFirst().getModelParams());
 | 
				
			||||||
        assertNotNull(pair.getSecond());
 | 
					        assertNotNull(pair.getSecond());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -496,6 +496,6 @@ public class ModelSerializerTest extends BaseDL4JTest {
 | 
				
			|||||||
        assertTrue(entries.contains("otherData.bin"));
 | 
					        assertTrue(entries.contains("otherData.bin"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ComputationGraph restoredNet = ModelSerializer.restoreComputationGraph(tempFile);
 | 
					        ComputationGraph restoredNet = ModelSerializer.restoreComputationGraph(tempFile);
 | 
				
			||||||
        assertEquals(net.params(), restoredNet.params());
 | 
					        assertEquals(net.getModelParams(), restoredNet.getModelParams());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -21,7 +21,6 @@
 | 
				
			|||||||
package org.deeplearning4j.nn.modelimport.keras.layers;
 | 
					package org.deeplearning4j.nn.modelimport.keras.layers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import org.deeplearning4j.nn.api.ParamInitializer;
 | 
					import org.deeplearning4j.nn.api.ParamInitializer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
 | 
					 | 
				
			||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
					import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
					import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.RNNFormat;
 | 
					import org.deeplearning4j.nn.conf.RNNFormat;
 | 
				
			||||||
@ -80,10 +79,6 @@ public class TFOpLayer extends LayerConfiguration {
 | 
				
			|||||||
    public  void setNIn(InputType inputType, boolean override){}
 | 
					    public  void setNIn(InputType inputType, boolean override){}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					 | 
				
			||||||
    public GradientNormalization getGradientNormalization(){return null;}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
 | 
					    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
 | 
				
			||||||
                                                                Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
 | 
					                                                                Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
 | 
				
			||||||
@ -91,14 +86,11 @@ public class TFOpLayer extends LayerConfiguration {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
 | 
					        LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
 | 
				
			||||||
        TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, lconf, networkDataType);
 | 
					        TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, lconf, networkDataType);
 | 
				
			||||||
        tfOpLayerImpl.setListeners(trainingListeners);
 | 
					        tfOpLayerImpl.addTrainingListeners(trainingListeners);
 | 
				
			||||||
        tfOpLayerImpl.setIndex(layerIndex);
 | 
					        tfOpLayerImpl.setIndex(layerIndex);
 | 
				
			||||||
        return tfOpLayerImpl;
 | 
					        return tfOpLayerImpl;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					 | 
				
			||||||
    public double getGradientNormalizationThreshold(){return 0.;}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public List<Regularization> getRegularizationByParam(String paramName){return null;}
 | 
					    public List<Regularization> getRegularizationByParam(String paramName){return null;}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -31,7 +31,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
				
			|||||||
import org.deeplearning4j.nn.conf.layers.*;
 | 
					import org.deeplearning4j.nn.conf.layers.*;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
 | 
					import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
 | 
					import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
 | 
					import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
 | 
					import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
 | 
				
			||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
 | 
					import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
 | 
				
			||||||
@ -448,8 +448,8 @@ public class KerasLSTM extends KerasLayer {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        FeedForwardLayer ffl;
 | 
					        FeedForwardLayer ffl;
 | 
				
			||||||
        if(this.layer instanceof BaseWrapperLayer){
 | 
					        if(this.layer instanceof BaseWrapperLayerConfiguration){
 | 
				
			||||||
            BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer;
 | 
					            BaseWrapperLayerConfiguration bwl = (BaseWrapperLayerConfiguration)this.layer;
 | 
				
			||||||
            ffl = (FeedForwardLayer)bwl.getUnderlying();
 | 
					            ffl = (FeedForwardLayer)bwl.getUnderlying();
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            ffl = (FeedForwardLayer) this.layer;
 | 
					            ffl = (FeedForwardLayer) this.layer;
 | 
				
			||||||
 | 
				
			|||||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user