Playing with some new code 2 - clean build/test

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-14 07:31:32 +02:00
parent a5dfdcb18f
commit 0f21ed9ec5
317 changed files with 4528 additions and 4191 deletions

View File

@ -47,6 +47,7 @@ import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
@ -54,9 +55,11 @@ import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreToChartListener; import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -181,6 +184,7 @@ public class App {
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold( 100 ) .gradientNormalizationThreshold( 100 )
//.weightInitFn( new WeightInitXavier() ) //this is internal //.weightInitFn( new WeightInitXavier() ) //this is internal
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5)))
.weightInit( WeightInit.XAVIER) .weightInit( WeightInit.XAVIER)
//.activationFn( new ActivationIdentity()) //this is internal //.activationFn( new ActivationIdentity()) //this is internal
.activation( Activation.IDENTITY ) .activation( Activation.IDENTITY )
@ -232,10 +236,10 @@ public class App {
copyParams(gen, dis, gan); copyParams(gen, dis, gan);
//gen.setListeners(new PerformanceListener(10, true)); gen.addTrainingListeners(new PerformanceListener(10, true));
//dis.setListeners(new PerformanceListener(10, true)); dis.addTrainingListeners(new PerformanceListener(10, true));
//gan.setListeners(new PerformanceListener(10, true)); gan.addTrainingListeners(new PerformanceListener(10, true));
gan.setListeners(new ScoreToChartListener("gan")); gan.addTrainingListeners(new ScoreToChartListener("gan"));
//dis.setListeners(new ScoreToChartListener("dis")); //dis.setListeners(new ScoreToChartListener("dis"));
gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)); gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1));
@ -322,23 +326,25 @@ public class App {
int genLayerCount = gen.getLayers().length; int genLayerCount = gen.getLayers().length;
for (int i = 0; i < gan.getLayers().length; i++) { for (int i = 0; i < gan.getLayers().length; i++) {
if (i < genLayerCount) { if (i < genLayerCount) {
gen.getLayer(i).setParams(gan.getLayer(i).params()); if(gan.getLayer(i).getParams() != null)
gen.getLayer(i).setParams(gan.getLayer(i).getParams());
} else { } else {
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params()); if(gan.getLayer(i).getParams() != null)
dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
} }
} }
} }
private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) { private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
for (int i = 0; i < gen.getLayers().length; i++) { for (int i = 0; i < gen.getLayers().length; i++) {
gen.getLayer(i).setParams(gan.getLayer(i).params()); gen.getLayer(i).setParams(gan.getLayer(i).getParams());
} }
} }
private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
int genLayerCount = gen.getLayers().length; int genLayerCount = gen.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) { for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).params()); gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams());
} }
} }

View File

@ -115,15 +115,15 @@ public class GAN {
public void setGeneratorListeners(BaseTrainingListener[] listeners) { public void setGeneratorListeners(BaseTrainingListener[] listeners) {
generator.setListeners(listeners); generator.addTrainingListeners(listeners);
} }
public void setDiscriminatorListeners(BaseTrainingListener[] listeners) { public void setDiscriminatorListeners(BaseTrainingListener[] listeners) {
discriminator.setListeners(listeners); discriminator.addTrainingListeners(listeners);
} }
public void setGanListeners(BaseTrainingListener[] listeners) { public void setGanListeners(BaseTrainingListener[] listeners) {
gan.setListeners(listeners); gan.addTrainingListeners(listeners);
} }
public void fit(DataSetIterator realData, int numEpochs) { public void fit(DataSetIterator realData, int numEpochs) {
@ -239,9 +239,9 @@ public class GAN {
int genLayerCount = generator.getLayers().length; int genLayerCount = generator.getLayers().length;
for (int i = 0; i < gan.getLayers().length; i++) { for (int i = 0; i < gan.getLayers().length; i++) {
if (i < genLayerCount) { if (i < genLayerCount) {
generator.getLayer(i).setParams(gan.getLayer(i).params()); generator.getLayer(i).setParams(gan.getLayer(i).getParams());
} else { } else {
discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params()); discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams());
} }
} }
} }
@ -252,7 +252,7 @@ public class GAN {
*/ */
private void updateGeneratorFromGan() { private void updateGeneratorFromGan() {
for (int i = 0; i < generator.getLayers().length; i++) { for (int i = 0; i < generator.getLayers().length; i++) {
generator.getLayer(i).setParams(gan.getLayer(i).params()); generator.getLayer(i).setParams(gan.getLayer(i).getParams());
} }
} }
@ -263,7 +263,7 @@ public class GAN {
private void updateGanWithDiscriminator() { private void updateGanWithDiscriminator() {
int genLayerCount = generator.getLayers().length; int genLayerCount = generator.getLayers().length;
for (int i = genLayerCount; i < gan.getLayers().length; i++) { for (int i = genLayerCount; i < gan.getLayers().length; i++) {
gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).params()); gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).getParams());
} }
} }

View File

@ -155,8 +155,8 @@ public class MnistDCGANExample {
.updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build()) .updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build())
.build(); .build();
gan.getGenerator().setListeners(new PerformanceListener(1, true)); gan.getGenerator().addTrainingListeners(new PerformanceListener(1, true));
gan.getDiscriminator().setListeners(new PerformanceListener(1, true)); gan.getDiscriminator().addTrainingListeners(new PerformanceListener(1, true));
Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);

View File

@ -205,7 +205,7 @@ public class TestServer2 {
//PostgresStatsStorage psqlStore = new PostgresStatsStorage(); //PostgresStatsStorage psqlStore = new PostgresStatsStorage();
int listenerFrequency = 2; int listenerFrequency = 2;
//net.setListeners(new StatsListener(psqlStore, listenerFrequency), new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200)); //net.setListeners(new StatsListener(psqlStore, listenerFrequency), new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
net.setListeners(new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200)); net.addTrainingListeners(new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
//Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized

View File

@ -290,7 +290,7 @@ public class IntegrationTestBaselineGenerator {
for (int i : layersToTrain) { for (int i : layersToTrain) {
mln.pretrainLayer(i, dsi); mln.pretrainLayer(i, dsi);
} }
paramsPostTraining = mln.params(); paramsPostTraining = mln.getModelParams();
} else if (modelType == ModelType.CG) { } else if (modelType == ModelType.CG) {
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG(); String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
Preconditions.checkState(layersToTrain != null, "ILayer names must not be null"); Preconditions.checkState(layersToTrain != null, "ILayer names must not be null");
@ -298,7 +298,7 @@ public class IntegrationTestBaselineGenerator {
for (String i : layersToTrain) { for (String i : layersToTrain) {
cg.pretrainLayer(i, iter); cg.pretrainLayer(i, iter);
} }
paramsPostTraining = cg.params(); paramsPostTraining = cg.getModelParams();
} else { } else {
throw new UnsupportedOperationException("SameDiff not supported for unsupervised training tests"); throw new UnsupportedOperationException("SameDiff not supported for unsupervised training tests");
} }
@ -314,7 +314,7 @@ public class IntegrationTestBaselineGenerator {
CollectScoresListener l = new CollectScoresListener(1); CollectScoresListener l = new CollectScoresListener(1);
if (modelType != ModelType.SAMEDIFF) if (modelType != ModelType.SAMEDIFF)
m.setListeners(l); m.addTrainingListeners(l);
History h = null; History h = null;
if (modelType == ModelType.MLN) { if (modelType == ModelType.MLN) {
@ -349,7 +349,7 @@ public class IntegrationTestBaselineGenerator {
} }
} else { } else {
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME); File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
IntegrationTestRunner.write(m.params(), p); IntegrationTestRunner.write(m.getModelParams(), p);
} }
} }
} }

View File

@ -191,7 +191,7 @@ public class IntegrationTestRunner {
MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true); MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
assertEquals(loaded.getNetConfiguration(), mln.getNetConfiguration(), "Configs not equal"); assertEquals(loaded.getNetConfiguration(), mln.getNetConfiguration(), "Configs not equal");
assertEquals( loaded.params(), mln.params(), "Params not equal"); assertEquals( loaded.getModelParams(), mln.getModelParams(), "Params not equal");
assertEquals( loaded.getParamTable(), mln.getParamTable(), "Param table not equal"); assertEquals( loaded.getParamTable(), mln.getParamTable(), "Param table not equal");
} else if(config instanceof ComputationGraphConfiguration ){ } else if(config instanceof ComputationGraphConfiguration ){
ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config; ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config;
@ -201,7 +201,7 @@ public class IntegrationTestRunner {
ComputationGraph loaded = ComputationGraph.load(savedModel, true); ComputationGraph loaded = ComputationGraph.load(savedModel, true);
assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" ); assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" );
assertEquals( loaded.params(), cg.params(), "Params not equal"); assertEquals( loaded.getModelParams(), cg.getModelParams(), "Params not equal");
assertEquals(loaded.getParamTable(), cg.getParamTable(), "Param table not equal"); assertEquals(loaded.getParamTable(), cg.getParamTable(), "Param table not equal");
} else if(config instanceof SameDiff){ } else if(config instanceof SameDiff){
sd = (SameDiff)config; sd = (SameDiff)config;
@ -389,7 +389,7 @@ public class IntegrationTestRunner {
for( int i : layersToTrain){ for( int i : layersToTrain){
mln.pretrainLayer(i, dsi); mln.pretrainLayer(i, dsi);
} }
paramsPostTraining = mln.params(); paramsPostTraining = mln.getModelParams();
layers = mln.getLayers(); layers = mln.getLayers();
} else if(modelType == ModelType.CG) { } else if(modelType == ModelType.CG) {
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG(); String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
@ -398,7 +398,7 @@ public class IntegrationTestRunner {
for( String i : layersToTrain){ for( String i : layersToTrain){
cg.pretrainLayer(i, iter); cg.pretrainLayer(i, iter);
} }
paramsPostTraining = cg.params(); paramsPostTraining = cg.getModelParams();
layers = cg.getLayers(); layers = cg.getLayers();
} else { } else {
throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models"); throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models");
@ -439,7 +439,7 @@ public class IntegrationTestRunner {
CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength); CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength);
CollectScoresListener l = new CollectScoresListener(1); CollectScoresListener l = new CollectScoresListener(1);
if(modelType != ModelType.SAMEDIFF) { if(modelType != ModelType.SAMEDIFF) {
m.setListeners(l); m.addTrainingListeners(l);
} }
int iterBefore; int iterBefore;
@ -519,10 +519,10 @@ public class IntegrationTestRunner {
if(modelType != ModelType.SAMEDIFF) { if(modelType != ModelType.SAMEDIFF) {
File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME); File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME);
INDArray paramsExp = read(p); INDArray paramsExp = read(p);
INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining()); INDArray z = exceedsRelError(m.getModelParams(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining());
int count = z.sumNumber().intValue(); int count = z.sumNumber().intValue();
if (count > 0) { if (count > 0) {
logFailedParams(20, "Parameter", layers, z, paramsExp, m.params()); logFailedParams(20, "Parameter", layers, z, paramsExp, m.getModelParams());
} }
assertEquals( 0, count, "Number of params exceeded max relative error"); assertEquals( 0, count, "Number of params exceeded max relative error");
} else { } else {
@ -607,12 +607,12 @@ public class IntegrationTestRunner {
ModelSerializer.writeModel(m, f, true); ModelSerializer.writeModel(m, f, true);
MultiLayerNetwork restored = MultiLayerNetwork.load(f, true); MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
assertEquals(mln.getNetConfiguration(), restored.getNetConfiguration()); assertEquals(mln.getNetConfiguration(), restored.getNetConfiguration());
assertEquals(mln.params(), restored.params()); assertEquals(mln.getModelParams(), restored.getModelParams());
} else if(modelType == ModelType.CG){ } else if(modelType == ModelType.CG){
ModelSerializer.writeModel(m, f, true); ModelSerializer.writeModel(m, f, true);
ComputationGraph restored = ComputationGraph.load(f, true); ComputationGraph restored = ComputationGraph.load(f, true);
assertEquals(cg.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration()); assertEquals(cg.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
assertEquals(cg.params(), restored.params()); assertEquals(cg.getModelParams(), restored.getModelParams());
} else { } else {
sd.save(f, true); sd.save(f, true);
SameDiff restored = SameDiff.load(f, true); SameDiff restored = SameDiff.load(f, true);

View File

@ -49,7 +49,7 @@ public class TestUtils {
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getNetConfiguration(), restored.getNetConfiguration()); assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
assertEquals(net.params(), restored.params()); assertEquals(net.getModelParams(), restored.getModelParams());
} catch (IOException e){ } catch (IOException e){
//Should never happen //Should never happen
throw new RuntimeException(e); throw new RuntimeException(e);
@ -74,7 +74,7 @@ public class TestUtils {
restored = ModelSerializer.restoreComputationGraph(bais, true); restored = ModelSerializer.restoreComputationGraph(bais, true);
assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration()); assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
assertEquals(net.params(), restored.params()); assertEquals(net.getModelParams(), restored.getModelParams());
} catch (IOException e){ } catch (IOException e){
//Should never happen //Should never happen
throw new RuntimeException(e); throw new RuntimeException(e);

View File

@ -26,7 +26,7 @@ import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
/** /** The ActivationIdentity activation function, just returns the input as is.
* f(x) = x * f(x) = x
*/ */
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)

View File

@ -195,7 +195,7 @@ public abstract class BaseWorkspaceMgr<T extends Enum<T>> implements WorkspaceMg
} }
@Override @Override
public INDArray validateArrayLocation(@NonNull T arrayType, @NonNull INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) { public INDArray validateArrayLocation(T arrayType, INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) {
validateConfig(arrayType); validateConfig(arrayType);
if(scopeOutOfWs.contains(arrayType)){ if(scopeOutOfWs.contains(arrayType)){

View File

@ -19,6 +19,7 @@ dependencies {
testImplementation projects.cavisNative.cavisNativeCommon testImplementation projects.cavisNative.cavisNativeCommon
testImplementation projects.cavisNd4j.cavisNd4jCommonTests testImplementation projects.cavisNd4j.cavisNd4jCommonTests
testImplementation projects.cavisDnn.cavisDnnCommonTests testImplementation projects.cavisDnn.cavisDnnCommonTests
testImplementation projects.cavisDnn.cavisDnnNn
implementation "org.apache.commons:commons-lang3" implementation "org.apache.commons:commons-lang3"

View File

@ -116,7 +116,7 @@ public class LayerHelperValidationUtil {
MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone()); MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
net2With.init(); net2With.init();
net2With.params().assign(netOrig.params()); net2With.getModelParams().assign(netOrig.getModelParams());
log.info("Removing all except for specified helpers from network copy 2: " + t.getAllowHelpersForClasses()); log.info("Removing all except for specified helpers from network copy 2: " + t.getAllowHelpersForClasses());
removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses()); removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
@ -124,7 +124,7 @@ public class LayerHelperValidationUtil {
Preconditions.checkNotNull(t.getFeatures(), "Features are not set (null)"); Preconditions.checkNotNull(t.getFeatures(), "Features are not set (null)");
for (boolean train : new boolean[]{false, true}) { for (boolean train : new boolean[]{false, true}) {
assertEquals(net1NoHelper.params(), net2With.params()); assertEquals(net1NoHelper.getModelParams(), net2With.getModelParams());
String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: "); String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: ");
List<INDArray> ff1; List<INDArray> ff1;
try { try {
@ -180,7 +180,7 @@ public class LayerHelperValidationUtil {
double maxRE = relError.maxNumber().doubleValue(); double maxRE = relError.maxNumber().doubleValue();
log.info(s + "Output, max relative error: " + maxRE); log.info(s + "Output, max relative error: " + maxRE);
assertEquals(net1NoHelper.params(), net2With.params()); //Check that forward pass does not modify params assertEquals(net1NoHelper.getModelParams(), net2With.getModelParams()); //Check that forward pass does not modify params
assertTrue(maxRE < t.getMaxRelError(), s + "Max RE: " + maxRE); assertTrue(maxRE < t.getMaxRelError(), s + "Max RE: " + maxRE);
} }
} }
@ -255,24 +255,24 @@ public class LayerHelperValidationUtil {
net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone()); net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
net2With.init(); net2With.init();
net2With.params().assign(netOrig.params()); net2With.getModelParams().assign(netOrig.getModelParams());
log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses()); log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());
removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses()); removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
CollectScoresListener listener = new CollectScoresListener(1); CollectScoresListener listener = new CollectScoresListener(1);
net2With.setListeners(listener); net2With.addTrainingListeners(listener);
net2With.fit(t.getData()); net2With.fit(t.getData());
for( int i=0; i<2; i++ ) { for( int i=0; i<2; i++ ) {
net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone()); net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
net2With.init(); net2With.init();
net2With.params().assign(netOrig.params()); net2With.getModelParams().assign(netOrig.getModelParams());
log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses()); log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());
removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses()); removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses());
CollectScoresListener listener2 = new CollectScoresListener(1); CollectScoresListener listener2 = new CollectScoresListener(1);
net2With.setListeners(listener2); net2With.addTrainingListeners(listener2);
net2With.fit(t.getData()); net2With.fit(t.getData());
DoubleArrayList listOrig = listener.getListScore(); DoubleArrayList listOrig = listener.getListScore();

View File

@ -25,7 +25,7 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer; import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
@ -67,7 +67,7 @@ public class TestUtils {
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getNetConfiguration(), restored.getNetConfiguration()); assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
assertEquals(net.params(), restored.params()); assertEquals(net.getModelParams(), restored.getModelParams());
} catch (IOException e){ } catch (IOException e){
//Should never happen //Should never happen
throw new RuntimeException(e); throw new RuntimeException(e);
@ -91,7 +91,7 @@ public class TestUtils {
restored = ModelSerializer.restoreComputationGraph(bais, true); restored = ModelSerializer.restoreComputationGraph(bais, true);
assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration()); assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
assertEquals(net.params(), restored.params()); assertEquals(net.getModelParams(), restored.getModelParams());
} catch (IOException e){ } catch (IOException e){
//Should never happen //Should never happen
throw new RuntimeException(e); throw new RuntimeException(e);
@ -205,8 +205,8 @@ public class TestUtils {
return null; return null;
} }
public static L2Regularization getL2Reg(BaseLayer baseLayer){ public static L2Regularization getL2Reg(BaseLayerConfiguration baseLayerConfiguration){
return getL2Reg(baseLayer.getRegularization()); return getL2Reg(baseLayerConfiguration.getRegularization());
} }
public static L2Regularization getL2Reg(List<Regularization> l){ public static L2Regularization getL2Reg(List<Regularization> l){
@ -218,7 +218,7 @@ public class TestUtils {
return null; return null;
} }
public static WeightDecay getWeightDecayReg(BaseLayer bl){ public static WeightDecay getWeightDecayReg(BaseLayerConfiguration bl){
return getWeightDecayReg(bl.getRegularization()); return getWeightDecayReg(bl.getRegularization());
} }
@ -231,7 +231,7 @@ public class TestUtils {
return null; return null;
} }
public static double getL1(BaseLayer layer) { public static double getL1(BaseLayerConfiguration layer) {
List<Regularization> l = layer.getRegularization(); List<Regularization> l = layer.getRegularization();
return getL1(l); return getL1(l);
} }
@ -246,7 +246,7 @@ public class TestUtils {
return l1Reg.getL1().valueAt(0,0); return l1Reg.getL1().valueAt(0,0);
} }
public static double getL2(BaseLayer layer) { public static double getL2(BaseLayerConfiguration layer) {
List<Regularization> l = layer.getRegularization(); List<Regularization> l = layer.getRegularization();
return getL2(l); return getL2(l);
} }
@ -269,7 +269,7 @@ public class TestUtils {
return getL2(layer.getRegularization()); return getL2(layer.getRegularization());
} }
public static double getWeightDecay(BaseLayer layer) { public static double getWeightDecay(BaseLayerConfiguration layer) {
return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0); return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0);
} }

View File

@ -32,7 +32,6 @@ import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
@ -183,7 +182,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
model.init(); model.init();
model.setListeners(new ScoreIterationListener(listenerFreq)); model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
model.fit(lfw.next()); model.fit(lfw.next());
@ -247,7 +246,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
//model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq))); //model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq); CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
model.setListeners(listener); model.addTrainingListeners(listener);
model.fit(cifar); model.fit(cifar);

View File

@ -226,7 +226,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
@ -255,7 +255,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter); MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
@ -304,7 +304,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
@ -343,7 +343,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
@ -386,7 +386,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
@ -430,7 +430,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
.build()) .build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
int nSamples = 100; int nSamples = 100;
//Generate the training data //Generate the training data
INDArray x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1); INDArray x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1);
@ -473,7 +473,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter); MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
@ -496,9 +496,9 @@ public class TestEarlyStopping extends BaseDL4JTest {
assertEquals(net.getnLayers(), mln.getnLayers()); assertEquals(net.getnLayers(), mln.getnLayers());
assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo()); assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo());
BaseLayer bl = (BaseLayer) net.getLayerConfiguration(); BaseLayerConfiguration bl = (BaseLayerConfiguration) net.getLayerConfiguration();
assertEquals(bl.getActivationFn().toString(), ((BaseLayer) mln.getLayerConfiguration()).getActivationFn().toString()); assertEquals(bl.getActivationFn().toString(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getActivationFn().toString());
assertEquals(bl.getIUpdater(), ((BaseLayer) mln.getLayerConfiguration()).getIUpdater()); assertEquals(bl.getIUpdater(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getIUpdater());
} }
@Test @Test
@ -511,7 +511,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
@ -792,7 +792,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
TestListener tl = new TestListener(); TestListener tl = new TestListener();
net.setListeners(tl); net.addTrainingListeners(tl);
DataSetIterator irisIter = new IrisDataSetIterator(50, 150); DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();

View File

@ -84,7 +84,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
@ -128,7 +128,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
@ -165,7 +165,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
@ -207,7 +207,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
@ -241,7 +241,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150); DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
@ -538,7 +538,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
TestEarlyStopping.TestListener tl = new TestEarlyStopping.TestListener(); TestEarlyStopping.TestListener tl = new TestEarlyStopping.TestListener();
net.setListeners(tl); net.addTrainingListeners(tl);
DataSetIterator irisIter = new IrisDataSetIterator(50, 150); DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();

View File

@ -84,7 +84,7 @@ public class EvalTest extends BaseDL4JTest {
// Instantiate model // Instantiate model
MultiLayerNetwork model = new MultiLayerNetwork(conf); MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init(); model.init();
model.addListeners(new ScoreIterationListener(1)); model.addTrainingListeners(new ScoreIterationListener(1));
// Train-test split // Train-test split
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -324,7 +324,7 @@ public class EvalTest extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
net2.setParams(net1.params()); net2.setParams(net1.getModelParams());
for(boolean useMask : new boolean[]{false, true}) { for(boolean useMask : new boolean[]{false, true}) {
@ -405,7 +405,7 @@ public class EvalTest extends BaseDL4JTest {
ComputationGraph net2 = new ComputationGraph(conf2); ComputationGraph net2 = new ComputationGraph(conf2);
net2.init(); net2.init();
net2.setParams(net1.params()); net2.setParams(net1.getModelParams());
for (boolean useMask : new boolean[]{false, true}) { for (boolean useMask : new boolean[]{false, true}) {
@ -492,7 +492,7 @@ public class EvalTest extends BaseDL4JTest {
DataSetIterator iter = new IrisDataSetIterator(30, 150); DataSetIterator iter = new IrisDataSetIterator(30, 150);
DataSetIterator iterTest = new IrisDataSetIterator(30, 150); DataSetIterator iterTest = new IrisDataSetIterator(30, 150);
net.setListeners(new EvaluativeListener(iterTest, 3)); net.addTrainingListeners(new EvaluativeListener(iterTest, 3));
for( int i=0; i<3; i++ ){ for( int i=0; i<3; i++ ){
net.fit(iter); net.fit(iter);

View File

@ -26,7 +26,6 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
@ -219,11 +218,11 @@ public class BNGradientCheckTest extends BaseDL4JTest {
mln.setInput(ds.getFeatures()); mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels()); mln.setLabels(ds.getLabels());
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreBefore = mln.score(); double scoreBefore = mln.getScore();
for (int k = 0; k < 20; k++) for (int k = 0; k < 20; k++)
mln.fit(ds); mln.fit(ds);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreAfter = mln.score(); double scoreAfter = mln.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
String msg = name String msg = name
+ " - score did not (sufficiently) decrease during learning - activationFn=" + " - score did not (sufficiently) decrease during learning - activationFn="
@ -323,11 +322,11 @@ public class BNGradientCheckTest extends BaseDL4JTest {
mln.setInput(ds.getFeatures()); mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels()); mln.setLabels(ds.getLabels());
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreBefore = mln.score(); double scoreBefore = mln.getScore();
for (int k = 0; k < 10; k++) for (int k = 0; k < 10; k++)
mln.fit(ds); mln.fit(ds);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreAfter = mln.score(); double scoreAfter = mln.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
String msg = name String msg = name
+ " - score did not (sufficiently) decrease during learning - activationFn=" + " - score did not (sufficiently) decrease during learning - activationFn="
@ -554,11 +553,11 @@ public class BNGradientCheckTest extends BaseDL4JTest {
net.setInput(0, ds.getFeatures()); net.setInput(0, ds.getFeatures());
net.setLabels(ds.getLabels()); net.setLabels(ds.getLabels());
net.computeGradientAndScore(); net.computeGradientAndScore();
double scoreBefore = net.score(); double scoreBefore = net.getScore();
for (int k = 0; k < 20; k++) for (int k = 0; k < 20; k++)
net.fit(ds); net.fit(ds);
net.computeGradientAndScore(); net.computeGradientAndScore();
double scoreAfter = net.score(); double scoreAfter = net.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
String msg = name String msg = name
+ " - score did not (sufficiently) decrease during learning - activationFn=" + " - score did not (sufficiently) decrease during learning - activationFn="

View File

@ -27,7 +27,6 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
@ -120,11 +119,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
mln.setInput(ds.getFeatures()); mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels()); mln.setLabels(ds.getLabels());
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreBefore = mln.score(); double scoreBefore = mln.getScore();
for (int j = 0; j < 10; j++) for (int j = 0; j < 10; j++)
mln.fit(ds); mln.fit(ds);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreAfter = mln.score(); double scoreAfter = mln.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" String msg = name + " - score did not (sufficiently) decrease during learning - activationFn="
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
@ -212,11 +211,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
mln.setInput(ds.getFeatures()); mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels()); mln.setLabels(ds.getLabels());
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreBefore = mln.score(); double scoreBefore = mln.getScore();
for (int j = 0; j < 10; j++) for (int j = 0; j < 10; j++)
mln.fit(ds); mln.fit(ds);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreAfter = mln.score(); double scoreAfter = mln.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
String msg = testName String msg = testName
+ "- score did not (sufficiently) decrease during learning - activationFn=" + "- score did not (sufficiently) decrease during learning - activationFn="

View File

@ -105,11 +105,11 @@ public class GradientCheckTests extends BaseDL4JTest {
mln.setInput(ds.getFeatures()); mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels()); mln.setLabels(ds.getLabels());
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreBefore = mln.score(); double scoreBefore = mln.getScore();
for (int j = 0; j < 10; j++) for (int j = 0; j < 10; j++)
mln.fit(ds); mln.fit(ds);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreAfter = mln.score(); double scoreAfter = mln.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
String msg = "testMinibatchApplication() - score did not (sufficiently) decrease during learning - activationFn=" String msg = "testMinibatchApplication() - score did not (sufficiently) decrease during learning - activationFn="
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
@ -184,11 +184,11 @@ public class GradientCheckTests extends BaseDL4JTest {
mln.setInput(ds.getFeatures()); mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels()); mln.setLabels(ds.getLabels());
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreBefore = mln.score(); double scoreBefore = mln.getScore();
for (int j = 0; j < 10; j++) for (int j = 0; j < 10; j++)
mln.fit(ds); mln.fit(ds);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreAfter = mln.score(); double scoreAfter = mln.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn=" String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
@ -278,11 +278,11 @@ public class GradientCheckTests extends BaseDL4JTest {
mln.setInput(ds.getFeatures()); mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels()); mln.setLabels(ds.getLabels());
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreBefore = mln.score(); double scoreBefore = mln.getScore();
for (int j = 0; j < 10; j++) for (int j = 0; j < 10; j++)
mln.fit(ds); mln.fit(ds);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreAfter = mln.score(); double scoreAfter = mln.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn=" String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
@ -452,11 +452,11 @@ public class GradientCheckTests extends BaseDL4JTest {
mln.setInput(ds.getFeatures()); mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels()); mln.setLabels(ds.getLabels());
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreBefore = mln.score(); double scoreBefore = mln.getScore();
for (int j = 0; j < 10; j++) for (int j = 0; j < 10; j++)
mln.fit(ds); mln.fit(ds);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreAfter = mln.score(); double scoreAfter = mln.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn=" msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
@ -523,13 +523,13 @@ public class GradientCheckTests extends BaseDL4JTest {
netGraph.setInputs(features); netGraph.setInputs(features);
netGraph.setLabels(labels); netGraph.setLabels(labels);
netGraph.computeGradientAndScore(); netGraph.computeGradientAndScore();
double scoreBefore = netGraph.score(); double scoreBefore = netGraph.getScore();
String msg; String msg;
for (int epoch = 0; epoch < 5; epoch++) for (int epoch = 0; epoch < 5; epoch++)
netGraph.fit(new INDArray[]{features}, new INDArray[]{labels}); netGraph.fit(new INDArray[]{features}, new INDArray[]{labels});
netGraph.computeGradientAndScore(); netGraph.computeGradientAndScore();
double scoreAfter = netGraph.score(); double scoreAfter = netGraph.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
msg = "elementWiseMultiplicationLayerTest() - score did not (sufficiently) decrease during learning - activationFn=" msg = "elementWiseMultiplicationLayerTest() - score did not (sufficiently) decrease during learning - activationFn="
+ "Id" + ", lossFn=" + "Cos-sim" + ", outputActivation=" + "Id" + "Id" + ", lossFn=" + "Cos-sim" + ", outputActivation=" + "Id"
@ -757,11 +757,11 @@ public class GradientCheckTests extends BaseDL4JTest {
mln.setInput(ds.getFeatures()); mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels()); mln.setLabels(ds.getLabels());
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreBefore = mln.score(); double scoreBefore = mln.getScore();
for (int j = 0; j < 10; j++) for (int j = 0; j < 10; j++)
mln.fit(ds); mln.fit(ds);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreAfter = mln.score(); double scoreAfter = mln.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn=" String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn="
+ afn + ", lossFn=" + lf + ", layerNorm=" + layerNorm + ", outputActivation=" + outputActivation + afn + ", lossFn=" + lf + ", layerNorm=" + layerNorm + ", outputActivation=" + outputActivation

View File

@ -666,7 +666,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
net.init(); net.init();
//Check params to avoid test flakiness on small or large params //Check params to avoid test flakiness on small or large params
INDArray params = net.params(); INDArray params = net.getModelParams();
for( int x=0; x<params.length(); x++ ){ for( int x=0; x<params.length(); x++ ){
while(Math.abs(params.getDouble(x)) < 0.01 || Math.abs(params.getDouble(x)) > 1.5){ while(Math.abs(params.getDouble(x)) < 0.01 || Math.abs(params.getDouble(x)) > 1.5){
double d = Nd4j.getRandom().nextDouble(); double d = Nd4j.getRandom().nextDouble();

View File

@ -37,10 +37,9 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
@ -254,8 +253,8 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
model2.init(); model2.init();
float[] p1 = model1.params().data().asFloat(); float[] p1 = model1.getModelParams().data().asFloat();
float[] p2 = model2.params().data().asFloat(); float[] p2 = model2.getModelParams().data().asFloat();
System.out.println(Arrays.toString(p1)); System.out.println(Arrays.toString(p1));
System.out.println(Arrays.toString(p2)); System.out.println(Arrays.toString(p2));
@ -266,20 +265,20 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
public void testTrainingListener() { public void testTrainingListener() {
MultiLayerNetwork model1 = new MultiLayerNetwork(getConf()); MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
model1.init(); model1.init();
model1.addListeners(new ScoreIterationListener(1)); model1.addTrainingListeners(new ScoreIterationListener(1));
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
model2.addListeners(new ScoreIterationListener(1)); model2.addTrainingListeners(new ScoreIterationListener(1));
model2.init(); model2.init();
Layer[] l1 = model1.getLayers(); Layer[] l1 = model1.getLayers();
for (int i = 0; i < l1.length; i++) { for (int i = 0; i < l1.length; i++) {
assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1); assertTrue(l1[i].getTrainingListeners() != null && l1[i].getTrainingListeners().size() == 1);
} }
Layer[] l2 = model2.getLayers(); Layer[] l2 = model2.getLayers();
for (int i = 0; i < l2.length; i++) { for (int i = 0; i < l2.length; i++) {
assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1); assertTrue(l2[i].getTrainingListeners() != null && l2[i].getTrainingListeners().size() == 1);
} }
} }
@ -384,10 +383,10 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()) .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build())
.inputType(InputType.convolutional(28, 28, 1)).build(); .inputType(InputType.convolutional(28, 28, 1)).build();
org.deeplearning4j.nn.conf.layers.BaseLayer l0 = (BaseLayer) conf.getConf(0).getLayer(); BaseLayerConfiguration l0 = (BaseLayerConfiguration) conf.getConf(0).getLayer();
org.deeplearning4j.nn.conf.layers.BaseLayer l1 = (BaseLayer) conf.getConf(1).getLayer(); BaseLayerConfiguration l1 = (BaseLayerConfiguration) conf.getConf(1).getLayer();
org.deeplearning4j.nn.conf.layers.BaseLayer l2 = (BaseLayer) conf.getConf(2).getLayer(); BaseLayerConfiguration l2 = (BaseLayerConfiguration) conf.getConf(2).getLayer();
org.deeplearning4j.nn.conf.layers.BaseLayer l3 = (BaseLayer) conf.getConf(3).getLayer(); BaseLayerConfiguration l3 = (BaseLayerConfiguration) conf.getConf(3).getLayer();
assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6); assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6); assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6);

View File

@ -25,7 +25,7 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
@ -100,7 +100,7 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
@Test @Test
public void testClone() { public void testClone() {
NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true); NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true);
BaseLayer bl = (BaseLayer) conf.getFlattenedLayerConfigurations().get(0); BaseLayerConfiguration bl = (BaseLayerConfiguration) conf.getFlattenedLayerConfigurations().get(0);
conf.setStepFunction(new DefaultStepFunction()); conf.setStepFunction(new DefaultStepFunction());
NeuralNetConfiguration conf2 = conf.clone(); NeuralNetConfiguration conf2 = conf.clone();

View File

@ -158,7 +158,7 @@ public class ShiftVertexTest extends BaseDL4JTest {
cg.setInput(0, input); cg.setInput(0, input);
cg.setLabel(0, target); cg.setLabel(0, target);
cg.computeGradientAndScore(); cg.computeGradientAndScore();
double score_dl4j = cg.score(); double score_dl4j = cg.getScore();
Map<String, INDArray> weights = cg.getParamTable(); Map<String, INDArray> weights = cg.getParamTable();
Gradient g = cg.gradient(); Gradient g = cg.gradient();
Map<String, INDArray> gradients = g.gradientForVariable(); Map<String, INDArray> gradients = g.gradientForVariable();

View File

@ -72,8 +72,8 @@ public class LayerConfigTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString()); assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getActivationFn().toString());
assertEquals("relu", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString()); assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getActivationFn().toString());
//With //With
conf = NeuralNetConfiguration.builder().activation(Activation.RELU) conf = NeuralNetConfiguration.builder().activation(Activation.RELU)
@ -83,8 +83,8 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString()); assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getActivationFn().toString());
assertEquals("tanh", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString()); assertEquals("tanh", ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getActivationFn().toString());
} }
@ -99,11 +99,11 @@ public class LayerConfigTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn());
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn());
assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0); assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
assertEquals(1, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0); assertEquals(1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
//With: //With:
final Distribution overriddenDistribution = new UniformDistribution(0, 1); final Distribution overriddenDistribution = new UniformDistribution(0, 1);
@ -117,11 +117,11 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn());
assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn());
assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0); assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
assertEquals(0, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0); assertEquals(0, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
} }
/* /*
@ -137,8 +137,8 @@ public class LayerConfigTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.3, ((BaseLayer) conf.getConf(0).getLayer()).getLearningRate(), 0.0); assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getLearningRate(), 0.0);
assertEquals(0.3, ((BaseLayer) conf.getConf(1).getLayer()).getLearningRate(), 0.0); assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getLearningRate(), 0.0);
//With: //With:
conf = NeuralNetConfiguration.builder().learningRate(0.3) conf = NeuralNetConfiguration.builder().learningRate(0.3)
@ -148,8 +148,8 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.3, ((BaseLayer) conf.getConf(0).getLayer()).getLearningRate(), 0.0); assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getLearningRate(), 0.0);
assertEquals(0.2, ((BaseLayer) conf.getConf(1).getLayer()).getLearningRate(), 0.0); assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getLearningRate(), 0.0);
//L1 and L2 without layerwise override: //L1 and L2 without layerwise override:
conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2) conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2)
@ -158,10 +158,10 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.1, ((BaseLayer) conf.getConf(0).getLayer()).getL1(), 0.0); assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL1(), 0.0);
assertEquals(0.1, ((BaseLayer) conf.getConf(1).getLayer()).getL1(), 0.0); assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL1(), 0.0);
assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0); assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL2(), 0.0);
assertEquals(0.2, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0); assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL2(), 0.0);
//L1 and L2 with layerwise override: //L1 and L2 with layerwise override:
conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2) conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2)
@ -170,10 +170,10 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.9, ((BaseLayer) conf.getConf(0).getLayer()).getL1(), 0.0); assertEquals(0.9, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL1(), 0.0);
assertEquals(0.1, ((BaseLayer) conf.getConf(1).getLayer()).getL1(), 0.0); assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL1(), 0.0);
assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0); assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL2(), 0.0);
assertEquals(0.8, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0); assertEquals(0.8, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL2(), 0.0);
}*/ }*/
@ -213,8 +213,8 @@ public class LayerConfigTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
Map<Integer, Double> testMomentumAfter2 = new HashMap<>(); Map<Integer, Double> testMomentumAfter2 = new HashMap<>();
testMomentumAfter2.put(0, 0.2); testMomentumAfter2.put(0, 0.2);
@ -227,8 +227,8 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
assertEquals(0.2, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); assertEquals(0.2, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
} }
@Test @Test
@ -239,10 +239,10 @@ public class LayerConfigTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta); assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta);
assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
assertEquals(0.5, ((AdaDelta)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0); assertEquals(0.5, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0);
assertEquals(0.01, ((AdaDelta)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); assertEquals(0.01, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON)) conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON))
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build())
@ -252,10 +252,10 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp); assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp);
assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
assertEquals(1.0, ((RmsProp) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0); assertEquals(1.0, ((RmsProp) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0);
assertEquals(0.5, ((AdaDelta) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); assertEquals(0.5, ((AdaDelta) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
} }
@ -270,10 +270,10 @@ public class LayerConfigTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0); assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0);
assertEquals(0.6, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0); assertEquals(0.6, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0);
assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0); assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0);
assertEquals(0.7, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta2(), 0.0); assertEquals(0.7, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta2(), 0.0);
} }
@Test @Test
@ -287,13 +287,11 @@ public class LayerConfigTest extends BaseDL4JTest {
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
BaseLayerConfiguration bconf = (BaseLayerConfiguration) conf.getConf(0).getLayer();
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization());
conf.getConf(0).getLayer().getGradientNormalization()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization());
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0);
conf.getConf(1).getLayer().getGradientNormalization()); assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0);
assertEquals(10, conf.getConf(0).getLayer().getGradientNormalizationThreshold(), 0.0);
assertEquals(10, conf.getConf(1).getLayer().getGradientNormalizationThreshold(), 0.0);
//With: //With:
conf = NeuralNetConfiguration.builder() conf = NeuralNetConfiguration.builder()
@ -308,11 +306,10 @@ public class LayerConfigTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization());
conf.getConf(0).getLayer().getGradientNormalization()); assertEquals(GradientNormalization.None, bconf.getGradientNormalization());
assertEquals(GradientNormalization.None, conf.getConf(1).getLayer().getGradientNormalization()); assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0);
assertEquals(10, conf.getConf(0).getLayer().getGradientNormalizationThreshold(), 0.0); assertEquals(2.5, bconf.getGradientNormalizationThreshold(), 0.0);
assertEquals(2.5, conf.getConf(1).getLayer().getGradientNormalizationThreshold(), 0.0);
} }

View File

@ -162,12 +162,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
BaseLayer layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration(); BaseLayerConfiguration layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3); assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3);
assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
BaseLayer layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); BaseLayerConfiguration layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3); assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3);
// Adam Updater // Adam Updater
@ -178,11 +178,11 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration(); layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3); assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3);
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3); assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3); assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn()); assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn());
@ -196,12 +196,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration(); layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3); assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3);
assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
assertNull(TestUtils.getL2Reg(layerConf.getRegularization())); assertNull(TestUtils.getL2Reg(layerConf.getRegularization()));
layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3); assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3);

View File

@ -29,7 +29,7 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
@ -75,9 +75,9 @@ public class TestWeightNoise extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(wn, ((BaseLayer) net.getLayer(0).getLayerConfiguration()).getWeightNoise()); assertEquals(wn, ((BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration()).getWeightNoise());
assertEquals(new DropConnect(0.25), ((BaseLayer) net.getLayer(1).getLayerConfiguration()).getWeightNoise()); assertEquals(new DropConnect(0.25), ((BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration()).getWeightNoise());
assertEquals(wn, ((BaseLayer) net.getLayer(2).getLayerConfiguration()).getWeightNoise()); assertEquals(wn, ((BaseLayerConfiguration) net.getLayer(2).getLayerConfiguration()).getWeightNoise());
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -95,9 +95,9 @@ public class TestWeightNoise extends BaseDL4JTest {
ComputationGraph graph = new ComputationGraph(conf2); ComputationGraph graph = new ComputationGraph(conf2);
graph.init(); graph.init();
assertEquals(wn, ((BaseLayer) graph.getLayer(0).getLayerConfiguration()).getWeightNoise()); assertEquals(wn, ((BaseLayerConfiguration) graph.getLayer(0).getLayerConfiguration()).getWeightNoise());
assertEquals(new DropConnect(0.25), ((BaseLayer) graph.getLayer(1).getLayerConfiguration()).getWeightNoise()); assertEquals(new DropConnect(0.25), ((BaseLayerConfiguration) graph.getLayer(1).getLayerConfiguration()).getWeightNoise());
assertEquals(wn, ((BaseLayer) graph.getLayer(2).getLayerConfiguration()).getWeightNoise()); assertEquals(wn, ((BaseLayerConfiguration) graph.getLayer(2).getLayerConfiguration()).getWeightNoise());
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);

View File

@ -124,7 +124,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.util.MaskLayer; import org.deeplearning4j.nn.conf.layers.util.MaskLayer;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer; import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
@ -260,8 +260,8 @@ public class DTypeTests extends BaseDL4JTest {
for (NeuralNetConfiguration nnc : conf.getNetConfigurations()) { for (NeuralNetConfiguration nnc : conf.getNetConfigurations()) {
LayerConfiguration l = nnc.getFlattenedLayerConfigurations().get(0); LayerConfiguration l = nnc.getFlattenedLayerConfigurations().get(0);
seenLayers.add(l.getClass()); seenLayers.add(l.getClass());
if (l instanceof BaseWrapperLayer) { if (l instanceof BaseWrapperLayerConfiguration) {
BaseWrapperLayer bwl = (BaseWrapperLayer) l; BaseWrapperLayerConfiguration bwl = (BaseWrapperLayerConfiguration) l;
seenLayers.add(bwl.getUnderlying().getClass()); seenLayers.add(bwl.getUnderlying().getClass());
} else if (l instanceof Bidirectional) { } else if (l instanceof Bidirectional) {
seenLayers.add(((Bidirectional) l).getFwd().getClass()); seenLayers.add(((Bidirectional) l).getFwd().getClass());
@ -321,17 +321,17 @@ public class DTypeTests extends BaseDL4JTest {
net.setInput(inD); net.setInput(inD);
net.setLabels(lD); net.setLabels(lD);
net.computeGradientAndScore(); net.computeGradientAndScore();
double scoreDouble = net.score(); double scoreDouble = net.getScore();
INDArray grads = net.getFlattenedGradients(); INDArray grads = net.getFlattenedGradients();
INDArray u = net.getUpdater().getStateViewArray(); INDArray u = net.getUpdater().getStateViewArray();
assertEquals(DataType.DOUBLE, net.params().dataType()); assertEquals(DataType.DOUBLE, net.getModelParams().dataType());
assertEquals(DataType.DOUBLE, grads.dataType()); assertEquals(DataType.DOUBLE, grads.dataType());
assertEquals(DataType.DOUBLE, u.dataType()); assertEquals(DataType.DOUBLE, u.dataType());
MultiLayerNetwork netFloat = net.convertDataType(DataType.FLOAT); MultiLayerNetwork netFloat = net.convertDataType(DataType.FLOAT);
netFloat.initGradientsView(); netFloat.initGradientsView();
assertEquals(DataType.FLOAT, netFloat.params().dataType()); assertEquals(DataType.FLOAT, netFloat.getModelParams().dataType());
assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType()); assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType());
assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType()); assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType());
INDArray inF = inD.castTo(DataType.FLOAT); INDArray inF = inD.castTo(DataType.FLOAT);
@ -340,7 +340,7 @@ public class DTypeTests extends BaseDL4JTest {
netFloat.setInput(inF); netFloat.setInput(inF);
netFloat.setLabels(lF); netFloat.setLabels(lF);
netFloat.computeGradientAndScore(); netFloat.computeGradientAndScore();
double scoreFloat = netFloat.score(); double scoreFloat = netFloat.getScore();
INDArray gradsFloat = netFloat.getFlattenedGradients(); INDArray gradsFloat = netFloat.getFlattenedGradients();
INDArray uFloat = netFloat.getUpdater().getStateViewArray(); INDArray uFloat = netFloat.getUpdater().getStateViewArray();
@ -352,7 +352,7 @@ public class DTypeTests extends BaseDL4JTest {
MultiLayerNetwork netFP16 = net.convertDataType(DataType.HALF); MultiLayerNetwork netFP16 = net.convertDataType(DataType.HALF);
netFP16.initGradientsView(); netFP16.initGradientsView();
assertEquals(DataType.HALF, netFP16.params().dataType()); assertEquals(DataType.HALF, netFP16.getModelParams().dataType());
assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType()); assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType());
assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType()); assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType());
@ -362,7 +362,7 @@ public class DTypeTests extends BaseDL4JTest {
netFP16.setInput(inH); netFP16.setInput(inH);
netFP16.setLabels(lH); netFP16.setLabels(lH);
netFP16.computeGradientAndScore(); netFP16.computeGradientAndScore();
double scoreHalf = netFP16.score(); double scoreHalf = netFP16.getScore();
INDArray gradsHalf = netFP16.getFlattenedGradients(); INDArray gradsHalf = netFP16.getFlattenedGradients();
INDArray uHalf = netFP16.getUpdater().getStateViewArray(); INDArray uHalf = netFP16.getUpdater().getStateViewArray();
@ -406,17 +406,17 @@ public class DTypeTests extends BaseDL4JTest {
net.setInput(0, inD); net.setInput(0, inD);
net.setLabels(lD); net.setLabels(lD);
net.computeGradientAndScore(); net.computeGradientAndScore();
double scoreDouble = net.score(); double scoreDouble = net.getScore();
INDArray grads = net.getFlattenedGradients(); INDArray grads = net.getFlattenedGradients();
INDArray u = net.getUpdater().getStateViewArray(); INDArray u = net.getUpdater().getStateViewArray();
assertEquals(DataType.DOUBLE, net.params().dataType()); assertEquals(DataType.DOUBLE, net.getModelParams().dataType());
assertEquals(DataType.DOUBLE, grads.dataType()); assertEquals(DataType.DOUBLE, grads.dataType());
assertEquals(DataType.DOUBLE, u.dataType()); assertEquals(DataType.DOUBLE, u.dataType());
ComputationGraph netFloat = net.convertDataType(DataType.FLOAT); ComputationGraph netFloat = net.convertDataType(DataType.FLOAT);
netFloat.initGradientsView(); netFloat.initGradientsView();
assertEquals(DataType.FLOAT, netFloat.params().dataType()); assertEquals(DataType.FLOAT, netFloat.getModelParams().dataType());
assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType()); assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType());
assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType()); assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType());
INDArray inF = inD.castTo(DataType.FLOAT); INDArray inF = inD.castTo(DataType.FLOAT);
@ -425,7 +425,7 @@ public class DTypeTests extends BaseDL4JTest {
netFloat.setInput(0, inF); netFloat.setInput(0, inF);
netFloat.setLabels(lF); netFloat.setLabels(lF);
netFloat.computeGradientAndScore(); netFloat.computeGradientAndScore();
double scoreFloat = netFloat.score(); double scoreFloat = netFloat.getScore();
INDArray gradsFloat = netFloat.getFlattenedGradients(); INDArray gradsFloat = netFloat.getFlattenedGradients();
INDArray uFloat = netFloat.getUpdater().getStateViewArray(); INDArray uFloat = netFloat.getUpdater().getStateViewArray();
@ -437,7 +437,7 @@ public class DTypeTests extends BaseDL4JTest {
ComputationGraph netFP16 = net.convertDataType(DataType.HALF); ComputationGraph netFP16 = net.convertDataType(DataType.HALF);
netFP16.initGradientsView(); netFP16.initGradientsView();
assertEquals(DataType.HALF, netFP16.params().dataType()); assertEquals(DataType.HALF, netFP16.getModelParams().dataType());
assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType()); assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType());
assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType()); assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType());
@ -447,7 +447,7 @@ public class DTypeTests extends BaseDL4JTest {
netFP16.setInput(0, inH); netFP16.setInput(0, inH);
netFP16.setLabels(lH); netFP16.setLabels(lH);
netFP16.computeGradientAndScore(); netFP16.computeGradientAndScore();
double scoreHalf = netFP16.score(); double scoreHalf = netFP16.getScore();
INDArray gradsHalf = netFP16.getFlattenedGradients(); INDArray gradsHalf = netFP16.getFlattenedGradients();
INDArray uHalf = netFP16.getUpdater().getStateViewArray(); INDArray uHalf = netFP16.getUpdater().getStateViewArray();
@ -536,7 +536,7 @@ public class DTypeTests extends BaseDL4JTest {
net.init(); net.init();
net.initGradientsView(); net.initGradientsView();
assertEquals(networkDtype, net.params().dataType(), msg); assertEquals(networkDtype, net.getModelParams().dataType(), msg);
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
@ -641,7 +641,7 @@ public class DTypeTests extends BaseDL4JTest {
net.init(); net.init();
net.initGradientsView(); net.initGradientsView();
assertEquals(networkDtype, net.params().dataType(), msg); assertEquals(networkDtype, net.getModelParams().dataType(), msg);
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
@ -754,7 +754,7 @@ public class DTypeTests extends BaseDL4JTest {
net.init(); net.init();
net.initGradientsView(); net.initGradientsView();
assertEquals(networkDtype, net.params().dataType(), msg); assertEquals(networkDtype, net.getModelParams().dataType(), msg);
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
@ -827,7 +827,7 @@ public class DTypeTests extends BaseDL4JTest {
net.init(); net.init();
net.initGradientsView(); net.initGradientsView();
assertEquals(networkDtype, net.params().dataType(), msg); assertEquals(networkDtype, net.getModelParams().dataType(), msg);
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);
@ -916,7 +916,7 @@ public class DTypeTests extends BaseDL4JTest {
net.init(); net.init();
net.initGradientsView(); net.initGradientsView();
assertEquals(networkDtype, net.params().dataType(), msg); assertEquals(networkDtype, net.getModelParams().dataType(), msg);
assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg);
assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg);

View File

@ -520,9 +520,9 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
INDArray inputLong = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); INDArray inputLong = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
INDArray labelsLong = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength); INDArray labelsLong = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength);
INDArray initialParams = graph.params().dup(); INDArray initialParams = graph.getModelParams().dup();
graph.fit(new INDArray[] {inputLong}, new INDArray[] {labelsLong}); graph.fit(new INDArray[] {inputLong}, new INDArray[] {labelsLong});
INDArray afterParams = graph.params(); INDArray afterParams = graph.getModelParams();
assertNotEquals(initialParams, afterParams); assertNotEquals(initialParams, afterParams);
} }

View File

@ -117,7 +117,7 @@ public class TestCompGraphCNN extends BaseDL4JTest {
boolean orderOK = Arrays.equals(expOrder1, order) || Arrays.equals(expOrder2, order); boolean orderOK = Arrays.equals(expOrder1, order) || Arrays.equals(expOrder2, order);
assertTrue(orderOK); assertTrue(orderOK);
INDArray params = graph.params(); INDArray params = graph.getModelParams();
assertNotNull(params); assertNotNull(params);
// confirm param shape is what is expected // confirm param shape is what is expected
@ -129,7 +129,7 @@ public class TestCompGraphCNN extends BaseDL4JTest {
// params are set // params are set
graph.setParams(arr); graph.setParams(arr);
params = graph.params(); params = graph.getModelParams();
assertEquals(arr, params); assertEquals(arr, params);
//Number of inputs and outputs: //Number of inputs and outputs:

View File

@ -108,7 +108,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
} }
} }
int count = Nd4j.getExecutioner().exec(new MatchCondition(cg.params(), Conditions.isNan())).getInt(0); int count = Nd4j.getExecutioner().exec(new MatchCondition(cg.getModelParams(), Conditions.isNan())).getInt(0);
assertEquals(0, count); assertEquals(0, count);
@ -125,7 +125,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
} }
} }
count = Nd4j.getExecutioner().exec(new MatchCondition(cg.params(), Conditions.isNan())).getInt(0); count = Nd4j.getExecutioner().exec(new MatchCondition(cg.getModelParams(), Conditions.isNan())).getInt(0);
assertEquals(0, count); assertEquals(0, count);
} }
} }
@ -176,7 +176,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
cg.pretrainLayer("0", ds); cg.pretrainLayer("0", ds);
assertEquals(net.params(), cg.params()); assertEquals(net.getModelParams(), cg.getModelParams());
} }
} }

View File

@ -159,7 +159,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
DataSet ds = iris.next(); DataSet ds = iris.next();
graph.setInput(0, ds.getFeatures()); graph.setInput(0, ds.getFeatures());
net.setParams(graph.params()); net.setParams(graph.getModelParams());
Map<String, INDArray> activations = graph.feedForward(false); Map<String, INDArray> activations = graph.feedForward(false);
List<INDArray> feedForward = net.feedForward(ds.getFeatures()); List<INDArray> feedForward = net.feedForward(ds.getFeatures());
@ -184,7 +184,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
int[] expOrder = new int[]{0, 1, 2}; int[] expOrder = new int[]{0, 1, 2};
assertArrayEquals(expOrder, order); //Only one valid order: 0 (input) -> 1 (firstlayer) -> 2 (outputlayer) assertArrayEquals(expOrder, order); //Only one valid order: 0 (input) -> 1 (firstlayer) -> 2 (outputlayer)
INDArray params = graph.params(); INDArray params = graph.getModelParams();
assertNotNull(params); assertNotNull(params);
int nParams = getNumParams(); int nParams = getNumParams();
@ -194,7 +194,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
assertEquals(nParams, arr.length()); assertEquals(nParams, arr.length());
graph.setParams(arr); graph.setParams(arr);
params = graph.params(); params = graph.getModelParams();
assertEquals(arr, params); assertEquals(arr, params);
//Number of inputs and outputs: //Number of inputs and outputs:
@ -315,8 +315,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
graph.fit(iris); graph.fit(iris);
//Check that parameters are equal for both models after fitting: //Check that parameters are equal for both models after fitting:
INDArray paramsMLN = net.params(); INDArray paramsMLN = net.getModelParams();
INDArray paramsGraph = graph.params(); INDArray paramsGraph = graph.getModelParams();
assertNotEquals(params, paramsGraph); assertNotEquals(params, paramsGraph);
assertEquals(paramsMLN, paramsGraph); assertEquals(paramsMLN, paramsGraph);
@ -636,7 +636,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.init(); net.init();
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(10, 150); DataSetIterator iter = new IrisDataSetIterator(10, 150);
net.pretrain(iter); net.pretrain(iter);
@ -675,7 +675,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
ComputationGraph netNoReg = new ComputationGraph(confNoReg); ComputationGraph netNoReg = new ComputationGraph(confNoReg);
netNoReg.init(); netNoReg.init();
netNoReg.setParams(net.params().dup()); netNoReg.setParams(net.getModelParams().dup());
//Score single example, and compare to scoreExamples: //Score single example, and compare to scoreExamples:
INDArray input = Nd4j.rand(3, nIn); INDArray input = Nd4j.rand(3, nIn);
@ -878,13 +878,13 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
net.setParam("first_b", Nd4j.ones(1, 5)); net.setParam("first_b", Nd4j.ones(1, 5));
net.setParam("output_W", Nd4j.ones(5, 3)); net.setParam("output_W", Nd4j.ones(5, 3));
net.setParam("output_b", Nd4j.ones(1, 3)); net.setParam("output_b", Nd4j.ones(1, 3));
INDArray actualParams = net.params(); INDArray actualParams = net.getModelParams();
// Confirm params // Confirm params
assertEquals(Nd4j.ones(1, 43), actualParams); assertEquals(Nd4j.ones(1, 43), actualParams);
net.update(expectedGradient); net.update(expectedGradient);
actualParams = net.params(); actualParams = net.getModelParams();
assertEquals(Nd4j.ones(1, 43).addi(1), actualParams); assertEquals(Nd4j.ones(1, 43).addi(1), actualParams);
} }
@ -1638,7 +1638,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
conf3.setTopologicalOrderStr(null); conf3.setTopologicalOrderStr(null);
ComputationGraph cg3 = new ComputationGraph(conf3); ComputationGraph cg3 = new ComputationGraph(conf3);
cg3.init(); cg3.init();
cg3.setParams(cg2.params()); cg3.setParams(cg2.getModelParams());
int[] order3 = cg3.topologicalSortOrder(); int[] order3 = cg3.topologicalSortOrder();
List<String> strOrder3 = cg.getComputationGraphConfiguration().getTopologicalOrderStr(); List<String> strOrder3 = cg.getComputationGraphConfiguration().getTopologicalOrderStr();
@ -1712,7 +1712,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
exp.add(ComputationGraph.class); exp.add(ComputationGraph.class);
MultiLayerTest.CheckModelsListener listener = new MultiLayerTest.CheckModelsListener(); MultiLayerTest.CheckModelsListener listener = new MultiLayerTest.CheckModelsListener();
net.setListeners(listener); net.addTrainingListeners(listener);
INDArray f = Nd4j.create(1,10); INDArray f = Nd4j.create(1,10);
INDArray l = Nd4j.create(1,10); INDArray l = Nd4j.create(1,10);
@ -1874,7 +1874,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
ComputationGraph cg = new ComputationGraph(conf); ComputationGraph cg = new ComputationGraph(conf);
cg.init(); cg.init();
cg.params().assign(Nd4j.linspace(1, 220, 220).reshape(1, -11)); cg.getModelParams().assign(Nd4j.linspace(1, 220, 220).reshape(1, -11));
INDArray p0w = cg.getParam("layer_zero_W"); INDArray p0w = cg.getParam("layer_zero_W");
assertEquals(Nd4j.linspace(1, 100, 100).reshape('f', 10, 10), p0w); assertEquals(Nd4j.linspace(1, 100, 100).reshape('f', 10, 10), p0w);

View File

@ -56,7 +56,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.init(); net.init();
INDArray params = net.params(); INDArray params = net.getModelParams();
ComputationGraph net2 = new ComputationGraph(conf); ComputationGraph net2 = new ComputationGraph(conf);
@ -65,11 +65,11 @@ public class TestSetGetParameters extends BaseDL4JTest {
ComputationGraph net3 = new ComputationGraph(conf); ComputationGraph net3 = new ComputationGraph(conf);
net3.init(params, false); net3.init(params, false);
assertEquals(params, net2.params()); assertEquals(params, net2.getModelParams());
assertEquals(params, net3.params()); assertEquals(params, net3.getModelParams());
assertNotSame(params, net2.params()); //Different objects due to clone assertNotSame(params, net2.getModelParams()); //Different objects due to clone
assertSame(params, net3.params()); //Same object due to clone assertSame(params, net3.getModelParams()); //Same object due to clone
Map<String, INDArray> paramsMap = net.getParamTable(); Map<String, INDArray> paramsMap = net.getParamTable();

View File

@ -103,14 +103,14 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
net.setInput(0, in1); net.setInput(0, in1);
net.setLabel(0, labels1); net.setLabel(0, labels1);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score1 = net.score(); double score1 = net.getScore();
Gradient g1 = net.gradient(); Gradient g1 = net.gradient();
net.setInput(0, in2); net.setInput(0, in2);
net.setLabel(0, labels2); net.setLabel(0, labels2);
net.setLayerMaskArrays(null, new INDArray[] {labelMask}); net.setLayerMaskArrays(null, new INDArray[] {labelMask});
net.computeGradientAndScore(); net.computeGradientAndScore();
double score2 = net.score(); double score2 = net.getScore();
Gradient g2 = net.gradient(); Gradient g2 = net.gradient();
//Scores and gradients should be identical for two cases (given mask array) //Scores and gradients should be identical for two cases (given mask array)
@ -134,7 +134,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
} }
net.setLabel(0, labels2); net.setLabel(0, labels2);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score2a = net.score(); double score2a = net.getScore();
Gradient g2a = net.gradient(); Gradient g2a = net.gradient();
assertEquals(score2, score2a, 1e-6); assertEquals(score2, score2a, 1e-6);
for (String s : g2map.keySet()) { for (String s : g2map.keySet()) {
@ -200,7 +200,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
net.setInput(0, in1); net.setInput(0, in1);
net.setLabel(0, labels1); net.setLabel(0, labels1);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score1 = net.score(); double score1 = net.getScore();
Gradient g1 = net.gradient(); Gradient g1 = net.gradient();
Map<String, INDArray> map = g1.gradientForVariable(); Map<String, INDArray> map = g1.gradientForVariable();
for (String s : map.keySet()) { for (String s : map.keySet()) {
@ -211,7 +211,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
net.setLabel(0, labels2); net.setLabel(0, labels2);
net.setLayerMaskArrays(new INDArray[] {inputMask}, null); net.setLayerMaskArrays(new INDArray[] {inputMask}, null);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score2 = net.score(); double score2 = net.getScore();
Gradient g2 = net.gradient(); Gradient g2 = net.gradient();
Map<String, INDArray> activations2 = net.feedForward(); Map<String, INDArray> activations2 = net.feedForward();
@ -236,7 +236,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
net.setInput(0, in2); net.setInput(0, in2);
net.setLayerMaskArrays(new INDArray[]{inputMask}, null); net.setLayerMaskArrays(new INDArray[]{inputMask}, null);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score2a = net.score(); double score2a = net.getScore();
Gradient g2a = net.gradient(); Gradient g2a = net.gradient();
assertEquals(score2, score2a, 1e-12); assertEquals(score2, score2a, 1e-12);
for (String s : g2.gradientForVariable().keySet()) { for (String s : g2.gradientForVariable().keySet()) {
@ -330,7 +330,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
net.setLabel(0, labels); net.setLabel(0, labels);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score = net.score(); double score = net.getScore();
assertEquals(expScore, score, 0.1, msg); assertEquals(expScore, score, 0.1, msg);
} }

View File

@ -40,7 +40,7 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
public class BaseLayerTest extends BaseDL4JTest { public class BaseLayerConfigurationTest extends BaseDL4JTest {
protected INDArray weight = Nd4j.create(new double[] {0.10, -0.20, -0.15, 0.05}, new int[] {2, 2}); protected INDArray weight = Nd4j.create(new double[] {0.10, -0.20, -0.15, 0.05}, new int[] {2, 2});
protected INDArray bias = Nd4j.create(new double[] {0.5, 0.5}, new int[] {1, 2}); protected INDArray bias = Nd4j.create(new double[] {0.5, 0.5}, new int[] {1, 2});

View File

@ -56,10 +56,10 @@ public class CacheModeTest extends BaseDL4JTest {
INDArray out2 = net2.output(in); INDArray out2 = net2.output(in);
assertEquals(out1, out2); assertEquals(out1, out2);
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
net1.fit(in, labels); net1.fit(in, labels);
net2.fit(in, labels); net2.fit(in, labels);
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
} }
private static NeuralNetConfiguration getConf(CacheMode cacheMode){ private static NeuralNetConfiguration getConf(CacheMode cacheMode){
@ -99,10 +99,10 @@ public class CacheModeTest extends BaseDL4JTest {
INDArray out2 = net2.output(in); INDArray out2 = net2.output(in);
assertEquals(out1, out2); assertEquals(out1, out2);
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
net1.fit(in, labels); net1.fit(in, labels);
net2.fit(in, labels); net2.fit(in, labels);
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
} }
} }
@ -145,10 +145,10 @@ public class CacheModeTest extends BaseDL4JTest {
INDArray out2 = net2.outputSingle(in); INDArray out2 = net2.outputSingle(in);
assertEquals(out1, out2); assertEquals(out1, out2);
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
net1.fit(new DataSet(in, labels)); net1.fit(new DataSet(in, labels));
net2.fit(new DataSet(in, labels)); net2.fit(new DataSet(in, labels));
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
} }
private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){ private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){

View File

@ -121,7 +121,7 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest {
graph.setInput(0, input); graph.setInput(0, input);
graph.setLabel(0, labels); graph.setLabel(0, labels);
graph.computeGradientAndScore(); graph.computeGradientAndScore();
results[i] = graph.score(); results[i] = graph.getScore();
} }
assertNotEquals(results[0], results[1]); assertNotEquals(results[0], results[1]);
@ -137,7 +137,7 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest {
ComputationGraph net = getCNNMnistConfig(); ComputationGraph net = getCNNMnistConfig();
net.init(); net.init();
net.setListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
for (int i = 0; i < 50; i++) { for (int i = 0; i < 50; i++) {
net.fit(mnistTrain.next()); net.fit(mnistTrain.next());

View File

@ -265,7 +265,7 @@ public class DropoutLayerTest extends BaseDL4JTest {
MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate); MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
netSeparate.init(); netSeparate.init();
assertEquals(netIntegrated.params(), netSeparate.params()); assertEquals(netIntegrated.getModelParams(), netSeparate.getModelParams());
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
netIntegrated.fit(next); netIntegrated.fit(next);
@ -273,7 +273,7 @@ public class DropoutLayerTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
netSeparate.fit(next); netSeparate.fit(next);
assertEquals(netIntegrated.params(), netSeparate.params()); assertEquals(netIntegrated.getModelParams(), netSeparate.getModelParams());
// check parameters // check parameters
assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W")); assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W"));

View File

@ -80,7 +80,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
.setFeatureExtractor(1).build(); .setFeatureExtractor(1).build();
INDArray paramsLastTwoLayers = INDArray paramsLastTwoLayers =
Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams());
MultiLayerNetwork notFrozen = new MultiLayerNetwork( MultiLayerNetwork notFrozen = new MultiLayerNetwork(
(NeuralNetConfiguration) overallConf.clone() (NeuralNetConfiguration) overallConf.clone()
.layer(0, new Builder().nIn(2).nOut(3).build()) .layer(0, new Builder().nIn(2).nOut(3).build())
@ -102,9 +102,9 @@ public class FrozenLayerTest extends BaseDL4JTest {
modelNow.fit(randomData); modelNow.fit(randomData);
} }
INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), modelToFineTune.getLayer(1).getParams(),
notFrozen.params()); notFrozen.getModelParams());
INDArray act = modelNow.params(); INDArray act = modelNow.getModelParams();
assertEquals(expected, act); assertEquals(expected, act);
} }
@ -136,7 +136,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
assertEquals(modelNow.getNetConfiguration().toJson(), clonedModel.getNetConfiguration().toJson()); assertEquals(modelNow.getNetConfiguration().toJson(), clonedModel.getNetConfiguration().toJson());
//Check params //Check params
assertEquals(modelNow.params(), clonedModel.params()); assertEquals(modelNow.getModelParams(), clonedModel.getModelParams());
MultiLayerNetwork notFrozen = new MultiLayerNetwork( MultiLayerNetwork notFrozen = new MultiLayerNetwork(
(NeuralNetConfiguration) overallConf.layer(0, new Builder().nIn(2).nOut(3).build()) (NeuralNetConfiguration) overallConf.layer(0, new Builder().nIn(2).nOut(3).build())
@ -145,7 +145,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
.activation(Activation.SOFTMAX).nIn(3).nOut(3) .activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build()) .build())
.build(), .build(),
Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params())); Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams()));
int i = 0; int i = 0;
while (i < 5) { while (i < 5) {
@ -155,10 +155,10 @@ public class FrozenLayerTest extends BaseDL4JTest {
i++; i++;
} }
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(),
modelToFineTune.getLayer(1).params(), notFrozen.params()); modelToFineTune.getLayer(1).getParams(), notFrozen.getModelParams());
assertEquals(expectedParams, modelNow.params()); assertEquals(expectedParams, modelNow.getModelParams());
assertEquals(expectedParams, clonedModel.params()); assertEquals(expectedParams, clonedModel.getModelParams());
} }
@ -199,8 +199,8 @@ public class FrozenLayerTest extends BaseDL4JTest {
.setOutputs("layer1").build()); .setOutputs("layer1").build());
notFrozen.init(); notFrozen.init();
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").getParams(),
modelToFineTune.getLayer("layer3").params())); modelToFineTune.getLayer("layer3").getParams()));
int i = 0; int i = 0;
while (i < 5) { while (i < 5) {
@ -209,8 +209,8 @@ public class FrozenLayerTest extends BaseDL4JTest {
i++; i++;
} }
assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").getParams(),
modelToFineTune.getLayer("layer1").params(), notFrozen.params()), modelNow.params()); modelToFineTune.getLayer("layer1").getParams(), notFrozen.getModelParams()), modelNow.getModelParams());
} }
@Test @Test
@ -244,7 +244,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
assertEquals(clonedModel.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson()); assertEquals(clonedModel.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
//Check params //Check params
assertEquals(modelNow.params(), clonedModel.params()); assertEquals(modelNow.getModelParams(), clonedModel.getModelParams());
ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In") .addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In")
@ -256,8 +256,8 @@ public class FrozenLayerTest extends BaseDL4JTest {
"layer0") "layer0")
.setOutputs("layer1").build()); .setOutputs("layer1").build());
notFrozen.init(); notFrozen.init();
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").getParams(),
modelToFineTune.getLayer("layer3").params())); modelToFineTune.getLayer("layer3").getParams()));
int i = 0; int i = 0;
@ -268,10 +268,10 @@ public class FrozenLayerTest extends BaseDL4JTest {
i++; i++;
} }
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").getParams(),
modelToFineTune.getLayer("layer1").params(), notFrozen.params()); modelToFineTune.getLayer("layer1").getParams(), notFrozen.getModelParams());
assertEquals(expectedParams, modelNow.params()); assertEquals(expectedParams, modelNow.getModelParams());
assertEquals(expectedParams, clonedModel.params()); assertEquals(expectedParams, clonedModel.getModelParams());
} }
@ -305,7 +305,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
String json = conf2.toJson(); String json = conf2.toJson();
@ -362,7 +362,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
ComputationGraph net2 = new ComputationGraph(conf2); ComputationGraph net2 = new ComputationGraph(conf2);
net2.init(); net2.init();
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
String json = conf2.toJson(); String json = conf2.toJson();

View File

@ -75,7 +75,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
String json = conf2.toJson(); String json = conf2.toJson();
@ -130,7 +130,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
ComputationGraph net2 = new ComputationGraph(conf2); ComputationGraph net2 = new ComputationGraph(conf2);
net2.init(); net2.init();
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
String json = conf2.toJson(); String json = conf2.toJson();
@ -170,19 +170,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
MultiLayerNetwork network = new MultiLayerNetwork(conf1); MultiLayerNetwork network = new MultiLayerNetwork(conf1);
network.init(); network.init();
INDArray unfrozenLayerParams = network.getLayer(0).params().dup(); INDArray unfrozenLayerParams = network.getLayer(0).getParams().dup();
INDArray frozenLayerParams1 = network.getLayer(1).params().dup(); INDArray frozenLayerParams1 = network.getLayer(1).getParams().dup();
INDArray frozenLayerParams2 = network.getLayer(2).params().dup(); INDArray frozenLayerParams2 = network.getLayer(2).getParams().dup();
INDArray frozenOutputLayerParams = network.getLayer(3).params().dup(); INDArray frozenOutputLayerParams = network.getLayer(3).getParams().dup();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
network.fit(randomData); network.fit(randomData);
} }
assertNotEquals(unfrozenLayerParams, network.getLayer(0).params()); assertNotEquals(unfrozenLayerParams, network.getLayer(0).getParams());
assertEquals(frozenLayerParams1, network.getLayer(1).params()); assertEquals(frozenLayerParams1, network.getLayer(1).getParams());
assertEquals(frozenLayerParams2, network.getLayer(2).params()); assertEquals(frozenLayerParams2, network.getLayer(2).getParams());
assertEquals(frozenOutputLayerParams, network.getLayer(3).params()); assertEquals(frozenOutputLayerParams, network.getLayer(3).getParams());
} }
@ -228,19 +228,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
ComputationGraph computationGraph = new ComputationGraph(computationGraphConf); ComputationGraph computationGraph = new ComputationGraph(computationGraphConf);
computationGraph.init(); computationGraph.init();
INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup();
INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup();
INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup();
INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).params().dup(); INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).getParams().dup();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
computationGraph.fit(randomData); computationGraph.fit(randomData);
} }
assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).params()); assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams());
assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).params()); assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).getParams());
assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).params()); assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).getParams());
assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).params()); assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).getParams());
} }
@ -275,17 +275,17 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.build(); .build();
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen); MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
frozenNetwork.init(); frozenNetwork.init();
INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).params().dup(); INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).getParams().dup();
INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).params().dup(); INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).getParams().dup();
INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).params().dup(); INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).getParams().dup();
INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).params().dup(); INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).getParams().dup();
MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd); MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd);
sgdNetwork.init(); sgdNetwork.init();
INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).params().dup(); INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).getParams().dup();
INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).params().dup(); INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).getParams().dup();
INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).params().dup(); INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).getParams().dup();
INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).params().dup(); INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).getParams().dup();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
frozenNetwork.fit(randomData); frozenNetwork.fit(randomData);
@ -294,10 +294,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
sgdNetwork.fit(randomData); sgdNetwork.fit(randomData);
} }
assertEquals(frozenNetwork.getLayer(0).params(), sgdNetwork.getLayer(0).params()); assertEquals(frozenNetwork.getLayer(0).getParams(), sgdNetwork.getLayer(0).getParams());
assertEquals(frozenNetwork.getLayer(1).params(), sgdNetwork.getLayer(1).params()); assertEquals(frozenNetwork.getLayer(1).getParams(), sgdNetwork.getLayer(1).getParams());
assertEquals(frozenNetwork.getLayer(2).params(), sgdNetwork.getLayer(2).params()); assertEquals(frozenNetwork.getLayer(2).getParams(), sgdNetwork.getLayer(2).getParams());
assertEquals(frozenNetwork.getLayer(3).params(), sgdNetwork.getLayer(3).params()); assertEquals(frozenNetwork.getLayer(3).getParams(), sgdNetwork.getLayer(3).getParams());
} }
@ -360,17 +360,17 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf); ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf);
frozenComputationGraph.init(); frozenComputationGraph.init();
INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup();
INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup();
INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup();
INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).params().dup(); INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).getParams().dup();
ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf); ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf);
sgdComputationGraph.init(); sgdComputationGraph.init();
INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup();
INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup();
INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup();
INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).params().dup(); INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).getParams().dup();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
frozenComputationGraph.fit(randomData); frozenComputationGraph.fit(randomData);
@ -379,10 +379,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
sgdComputationGraph.fit(randomData); sgdComputationGraph.fit(randomData);
} }
assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params()); assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams());
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params()); assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams());
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params()); assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams());
assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).params(), sgdComputationGraph.getLayer(frozenBranchOutput).params()); assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).getParams(), sgdComputationGraph.getLayer(frozenBranchOutput).getParams());
} }

View File

@ -68,9 +68,9 @@ public class OutputLayerTest extends BaseDL4JTest {
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
OutputLayer l = (OutputLayer) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, OutputLayer l = (OutputLayer) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf,
Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType()); Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType());
params = l.params(); params = l.getModelParams();
l.setParamsTable(params); l.setParamsTable(params);
assertEquals(params, l.params()); assertEquals(params, l.getModelParams());
} }
@Test @Test
@ -217,8 +217,8 @@ public class OutputLayerTest extends BaseDL4JTest {
//However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping) //However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping)
//RnnOutputLayer has miniBatch examples //RnnOutputLayer has miniBatch examples
//Hence: expect difference in scores by factor of timeSeriesLength //Hence: expect difference in scores by factor of timeSeriesLength
double score = mln.score() * timeSeriesLength; double score = mln.getScore() * timeSeriesLength;
double scoreRNN = mlnRnn.score(); double scoreRNN = mlnRnn.getScore();
assertFalse(Double.isNaN(score)); assertFalse(Double.isNaN(score));
assertFalse(Double.isNaN(scoreRNN)); assertFalse(Double.isNaN(scoreRNN));
@ -234,7 +234,7 @@ public class OutputLayerTest extends BaseDL4JTest {
RnnOutputLayer rnnol = (RnnOutputLayer) mlnRnn.getOutputLayer(); RnnOutputLayer rnnol = (RnnOutputLayer) mlnRnn.getOutputLayer();
//assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength}); //assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength});
//Input may be set by BaseLayer methods. Thus input may end up as reshaped 2d version instead of original 3d version. //Input may be set by BaseLayerConfiguration methods. Thus input may end up as reshaped 2d version instead of original 3d version.
//Not ideal, but everything else works. //Not ideal, but everything else works.
assertArrayEquals(rnnol.getLabels().shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); assertArrayEquals(rnnol.getLabels().shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
@ -303,7 +303,7 @@ public class OutputLayerTest extends BaseDL4JTest {
MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2); MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2);
mln2.init(); mln2.init();
mln2.setParams(mln.params()); mln2.setParams(mln.getModelParams());
INDArray in = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); INDArray in = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
@ -330,7 +330,7 @@ public class OutputLayerTest extends BaseDL4JTest {
mln2.computeGradientAndScore(); mln2.computeGradientAndScore();
assertEquals(mln.gradient().gradient(), mln2.gradient().gradient()); assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
assertEquals(mln.score(), mln2.score(), 1e-6); assertEquals(mln.getScore(), mln2.getScore(), 1e-6);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
} }
@ -386,7 +386,7 @@ public class OutputLayerTest extends BaseDL4JTest {
mln2.init(); mln2.init();
mln2.setParams(mln.params()); mln2.setParams(mln.getModelParams());
INDArray in = Nd4j.rand(3, 3, 5, 5); INDArray in = Nd4j.rand(3, 3, 5, 5);
@ -407,7 +407,7 @@ public class OutputLayerTest extends BaseDL4JTest {
mln.computeGradientAndScore(); mln.computeGradientAndScore();
mln2.computeGradientAndScore(); mln2.computeGradientAndScore();
assertEquals(mln.score(), mln2.score(), 1e-6); assertEquals(mln.getScore(), mln2.getScore(), 1e-6);
assertEquals(mln.gradient().gradient(), mln2.gradient().gradient()); assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
//Also check computeScoreForExamples //Also check computeScoreForExamples
@ -479,7 +479,7 @@ public class OutputLayerTest extends BaseDL4JTest {
graph2.init(); graph2.init();
graph2.setParams(graph.params()); graph2.setParams(graph.getModelParams());
INDArray in = Nd4j.rand(3, 3, 5, 5); INDArray in = Nd4j.rand(3, 3, 5, 5);
@ -500,7 +500,7 @@ public class OutputLayerTest extends BaseDL4JTest {
graph.computeGradientAndScore(); graph.computeGradientAndScore();
graph2.computeGradientAndScore(); graph2.computeGradientAndScore();
assertEquals(graph.score(), graph2.score(), 1e-6); assertEquals(graph.getScore(), graph2.getScore(), 1e-6);
assertEquals(graph.gradient().gradient(), graph2.gradient().gradient()); assertEquals(graph.gradient().gradient(), graph2.gradient().gradient());
//Also check computeScoreForExamples //Also check computeScoreForExamples

View File

@ -59,13 +59,13 @@ public class SeedTest extends BaseDL4JTest {
layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
double score = layer.score(); double score = layer.getScore();
INDArray parameters = layer.params(); INDArray parameters = layer.getParams();
layer.setParams(parameters); layer.setParams(parameters);
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
double score2 = layer.score(); double score2 = layer.getScore();
assertEquals(parameters, layer.params()); assertEquals(parameters, layer.getParams());
assertEquals(score, score2, 1e-4); assertEquals(score, score2, 1e-4);
} }
} }

View File

@ -845,9 +845,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
public static void testHelper(TestCase tc) { public static void testHelper(TestCase tc) {
tc.net2.params().assign(tc.net1.params()); tc.net2.getModelParams().assign(tc.net1.getModelParams());
tc.net3.params().assign(tc.net1.params()); tc.net3.getModelParams().assign(tc.net1.getModelParams());
tc.net4.params().assign(tc.net1.params()); tc.net4.getModelParams().assign(tc.net1.getModelParams());
//Test forward pass: //Test forward pass:
INDArray inNCHW = tc.inNCHW; INDArray inNCHW = tc.inNCHW;
@ -909,9 +909,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
tc.net3.fit(inNHWC, tc.labelsNHWC); tc.net3.fit(inNHWC, tc.labelsNHWC);
tc.net4.fit(inNHWC, tc.labelsNHWC); tc.net4.fit(inNHWC, tc.labelsNHWC);
assertEquals(tc.net1.params(), tc.net2.params(), tc.msg); assertEquals(tc.net1.getModelParams(), tc.net2.getModelParams(), tc.msg);
assertEquals(tc.net1.params(), tc.net3.params(), tc.msg); assertEquals(tc.net1.getModelParams(), tc.net3.getModelParams(), tc.msg);
assertEquals(tc.net1.params(), tc.net4.params(), tc.msg); assertEquals(tc.net1.getModelParams(), tc.net4.getModelParams(), tc.msg);
//Test serialization //Test serialization
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);

View File

@ -30,7 +30,6 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
@ -38,7 +37,6 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitNormal;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -450,10 +448,10 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
MultiLayerNetwork net = getCNNMLNConfig(true, false); MultiLayerNetwork net = getCNNMLNConfig(true, false);
INDArray paramsOrig = net.params().dup(); INDArray paramsOrig = net.getModelParams().dup();
net.setParams(paramsOrig); net.setParams(paramsOrig);
INDArray params2 = net.params(); INDArray params2 = net.getModelParams();
assertEquals(paramsOrig, params2); assertEquals(paramsOrig, params2);
} }

View File

@ -154,7 +154,7 @@ public class TestCustomLayers extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
assertEquals(net2.params(), net.params()); assertEquals(net2.getModelParams(), net.getModelParams());
INDArray testFeatures = Nd4j.rand(1, 10); INDArray testFeatures = Nd4j.rand(1, 10);
INDArray testLabels = Nd4j.zeros(1, 10); INDArray testLabels = Nd4j.zeros(1, 10);
@ -207,7 +207,7 @@ public class TestCustomLayers extends BaseDL4JTest {
ComputationGraph net2 = new ComputationGraph(conf2); ComputationGraph net2 = new ComputationGraph(conf2);
net2.init(); net2.init();
assertEquals(net2.params(), net.params()); assertEquals(net2.getModelParams(), net.getModelParams());
INDArray testFeatures = Nd4j.rand(1, 10); INDArray testFeatures = Nd4j.rand(1, 10);
INDArray testLabels = Nd4j.zeros(1, 10); INDArray testLabels = Nd4j.zeros(1, 10);

View File

@ -56,7 +56,7 @@ public class CustomLayer extends FeedForwardLayer {
boolean initializeParams, DataType networkDataType) { boolean initializeParams, DataType networkDataType) {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType); CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType);
ret.setListeners(trainingListeners); ret.addTrainingListeners(trainingListeners);
ret.setIndex(layerIndex); ret.setIndex(layerIndex);
ret.setParamsViewArray(layerParamsView); ret.setParamsViewArray(layerParamsView);
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams); Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);

View File

@ -54,7 +54,7 @@ public class CustomOutputLayer extends BaseOutputLayer {
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
CustomOutputLayerImpl ret = new CustomOutputLayerImpl(lconf, networkDataType); CustomOutputLayerImpl ret = new CustomOutputLayerImpl(lconf, networkDataType);
ret.setListeners(trainingListeners); ret.addTrainingListeners(trainingListeners);
ret.setIndex(layerIndex); ret.setIndex(layerIndex);
ret.setParamsViewArray(layerParamsView); ret.setParamsViewArray(layerParamsView);
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams); Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);

View File

@ -72,7 +72,7 @@ public class DenseTest extends BaseDL4JTest {
DataSet test = iter.next(); DataSet test = iter.next();
assertEquals(model.params(), model2.params()); assertEquals(model.getModelParams(), model2.getModelParams());
Evaluation eval = new Evaluation(); Evaluation eval = new Evaluation();
INDArray output = model.output(test.getFeatures()); INDArray output = model.output(test.getFeatures());
@ -99,7 +99,7 @@ public class DenseTest extends BaseDL4JTest {
DataSet test = iter.next(); DataSet test = iter.next();
assertEquals(model.params(), model2.params()); assertEquals(model.getModelParams(), model2.getModelParams());
Evaluation eval = new Evaluation(); Evaluation eval = new Evaluation();
INDArray output = model.output(test.getFeatures()); INDArray output = model.output(test.getFeatures());

View File

@ -169,7 +169,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.init(); net.init();
net2.init(); net2.init();
net2.setParams(net.params().dup()); net2.setParams(net.getModelParams().dup());
int batchSize = 3; int batchSize = 3;
INDArray inEmbedding = Nd4j.create(batchSize, 1); INDArray inEmbedding = Nd4j.create(batchSize, 1);
@ -216,7 +216,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.init(); net.init();
net2.init(); net2.init();
net2.setParams(net.params().dup()); net2.setParams(net.getModelParams().dup());
int batchSize = 3; int batchSize = 3;
INDArray inEmbedding = Nd4j.create(batchSize, 1); INDArray inEmbedding = Nd4j.create(batchSize, 1);
@ -262,7 +262,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.init(); net.init();
net2.init(); net2.init();
net2.setParams(net.params().dup()); net2.setParams(net.getModelParams().dup());
int batchSize = 3; int batchSize = 3;
INDArray inEmbedding = Nd4j.create(batchSize, 1); INDArray inEmbedding = Nd4j.create(batchSize, 1);
@ -287,7 +287,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.computeGradientAndScore(); net.computeGradientAndScore();
net2.computeGradientAndScore(); net2.computeGradientAndScore();
assertEquals(net2.score(), net.score(), 1e-6); assertEquals(net2.getScore(), net.getScore(), 1e-6);
Map<String, INDArray> gradient = net.gradient().gradientForVariable(); Map<String, INDArray> gradient = net.gradient().gradientForVariable();
Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable(); Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
@ -323,7 +323,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.init(); net.init();
net2.init(); net2.init();
net2.setParams(net.params().dup()); net2.setParams(net.getModelParams().dup());
int batchSize = 3; int batchSize = 3;
INDArray inEmbedding = Nd4j.create(batchSize, 1); INDArray inEmbedding = Nd4j.create(batchSize, 1);
@ -349,7 +349,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net2.computeGradientAndScore(); net2.computeGradientAndScore();
// System.out.println(net.score() + "\t" + net2.score()); // System.out.println(net.score() + "\t" + net2.score());
assertEquals(net2.score(), net.score(), 1e-6); assertEquals(net2.getScore(), net.getScore(), 1e-6);
Map<String, INDArray> gradient = net.gradient().gradientForVariable(); Map<String, INDArray> gradient = net.gradient().gradientForVariable();
Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable(); Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
@ -395,7 +395,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.init(); net.init();
net2.init(); net2.init();
net2.setParams(net.params().dup()); net2.setParams(net.getModelParams().dup());
INDArray inEmbedding = Nd4j.create(batchSize, 1, timeSeriesLength); INDArray inEmbedding = Nd4j.create(batchSize, 1, timeSeriesLength);
INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, timeSeriesLength); INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, timeSeriesLength);
@ -422,7 +422,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net2.computeGradientAndScore(); net2.computeGradientAndScore();
// System.out.println(net.score() + "\t" + net2.score()); // System.out.println(net.score() + "\t" + net2.score());
assertEquals(net2.score(), net.score(), 1e-5); assertEquals(net2.getScore(), net.getScore(), 1e-5);
Map<String, INDArray> gradient = net.gradient().gradientForVariable(); Map<String, INDArray> gradient = net.gradient().gradientForVariable();
Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable(); Map<String, INDArray> gradient2 = net2.gradient().gradientForVariable();
@ -484,7 +484,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
net2.setParams(net.params().dup()); net2.setParams(net.getModelParams().dup());
INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength); INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength);
INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength); INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength);
@ -523,7 +523,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net2.computeGradientAndScore(); net2.computeGradientAndScore();
// System.out.println(net.score() + "\t" + net2.score()); // System.out.println(net.score() + "\t" + net2.score());
assertEquals(net2.score(), net.score(), 1e-5); assertEquals(net2.getScore(), net.getScore(), 1e-5);
Map<String, INDArray> gradients = net.gradient().gradientForVariable(); Map<String, INDArray> gradients = net.gradient().gradientForVariable();
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable(); Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
@ -640,7 +640,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
net2.setParams(net.params().dup()); net2.setParams(net.getModelParams().dup());
INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength}); INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength});
INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength); INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength);
@ -678,7 +678,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.computeGradientAndScore(); net.computeGradientAndScore();
net2.computeGradientAndScore(); net2.computeGradientAndScore();
assertEquals(net2.score(), net.score(), 1e-5); assertEquals(net2.getScore(), net.getScore(), 1e-5);
Map<String, INDArray> gradients = net.gradient().gradientForVariable(); Map<String, INDArray> gradients = net.gradient().gradientForVariable();
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable(); Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
@ -777,9 +777,9 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3); MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
net3.init(); net3.init();
INDArray p1 = net.params(); INDArray p1 = net.getModelParams();
INDArray p2 = net2.params(); INDArray p2 = net2.getModelParams();
INDArray p3 = net3.params(); INDArray p3 = net3.getModelParams();
boolean eq = p1.equalsWithEps(p2, 1e-4); boolean eq = p1.equalsWithEps(p2, 1e-4);
String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi; String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
assertTrue(eq, str + " p1/p2 params not equal"); assertTrue(eq, str + " p1/p2 params not equal");

View File

@ -514,7 +514,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
net.setListeners(new ScoreIterationListener(100)); net.addTrainingListeners(new ScoreIterationListener(100));
int nEpochs = 1000; int nEpochs = 1000;
DataSet ds = iter.next(); DataSet ds = iter.next();

View File

@ -79,13 +79,13 @@ public class OCNNOutputLayerTest extends BaseDL4JTest {
if (doLearningFirst) { if (doLearningFirst) {
//Run a number of iterations of learning //Run a number of iterations of learning
network.setInput(arr); network.setInput(arr);
network.setListeners(new ScoreIterationListener(1)); network.addTrainingListeners(new ScoreIterationListener(1));
network.computeGradientAndScore(); network.computeGradientAndScore();
double scoreBefore = network.score(); double scoreBefore = network.getScore();
for (int j = 0; j < 10; j++) for (int j = 0; j < 10; j++)
network.fit(ds); network.fit(ds);
network.computeGradientAndScore(); network.computeGradientAndScore();
double scoreAfter = network.score(); double scoreAfter = network.getScore();
//Can't test in 'characteristic mode of operation' if not learning //Can't test in 'characteristic mode of operation' if not learning
String msg = "testLayer() - score did not (sufficiently) decrease during learning - activationFn=" String msg = "testLayer() - score did not (sufficiently) decrease during learning - activationFn="
+ "relu" + ", lossFn=" + "ocnn" + ", " + "sigmoid" + "relu" + ", lossFn=" + "ocnn" + ", " + "sigmoid"
@ -147,7 +147,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest {
tmpFile.deleteOnExit(); tmpFile.deleteOnExit();
MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(tmpFile); MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(tmpFile);
assertEquals(network.params(),multiLayerNetwork.params()); assertEquals(network.getModelParams(),multiLayerNetwork.getModelParams());
assertEquals(network.numParams(),multiLayerNetwork.numParams()); assertEquals(network.numParams(),multiLayerNetwork.numParams());
} }
@ -187,7 +187,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest {
.build(); .build();
MultiLayerNetwork network = new MultiLayerNetwork(configuration); MultiLayerNetwork network = new MultiLayerNetwork(configuration);
network.init(); network.init();
network.setListeners(new ScoreIterationListener(1)); network.addTrainingListeners(new ScoreIterationListener(1));
return network; return network;
} }

View File

@ -124,7 +124,7 @@ public class BidirectionalTest extends BaseDL4JTest {
assertEquals(n1, n2); assertEquals(n1, n2);
} }
net2.setParams(net1.params()); //Assuming exact same layout here... net2.setParams(net1.getModelParams()); //Assuming exact same layout here...
INDArray in; INDArray in;
if (rnnDataFormat == NCW){ if (rnnDataFormat == NCW){
@ -154,7 +154,7 @@ public class BidirectionalTest extends BaseDL4JTest {
net2.computeGradientAndScore(); net2.computeGradientAndScore();
//Ensure scores are equal: //Ensure scores are equal:
assertEquals(net1.score(), net2.score(), 1e-6); assertEquals(net1.getScore(), net2.getScore(), 1e-6);
//Ensure gradients are equal: //Ensure gradients are equal:
Gradient g1 = net1.gradient(); Gradient g1 = net1.gradient();
@ -174,8 +174,8 @@ public class BidirectionalTest extends BaseDL4JTest {
net1.fit(in, labels); net1.fit(in, labels);
net2.fit(in, labels); net2.fit(in, labels);
INDArray p1 = net1.params(); INDArray p1 = net1.getModelParams();
INDArray p2 = net2.params(); INDArray p2 = net2.getModelParams();
assertEquals(p1, p2); assertEquals(p1, p2);
} }
} }
@ -232,7 +232,7 @@ public class BidirectionalTest extends BaseDL4JTest {
assertEquals(n1, n2); assertEquals(n1, n2);
} }
net2.setParams(net1.params()); //Assuming exact same layout here... net2.setParams(net1.getModelParams()); //Assuming exact same layout here...
INDArray in = Nd4j.rand(3, 10, 5); INDArray in = Nd4j.rand(3, 10, 5);
@ -253,7 +253,7 @@ public class BidirectionalTest extends BaseDL4JTest {
net2.computeGradientAndScore(); net2.computeGradientAndScore();
//Ensure scores are equal: //Ensure scores are equal:
assertEquals(net1.score(), net2.score(), 1e-6); assertEquals(net1.getScore(), net2.getScore(), 1e-6);
//Ensure gradients are equal: //Ensure gradients are equal:
Gradient g1 = net1.gradient(); Gradient g1 = net1.gradient();
@ -273,8 +273,8 @@ public class BidirectionalTest extends BaseDL4JTest {
net1.fit(new DataSet(in, labels)); net1.fit(new DataSet(in, labels));
net2.fit(new DataSet(in, labels)); net2.fit(new DataSet(in, labels));
INDArray p1 = net1.params(); INDArray p1 = net1.getModelParams();
INDArray p2 = net2.params(); INDArray p2 = net2.getModelParams();
assertEquals(p1, p2); assertEquals(p1, p2);
} }
} }
@ -340,7 +340,7 @@ public class BidirectionalTest extends BaseDL4JTest {
net1.computeGradientAndScore(); net1.computeGradientAndScore();
net2.computeGradientAndScore(); net2.computeGradientAndScore();
assertEquals(net1.score(), net2.score(), 1e-6); assertEquals(net1.getScore(), net2.getScore(), 1e-6);
assertEquals(net1.gradient().gradient(), net2.gradient().gradient()); assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
} }
} }
@ -403,7 +403,7 @@ public class BidirectionalTest extends BaseDL4JTest {
net1.computeGradientAndScore(); net1.computeGradientAndScore();
net2.computeGradientAndScore(); net2.computeGradientAndScore();
assertEquals(net1.score(), net2.score(), 1e-6); assertEquals(net1.getScore(), net2.getScore(), 1e-6);
assertEquals(net1.gradient().gradient(), net2.gradient().gradient()); assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
} }
} }

View File

@ -277,7 +277,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces()); final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
params = bidirectionalLSTM.params(); params = bidirectionalLSTM.getModelParams();
bidirectionalLSTM.setParamsTable(params); bidirectionalLSTM.setParamsTable(params);

View File

@ -285,9 +285,9 @@ public class RnnDataFormatTests extends BaseDL4JTest {
public static void testHelper(TestCase tc) { public static void testHelper(TestCase tc) {
tc.net2.params().assign(tc.net1.params()); tc.net2.getModelParams().assign(tc.net1.getModelParams());
tc.net3.params().assign(tc.net1.params()); tc.net3.getModelParams().assign(tc.net1.getModelParams());
tc.net4.params().assign(tc.net1.params()); tc.net4.getModelParams().assign(tc.net1.getModelParams());
INDArray inNCW = tc.inNCW; INDArray inNCW = tc.inNCW;
INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup(); INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup();
@ -352,9 +352,9 @@ public class RnnDataFormatTests extends BaseDL4JTest {
tc.net3.fit(inNWC, tc.labelsNWC); tc.net3.fit(inNWC, tc.labelsNWC);
tc.net4.fit(inNWC, tc.labelsNWC); tc.net4.fit(inNWC, tc.labelsNWC);
assertEquals(tc.net1.params(), tc.net2.params(), tc.msg); assertEquals(tc.net1.getModelParams(), tc.net2.getModelParams(), tc.msg);
assertEquals(tc.net1.params(), tc.net3.params(), tc.msg); assertEquals(tc.net1.getModelParams(), tc.net3.getModelParams(), tc.msg);
assertEquals(tc.net1.params(), tc.net4.params(), tc.msg); assertEquals(tc.net1.getModelParams(), tc.net4.getModelParams(), tc.msg);
//Test serialization //Test serialization
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);

View File

@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.recurrent;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.dropout.TestDropout; import org.deeplearning4j.nn.conf.dropout.TestDropout;
import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.GravesLSTM;
@ -173,8 +172,8 @@ public class TestRnnLayers extends BaseDL4JTest {
MultiLayerNetwork netD2 = new MultiLayerNetwork(confD2); MultiLayerNetwork netD2 = new MultiLayerNetwork(confD2);
netD2.init(); netD2.init();
assertEquals(net.params(), netD.params(), s); assertEquals(net.getModelParams(), netD.getModelParams(), s);
assertEquals(net.params(), netD2.params(), s); assertEquals(net.getModelParams(), netD2.getModelParams(), s);
INDArray f = Nd4j.rand(DataType.FLOAT, 3, 10, 10); INDArray f = Nd4j.rand(DataType.FLOAT, 3, 10, 10);
@ -193,7 +192,7 @@ public class TestRnnLayers extends BaseDL4JTest {
INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345); INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
net.fit(f.dup(), l); net.fit(f.dup(), l);
netD.fit(f.dup(), l); netD.fit(f.dup(), l);
assertNotEquals(net.params(), netD.params(), s); assertNotEquals(net.getModelParams(), netD.getModelParams(), s);
netD2.fit(f.dup(), l); netD2.fit(f.dup(), l);
netD2.fit(f.dup(), l); netD2.fit(f.dup(), l);

View File

@ -115,7 +115,7 @@ public class TestTimeDistributed extends BaseDL4JTest {
net1.fit(ds); net1.fit(ds);
net2.fit(ds); net2.fit(ds);
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2); MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2);
out2 = net2.output(in); out2 = net2.output(in);

View File

@ -124,10 +124,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
net.params().assign(net2.params()); net.getModelParams().assign(net2.getModelParams());
//Check params: //Check params:
assertEquals(net2.params(), net.params()); assertEquals(net2.getModelParams(), net.getModelParams());
Map<String, INDArray> params1 = net.getParamTable(); Map<String, INDArray> params1 = net.getParamTable();
Map<String, INDArray> params2 = net2.getParamTable(); Map<String, INDArray> params2 = net2.getParamTable();
assertEquals(params2, params1); assertEquals(params2, params1);
@ -209,10 +209,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
assertEquals(net2.params(), net.params()); assertEquals(net2.getModelParams(), net.getModelParams());
//Check params: //Check params:
assertEquals(net2.params(), net.params()); assertEquals(net2.getModelParams(), net.getModelParams());
Map<String, INDArray> params1 = net.getParamTable(); Map<String, INDArray> params1 = net.getParamTable();
Map<String, INDArray> params2 = net2.getParamTable(); Map<String, INDArray> params2 = net2.getParamTable();
assertEquals(params2, params1); assertEquals(params2, params1);
@ -287,10 +287,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2); MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2);
netStandard.init(); netStandard.init();
netSD.params().assign(netStandard.params()); netSD.getModelParams().assign(netStandard.getModelParams());
//Check params: //Check params:
assertEquals(netStandard.params(), netSD.params()); assertEquals(netStandard.getModelParams(), netSD.getModelParams());
assertEquals(netStandard.getParamTable(), netSD.getParamTable()); assertEquals(netStandard.getParamTable(), netSD.getParamTable());
INDArray in = Nd4j.rand(minibatch, nIn); INDArray in = Nd4j.rand(minibatch, nIn);
@ -379,10 +379,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2); MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2);
netStandard.init(); netStandard.init();
netSD.params().assign(netStandard.params()); netSD.getModelParams().assign(netStandard.getModelParams());
//Check params: //Check params:
assertEquals(netStandard.params(), netSD.params()); assertEquals(netStandard.getModelParams(), netSD.getModelParams());
assertEquals(netStandard.getParamTable(), netSD.getParamTable()); assertEquals(netStandard.getParamTable(), netSD.getParamTable());
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -398,7 +398,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
netStandard.fit(ds); netStandard.fit(ds);
String s = String.valueOf(i); String s = String.valueOf(i);
assertEquals( netStandard.getFlattenedGradients(), netSD.getFlattenedGradients(), s); assertEquals( netStandard.getFlattenedGradients(), netSD.getFlattenedGradients(), s);
assertEquals( netStandard.params(), netSD.params(), s); assertEquals( netStandard.getModelParams(), netSD.getModelParams(), s);
assertEquals( netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray(), s); assertEquals( netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray(), s);
} }

View File

@ -100,10 +100,10 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
ComputationGraph netStandard = new ComputationGraph(conf2); ComputationGraph netStandard = new ComputationGraph(conf2);
netStandard.init(); netStandard.init();
netSD.params().assign(netStandard.params()); netSD.getModelParams().assign(netStandard.getModelParams());
//Check params: //Check params:
assertEquals(netStandard.params(), netSD.params()); assertEquals(netStandard.getModelParams(), netSD.getModelParams());
assertEquals(netStandard.getParamTable(), netSD.getParamTable()); assertEquals(netStandard.getParamTable(), netSD.getParamTable());
INDArray in = Nd4j.rand(minibatch, nIn); INDArray in = Nd4j.rand(minibatch, nIn);
@ -160,7 +160,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
netStandard.fit(ds); netStandard.fit(ds);
assertEquals(netStandard.getParamTable(), netSD.getParamTable()); assertEquals(netStandard.getParamTable(), netSD.getParamTable());
assertEquals(netStandard.params(), netSD.params()); assertEquals(netStandard.getModelParams(), netSD.getModelParams());
assertEquals(netStandard.getFlattenedGradients(), netSD.getFlattenedGradients()); assertEquals(netStandard.getFlattenedGradients(), netSD.getFlattenedGradients());
} }

View File

@ -98,7 +98,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
ComputationGraph std = new ComputationGraph(confStd); ComputationGraph std = new ComputationGraph(confStd);
std.init(); std.init();
lambda.setParams(std.params()); lambda.setParams(std.getModelParams());
INDArray in = Nd4j.rand(3, 5); INDArray in = Nd4j.rand(3, 5);
INDArray labels = TestUtils.randomOneHot(3, 5); INDArray labels = TestUtils.randomOneHot(3, 5);
@ -119,7 +119,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
std.fit(ds); std.fit(ds);
String s = String.valueOf(i); String s = String.valueOf(i);
assertEquals(std.params(), lambda.params(), s); assertEquals(std.getModelParams(), lambda.getModelParams(), s);
assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s); assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s);
} }
@ -182,7 +182,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
ComputationGraph std = new ComputationGraph(confStd); ComputationGraph std = new ComputationGraph(confStd);
std.init(); std.init();
lambda.setParams(std.params()); lambda.setParams(std.getModelParams());
INDArray in1 = Nd4j.rand(3, 5); INDArray in1 = Nd4j.rand(3, 5);
INDArray in2 = Nd4j.rand(3, 5); INDArray in2 = Nd4j.rand(3, 5);
@ -204,7 +204,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
std.fit(mds); std.fit(mds);
String s = String.valueOf(i); String s = String.valueOf(i);
assertEquals(std.params(), lambda.params(), s); assertEquals(std.getModelParams(), lambda.getModelParams(), s);
assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s); assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s);
} }

View File

@ -85,7 +85,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
netSD.fit(ds); netSD.fit(ds);
netStd.fit(ds); netStd.fit(ds);
assertEquals(netStd.params(), netSD.params()); assertEquals(netStd.getModelParams(), netSD.getModelParams());
assertEquals(netStd.getFlattenedGradients(), netSD.getFlattenedGradients()); assertEquals(netStd.getFlattenedGradients(), netSD.getFlattenedGradients());
} }
@ -131,7 +131,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
MultiLayerNetwork netStd = new MultiLayerNetwork(confStd); MultiLayerNetwork netStd = new MultiLayerNetwork(confStd);
netStd.init(); netStd.init();
netSD.params().assign(netStd.params()); netSD.getModelParams().assign(netStd.getModelParams());
assertEquals(netStd.getParamTable(), netSD.getParamTable()); assertEquals(netStd.getParamTable(), netSD.getParamTable());
@ -165,7 +165,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
netSD.fit(ds); netSD.fit(ds);
netStd.fit(ds); netStd.fit(ds);
String s = String.valueOf(i); String s = String.valueOf(i);
assertEquals( netStd.params(), netSD.params(), s); assertEquals( netStd.getModelParams(), netSD.getModelParams(), s);
assertEquals( netStd.getFlattenedGradients(), netSD.getFlattenedGradients(),s ); assertEquals( netStd.getFlattenedGradients(), netSD.getFlattenedGradients(),s );
} }

View File

@ -77,7 +77,7 @@ public class TestVAE extends BaseDL4JTest {
net.init(); net.init();
System.out.println("Exp num params: " + expNumParams); System.out.println("Exp num params: " + expNumParams);
assertEquals(expNumParams, net.getLayer(0).params().length()); assertEquals(expNumParams, net.getLayer(0).getParams().length());
Map<String, INDArray> paramTable = net.getLayer(0).getParamTable(); Map<String, INDArray> paramTable = net.getLayer(0).getParamTable();
int count = 0; int count = 0;
for (INDArray arr : paramTable.values()) { for (INDArray arr : paramTable.values()) {

View File

@ -79,7 +79,7 @@ public class CloseNetworkTests extends BaseDL4JTest {
net.close(); net.close();
assertTrue(net.params().wasClosed()); assertTrue(net.getModelParams().wasClosed());
if(train) { if(train) {
assertTrue(net.getGradientsViewArray().wasClosed()); assertTrue(net.getGradientsViewArray().wasClosed());
Updater u = net.getUpdater(false); Updater u = net.getUpdater(false);
@ -127,7 +127,7 @@ public class CloseNetworkTests extends BaseDL4JTest {
net.close(); net.close();
assertTrue(net.params().wasClosed()); assertTrue(net.getModelParams().wasClosed());
if(train) { if(train) {
assertTrue(net.getGradientsViewArray().wasClosed()); assertTrue(net.getGradientsViewArray().wasClosed());
Updater u = net.getUpdater(false); Updater u = net.getUpdater(false);

View File

@ -57,7 +57,7 @@ public class LargeNetTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
INDArray params = net.params(); INDArray params = net.getModelParams();
long paramsLength = params.length(); long paramsLength = params.length();
long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10; long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10;
assertEquals(expParamsLength, paramsLength); assertEquals(expParamsLength, paramsLength);
@ -91,7 +91,7 @@ public class LargeNetTest extends BaseDL4JTest {
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.init(); net.init();
INDArray params = net.params(); INDArray params = net.getModelParams();
long paramsLength = params.length(); long paramsLength = params.length();
long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10; long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10;
assertEquals(expParamsLength, paramsLength); assertEquals(expParamsLength, paramsLength);

View File

@ -76,7 +76,7 @@ public class TestLrChanges extends BaseDL4JTest {
net2.init(); net2.init();
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf2.setIterationCount(conf.getIterationCount()); conf2.setIterationCount(conf.getIterationCount());
net2.setParams(net.params().dup()); net2.setParams(net.getModelParams().dup());
assertEquals(0.1, net.getLearningRate(0).doubleValue(), 0.0); assertEquals(0.1, net.getLearningRate(0).doubleValue(), 0.0);
net.setLearningRate(0, 0.5); //Set LR for layer 0 to 0.5 net.setLearningRate(0, 0.5); //Set LR for layer 0 to 0.5
@ -96,7 +96,7 @@ public class TestLrChanges extends BaseDL4JTest {
net2.fit(in, l); net2.fit(in, l);
} }
assertEquals(net.params(), net2.params()); assertEquals(net.getModelParams(), net2.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
INDArray in1 = Nd4j.rand(10, 10); INDArray in1 = Nd4j.rand(10, 10);
@ -110,7 +110,7 @@ public class TestLrChanges extends BaseDL4JTest {
net2.setLabels(l1); net2.setLabels(l1);
net2.computeGradientAndScore(); net2.computeGradientAndScore();
assertEquals(net.score(), net2.score(), 1e-8); assertEquals(net.getScore(), net2.getScore(), 1e-8);
//Now: Set *all* LRs to say 0.3... //Now: Set *all* LRs to say 0.3...
@ -126,7 +126,7 @@ public class TestLrChanges extends BaseDL4JTest {
net3.init(); net3.init();
net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf3.setIterationCount(conf.getIterationCount()); conf3.setIterationCount(conf.getIterationCount());
net3.setParams(net.params().dup()); net3.setParams(net.getModelParams().dup());
net.setLearningRate(0.3); net.setLearningRate(0.3);
@ -139,7 +139,7 @@ public class TestLrChanges extends BaseDL4JTest {
net3.fit(in, l); net3.fit(in, l);
} }
assertEquals(net.params(), net3.params()); assertEquals(net.getModelParams(), net3.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
} }
@ -206,7 +206,7 @@ public class TestLrChanges extends BaseDL4JTest {
net2.init(); net2.init();
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf2.setIterationCount(conf.getIterationCount()); conf2.setIterationCount(conf.getIterationCount());
net2.setParams(net.params().dup()); net2.setParams(net.getModelParams().dup());
net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5 net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5
@ -224,7 +224,7 @@ public class TestLrChanges extends BaseDL4JTest {
net2.fit(in, l); net2.fit(in, l);
} }
assertEquals(net.params(), net2.params()); assertEquals(net.getModelParams(), net2.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
} }
@ -270,7 +270,7 @@ public class TestLrChanges extends BaseDL4JTest {
net2.init(); net2.init();
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf2.setIterationCount(conf.getIterationCount()); conf2.setIterationCount(conf.getIterationCount());
net2.setParams(net.params().dup()); net2.setParams(net.getModelParams().dup());
assertEquals(0.1, net.getLearningRate("0").doubleValue(), 0.0); assertEquals(0.1, net.getLearningRate("0").doubleValue(), 0.0);
net.setLearningRate("0", 0.5); //Set LR for layer 0 to 0.5 net.setLearningRate("0", 0.5); //Set LR for layer 0 to 0.5
@ -290,7 +290,7 @@ public class TestLrChanges extends BaseDL4JTest {
net2.fit(new DataSet(in, l)); net2.fit(new DataSet(in, l));
} }
assertEquals(net.params(), net2.params()); assertEquals(net.getModelParams(), net2.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
INDArray in1 = Nd4j.rand(10, 10); INDArray in1 = Nd4j.rand(10, 10);
@ -304,7 +304,7 @@ public class TestLrChanges extends BaseDL4JTest {
net2.setLabels(l1); net2.setLabels(l1);
net2.computeGradientAndScore(); net2.computeGradientAndScore();
assertEquals(net.score(), net2.score(), 1e-8); assertEquals(net.getScore(), net2.getScore(), 1e-8);
//Now: Set *all* LRs to say 0.3... //Now: Set *all* LRs to say 0.3...
@ -320,7 +320,7 @@ public class TestLrChanges extends BaseDL4JTest {
net3.init(); net3.init();
net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf3.setIterationCount(conf.getIterationCount()); conf3.setIterationCount(conf.getIterationCount());
net3.setParams(net.params().dup()); net3.setParams(net.getModelParams().dup());
net.setLearningRate(0.3); net.setLearningRate(0.3);
@ -333,7 +333,7 @@ public class TestLrChanges extends BaseDL4JTest {
net3.fit(new DataSet(in, l)); net3.fit(new DataSet(in, l));
} }
assertEquals(net.params(), net3.params()); assertEquals(net.getModelParams(), net3.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray());
} }
@ -375,7 +375,7 @@ public class TestLrChanges extends BaseDL4JTest {
net2.init(); net2.init();
net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray());
conf2.setIterationCount(conf.getIterationCount()); conf2.setIterationCount(conf.getIterationCount());
net2.setParams(net.params().dup()); net2.setParams(net.getModelParams().dup());
net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5 net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5
@ -393,7 +393,7 @@ public class TestLrChanges extends BaseDL4JTest {
net2.fit(new DataSet(in, l)); net2.fit(new DataSet(in, l));
} }
assertEquals(net.params(), net2.params()); assertEquals(net.getModelParams(), net2.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray());
} }

View File

@ -77,14 +77,14 @@ public class TestNetConversion extends BaseDL4JTest {
n.computeGradientAndScore(); n.computeGradientAndScore();
cg.computeGradientAndScore(); cg.computeGradientAndScore();
assertEquals(n.score(), cg.score(), 1e-6); assertEquals(n.getScore(), cg.getScore(), 1e-6);
assertEquals(n.gradient().gradient(), cg.gradient().gradient()); assertEquals(n.gradient().gradient(), cg.gradient().gradient());
n.fit(in, labels); n.fit(in, labels);
cg.fit(new INDArray[]{in}, new INDArray[]{labels}); cg.fit(new INDArray[]{in}, new INDArray[]{labels});
assertEquals(n.params(), cg.params()); assertEquals(n.getModelParams(), cg.getModelParams());
} }
} }

View File

@ -476,7 +476,7 @@ public class WorkspaceTests extends BaseDL4JTest {
final ComputationGraph computationGraph = new ComputationGraph(config); final ComputationGraph computationGraph = new ComputationGraph(config);
computationGraph.init(); computationGraph.init();
computationGraph.setListeners(new ScoreIterationListener(3)); computationGraph.addTrainingListeners(new ScoreIterationListener(3));
WSTestDataSetIterator iterator = new WSTestDataSetIterator(); WSTestDataSetIterator iterator = new WSTestDataSetIterator();
computationGraph.fit(iterator); computationGraph.fit(iterator);

View File

@ -54,7 +54,7 @@ public class BackPropMLPTest extends BaseDL4JTest {
public void testMLPTrivial() { public void testMLPTrivial() {
//Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1. //Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1.
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] {1}, Activation.SIGMOID)); MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] {1}, Activation.SIGMOID));
network.setListeners(new ScoreIterationListener(1)); network.addTrainingListeners(new ScoreIterationListener(1));
network.init(); network.init();
DataSetIterator iter = new IrisDataSetIterator(1, 10); DataSetIterator iter = new IrisDataSetIterator(1, 10);

View File

@ -64,7 +64,7 @@ import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.AutoEncoder; import org.deeplearning4j.nn.conf.layers.AutoEncoder;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
@ -184,13 +184,13 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork network3 = new MultiLayerNetwork(conf); MultiLayerNetwork network3 = new MultiLayerNetwork(conf);
network3.init(); network3.init();
INDArray params = network3.params(); INDArray params = network3.getModelParams();
INDArray weights = network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); INDArray weights = network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY).dup();
INDArray bias = network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY).dup(); INDArray bias = network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY).dup();
network3.setParameters(params); network3.setParameters(params);
assertEquals(weights, network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY)); assertEquals(weights, network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY));
assertEquals(bias, network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY)); assertEquals(bias, network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY));
INDArray params4 = network3.params(); INDArray params4 = network3.getModelParams();
assertEquals(params, params4); assertEquals(params, params4);
} }
@ -211,7 +211,7 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork network = new MultiLayerNetwork(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init(); network.init();
network.setListeners(new ScoreIterationListener(1)); network.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -242,7 +242,7 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork network = new MultiLayerNetwork(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init(); network.init();
network.setListeners(new ScoreIterationListener(1)); network.addTrainingListeners(new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -330,7 +330,7 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork model = new MultiLayerNetwork(conf); MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init(); model.init();
model.addListeners(new ScoreIterationListener(listenerFreq)); model.addTrainingListeners(new ScoreIterationListener(listenerFreq));
log.info("Train model...."); log.info("Train model....");
int cnt = 0; int cnt = 0;
@ -503,7 +503,7 @@ public class MultiLayerTest extends BaseDL4JTest {
assertEquals(layerNameList.get(0), net.getLayer(0).getLayerConfiguration().getLayerName()); assertEquals(layerNameList.get(0), net.getLayer(0).getLayerConfiguration().getLayerName());
assertEquals(layerNameList, net.getLayerNames()); assertEquals(layerNameList, net.getLayerNames());
BaseLayer b = (BaseLayer) net.getLayer(layerNameList.get(2)).getLayerConfiguration(); BaseLayerConfiguration b = (BaseLayerConfiguration) net.getLayer(layerNameList.get(2)).getLayerConfiguration();
assertEquals("softmax", b.getActivationFn().toString()); assertEquals("softmax", b.getActivationFn().toString());
} }
@ -535,7 +535,7 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork netNoReg = new MultiLayerNetwork(confNoReg); MultiLayerNetwork netNoReg = new MultiLayerNetwork(confNoReg);
netNoReg.init(); netNoReg.init();
netNoReg.setParameters(net.params().dup()); netNoReg.setParameters(net.getModelParams().dup());
//Score single example, and compare to scoreExamples: //Score single example, and compare to scoreExamples:
INDArray input = Nd4j.rand(3, nIn); INDArray input = Nd4j.rand(3, nIn);
@ -703,7 +703,7 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
net.fit(iter.next()); net.fit(iter.next());
// TODO validate actual layer gradientView - issue getting var out of BaseLayer w/o adding MLN getter that gets confused with local gradient vars // TODO validate actual layer gradientView - issue getting var out of BaseLayerConfiguration w/o adding MLN getter that gets confused with local gradient vars
Gradient actualGradient = net.gradient; Gradient actualGradient = net.gradient;
assertNotEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W")); assertNotEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W"));
@ -716,13 +716,13 @@ public class MultiLayerTest extends BaseDL4JTest {
net.setParam("0_b", Nd4j.ones(1, 5)); net.setParam("0_b", Nd4j.ones(1, 5));
net.setParam("1_W", Nd4j.ones(5, 3)); net.setParam("1_W", Nd4j.ones(5, 3));
net.setParam("1_b", Nd4j.ones(1, 3)); net.setParam("1_b", Nd4j.ones(1, 3));
INDArray actualParams = net.params(); INDArray actualParams = net.getModelParams();
// Confirm params // Confirm params
assertEquals(expectedGradient.gradient(), actualParams); assertEquals(expectedGradient.gradient(), actualParams);
net.update(expectedGradient); net.update(expectedGradient);
actualParams = net.params(); actualParams = net.getModelParams();
assertEquals(Nd4j.ones(1, 43).addi(1), actualParams); assertEquals(Nd4j.ones(1, 43).addi(1), actualParams);
} }
@ -762,7 +762,7 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork aePre = getAeModel(true, nIn, nOut); MultiLayerNetwork aePre = getAeModel(true, nIn, nOut);
int actualNP = (int) aePre.numParams(); int actualNP = (int) aePre.numParams();
assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP); assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP);
INDArray params = aePre.params(); INDArray params = aePre.getModelParams();
assertEquals(params.length(), actualNP); // check num params assertEquals(params.length(), actualNP); // check num params
Map<String, INDArray> paramTable = aePre.getParamTable(); Map<String, INDArray> paramTable = aePre.getParamTable();
assertTrue(paramTable.containsKey("0_vb")); // check vb exists for pretrain layer assertTrue(paramTable.containsKey("0_vb")); // check vb exists for pretrain layer
@ -774,7 +774,7 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork aeNoPre = getAeModel(false, nIn, nOut); MultiLayerNetwork aeNoPre = getAeModel(false, nIn, nOut);
actualNP = (int) aeNoPre.numParams(); actualNP = (int) aeNoPre.numParams();
assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP); assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP);
params = aeNoPre.params(); params = aeNoPre.getModelParams();
assertEquals(params.length(), actualNP); assertEquals(params.length(), actualNP);
paramTable = aePre.getParamTable(); paramTable = aePre.getParamTable();
assertTrue(paramTable.containsKey("0_vb")); assertTrue(paramTable.containsKey("0_vb"));
@ -865,14 +865,14 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
BaseLayer bl0 = (BaseLayer) net2.getLayer(0).getLayerConfiguration(); BaseLayerConfiguration bl0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
assertEquals(0.1, TestUtils.getL1(bl0.getRegularizationBias()), 1e-6); assertEquals(0.1, TestUtils.getL1(bl0.getRegularizationBias()), 1e-6);
assertEquals(0.2, TestUtils.getL2(bl0.getRegularizationBias()), 1e-6); assertEquals(0.2, TestUtils.getL2(bl0.getRegularizationBias()), 1e-6);
INDArray features = Nd4j.rand(10, 10); INDArray features = Nd4j.rand(10, 10);
INDArray labels = Nd4j.rand(10, 10); INDArray labels = Nd4j.rand(10, 10);
net2.setParams(net1.params().dup()); net2.setParams(net1.getModelParams().dup());
net1.setInput(features); net1.setInput(features);
net1.setLabels(labels); net1.setLabels(labels);
@ -888,15 +888,15 @@ public class MultiLayerTest extends BaseDL4JTest {
r = net2.calcRegularizationScore(true); r = net2.calcRegularizationScore(true);
assertEquals(0.0, r, 0.0); assertEquals(0.0, r, 0.0);
double s1 = net1.score(); double s1 = net1.getScore();
double s2 = net2.score(); double s2 = net2.getScore();
assertEquals(s1, s2, 1e-6); //Biases initialized to 0 -> should initially have same score assertEquals(s1, s2, 1e-6); //Biases initialized to 0 -> should initially have same score
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
net1.fit(features, labels); net1.fit(features, labels);
} }
net2.setParams(net1.params().dup()); net2.setParams(net1.getModelParams().dup());
net1.computeGradientAndScore(); net1.computeGradientAndScore();
net2.computeGradientAndScore(); net2.computeGradientAndScore();
@ -906,8 +906,8 @@ public class MultiLayerTest extends BaseDL4JTest {
r = net2.calcRegularizationScore(true); r = net2.calcRegularizationScore(true);
assertTrue(r > 0.0); assertTrue(r > 0.0);
s1 = net1.score(); s1 = net1.getScore();
s2 = net2.score(); s2 = net2.getScore();
assertNotEquals(s1, s2, 1e-6); //Scores should differ due to bias l1/l2 assertNotEquals(s1, s2, 1e-6); //Scores should differ due to bias l1/l2
@ -1022,11 +1022,11 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
assertNotEquals(net1.params(), net2.params()); assertNotEquals(net1.getModelParams(), net2.getModelParams());
assertNotEquals(net1.getParamTable(), net2.getParamTable()); assertNotEquals(net1.getParamTable(), net2.getParamTable());
net1.setParamTable(net2.getParamTable()); net1.setParamTable(net2.getParamTable());
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
assertEquals(net1.getParamTable(), net2.getParamTable()); assertEquals(net1.getParamTable(), net2.getParamTable());
} }
@ -1412,7 +1412,7 @@ public class MultiLayerTest extends BaseDL4JTest {
exp.add(MultiLayerNetwork.class); exp.add(MultiLayerNetwork.class);
CheckModelsListener listener = new CheckModelsListener(); CheckModelsListener listener = new CheckModelsListener();
net.setListeners(listener); net.addTrainingListeners(listener);
INDArray f = Nd4j.create(1, 10); INDArray f = Nd4j.create(1, 10);
INDArray l = Nd4j.create(1, 10); INDArray l = Nd4j.create(1, 10);

View File

@ -753,9 +753,9 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
DataSet ds = new DataSet(features, labels, maskArrayInput, maskArrayOutput); DataSet ds = new DataSet(features, labels, maskArrayInput, maskArrayOutput);
INDArray initialParams = mln.params().dup(); INDArray initialParams = mln.getModelParams().dup();
mln.fit(ds); mln.fit(ds);
INDArray afterParams = mln.params(); INDArray afterParams = mln.getModelParams();
assertNotEquals(initialParams, afterParams); assertNotEquals(initialParams, afterParams);
} }

View File

@ -172,7 +172,7 @@ public class TestMasking extends BaseDL4JTest {
net.setLabels(labels); net.setLabels(labels);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score1 = net.score(); double score1 = net.getScore();
INDArray grad1 = net.gradient().gradient(); INDArray grad1 = net.gradient().gradient();
//Now: change the label values for the masked steps. The //Now: change the label values for the masked steps. The
@ -187,7 +187,7 @@ public class TestMasking extends BaseDL4JTest {
assertNotEquals(labels, newLabels); assertNotEquals(labels, newLabels);
double score2 = net.score(); double score2 = net.getScore();
INDArray grad2 = net.gradient().gradient(); INDArray grad2 = net.gradient().gradient();
assertEquals(score1, score2, 1e-6); assertEquals(score1, score2, 1e-6);
@ -214,7 +214,7 @@ public class TestMasking extends BaseDL4JTest {
graph.setLabels(labels); graph.setLabels(labels);
graph.computeGradientAndScore(); graph.computeGradientAndScore();
double gScore1 = graph.score(); double gScore1 = graph.getScore();
INDArray gGrad1 = graph.gradient().gradient(); INDArray gGrad1 = graph.gradient().gradient();
graph.setLayerMaskArrays(null, new INDArray[] {labelMask}); graph.setLayerMaskArrays(null, new INDArray[] {labelMask});
@ -222,7 +222,7 @@ public class TestMasking extends BaseDL4JTest {
graph.setLabels(newLabels); graph.setLabels(newLabels);
graph.computeGradientAndScore(); graph.computeGradientAndScore();
double gScore2 = graph.score(); double gScore2 = graph.getScore();
INDArray gGrad2 = graph.gradient().gradient(); INDArray gGrad2 = graph.gradient().gradient();
assertEquals(gScore1, gScore2, 1e-6); assertEquals(gScore1, gScore2, 1e-6);

View File

@ -53,12 +53,12 @@ public class TestSetGetParameters extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
INDArray initParams = net.params().dup(); INDArray initParams = net.getModelParams().dup();
Map<String, INDArray> initParams2 = net.getParamTable(); Map<String, INDArray> initParams2 = net.getParamTable();
net.setParams(net.params()); net.setParams(net.getModelParams());
INDArray initParamsAfter = net.params(); INDArray initParamsAfter = net.getModelParams();
Map<String, INDArray> initParams2After = net.getParamTable(); Map<String, INDArray> initParams2After = net.getParamTable();
for (String s : initParams2.keySet()) { for (String s : initParams2.keySet()) {
@ -71,7 +71,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape()); INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape());
net.setParams(randomParams.dup()); net.setParams(randomParams.dup());
assertEquals(net.params(), randomParams); assertEquals(net.getModelParams(), randomParams);
} }
@Test @Test
@ -90,12 +90,12 @@ public class TestSetGetParameters extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
INDArray initParams = net.params().dup(); INDArray initParams = net.getModelParams().dup();
Map<String, INDArray> initParams2 = net.getParamTable(); Map<String, INDArray> initParams2 = net.getParamTable();
net.setParams(net.params()); net.setParams(net.getModelParams());
INDArray initParamsAfter = net.params(); INDArray initParamsAfter = net.getModelParams();
Map<String, INDArray> initParams2After = net.getParamTable(); Map<String, INDArray> initParams2After = net.getParamTable();
for (String s : initParams2.keySet()) { for (String s : initParams2.keySet()) {
@ -108,7 +108,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape()); INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape());
net.setParams(randomParams.dup()); net.setParams(randomParams.dup());
assertEquals(net.params(), randomParams); assertEquals(net.getModelParams(), randomParams);
} }
@Test @Test
@ -128,7 +128,7 @@ public class TestSetGetParameters extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
INDArray params = net.params(); INDArray params = net.getModelParams();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf); MultiLayerNetwork net2 = new MultiLayerNetwork(conf);
@ -137,11 +137,11 @@ public class TestSetGetParameters extends BaseDL4JTest {
MultiLayerNetwork net3 = new MultiLayerNetwork(conf); MultiLayerNetwork net3 = new MultiLayerNetwork(conf);
net3.init(params, false); net3.init(params, false);
assertEquals(params, net2.params()); assertEquals(params, net2.getModelParams());
assertEquals(params, net3.params()); assertEquals(params, net3.getModelParams());
assertNotSame(params, net2.params()); //Different objects due to clone assertNotSame(params, net2.getModelParams()); //Different objects due to clone
assertSame(params, net3.params()); //Same object due to clone assertSame(params, net3.getModelParams()); //Same object due to clone
Map<String, INDArray> paramsMap = net.getParamTable(); Map<String, INDArray> paramsMap = net.getParamTable();

View File

@ -103,14 +103,14 @@ public class TestVariableLengthTS extends BaseDL4JTest {
net.setInput(in1); net.setInput(in1);
net.setLabels(labels1); net.setLabels(labels1);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score1 = net.score(); double score1 = net.getScore();
Gradient g1 = net.gradient(); Gradient g1 = net.gradient();
net.setInput(in2); net.setInput(in2);
net.setLabels(labels2); net.setLabels(labels2);
net.setLayerMaskArrays(null, labelMask); net.setLayerMaskArrays(null, labelMask);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score2 = net.score(); double score2 = net.getScore();
Gradient g2 = net.gradient(); Gradient g2 = net.gradient();
//Scores and gradients should be identical for two cases (given mask array) //Scores and gradients should be identical for two cases (given mask array)
@ -134,7 +134,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
} }
net.setLabels(labels2); net.setLabels(labels2);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score2a = net.score(); double score2a = net.getScore();
Gradient g2a = net.gradient(); Gradient g2a = net.gradient();
assertEquals(score2, score2a, 1e-6); assertEquals(score2, score2a, 1e-6);
for (String s : g2map.keySet()) { for (String s : g2map.keySet()) {
@ -196,7 +196,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
net.setInput(in1); net.setInput(in1);
net.setLabels(labels1); net.setLabels(labels1);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score1 = net.score(); double score1 = net.getScore();
Gradient g1 = net.gradient(); Gradient g1 = net.gradient();
Map<String, INDArray> map1 = g1.gradientForVariable(); Map<String, INDArray> map1 = g1.gradientForVariable();
for (String s : map1.keySet()) { for (String s : map1.keySet()) {
@ -207,7 +207,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
net.setLabels(labels2); net.setLabels(labels2);
net.setLayerMaskArrays(inputMask, null); net.setLayerMaskArrays(inputMask, null);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score2 = net.score(); double score2 = net.getScore();
Gradient g2 = net.gradient(); Gradient g2 = net.gradient();
net.setInput(in2); net.setInput(in2);
@ -240,7 +240,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
net.setInput(in2); net.setInput(in2);
net.setLayerMaskArrays(inputMask, null); net.setLayerMaskArrays(inputMask, null);
net.computeGradientAndScore(); net.computeGradientAndScore();
double score2a = net.score(); double score2a = net.getScore();
Gradient g2a = net.gradient(); Gradient g2a = net.gradient();
assertEquals(score2, score2a, 1e-12); assertEquals(score2, score2a, 1e-12);
for (String s : g2.gradientForVariable().keySet()) { for (String s : g2.gradientForVariable().keySet()) {
@ -327,7 +327,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
mln.setLabels(labels); mln.setLabels(labels);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double score = mln.score(); double score = mln.getScore();
assertEquals(expScore, score, 0.1, msg); assertEquals(expScore, score, 0.1, msg);
} }

View File

@ -77,7 +77,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
MultiLayerNetwork net2GradUpd = new MultiLayerNetwork(conf.clone()); MultiLayerNetwork net2GradUpd = new MultiLayerNetwork(conf.clone());
net2GradUpd.init(); net2GradUpd.init();
assertEquals(net1GradCalc.params(), net2GradUpd.params()); assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
INDArray f = Nd4j.rand(minibatch, nIn); INDArray f = Nd4j.rand(minibatch, nIn);
INDArray l = Nd4j.create(minibatch, nOut); INDArray l = Nd4j.create(minibatch, nOut);
@ -109,17 +109,17 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
//Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op //Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op
// on the original network // on the original network
net2GradUpd.params().subi(g.gradient()); net2GradUpd.getModelParams().subi(g.gradient());
net1GradCalc.fit(f, l); net1GradCalc.fit(f, l);
assertEquals(net1GradCalc.params(), net2GradUpd.params()); assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
//============================= //=============================
if (!(u instanceof Sgd)) { if (!(u instanceof Sgd)) {
net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray()); net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray());
} }
assertEquals(net1GradCalc.params(), net2GradUpd.params()); assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
assertEquals(net1GradCalc.getUpdater().getStateViewArray(), assertEquals(net1GradCalc.getUpdater().getStateViewArray(),
net2GradUpd.getUpdater().getStateViewArray()); net2GradUpd.getUpdater().getStateViewArray());
@ -130,7 +130,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
net1GradCalc.fit(f, l); net1GradCalc.fit(f, l);
net2GradUpd.fit(f, l); net2GradUpd.fit(f, l);
assertEquals(net1GradCalc.params(), net2GradUpd.params()); assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
} }
} }
} }
@ -169,7 +169,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
ComputationGraph net2GradUpd = new ComputationGraph(conf.clone()); ComputationGraph net2GradUpd = new ComputationGraph(conf.clone());
net2GradUpd.init(); net2GradUpd.init();
assertEquals(net1GradCalc.params(), net2GradUpd.params()); assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
INDArray f = Nd4j.rand(minibatch, nIn); INDArray f = Nd4j.rand(minibatch, nIn);
INDArray l = Nd4j.create(minibatch, nOut); INDArray l = Nd4j.create(minibatch, nOut);
@ -201,16 +201,16 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
//Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op //Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op
// on the original network // on the original network
net2GradUpd.params().subi(g.gradient()); net2GradUpd.getModelParams().subi(g.gradient());
net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l}); net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l});
assertEquals(net1GradCalc.params(), net2GradUpd.params()); assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
//============================= //=============================
if (!(u instanceof Sgd)) { if (!(u instanceof Sgd)) {
net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray()); net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray());
} }
assertEquals(net1GradCalc.params(), net2GradUpd.params()); assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
assertEquals(net1GradCalc.getUpdater().getStateViewArray(), assertEquals(net1GradCalc.getUpdater().getStateViewArray(),
net2GradUpd.getUpdater().getStateViewArray()); net2GradUpd.getUpdater().getStateViewArray());
@ -222,7 +222,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l}); net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l});
net2GradUpd.fit(new INDArray[] {f}, new INDArray[] {l}); net2GradUpd.fit(new INDArray[] {f}, new INDArray[] {l});
assertEquals(net1GradCalc.params(), net2GradUpd.params()); assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams());
} }
} }
} }

View File

@ -25,7 +25,6 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint; import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution; import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
@ -94,7 +93,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
ComputationGraph modelToFineTune = new ComputationGraph(expectedConf); ComputationGraph modelToFineTune = new ComputationGraph(expectedConf);
modelToFineTune.init(); modelToFineTune.init();
modelToFineTune.setParams(expectedModel.params()); modelToFineTune.setParams(expectedModel.getModelParams());
//model after applying changes with transfer learning //model after applying changes with transfer learning
ComputationGraph modelNow = ComputationGraph modelNow =
new TransferLearning.GraphBuilder(modelToFineTune) new TransferLearning.GraphBuilder(modelToFineTune)
@ -108,8 +107,8 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
//Check params after fit //Check params after fit
modelNow.fit(randomData); modelNow.fit(randomData);
expectedModel.fit(randomData); expectedModel.fit(randomData);
assertEquals(modelNow.score(), expectedModel.score(), 1e-8); assertEquals(modelNow.getScore(), expectedModel.getScore(), 1e-8);
assertEquals(modelNow.params(), expectedModel.params()); assertEquals(modelNow.getModelParams(), expectedModel.getModelParams());
} }
@Test @Test
@ -139,9 +138,9 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
//.setOutputs("layer3") //.setOutputs("layer3")
.build(); .build();
BaseLayer bl0 = ((BaseLayer) modelNow.getLayer("layer0").getLayerConfiguration()); BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getLayer("layer0").getLayerConfiguration());
BaseLayer bl1 = ((BaseLayer) modelNow.getLayer("layer1").getLayerConfiguration()); BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getLayer("layer1").getLayerConfiguration());
BaseLayer bl3 = ((BaseLayer) modelNow.getLayer("layer3").getLayerConfiguration()); BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getLayer("layer3").getLayerConfiguration());
assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1))); assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1)));
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); assertEquals(bl1.getWeightInitFn(), new WeightInitXavier());
@ -161,22 +160,22 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
modelExpectedArch.init(); modelExpectedArch.init();
//modelNow should have the same architecture as modelExpectedArch //modelNow should have the same architecture as modelExpectedArch
assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(), assertArrayEquals(modelExpectedArch.getLayer("layer0").getParams().shape(),
modelNow.getLayer("layer0").params().shape()); modelNow.getLayer("layer0").getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(), assertArrayEquals(modelExpectedArch.getLayer("layer1").getParams().shape(),
modelNow.getLayer("layer1").params().shape()); modelNow.getLayer("layer1").getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(), assertArrayEquals(modelExpectedArch.getLayer("layer2").getParams().shape(),
modelNow.getLayer("layer2").params().shape()); modelNow.getLayer("layer2").getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(), assertArrayEquals(modelExpectedArch.getLayer("layer3").getParams().shape(),
modelNow.getLayer("layer3").params().shape()); modelNow.getLayer("layer3").getParams().shape());
modelNow.setParams(modelExpectedArch.params()); modelNow.setParams(modelExpectedArch.getModelParams());
//fit should give the same results //fit should give the same results
modelExpectedArch.fit(randomData); modelExpectedArch.fit(randomData);
modelNow.fit(randomData); modelNow.fit(randomData);
assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8); assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 1e-8);
assertEquals(modelExpectedArch.params(), modelNow.params()); assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
} }
@Test @Test
@ -227,22 +226,22 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
modelExpectedArch.init(); modelExpectedArch.init();
//modelNow should have the same architecture as modelExpectedArch //modelNow should have the same architecture as modelExpectedArch
assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(), assertArrayEquals(modelExpectedArch.getLayer("layer0").getParams().shape(),
modelNow.getLayer("layer0").params().shape()); modelNow.getLayer("layer0").getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(), assertArrayEquals(modelExpectedArch.getLayer("layer1").getParams().shape(),
modelNow.getLayer("layer1").params().shape()); modelNow.getLayer("layer1").getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(), assertArrayEquals(modelExpectedArch.getLayer("layer2").getParams().shape(),
modelNow.getLayer("layer2").params().shape()); modelNow.getLayer("layer2").getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(), assertArrayEquals(modelExpectedArch.getLayer("layer3").getParams().shape(),
modelNow.getLayer("layer3").params().shape()); modelNow.getLayer("layer3").getParams().shape());
modelNow.setParams(modelExpectedArch.params()); modelNow.setParams(modelExpectedArch.getModelParams());
//fit should give the same results //fit should give the same results
modelExpectedArch.fit(randomData); modelExpectedArch.fit(randomData);
modelNow.fit(randomData); modelNow.fit(randomData);
assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8); assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 1e-8);
assertEquals(modelExpectedArch.params(), modelNow.params()); assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
} }
@Test @Test
@ -385,14 +384,14 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
assertEquals(modelExpectedArch.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson()); assertEquals(modelExpectedArch.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
modelNow.setParams(modelExpectedArch.params()); modelNow.setParams(modelExpectedArch.getModelParams());
int i = 0; int i = 0;
while (i < 5) { while (i < 5) {
modelExpectedArch.fit(randomData); modelExpectedArch.fit(randomData);
modelNow.fit(randomData); modelNow.fit(randomData);
i++; i++;
} }
assertEquals(modelExpectedArch.params(), modelNow.params()); assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
} }

View File

@ -26,10 +26,9 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
@ -99,7 +98,7 @@ public class TransferLearningComplex extends BaseDL4JTest {
} }
//Also check config: //Also check config:
BaseLayer bl = ((BaseLayer) l.getLayerConfiguration()); BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
assertEquals(new Adam(2e-2), bl.getIUpdater()); assertEquals(new Adam(2e-2), bl.getIUpdater());
assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn()); assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn());
} }
@ -154,8 +153,8 @@ public class TransferLearningComplex extends BaseDL4JTest {
.setOutputs("outRight").build(); .setOutputs("outRight").build();
ComputationGraph modelOther = new ComputationGraph(otherConf); ComputationGraph modelOther = new ComputationGraph(otherConf);
modelOther.init(); modelOther.init();
modelOther.getLayer("denseRight0").setParams(modelToTune.getLayer("denseRight0").params()); modelOther.getLayer("denseRight0").setParams(modelToTune.getLayer("denseRight0").getParams());
modelOther.getLayer("outRight").setParams(modelToTune.getLayer("outRight").params()); modelOther.getLayer("outRight").setParams(modelToTune.getLayer("outRight").getParams());
modelToTune.getVertex("denseCentre0").setLayerAsFrozen(); modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
ComputationGraph modelNow = ComputationGraph modelNow =
@ -179,11 +178,11 @@ public class TransferLearningComplex extends BaseDL4JTest {
assertEquals(otherRandData.getFeatures(0), assertEquals(otherRandData.getFeatures(0),
modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0")); modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
assertEquals(modelOther.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params()); assertEquals(modelOther.getLayer("denseRight0").getParams(), modelNow.getLayer("denseRight0").getParams());
assertEquals(modelOther.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params()); assertEquals(modelOther.getLayer("denseRight0").getParams(), modelToTune.getLayer("denseRight0").getParams());
assertEquals(modelOther.getLayer("outRight").params(), modelNow.getLayer("outRight").params()); assertEquals(modelOther.getLayer("outRight").getParams(), modelNow.getLayer("outRight").getParams());
assertEquals(modelOther.getLayer("outRight").params(), modelToTune.getLayer("outRight").params()); assertEquals(modelOther.getLayer("outRight").getParams(), modelToTune.getLayer("outRight").getParams());
n++; n++;
} }
@ -237,11 +236,11 @@ public class TransferLearningComplex extends BaseDL4JTest {
assertEquals(otherRandData.getFeatures(0), assertEquals(otherRandData.getFeatures(0),
modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0")); modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
assertEquals(modelToTune.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params()); assertEquals(modelToTune.getLayer("denseRight0").getParams(), modelNow.getLayer("denseRight0").getParams());
assertEquals(modelToTune.getLayer("outRight").params(), modelNow.getLayer("outRight").params()); assertEquals(modelToTune.getLayer("outRight").getParams(), modelNow.getLayer("outRight").getParams());
assertEquals(modelToTune.getLayer("outCentre").params(), modelNow.getLayer("outCentre").params()); assertEquals(modelToTune.getLayer("outCentre").getParams(), modelNow.getLayer("outCentre").getParams());
n++; n++;
} }

View File

@ -178,25 +178,25 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2"); TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
MultiDataSet featurizedDataSet = helper.featurize(origData); MultiDataSet featurizedDataSet = helper.featurize(origData);
assertEquals(modelIdentical.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params()); assertEquals(modelIdentical.getLayer("denseRight0").getParams(), modelToTune.getLayer("denseRight0").getParams());
modelIdentical.fit(origData); modelIdentical.fit(origData);
helper.fitFeaturized(featurizedDataSet); helper.fitFeaturized(featurizedDataSet);
assertEquals(modelIdentical.getLayer("denseCentre0").params(), modelToTune.getLayer("denseCentre0").params()); assertEquals(modelIdentical.getLayer("denseCentre0").getParams(), modelToTune.getLayer("denseCentre0").getParams());
assertEquals(modelIdentical.getLayer("denseCentre1").params(), modelToTune.getLayer("denseCentre1").params()); assertEquals(modelIdentical.getLayer("denseCentre1").getParams(), modelToTune.getLayer("denseCentre1").getParams());
assertEquals(modelIdentical.getLayer("denseCentre2").params(), modelToTune.getLayer("denseCentre2").params()); assertEquals(modelIdentical.getLayer("denseCentre2").getParams(), modelToTune.getLayer("denseCentre2").getParams());
assertEquals(modelIdentical.getLayer("denseCentre3").params(), modelToTune.getLayer("denseCentre3").params()); assertEquals(modelIdentical.getLayer("denseCentre3").getParams(), modelToTune.getLayer("denseCentre3").getParams());
assertEquals(modelIdentical.getLayer("outCentre").params(), modelToTune.getLayer("outCentre").params()); assertEquals(modelIdentical.getLayer("outCentre").getParams(), modelToTune.getLayer("outCentre").getParams());
assertEquals(modelIdentical.getLayer("denseRight").getNetConfiguration().toJson(), assertEquals(modelIdentical.getLayer("denseRight").getNetConfiguration().toJson(),
modelToTune.getLayer("denseRight").getNetConfiguration().toJson()); modelToTune.getLayer("denseRight").getNetConfiguration().toJson());
assertEquals(modelIdentical.getLayer("denseRight").params(), modelToTune.getLayer("denseRight").params()); assertEquals(modelIdentical.getLayer("denseRight").getParams(), modelToTune.getLayer("denseRight").getParams());
assertEquals(modelIdentical.getLayer("denseRight0").getNetConfiguration().toJson(), assertEquals(modelIdentical.getLayer("denseRight0").getNetConfiguration().toJson(),
modelToTune.getLayer("denseRight0").getNetConfiguration().toJson()); modelToTune.getLayer("denseRight0").getNetConfiguration().toJson());
//assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params()); //assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params());
assertEquals(modelIdentical.getLayer("denseRight1").params(), modelToTune.getLayer("denseRight1").params()); assertEquals(modelIdentical.getLayer("denseRight1").getParams(), modelToTune.getLayer("denseRight1").getParams());
assertEquals(modelIdentical.getLayer("outRight").params(), modelToTune.getLayer("outRight").params()); assertEquals(modelIdentical.getLayer("outRight").getParams(), modelToTune.getLayer("outRight").getParams());
assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params()); assertEquals(modelIdentical.getLayer("denseLeft0").getParams(), modelToTune.getLayer("denseLeft0").getParams());
assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params()); assertEquals(modelIdentical.getLayer("outLeft").getParams(), modelToTune.getLayer("outLeft").getParams());
// log.info(modelIdentical.summary()); // log.info(modelIdentical.summary());
// log.info(helper.unfrozenGraph().summary()); // log.info(helper.unfrozenGraph().summary());
@ -230,7 +230,7 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
TransferLearningHelper helper = new TransferLearningHelper(modelToFineTune, 1); TransferLearningHelper helper = new TransferLearningHelper(modelToFineTune, 1);
INDArray paramsLastTwoLayers = INDArray paramsLastTwoLayers =
Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams());
MultiLayerNetwork notFrozen = new MultiLayerNetwork( MultiLayerNetwork notFrozen = new MultiLayerNetwork(
(NeuralNetConfiguration) overallConf.clone().list() (NeuralNetConfiguration) overallConf.clone().list()
.layer(0, new Builder().nIn(2).nOut(3).build()) .layer(0, new Builder().nIn(2).nOut(3).build())
@ -248,9 +248,9 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
modelNow.fit(randomData); modelNow.fit(randomData);
} }
INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), modelToFineTune.getLayer(1).getParams(),
notFrozen.params()); notFrozen.getModelParams());
INDArray act = modelNow.params(); INDArray act = modelNow.getModelParams();
assertEquals(expected, act); assertEquals(expected, act);
} }
} }

View File

@ -91,7 +91,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
.build(); .build();
for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) { for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) {
BaseLayer bl = ((BaseLayer) l.getLayerConfiguration()); BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration());
assertEquals(new RmsProp(0.5), bl.getIUpdater()); assertEquals(new RmsProp(0.5), bl.getIUpdater());
} }
@ -107,9 +107,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
.build()) .build())
.build()); .build());
expectedModel.init(); expectedModel.init();
expectedModel.setParams(modelToFineTune.params().dup()); expectedModel.setParams(modelToFineTune.getModelParams().dup());
assertEquals(expectedModel.params(), modelNow.params()); assertEquals(expectedModel.getModelParams(), modelNow.getModelParams());
//Check json //Check json
NeuralNetConfiguration expectedConf = expectedModel.getNetConfiguration(); NeuralNetConfiguration expectedConf = expectedModel.getNetConfiguration();
@ -119,9 +119,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
modelNow.fit(randomData); modelNow.fit(randomData);
expectedModel.fit(randomData); expectedModel.fit(randomData);
assertEquals(modelNow.score(), expectedModel.score(), 1e-6); assertEquals(modelNow.getScore(), expectedModel.getScore(), 1e-6);
INDArray pExp = expectedModel.params(); INDArray pExp = expectedModel.getModelParams();
INDArray pNow = modelNow.params(); INDArray pNow = modelNow.getModelParams();
assertEquals(pExp, pNow); assertEquals(pExp, pNow);
} }
@ -160,9 +160,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
//Will fail - expected because of dist and weight init changes //Will fail - expected because of dist and weight init changes
//assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson()); //assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson());
BaseLayer bl0 = ((BaseLayer) modelNow.getNetConfiguration().getConf(0).getLayer()); BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(0).getLayer());
BaseLayer bl1 = ((BaseLayer) modelNow.getNetConfiguration().getConf(1).getLayer()); BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(1).getLayer());
BaseLayer bl3 = ((BaseLayer) modelNow.getNetConfiguration().getConf(3).getLayer()); BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(3).getLayer());
assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class); assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class);
try { try {
assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()), assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()),
@ -173,18 +173,18 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
assertEquals(bl3.getWeightInitFn(), new WeightInitXavier()); assertEquals(bl3.getWeightInitFn(), new WeightInitXavier());
//modelNow should have the same architecture as modelExpectedArch //modelNow should have the same architecture as modelExpectedArch
assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(1).getParams().shape(), modelNow.getLayer(1).getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape());
modelNow.setParams(modelExpectedArch.params()); modelNow.setParams(modelExpectedArch.getModelParams());
//fit should give the same results //fit should give the same results
modelExpectedArch.fit(randomData); modelExpectedArch.fit(randomData);
modelNow.fit(randomData); modelNow.fit(randomData);
assertEquals(modelExpectedArch.score(), modelNow.score(), 0.000001); assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 0.000001);
assertEquals(modelExpectedArch.params(), modelNow.params()); assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
} }
@ -227,20 +227,20 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
modelExpectedArch.init(); modelExpectedArch.init();
//modelNow should have the same architecture as modelExpectedArch //modelNow should have the same architecture as modelExpectedArch
assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(1).getParams().shape(), modelNow.getLayer(1).getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape());
modelNow.setParams(modelExpectedArch.params()); modelNow.setParams(modelExpectedArch.getModelParams());
//fit should give the same results //fit should give the same results
modelExpectedArch.fit(randomData); modelExpectedArch.fit(randomData);
modelNow.fit(randomData); modelNow.fit(randomData);
double scoreExpected = modelExpectedArch.score(); double scoreExpected = modelExpectedArch.getScore();
double scoreActual = modelNow.score(); double scoreActual = modelNow.getScore();
assertEquals(scoreExpected, scoreActual, 1e-4); assertEquals(scoreExpected, scoreActual, 1e-4);
assertEquals(modelExpectedArch.params(), modelNow.params()); assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams());
} }
@Test @Test
@ -370,14 +370,14 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
assertEquals(modelExpectedArch.getNetConfiguration().getConf(5).toJson(), assertEquals(modelExpectedArch.getNetConfiguration().getConf(5).toJson(),
modelNow.getNetConfiguration().getConf(5).toJson()); modelNow.getNetConfiguration().getConf(5).toJson());
assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
//subsampling has no params //subsampling has no params
//assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(4).params().shape(), modelNow.getLayer(4).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(4).getParams().shape(), modelNow.getLayer(4).getParams().shape());
assertArrayEquals(modelExpectedArch.getLayer(5).params().shape(), modelNow.getLayer(5).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(5).getParams().shape(), modelNow.getLayer(5).getParams().shape());
} }
@ -449,23 +449,23 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
.inputType(InputType.convolutionalFlat(12, 12, 20)).build()); .inputType(InputType.convolutionalFlat(12, 12, 20)).build());
notFrozen.init(); notFrozen.init();
assertArrayEquals(modelToFineTune.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); assertArrayEquals(modelToFineTune.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
//subsampling has no params //subsampling has no params
//assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
assertArrayEquals(notFrozen.getLayer(0).params().shape(), modelNow.getLayer(2).params().shape()); assertArrayEquals(notFrozen.getLayer(0).getParams().shape(), modelNow.getLayer(2).getParams().shape());
modelNow.getLayer(2).setParams(notFrozen.getLayer(0).params()); modelNow.getLayer(2).setParams(notFrozen.getLayer(0).getParams());
//subsampling has no params //subsampling has no params
//assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape()); //assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape());
assertArrayEquals(notFrozen.getLayer(2).params().shape(), modelNow.getLayer(4).params().shape()); assertArrayEquals(notFrozen.getLayer(2).getParams().shape(), modelNow.getLayer(4).getParams().shape());
modelNow.getLayer(4).setParams(notFrozen.getLayer(2).params()); modelNow.getLayer(4).setParams(notFrozen.getLayer(2).getParams());
assertArrayEquals(notFrozen.getLayer(3).params().shape(), modelNow.getLayer(5).params().shape()); assertArrayEquals(notFrozen.getLayer(3).getParams().shape(), modelNow.getLayer(5).getParams().shape());
modelNow.getLayer(5).setParams(notFrozen.getLayer(3).params()); modelNow.getLayer(5).setParams(notFrozen.getLayer(3).getParams());
assertArrayEquals(notFrozen.getLayer(4).params().shape(), modelNow.getLayer(6).params().shape()); assertArrayEquals(notFrozen.getLayer(4).getParams().shape(), modelNow.getLayer(6).getParams().shape());
modelNow.getLayer(6).setParams(notFrozen.getLayer(4).params()); modelNow.getLayer(6).setParams(notFrozen.getLayer(4).getParams());
assertArrayEquals(notFrozen.getLayer(5).params().shape(), modelNow.getLayer(7).params().shape()); assertArrayEquals(notFrozen.getLayer(5).getParams().shape(), modelNow.getLayer(7).getParams().shape());
modelNow.getLayer(7).setParams(notFrozen.getLayer(5).params()); modelNow.getLayer(7).setParams(notFrozen.getLayer(5).getParams());
assertArrayEquals(notFrozen.getLayer(6).params().shape(), modelNow.getLayer(8).params().shape()); assertArrayEquals(notFrozen.getLayer(6).getParams().shape(), modelNow.getLayer(8).getParams().shape());
modelNow.getLayer(8).setParams(notFrozen.getLayer(6).params()); modelNow.getLayer(8).setParams(notFrozen.getLayer(6).getParams());
int i = 0; int i = 0;
while (i < 3) { while (i < 3) {
@ -474,8 +474,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
i++; i++;
} }
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), notFrozen.params()); INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), notFrozen.getModelParams());
assertEquals(expectedParams, modelNow.params()); assertEquals(expectedParams, modelNow.getModelParams());
} }
@ -503,13 +503,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
//Check original net isn't modified: //Check original net isn't modified:
BaseLayer l0 = (BaseLayer) net.getLayer(0).getLayerConfiguration(); BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration();
assertEquals(new Adam(1e-4), l0.getIUpdater()); assertEquals(new Adam(1e-4), l0.getIUpdater());
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6); assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
BaseLayer l1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration();
assertEquals(new Adam(1e-4), l1.getIUpdater()); assertEquals(new Adam(1e-4), l1.getIUpdater());
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
@ -518,13 +518,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
assertEquals(BackpropType.Standard, conf.getBackpropType()); assertEquals(BackpropType.Standard, conf.getBackpropType());
//Check new net has only the appropriate things modified (i.e., LR) //Check new net has only the appropriate things modified (i.e., LR)
l0 = (BaseLayer) net2.getLayer(0).getLayerConfiguration(); l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration();
assertEquals(new Adam(2e-2), l0.getIUpdater()); assertEquals(new Adam(2e-2), l0.getIUpdater());
assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn());
assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(new WeightInitRelu(), l0.getWeightInitFn());
assertEquals(0.1, TestUtils.getL1(l0), 1e-6); assertEquals(0.1, TestUtils.getL1(l0), 1e-6);
l1 = (BaseLayer) net2.getLayer(1).getLayerConfiguration(); l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration();
assertEquals(new Adam(2e-2), l1.getIUpdater()); assertEquals(new Adam(2e-2), l1.getIUpdater());
assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn());
assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
@ -586,17 +586,17 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
.build()); .build());
notFrozen.init(); notFrozen.init();
assertArrayEquals(modelToFineTune.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); assertArrayEquals(modelToFineTune.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape());
//subsampling has no params //subsampling has no params
//assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape());
assertArrayEquals(notFrozen.getLayer(0).params().shape(), modelNow.getLayer(2).params().shape()); assertArrayEquals(notFrozen.getLayer(0).getParams().shape(), modelNow.getLayer(2).getParams().shape());
modelNow.getLayer(2).setParams(notFrozen.getLayer(0).params()); modelNow.getLayer(2).setParams(notFrozen.getLayer(0).getParams());
assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape()); assertArrayEquals(notFrozen.getLayer(1).getParams().shape(), modelNow.getLayer(3).getParams().shape());
modelNow.getLayer(3).setParams(notFrozen.getLayer(1).params()); modelNow.getLayer(3).setParams(notFrozen.getLayer(1).getParams());
assertArrayEquals(notFrozen.getLayer(2).params().shape(), modelNow.getLayer(4).params().shape()); assertArrayEquals(notFrozen.getLayer(2).getParams().shape(), modelNow.getLayer(4).getParams().shape());
modelNow.getLayer(4).setParams(notFrozen.getLayer(2).params()); modelNow.getLayer(4).setParams(notFrozen.getLayer(2).getParams());
assertArrayEquals(notFrozen.getLayer(3).params().shape(), modelNow.getLayer(5).params().shape()); assertArrayEquals(notFrozen.getLayer(3).getParams().shape(), modelNow.getLayer(5).getParams().shape());
modelNow.getLayer(5).setParams(notFrozen.getLayer(3).params()); modelNow.getLayer(5).setParams(notFrozen.getLayer(3).getParams());
int i = 0; int i = 0;
while (i < 3) { while (i < 3) {
@ -605,8 +605,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
i++; i++;
} }
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), notFrozen.params()); INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), notFrozen.getModelParams());
assertEquals(expectedParams, modelNow.params()); assertEquals(expectedParams, modelNow.getModelParams());
} }
@Test @Test

View File

@ -99,7 +99,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -144,7 +144,7 @@ public class TestUpdaters extends BaseDL4JTest {
msdx.put(key, msdxTmp); msdx.put(key, msdxTmp);
count++; count++;
} }
assertEquals(rho, ((AdaDelta)layer.layerConf().getIUpdater()).getRho(), 1e-4); assertEquals(rho, ((AdaDelta)layer.getTypedLayerConfiguration().getIUpdater()).getRho(), 1e-4);
} }
assertEquals(4, count); assertEquals(4, count);
@ -165,7 +165,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -185,7 +185,7 @@ public class TestUpdaters extends BaseDL4JTest {
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
count++; count++;
} }
assertEquals(lr, ((AdaGrad)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4); assertEquals(lr, ((AdaGrad)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
assertEquals(2, count); assertEquals(2, count);
} }
@ -209,7 +209,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -245,8 +245,8 @@ public class TestUpdaters extends BaseDL4JTest {
count++; count++;
} }
assertEquals(beta1, ((Adam)layer.layerConf().getIUpdater()).getBeta1(), 1e-4); assertEquals(beta1, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4);
assertEquals(beta2, ((Adam)layer.layerConf().getIUpdater()).getBeta2(), 1e-4); assertEquals(beta2, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4);
assertEquals(2, count); assertEquals(2, count);
} }
@ -273,7 +273,7 @@ public class TestUpdaters extends BaseDL4JTest {
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -362,7 +362,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -398,8 +398,8 @@ public class TestUpdaters extends BaseDL4JTest {
count++; count++;
} }
assertEquals(beta1, ((AdaMax)layer.layerConf().getIUpdater()).getBeta1(), 1e-4); assertEquals(beta1, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4);
assertEquals(beta2, ((AdaMax)layer.layerConf().getIUpdater()).getBeta2(), 1e-4); assertEquals(beta2, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4);
assertEquals(2, count); assertEquals(2, count);
} }
@ -418,7 +418,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -443,7 +443,7 @@ public class TestUpdaters extends BaseDL4JTest {
count++; count++;
} }
assertEquals(mu, ((Nesterovs)layer.layerConf().getIUpdater()).getMomentum(), 1e-4); assertEquals(mu, ((Nesterovs)layer.getTypedLayerConfiguration().getIUpdater()).getMomentum(), 1e-4);
assertEquals(2, count); assertEquals(2, count);
} }
@ -465,7 +465,7 @@ public class TestUpdaters extends BaseDL4JTest {
BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams);
INDArray updaterState = Nd4j.create(1, updaterStateSize); INDArray updaterState = Nd4j.create(1, updaterStateSize);
updater.setStateViewArray(layer, updaterState, true); updater.setStateViewArray(layer, updaterState, true);
@ -495,7 +495,7 @@ public class TestUpdaters extends BaseDL4JTest {
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
lastG.put(key, lastGTmp); lastG.put(key, lastGTmp);
} }
assertEquals(rmsDecay, ((RmsProp)layer.layerConf().getIUpdater()).getRmsDecay(), 1e-4); assertEquals(rmsDecay, ((RmsProp)layer.getTypedLayerConfiguration().getIUpdater()).getRmsDecay(), 1e-4);
} }
@Test @Test
@ -527,7 +527,7 @@ public class TestUpdaters extends BaseDL4JTest {
gradExpected = val.mul(lr); gradExpected = val.mul(lr);
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
} }
assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4); assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
} }
@ -769,7 +769,7 @@ public class TestUpdaters extends BaseDL4JTest {
gradExpected = val.mul(lr); gradExpected = val.mul(lr);
assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
} }
assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4); assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
//Test with pretrain == false //Test with pretrain == false
@ -797,7 +797,7 @@ public class TestUpdaters extends BaseDL4JTest {
layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(gradients); layer.setBackpropGradientsViewArray(gradients);
updater = UpdaterCreator.getUpdater(layer); updater = UpdaterCreator.getUpdater(layer);
assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4); assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4);
} }
@Test @Test
@ -858,11 +858,11 @@ public class TestUpdaters extends BaseDL4JTest {
//Check first updater block: //Check first updater block:
UpdaterBlock ub0 = blocks.get(0); UpdaterBlock ub0 = blocks.get(0);
assertEquals(3, ub0.getLayersAndVariablesInBlock().size()); assertEquals(3, ub0.getLayersAndVariablesInBlock().size());
assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName()); assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(0).getParamName()); assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(0).getParamName());
assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName()); assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
assertEquals(DefaultParamInitializer.BIAS_KEY, ub0.getLayersAndVariablesInBlock().get(1).getParamName()); assertEquals(DefaultParamInitializer.BIAS_KEY, ub0.getLayersAndVariablesInBlock().get(1).getParamName());
assertEquals("l1", ub0.getLayersAndVariablesInBlock().get(2).getLayer().getConfig().getLayerName()); assertEquals("l1", ub0.getLayersAndVariablesInBlock().get(2).getLayer().getTrainingConfig().getLayerName());
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(2).getParamName()); assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(2).getParamName());
int nParams0 = 10 * 10 + 10 + 10 * 10; int nParams0 = 10 * 10 + 10 + 10 * 10;
@ -875,7 +875,7 @@ public class TestUpdaters extends BaseDL4JTest {
//Check second updater block: //Check second updater block:
UpdaterBlock ub1 = blocks.get(1); UpdaterBlock ub1 = blocks.get(1);
assertEquals(1, ub1.getLayersAndVariablesInBlock().size()); assertEquals(1, ub1.getLayersAndVariablesInBlock().size());
assertEquals("l1", ub1.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName()); assertEquals("l1", ub1.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
assertEquals(DefaultParamInitializer.BIAS_KEY, ub1.getLayersAndVariablesInBlock().get(0).getParamName()); assertEquals(DefaultParamInitializer.BIAS_KEY, ub1.getLayersAndVariablesInBlock().get(0).getParamName());
int nParams1 = 10; int nParams1 = 10;
@ -888,9 +888,9 @@ public class TestUpdaters extends BaseDL4JTest {
//Check third updater block: //Check third updater block:
UpdaterBlock ub2 = blocks.get(2); UpdaterBlock ub2 = blocks.get(2);
assertEquals(2, ub2.getLayersAndVariablesInBlock().size()); assertEquals(2, ub2.getLayersAndVariablesInBlock().size());
assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName()); assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub2.getLayersAndVariablesInBlock().get(0).getParamName()); assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub2.getLayersAndVariablesInBlock().get(0).getParamName());
assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName()); assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
assertEquals(DefaultParamInitializer.BIAS_KEY, ub2.getLayersAndVariablesInBlock().get(1).getParamName()); assertEquals(DefaultParamInitializer.BIAS_KEY, ub2.getLayersAndVariablesInBlock().get(1).getParamName());
int nParams2 = 10 * 10 + 10; int nParams2 = 10 * 10 + 10;
@ -903,9 +903,9 @@ public class TestUpdaters extends BaseDL4JTest {
//Check fourth updater block: //Check fourth updater block:
UpdaterBlock ub3 = blocks.get(3); UpdaterBlock ub3 = blocks.get(3);
assertEquals(2, ub3.getLayersAndVariablesInBlock().size()); assertEquals(2, ub3.getLayersAndVariablesInBlock().size());
assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName()); assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub3.getLayersAndVariablesInBlock().get(0).getParamName()); assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub3.getLayersAndVariablesInBlock().get(0).getParamName());
assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName()); assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
assertEquals(DefaultParamInitializer.BIAS_KEY, ub3.getLayersAndVariablesInBlock().get(1).getParamName()); assertEquals(DefaultParamInitializer.BIAS_KEY, ub3.getLayersAndVariablesInBlock().get(1).getParamName());
int nParams3 = 10 * 10 + 10; int nParams3 = 10 * 10 + 10;
@ -918,9 +918,9 @@ public class TestUpdaters extends BaseDL4JTest {
//Check fifth updater black //Check fifth updater black
UpdaterBlock ub4 = blocks.get(4); UpdaterBlock ub4 = blocks.get(4);
assertEquals(2, ub4.getLayersAndVariablesInBlock().size()); assertEquals(2, ub4.getLayersAndVariablesInBlock().size());
assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName()); assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName());
assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub4.getLayersAndVariablesInBlock().get(0).getParamName()); assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub4.getLayersAndVariablesInBlock().get(0).getParamName());
assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName()); assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName());
assertEquals(DefaultParamInitializer.BIAS_KEY, ub4.getLayersAndVariablesInBlock().get(1).getParamName()); assertEquals(DefaultParamInitializer.BIAS_KEY, ub4.getLayersAndVariablesInBlock().get(1).getParamName());
int nParams4 = 10 * 10 + 10; int nParams4 = 10 * 10 + 10;

View File

@ -22,7 +22,7 @@ package org.deeplearning4j.nn.updater.custom;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@ -61,18 +61,18 @@ public class TestCustomUpdater extends BaseDL4JTest {
.build(); .build();
//First: Check updater config //First: Check updater config
assertTrue(((BaseLayer) conf1.getConf(0).getLayer()).getIUpdater() instanceof CustomIUpdater); assertTrue(((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater() instanceof CustomIUpdater);
assertTrue(((BaseLayer) conf1.getConf(1).getLayer()).getIUpdater() instanceof CustomIUpdater); assertTrue(((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater() instanceof CustomIUpdater);
assertTrue(((BaseLayer) conf2.getConf(0).getLayer()).getIUpdater() instanceof Sgd); assertTrue(((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater() instanceof Sgd);
assertTrue(((BaseLayer) conf2.getConf(1).getLayer()).getIUpdater() instanceof Sgd); assertTrue(((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater() instanceof Sgd);
CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayer) conf1.getConf(0).getLayer()).getIUpdater(); CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater();
CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayer) conf1.getConf(1).getLayer()).getIUpdater(); CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater();
assertEquals(lr, u0_0.getLearningRate(), 1e-6); assertEquals(lr, u0_0.getLearningRate(), 1e-6);
assertEquals(lr, u0_1.getLearningRate(), 1e-6); assertEquals(lr, u0_1.getLearningRate(), 1e-6);
Sgd u1_0 = (Sgd) ((BaseLayer) conf2.getConf(0).getLayer()).getIUpdater(); Sgd u1_0 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater();
Sgd u1_1 = (Sgd) ((BaseLayer) conf2.getConf(1).getLayer()).getIUpdater(); Sgd u1_1 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater();
assertEquals(lr, u1_0.getLearningRate(), 1e-6); assertEquals(lr, u1_0.getLearningRate(), 1e-6);
assertEquals(lr, u1_1.getLearningRate(), 1e-6); assertEquals(lr, u1_1.getLearningRate(), 1e-6);

View File

@ -81,7 +81,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, layer.getOptimizer()); BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, layer.getOptimizer());
double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
assertEquals(1.0, step, 1e-3); assertEquals(1.0, step, 1e-3);
} }
@ -97,11 +97,11 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
layer.setLabels(irisData.getLabels()); layer.setLabels(irisData.getLabels());
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
score1 = layer.score(); score1 = layer.getScore();
BackTrackLineSearch lineSearch = BackTrackLineSearch lineSearch =
new BackTrackLineSearch(layer, new NegativeDefaultStepFunction(), layer.getOptimizer()); new BackTrackLineSearch(layer, new NegativeDefaultStepFunction(), layer.getOptimizer());
double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
assertEquals(1.0, step, 1e-3); assertEquals(1.0, step, 1e-3);
} }
@ -118,18 +118,18 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
layer.setLabels(irisData.getLabels()); layer.setLabels(irisData.getLabels());
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
score1 = layer.score(); score1 = layer.getScore();
INDArray origGradient = layer.gradient().gradient().dup(); INDArray origGradient = layer.gradient().gradient().dup();
NegativeDefaultStepFunction sf = new NegativeDefaultStepFunction(); NegativeDefaultStepFunction sf = new NegativeDefaultStepFunction();
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer()); BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable());
INDArray currParams = layer.params(); INDArray currParams = layer.getModelParams();
sf.step(currParams, origGradient, step); sf.step(currParams, origGradient, step);
layer.setParamsTable(currParams); layer.setParamsTable(currParams);
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
score2 = layer.score(); score2 = layer.getScore();
assertTrue(score1 > score2, "score1=" + score1 + ", score2=" + score2); assertTrue(score1 > score2, "score1=" + score1 + ", score2=" + score2);
@ -146,19 +146,19 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
layer.setLabels(irisData.getLabels()); layer.setLabels(irisData.getLabels());
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
score1 = layer.score(); score1 = layer.getScore();
INDArray origGradient = layer.gradient().gradient().dup(); INDArray origGradient = layer.gradient().gradient().dup();
DefaultStepFunction sf = new DefaultStepFunction(); DefaultStepFunction sf = new DefaultStepFunction();
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer()); BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
double step = lineSearch.optimize(layer.params().dup(), layer.gradient().gradient().dup(), double step = lineSearch.optimize(layer.getModelParams().dup(), layer.gradient().gradient().dup(),
layer.gradient().gradient().dup(), LayerWorkspaceMgr.noWorkspacesImmutable()); layer.gradient().gradient().dup(), LayerWorkspaceMgr.noWorkspacesImmutable());
INDArray currParams = layer.params(); INDArray currParams = layer.getModelParams();
sf.step(currParams, origGradient, step); sf.step(currParams, origGradient, step);
layer.setParamsTable(currParams); layer.setParamsTable(currParams);
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
score2 = layer.score(); score2 = layer.getScore();
assertTrue(score1 < score2, "score1 = " + score1 + ", score2 = " + score2); assertTrue(score1 < score2, "score1 = " + score1 + ", score2 = " + score2);
} }
@ -190,12 +190,12 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer)); MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer));
network.init(); network.init();
TrainingListener listener = new ScoreIterationListener(10); TrainingListener listener = new ScoreIterationListener(10);
network.setListeners(Collections.singletonList(listener)); network.addTrainingListeners(Collections.singletonList(listener));
double oldScore = network.score(data); double oldScore = network.score(data);
for( int i=0; i<100; i++ ) { for( int i=0; i<100; i++ ) {
network.fit(data.getFeatures(), data.getLabels()); network.fit(data.getFeatures(), data.getLabels());
} }
double score = network.score(); double score = network.getScore();
assertTrue(score < oldScore); assertTrue(score < oldScore);
} }
@ -208,13 +208,13 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer)); MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
network.init(); network.init();
TrainingListener listener = new ScoreIterationListener(10); TrainingListener listener = new ScoreIterationListener(10);
network.setListeners(Collections.singletonList(listener)); network.addTrainingListeners(Collections.singletonList(listener));
double firstScore = network.score(data); double firstScore = network.score(data);
for( int i=0; i<5; i++ ) { for( int i=0; i<5; i++ ) {
network.fit(data.getFeatures(), data.getLabels()); network.fit(data.getFeatures(), data.getLabels());
} }
double score = network.score(); double score = network.getScore();
assertTrue(score < firstScore); assertTrue(score < firstScore);
} }
@ -227,13 +227,13 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer)); MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
network.init(); network.init();
TrainingListener listener = new ScoreIterationListener(10); TrainingListener listener = new ScoreIterationListener(10);
network.setListeners(Collections.singletonList(listener)); network.addTrainingListeners(Collections.singletonList(listener));
double oldScore = network.score(data); double oldScore = network.score(data);
for( int i=0; i<5; i++ ) { for( int i=0; i<5; i++ ) {
network.fit(data.getFeatures(), data.getLabels()); network.fit(data.getFeatures(), data.getLabels());
} }
double score = network.score(); double score = network.getScore();
assertTrue(score < oldScore); assertTrue(score < oldScore);
} }

View File

@ -28,6 +28,7 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.*; import org.deeplearning4j.nn.api.*;
import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
@ -211,38 +212,38 @@ public class TestOptimizers extends BaseDL4JTest {
System.out.println("---------\n Alg= " + oa + ", nIter= " + numLineSearchIter + ", nDimensions= " System.out.println("---------\n Alg= " + oa + ", nIter= " + numLineSearchIter + ", nDimensions= "
+ nDimensions); + nDimensions);
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().maxNumLineSearchIterations(numLineSearchIter) LayerConfiguration conf = NeuralNetConfiguration.builder().maxNumLineSearchIterations(numLineSearchIter)
.updater(new Sgd(1e-2)) .updater(new Sgd(1e-2))
.layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build(); .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build().getFlattenedLayerConfigurations().get(0);
conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here conf.addVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
Random rng = new DefaultRandom(12345L); Random rng = new DefaultRandom(12345L);
org.nd4j.linalg.api.rng.distribution.Distribution dist = org.nd4j.linalg.api.rng.distribution.Distribution dist =
new org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution(rng, -10, 10); new org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution(rng, -10, 10);
IModel m = new SphereFunctionModel(nDimensions, dist, conf); IModel m = new SphereFunctionModel(nDimensions, dist, conf);
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
double scoreBefore = m.score(); double scoreBefore = m.getScore();
assertTrue(!Double.isNaN(scoreBefore) && !Double.isInfinite(scoreBefore)); assertTrue(!Double.isNaN(scoreBefore) && !Double.isInfinite(scoreBefore));
if (PRINT_OPT_RESULTS) { if (PRINT_OPT_RESULTS) {
System.out.println("Before:"); System.out.println("Before:");
System.out.println(scoreBefore); System.out.println(scoreBefore);
System.out.println(m.params()); System.out.println(m.getModelParams());
} }
ConvexOptimizer opt = getOptimizer(oa, conf, m); ConvexOptimizer opt = getOptimizer(oa, conf.getNetConfiguration(), m);
opt.setupSearchState(m.gradientAndScore()); opt.setupSearchState(m.gradientAndScore());
for( int i=0; i<100; i++ ) { for( int i=0; i<100; i++ ) {
opt.optimize(LayerWorkspaceMgr.noWorkspaces()); opt.optimize(LayerWorkspaceMgr.noWorkspaces());
} }
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
double scoreAfter = m.score(); double scoreAfter = m.getScore();
assertTrue(!Double.isNaN(scoreAfter) && !Double.isInfinite(scoreAfter)); assertTrue(!Double.isNaN(scoreAfter) && !Double.isInfinite(scoreAfter));
if (PRINT_OPT_RESULTS) { if (PRINT_OPT_RESULTS) {
System.out.println("After:"); System.out.println("After:");
System.out.println(scoreAfter); System.out.println(scoreAfter);
System.out.println(m.params()); System.out.println(m.getModelParams());
} }
//Expected behaviour after optimization: //Expected behaviour after optimization:
@ -279,17 +280,17 @@ public class TestOptimizers extends BaseDL4JTest {
.layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build(); .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
IModel m = new SphereFunctionModel(100, dist, conf); IModel m = new SphereFunctionModel(100, dist, conf.getFlattenedLayerConfigurations().get(0));
if (i == 0) { if (i == 0) {
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
scores[0] = m.score(); //Before optimization scores[0] = m.getScore(); //Before optimization
} else { } else {
ConvexOptimizer opt = getOptimizer(oa, conf, m); ConvexOptimizer opt = getOptimizer(oa, conf, m);
for( int j=0; j<100; j++ ) { for( int j=0; j<100; j++ ) {
opt.optimize(LayerWorkspaceMgr.noWorkspaces()); opt.optimize(LayerWorkspaceMgr.noWorkspaces());
} }
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
scores[i] = m.score(); scores[i] = m.getScore();
assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i])); assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
} }
} }
@ -316,7 +317,7 @@ public class TestOptimizers extends BaseDL4JTest {
private static final long serialVersionUID = -6963606137417355405L; private static final long serialVersionUID = -6963606137417355405L;
private SphereFunctionModel(int nParams, org.nd4j.linalg.api.rng.distribution.Distribution distribution, private SphereFunctionModel(int nParams, org.nd4j.linalg.api.rng.distribution.Distribution distribution,
NeuralNetConfiguration conf) { LayerConfiguration conf) {
super(distribution.sample(new int[] {1, nParams}), conf); super(distribution.sample(new int[] {1, nParams}), conf);
} }
@ -437,7 +438,7 @@ public class TestOptimizers extends BaseDL4JTest {
} }
@Override @Override
public void setListeners(TrainingListener... listeners) { public void addTrainingListeners(TrainingListener... listeners) {
} }
@ -499,17 +500,17 @@ public class TestOptimizers extends BaseDL4JTest {
.layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build(); .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build();
conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
IModel m = new RastriginFunctionModel(10, conf); IModel m = new RastriginFunctionModel(10, conf.getFlattenedLayerConfigurations().get(0));
int nParams = (int)m.numParams(); int nParams = (int)m.numParams();
if (i == 0) { if (i == 0) {
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
scores[0] = m.score(); //Before optimization scores[0] = m.getScore(); //Before optimization
} else { } else {
ConvexOptimizer opt = getOptimizer(oa, conf, m); ConvexOptimizer opt = getOptimizer(oa, conf, m);
opt.getUpdater().setStateViewArray((Layer) m, Nd4j.create(new int[] {1, nParams}, 'c'), true); opt.getUpdater().setStateViewArray((Layer) m, Nd4j.create(new int[] {1, nParams}, 'c'), true);
opt.optimize(LayerWorkspaceMgr.noWorkspaces()); opt.optimize(LayerWorkspaceMgr.noWorkspaces());
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
scores[i] = m.score(); scores[i] = m.getScore();
assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i])); assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]));
} }
} }
@ -540,7 +541,7 @@ public class TestOptimizers extends BaseDL4JTest {
private static class RastriginFunctionModel extends SimpleOptimizableModel { private static class RastriginFunctionModel extends SimpleOptimizableModel {
private static final long serialVersionUID = -1772954508787487941L; private static final long serialVersionUID = -1772954508787487941L;
private RastriginFunctionModel(int nDimensions, NeuralNetConfiguration conf) { private RastriginFunctionModel(int nDimensions, LayerConfiguration conf) {
super(initParams(nDimensions), conf); super(initParams(nDimensions), conf);
} }
@ -710,7 +711,7 @@ public class TestOptimizers extends BaseDL4JTest {
} }
@Override @Override
public void setListeners(TrainingListener... listeners) { public void addTrainingListeners(TrainingListener... listeners) {
} }
@ -768,15 +769,15 @@ public class TestOptimizers extends BaseDL4JTest {
.build(); .build();
conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here
IModel m = new RosenbrockFunctionModel(100, conf); IModel m = new RosenbrockFunctionModel(100, conf.getFlattenedLayerConfigurations().get(0));
if (i == 0) { if (i == 0) {
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
scores[0] = m.score(); //Before optimization scores[0] = m.getScore(); //Before optimization
} else { } else {
ConvexOptimizer opt = getOptimizer(oa, conf, m); ConvexOptimizer opt = getOptimizer(oa, conf, m);
opt.optimize(LayerWorkspaceMgr.noWorkspaces()); opt.optimize(LayerWorkspaceMgr.noWorkspaces());
m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
scores[i] = m.score(); scores[i] = m.getScore();
assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]), "NaN or infinite score: " + scores[i]); assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]), "NaN or infinite score: " + scores[i]);
} }
} }
@ -810,7 +811,7 @@ public class TestOptimizers extends BaseDL4JTest {
private static class RosenbrockFunctionModel extends SimpleOptimizableModel { private static class RosenbrockFunctionModel extends SimpleOptimizableModel {
private static final long serialVersionUID = -5129494342531033706L; private static final long serialVersionUID = -5129494342531033706L;
private RosenbrockFunctionModel(int nDimensions, NeuralNetConfiguration conf) { private RosenbrockFunctionModel(int nDimensions, LayerConfiguration conf) {
super(initParams(nDimensions), conf); super(initParams(nDimensions), conf);
} }
@ -995,7 +996,7 @@ public class TestOptimizers extends BaseDL4JTest {
} }
@Override @Override
public void setListeners(TrainingListener... listeners) { public void addTrainingListeners(TrainingListener... listeners) {
} }
@ -1029,13 +1030,31 @@ public class TestOptimizers extends BaseDL4JTest {
private static final long serialVersionUID = 4409380971404019303L; private static final long serialVersionUID = 4409380971404019303L;
protected INDArray parameters; protected INDArray parameters;
protected INDArray gradientView; protected INDArray gradientView;
protected final NeuralNetConfiguration conf; protected final LayerConfiguration conf;
protected Gradient gradient; protected Gradient gradient;
protected double score; protected double score;
/**
* @return 1d parameter vector
*/
@Override
public INDArray getParams() {
throw new RuntimeException("Not implemented");
}
/**
* Get a reference to the network this layer is part of.
*
* @return
*/
@Override
public IModel getNet() {
throw new RuntimeException("Not implemented");
}
/**@param parameterInit Initial parameters. Also determines dimensionality of problem. Should be row vector. /**@param parameterInit Initial parameters. Also determines dimensionality of problem. Should be row vector.
*/ */
private SimpleOptimizableModel(INDArray parameterInit, NeuralNetConfiguration conf) { private SimpleOptimizableModel(INDArray parameterInit, LayerConfiguration conf) {
this.parameters = parameterInit.dup(); this.parameters = parameterInit.dup();
this.gradientView = Nd4j.create(parameterInit.shape()); this.gradientView = Nd4j.create(parameterInit.shape());
this.conf = conf; this.conf = conf;
@ -1048,17 +1067,12 @@ public class TestOptimizers extends BaseDL4JTest {
*/ */
@Override @Override
public LayerConfiguration getLayerConfiguration() { public LayerConfiguration getLayerConfiguration() {
return this.conf.getFirstLayer(); return this.conf;
} }
@Override @Override
public void addListeners(TrainingListener... listener) { public ITraininableLayerConfiguration getTrainingConfig() {
// no-op return (BaseLayerConfiguration) conf;
}
@Override
public TrainingConfig getConfig() {
return conf.getFirstLayer();
} }
/** /**
@ -1092,7 +1106,7 @@ public class TestOptimizers extends BaseDL4JTest {
} }
@Override @Override
public void setListeners(TrainingListener... listeners) { public void addTrainingListeners(TrainingListener... listeners) {
} }
@ -1112,7 +1126,7 @@ public class TestOptimizers extends BaseDL4JTest {
} }
@Override @Override
public double score() { public double getScore() {
return score; return score;
} }
@ -1132,7 +1146,7 @@ public class TestOptimizers extends BaseDL4JTest {
} }
@Override @Override
public INDArray params() { public INDArray getModelParams() {
return parameters; return parameters;
} }
@ -1154,7 +1168,7 @@ public class TestOptimizers extends BaseDL4JTest {
@Override @Override
public Pair<Gradient, Double> gradientAndScore() { public Pair<Gradient, Double> gradientAndScore() {
computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
return new Pair<>(gradient(), score()); return new Pair<>(gradient(), getScore());
} }
@Override @Override
@ -1164,7 +1178,7 @@ public class TestOptimizers extends BaseDL4JTest {
@Override @Override
public NeuralNetConfiguration getNetConfiguration() { public NeuralNetConfiguration getNetConfiguration() {
return conf; return conf.getNetConfiguration();
} }
@Override @Override
@ -1225,12 +1239,12 @@ public class TestOptimizers extends BaseDL4JTest {
} }
@Override @Override
public Collection<TrainingListener> getListeners() { public Collection<TrainingListener> getTrainingListeners() {
return null; return null;
} }
@Override @Override
public void setListeners(Collection<TrainingListener> listeners) { public void addTrainingListeners(Collection<TrainingListener> listeners) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@ -1310,4 +1324,6 @@ public class TestOptimizers extends BaseDL4JTest {
public void close(){ public void close(){
} }
} }
} }

View File

@ -76,7 +76,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
.keepAll() .keepAll()
.saveEveryNEpochs(2) .saveEveryNEpochs(2)
.build(); .build();
net.setListeners(l); net.addTrainingListeners(l);
for(int i=0; i<10; i++ ){ for(int i=0; i<10; i++ ){
net.fit(iter); net.fit(iter);
@ -125,7 +125,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
.keepLast(3) .keepLast(3)
.saveEveryNIterations(5) .saveEveryNIterations(5)
.build(); .build();
net.setListeners(l); net.addTrainingListeners(l);
for(int i=0; i<20; i++ ){ //40 iterations total for(int i=0; i<20; i++ ){ //40 iterations total
net.fit(iter); net.fit(iter);
@ -167,7 +167,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
MultiLayerNetwork netStatic2 = CheckpointListener.loadLastCheckpointMLN(f); MultiLayerNetwork netStatic2 = CheckpointListener.loadLastCheckpointMLN(f);
assertEquals(35, netStatic2.getIterationCount()); assertEquals(35, netStatic2.getIterationCount());
assertEquals(netStatic.params(), netStatic2.params()); assertEquals(netStatic.getModelParams(), netStatic2.getModelParams());
} }
@Test @Test
@ -182,7 +182,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
.keepLast(3) .keepLast(3)
.saveEvery(4900, TimeUnit.MILLISECONDS) .saveEvery(4900, TimeUnit.MILLISECONDS)
.build(); .build();
net.setListeners(l); net.addTrainingListeners(l);
for(int i=0; i<3; i++ ){ //10 iterations total for(int i=0; i<3; i++ ){ //10 iterations total
net.fit(iter); net.fit(iter);
@ -226,7 +226,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
.keepLastAndEvery(3, 3) .keepLastAndEvery(3, 3)
.saveEveryNEpochs(2) .saveEveryNEpochs(2)
.build(); .build();
net.setListeners(l); net.addTrainingListeners(l);
for(int i=0; i<20; i++ ){ //40 iterations total for(int i=0; i<20; i++ ){ //40 iterations total
net.fit(iter); net.fit(iter);
@ -272,7 +272,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
.keepAll() .keepAll()
.saveEveryNEpochs(1) .saveEveryNEpochs(1)
.build(); .build();
net.setListeners(l); net.addTrainingListeners(l);
for(int i=0; i<3; i++ ){ for(int i=0; i<3; i++ ){
net.fit(iter); net.fit(iter);
@ -294,7 +294,7 @@ public class TestCheckpointListener extends BaseDL4JTest {
.saveEveryNEpochs(1) .saveEveryNEpochs(1)
.deleteExisting(true) .deleteExisting(true)
.build(); .build();
net.setListeners(l); net.addTrainingListeners(l);
net.fit(iter); net.fit(iter);

View File

@ -58,7 +58,7 @@ public class TestFailureListener extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
net.setListeners(new FailureTestingListener( net.addTrainingListeners(new FailureTestingListener(
// FailureTestingListener.FailureMode.OOM, // FailureTestingListener.FailureMode.OOM,
FailureTestingListener.FailureMode.SYSTEM_EXIT_1, FailureTestingListener.FailureMode.SYSTEM_EXIT_1,
new FailureTestingListener.IterationEpochTrigger(false, 10))); new FailureTestingListener.IterationEpochTrigger(false, 10)));
@ -84,7 +84,7 @@ public class TestFailureListener extends BaseDL4JTest {
assertNotNull(username); assertNotNull(username);
assertFalse(username.isEmpty()); assertFalse(username.isEmpty());
net.setListeners(new FailureTestingListener( net.addTrainingListeners(new FailureTestingListener(
FailureTestingListener.FailureMode.SYSTEM_EXIT_1, FailureTestingListener.FailureMode.SYSTEM_EXIT_1,
new FailureTestingListener.Or( new FailureTestingListener.Or(
new FailureTestingListener.IterationEpochTrigger(false, 10000), new FailureTestingListener.IterationEpochTrigger(false, 10000),
@ -112,7 +112,7 @@ public class TestFailureListener extends BaseDL4JTest {
assertNotNull(hostname); assertNotNull(hostname);
assertFalse(hostname.isEmpty()); assertFalse(hostname.isEmpty());
net.setListeners(new FailureTestingListener( net.addTrainingListeners(new FailureTestingListener(
FailureTestingListener.FailureMode.ILLEGAL_STATE, FailureTestingListener.FailureMode.ILLEGAL_STATE,
new FailureTestingListener.And( new FailureTestingListener.And(
new FailureTestingListener.HostNameTrigger(hostname), new FailureTestingListener.HostNameTrigger(hostname),

View File

@ -77,17 +77,17 @@ public class TestListeners extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
net.setListeners(new ScoreIterationListener(), new TestRoutingListener()); net.addTrainingListeners(new ScoreIterationListener(), new TestRoutingListener());
for (Layer l : net.getLayers()) { for (Layer l : net.getLayers()) {
Collection<TrainingListener> layerListeners = l.getListeners(); Collection<TrainingListener> layerListeners = l.getTrainingListeners();
assertEquals(2, layerListeners.size(), l.getClass().toString()); assertEquals(2, layerListeners.size(), l.getClass().toString());
TrainingListener[] lArr = layerListeners.toArray(new TrainingListener[2]); TrainingListener[] lArr = layerListeners.toArray(new TrainingListener[2]);
assertTrue(lArr[0] instanceof ScoreIterationListener); assertTrue(lArr[0] instanceof ScoreIterationListener);
assertTrue(lArr[1] instanceof TestRoutingListener); assertTrue(lArr[1] instanceof TestRoutingListener);
} }
Collection<TrainingListener> netListeners = net.getListeners(); Collection<TrainingListener> netListeners = net.getTrainingListeners();
assertEquals(2, netListeners.size()); assertEquals(2, netListeners.size());
TrainingListener[] lArr = netListeners.toArray(new TrainingListener[2]); TrainingListener[] lArr = netListeners.toArray(new TrainingListener[2]);
assertTrue(lArr[0] instanceof ScoreIterationListener); assertTrue(lArr[0] instanceof ScoreIterationListener);
@ -101,17 +101,17 @@ public class TestListeners extends BaseDL4JTest {
ComputationGraph cg = new ComputationGraph(gConf); ComputationGraph cg = new ComputationGraph(gConf);
cg.init(); cg.init();
cg.setListeners(new ScoreIterationListener(), new TestRoutingListener()); cg.addTrainingListeners(new ScoreIterationListener(), new TestRoutingListener());
for (Layer l : cg.getLayers()) { for (Layer l : cg.getLayers()) {
Collection<TrainingListener> layerListeners = l.getListeners(); Collection<TrainingListener> layerListeners = l.getTrainingListeners();
assertEquals(2, layerListeners.size()); assertEquals(2, layerListeners.size());
lArr = layerListeners.toArray(new TrainingListener[2]); lArr = layerListeners.toArray(new TrainingListener[2]);
assertTrue(lArr[0] instanceof ScoreIterationListener); assertTrue(lArr[0] instanceof ScoreIterationListener);
assertTrue(lArr[1] instanceof TestRoutingListener); assertTrue(lArr[1] instanceof TestRoutingListener);
} }
netListeners = cg.getListeners(); netListeners = cg.getTrainingListeners();
assertEquals(2, netListeners.size()); assertEquals(2, netListeners.size());
lArr = netListeners.toArray(new TrainingListener[2]); lArr = netListeners.toArray(new TrainingListener[2]);
assertTrue(lArr[0] instanceof ScoreIterationListener); assertTrue(lArr[0] instanceof ScoreIterationListener);
@ -180,7 +180,7 @@ public class TestListeners extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
net.setListeners(listeners); net.addTrainingListeners(listeners);
net.fit(iter); net.fit(iter);
@ -199,7 +199,7 @@ public class TestListeners extends BaseDL4JTest {
listeners2.add(il2); listeners2.add(il2);
} }
net.setListeners(listeners2); net.addTrainingListeners(listeners2);
net.fit(iter); net.fit(iter);
} }
@ -216,7 +216,7 @@ public class TestListeners extends BaseDL4JTest {
net.init(); net.init();
TestListener tl = new TestListener(); TestListener tl = new TestListener();
net.setListeners(tl); net.addTrainingListeners(tl);
DataSetIterator irisIter = new IrisDataSetIterator(50, 150); DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
@ -260,7 +260,7 @@ public class TestListeners extends BaseDL4JTest {
tl = new TestListener(); tl = new TestListener();
ComputationGraph cg = net.toComputationGraph(); ComputationGraph cg = net.toComputationGraph();
cg.setListeners(tl); cg.addTrainingListeners(tl);
cg.fit(irisIter, 2); cg.fit(irisIter, 2);

View File

@ -94,7 +94,7 @@ public class RandomTests extends BaseDL4JTest {
// at the end of day, model params has to // at the end of day, model params has to
for (int i = 0; i < models.size(); i++) { for (int i = 0; i < models.size(); i++) {
assertEquals(models.get(0).params(), models.get(i).params()); assertEquals(models.get(0).getModelParams(), models.get(i).getModelParams());
} }
} }
@ -119,7 +119,7 @@ public class RandomTests extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf); MultiLayerNetwork net2 = new MultiLayerNetwork(conf);
net2.init(); net2.init();
assertEquals(net1.params(), net2.params()); assertEquals(net1.getModelParams(), net2.getModelParams());
NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json); NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json);
@ -127,6 +127,6 @@ public class RandomTests extends BaseDL4JTest {
MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson); MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson);
net3.init(); net3.init();
assertEquals(net1.params(), net3.params()); assertEquals(net1.getModelParams(), net3.getModelParams());
} }
} }

View File

@ -63,7 +63,7 @@ public class TestSystemInfoPrintListener extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
net.setListeners(systemInfoFilePrintListener); net.addTrainingListeners(systemInfoFilePrintListener);
DataSetIterator iter = new IrisDataSetIterator(10, 150); DataSetIterator iter = new IrisDataSetIterator(10, 150);

View File

@ -87,7 +87,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new Nesterovs().stateSize(net.numParams()); int updaterSize = (int) new Nesterovs().stateSize(net.numParams());
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }
@ -126,7 +126,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1)); assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1));
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new RmsProp().stateSize(numParams); int updaterSize = (int) new RmsProp().stateSize(numParams);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }
@ -170,7 +170,7 @@ public class RegressionTest050 extends BaseDL4JTest {
assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new RmsProp().stateSize(numParams); int updaterSize = (int) new RmsProp().stateSize(numParams);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }

View File

@ -89,7 +89,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new Nesterovs().stateSize(numParams); int updaterSize = (int) new Nesterovs().stateSize(numParams);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }
@ -132,7 +132,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new RmsProp().stateSize(numParams); int updaterSize = (int) new RmsProp().stateSize(numParams);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }
@ -178,7 +178,7 @@ public class RegressionTest060 extends BaseDL4JTest {
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new RmsProp().stateSize(numParams); int updaterSize = (int) new RmsProp().stateSize(numParams);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }

View File

@ -90,7 +90,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6);
long numParams = (int)net.numParams(); long numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new Nesterovs().stateSize(numParams); int updaterSize = (int) new Nesterovs().stateSize(numParams);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }
@ -133,7 +133,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
long numParams = net.numParams(); long numParams = net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new RmsProp().stateSize(numParams); int updaterSize = (int) new RmsProp().stateSize(numParams);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }
@ -179,7 +179,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
long numParams = net.numParams(); long numParams = net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new RmsProp().stateSize(numParams); int updaterSize = (int) new RmsProp().stateSize(numParams);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }

View File

@ -94,7 +94,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertEquals(0.15, n.getLearningRate(), 1e-6); assertEquals(0.15, n.getLearningRate(), 1e-6);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new Nesterovs().stateSize(numParams); int updaterSize = (int) new Nesterovs().stateSize(numParams);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }
@ -143,7 +143,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new RmsProp().stateSize(numParams); int updaterSize = (int) new RmsProp().stateSize(numParams);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }
@ -194,7 +194,7 @@ public class RegressionTest080 extends BaseDL4JTest {
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
int numParams = (int)net.numParams(); int numParams = (int)net.numParams();
assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams());
int updaterSize = (int) new RmsProp().stateSize(numParams); int updaterSize = (int) new RmsProp().stateSize(numParams);
assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray());
} }

View File

@ -97,7 +97,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
assertEquals(dt, in.dataType()); assertEquals(dt, in.dataType());
assertEquals(dt, outExp.dataType()); assertEquals(dt, outExp.dataType());
assertEquals(dt, net.params().dataType()); assertEquals(dt, net.getModelParams().dataType());
assertEquals(dt, net.getFlattenedGradients().dataType()); assertEquals(dt, net.getFlattenedGradients().dataType());
assertEquals(dt, net.getUpdater().getStateViewArray().dataType()); assertEquals(dt, net.getUpdater().getStateViewArray().dataType());
@ -109,7 +109,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
List<INDArray> activations = net.feedForward(in); List<INDArray> activations = net.feedForward(in);
assertEquals(dt, net.getNetConfiguration().getDataType()); assertEquals(dt, net.getNetConfiguration().getDataType());
assertEquals(dt, net.params().dataType()); assertEquals(dt, net.getModelParams().dataType());
assertEquals( outExp, outAct, dtype); assertEquals( outExp, outAct, dtype);
} }
} }

View File

@ -116,7 +116,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(dtype, in.dataType()); assertEquals(dtype, in.dataType());
assertEquals(dtype, outExp.dataType()); assertEquals(dtype, outExp.dataType());
assertEquals(dtype, net.params().dataType()); assertEquals(dtype, net.getModelParams().dataType());
assertEquals(dtype, net.getFlattenedGradients().dataType()); assertEquals(dtype, net.getFlattenedGradients().dataType());
assertEquals(dtype, net.getUpdater().getStateViewArray().dataType()); assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
@ -126,7 +126,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(dtype, outAct.dataType()); assertEquals(dtype, outAct.dataType());
assertEquals(dtype, net.getNetConfiguration().getDataType()); assertEquals(dtype, net.getNetConfiguration().getDataType());
assertEquals(dtype, net.params().dataType()); assertEquals(dtype, net.getModelParams().dataType());
boolean eq = outExp.equalsWithEps(outAct, 0.01); boolean eq = outExp.equalsWithEps(outAct, 0.01);
assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct); assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct);
} }

View File

@ -98,7 +98,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(dtype, in.dataType()); assertEquals(dtype, in.dataType());
assertEquals(dtype, outExp.dataType()); assertEquals(dtype, outExp.dataType());
assertEquals(dtype, net.params().dataType()); assertEquals(dtype, net.getModelParams().dataType());
assertEquals(dtype, net.getFlattenedGradients().dataType()); assertEquals(dtype, net.getFlattenedGradients().dataType());
assertEquals(dtype, net.getUpdater().getStateViewArray().dataType()); assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
@ -108,7 +108,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(dtype, outAct.dataType()); assertEquals(dtype, outAct.dataType());
assertEquals(dtype, net.getNetConfiguration().getDataType()); assertEquals(dtype, net.getNetConfiguration().getDataType());
assertEquals(dtype, net.params().dataType()); assertEquals(dtype, net.getModelParams().dataType());
boolean eq = outExp.equalsWithEps(outAct, 0.01); boolean eq = outExp.equalsWithEps(outAct, 0.01);
assertTrue( eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct); assertTrue( eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct);
} }

View File

@ -76,7 +76,7 @@ public class CustomLayer extends FeedForwardLayer {
//For the most part, it's the same for each type of layer //For the most part, it's the same for each type of layer
CustomLayerImpl myCustomLayer = new CustomLayerImpl(lconf, networkDataType); CustomLayerImpl myCustomLayer = new CustomLayerImpl(lconf, networkDataType);
myCustomLayer.setListeners(iterationListeners); //Set the iteration listeners, if any myCustomLayer.addTrainingListeners(iterationListeners); //Set the iteration listeners, if any
myCustomLayer.setIndex(layerIndex); //Integer index of the layer myCustomLayer.setIndex(layerIndex); //Integer index of the layer
//Parameter view array: In Deeplearning4j, the network parameters for the entire network (all layers) are //Parameter view array: In Deeplearning4j, the network parameters for the entire network (all layers) are

View File

@ -20,7 +20,6 @@
package org.deeplearning4j.regressiontest.customlayer100a; package org.deeplearning4j.regressiontest.customlayer100a;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
@ -56,7 +55,7 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
INDArray firstHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2)); INDArray firstHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2));
INDArray secondHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns)); INDArray secondHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
IActivation activation1 = layerConf().getActivationFn(); IActivation activation1 = getTypedLayerConfiguration().getActivationFn();
IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction(); IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
//IActivation function instances modify the activation functions in-place //IActivation function instances modify the activation functions in-place
@ -75,7 +74,7 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
@Override @Override
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
/* /*
The baockprop gradient method here is very similar to the BaseLayer backprop gradient implementation The baockprop gradient method here is very similar to the BaseLayerConfiguration backprop gradient implementation
The only major difference is the two activation functions we have added in this example. The only major difference is the two activation functions we have added in this example.
Note that epsilon is dL/da - i.e., the derivative of the loss function with respect to the activations. Note that epsilon is dL/da - i.e., the derivative of the loss function with respect to the activations.
@ -105,14 +104,14 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
INDArray epsilonFirstHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2)); INDArray epsilonFirstHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2));
INDArray epsilonSecondHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns)); INDArray epsilonSecondHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
IActivation activation1 = layerConf().getActivationFn(); IActivation activation1 = getTypedLayerConfiguration().getActivationFn();
IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction(); IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
//IActivation backprop method modifies the 'firstHalf' and 'secondHalf' arrays in-place, to contain dL/dz //IActivation backprop method modifies the 'firstHalf' and 'secondHalf' arrays in-place, to contain dL/dz
activation1.backprop(firstHalf, epsilonFirstHalf); activation1.backprop(firstHalf, epsilonFirstHalf);
activation2.backprop(secondHalf, epsilonSecondHalf); activation2.backprop(secondHalf, epsilonSecondHalf);
//The remaining code for this method: just copy & pasted from BaseLayer.backpropGradient //The remaining code for this method: just copy & pasted from BaseLayerConfiguration.backpropGradient
// INDArray delta = epsilon.muli(activationDerivative); // INDArray delta = epsilon.muli(activationDerivative);
if (maskArray != null) { if (maskArray != null) {
activationDerivative.muliColumnVector(maskArray); activationDerivative.muliColumnVector(maskArray);
@ -128,7 +127,7 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad); ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad);
ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad); ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad);
INDArray epsilonNext = paramsTable.get(DefaultParamInitializer.WEIGHT_KEY).mmul(activationDerivative.transpose()).transpose(); INDArray epsilonNext = getParamTable().get(DefaultParamInitializer.WEIGHT_KEY).mmul(activationDerivative.transpose()).transpose();
return new Pair<>(ret, epsilonNext); return new Pair<>(ret, epsilonNext);
} }

View File

@ -190,7 +190,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
//Check score //Check score
double scoreDl4j = net.score(); double scoreDl4j = net.getScore();
double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore(); double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore();
assertEquals(scoreDl4j, scoreSd, 1e-6, testName); assertEquals(scoreDl4j, scoreSd, 1e-6, testName);

View File

@ -104,7 +104,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
net.addListeners(new ScoreIterationListener(1)); net.addTrainingListeners(new ScoreIterationListener(1));
//Test net that hasn't been trained yet //Test net that hasn't been trained yet
Exception e = new Exception(); Exception e = new Exception();
@ -161,7 +161,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
CrashReportingUtil.crashDumpOutputDirectory(dir); CrashReportingUtil.crashDumpOutputDirectory(dir);
ComputationGraph cg = net.toComputationGraph(); ComputationGraph cg = net.toComputationGraph();
cg.setListeners(new ScoreIterationListener(1)); cg.addTrainingListeners(new ScoreIterationListener(1));
//Test net that hasn't been trained yet //Test net that hasn't been trained yet
CrashReportingUtil.writeMemoryCrashDump(cg, e); CrashReportingUtil.writeMemoryCrashDump(cg, e);

View File

@ -156,7 +156,7 @@ public class ModelGuesserTest extends BaseDL4JTest {
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
assertEquals(net.params(), network.params()); assertEquals(net.getModelParams(), network.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
} }
@ -173,7 +173,7 @@ public class ModelGuesserTest extends BaseDL4JTest {
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream); MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream);
Assertions.assertNotNull(network); Assertions.assertNotNull(network);
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
assertEquals(net.params(), network.params()); assertEquals(net.getModelParams(), network.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
} }
} }

View File

@ -81,7 +81,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile); MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile);
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
assertEquals(net.params(), network.params()); assertEquals(net.getModelParams(), network.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
} }
@ -125,7 +125,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis); MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis);
assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
assertEquals(net.params(), network.params()); assertEquals(net.getModelParams(), network.getModelParams());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
} }
@ -151,7 +151,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile); ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile);
assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson()); assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
assertEquals(cg.params(), network.params()); assertEquals(cg.getModelParams(), network.getModelParams());
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
} }
@ -177,7 +177,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
ComputationGraph network = ModelSerializer.restoreComputationGraph(fis); ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);
assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson()); assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
assertEquals(cg.params(), network.params()); assertEquals(cg.getModelParams(), network.getModelParams());
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
} }
@ -346,7 +346,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
//Also test reading both model and normalizer from stream (correctly) //Also test reading both model and normalizer from stream (correctly)
Pair<MultiLayerNetwork,Normalizer> pair = ModelSerializer.restoreMultiLayerNetworkAndNormalizer(new FileInputStream(tempFile), true); Pair<MultiLayerNetwork,Normalizer> pair = ModelSerializer.restoreMultiLayerNetworkAndNormalizer(new FileInputStream(tempFile), true);
assertEquals(net.params(), pair.getFirst().params()); assertEquals(net.getModelParams(), pair.getFirst().getModelParams());
assertNotNull(pair.getSecond()); assertNotNull(pair.getSecond());
} }
@ -395,7 +395,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
//Also test reading both model and normalizer from stream (correctly) //Also test reading both model and normalizer from stream (correctly)
Pair<ComputationGraph,Normalizer> pair = ModelSerializer.restoreComputationGraphAndNormalizer(new FileInputStream(tempFile), true); Pair<ComputationGraph,Normalizer> pair = ModelSerializer.restoreComputationGraphAndNormalizer(new FileInputStream(tempFile), true);
assertEquals(net.params(), pair.getFirst().params()); assertEquals(net.getModelParams(), pair.getFirst().getModelParams());
assertNotNull(pair.getSecond()); assertNotNull(pair.getSecond());
} }
@ -496,6 +496,6 @@ public class ModelSerializerTest extends BaseDL4JTest {
assertTrue(entries.contains("otherData.bin")); assertTrue(entries.contains("otherData.bin"));
ComputationGraph restoredNet = ModelSerializer.restoreComputationGraph(tempFile); ComputationGraph restoredNet = ModelSerializer.restoreComputationGraph(tempFile);
assertEquals(net.params(), restoredNet.params()); assertEquals(net.getModelParams(), restoredNet.getModelParams());
} }
} }

View File

@ -21,7 +21,6 @@
package org.deeplearning4j.nn.modelimport.keras.layers; package org.deeplearning4j.nn.modelimport.keras.layers;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
@ -80,10 +79,6 @@ public class TFOpLayer extends LayerConfiguration {
public void setNIn(InputType inputType, boolean override){} public void setNIn(InputType inputType, boolean override){}
@Override
public GradientNormalization getGradientNormalization(){return null;}
@Override @Override
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
@ -91,14 +86,11 @@ public class TFOpLayer extends LayerConfiguration {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, lconf, networkDataType); TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, lconf, networkDataType);
tfOpLayerImpl.setListeners(trainingListeners); tfOpLayerImpl.addTrainingListeners(trainingListeners);
tfOpLayerImpl.setIndex(layerIndex); tfOpLayerImpl.setIndex(layerIndex);
return tfOpLayerImpl; return tfOpLayerImpl;
} }
@Override
public double getGradientNormalizationThreshold(){return 0.;}
@Override @Override
public List<Regularization> getRegularizationByParam(String paramName){return null;} public List<Regularization> getRegularizationByParam(String paramName){return null;}

View File

@ -31,7 +31,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
@ -448,8 +448,8 @@ public class KerasLSTM extends KerasLayer {
FeedForwardLayer ffl; FeedForwardLayer ffl;
if(this.layer instanceof BaseWrapperLayer){ if(this.layer instanceof BaseWrapperLayerConfiguration){
BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer; BaseWrapperLayerConfiguration bwl = (BaseWrapperLayerConfiguration)this.layer;
ffl = (FeedForwardLayer)bwl.getUnderlying(); ffl = (FeedForwardLayer)bwl.getUnderlying();
} else { } else {
ffl = (FeedForwardLayer) this.layer; ffl = (FeedForwardLayer) this.layer;

Some files were not shown because too many files have changed in this diff Show More