Playing with some new code 2 - clean build

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-07 17:05:32 +02:00
parent 3edb90dbd1
commit a5dfdcb18f
92 changed files with 716 additions and 318 deletions

View File

@ -74,7 +74,7 @@ public class TestFrozenLayers extends BaseSparkTest {
MultiLayerNetwork withFrozen = new TransferLearning.Builder(origModel).fineTuneConfiguration(finetune) MultiLayerNetwork withFrozen = new TransferLearning.Builder(origModel).fineTuneConfiguration(finetune)
.setFeatureExtractor(1).build(); .setFeatureExtractor(1).build();
Map<String, INDArray> m = withFrozen.paramTable(); Map<String, INDArray> m = withFrozen.getParamTable();
Map<String, INDArray> pCopy = new HashMap<>(); Map<String, INDArray> pCopy = new HashMap<>();
for (Map.Entry<String, INDArray> entry : m.entrySet()) { for (Map.Entry<String, INDArray> entry : m.entrySet()) {
pCopy.put(entry.getKey(), entry.getValue().dup()); pCopy.put(entry.getKey(), entry.getValue().dup());
@ -110,7 +110,7 @@ public class TestFrozenLayers extends BaseSparkTest {
MultiLayerNetwork fitted = sNet.getNetwork(); MultiLayerNetwork fitted = sNet.getNetwork();
Map<String, INDArray> fittedParams = fitted.paramTable(); Map<String, INDArray> fittedParams = fitted.getParamTable();
for (Map.Entry<String, INDArray> entry : fittedParams.entrySet()) { for (Map.Entry<String, INDArray> entry : fittedParams.entrySet()) {
INDArray orig = pCopy.get(entry.getKey()); INDArray orig = pCopy.get(entry.getKey());
@ -151,7 +151,7 @@ public class TestFrozenLayers extends BaseSparkTest {
ComputationGraph withFrozen = new TransferLearning.GraphBuilder(origModel).fineTuneConfiguration(finetune) ComputationGraph withFrozen = new TransferLearning.GraphBuilder(origModel).fineTuneConfiguration(finetune)
.setFeatureExtractor("1").build(); .setFeatureExtractor("1").build();
Map<String, INDArray> m = withFrozen.paramTable(); Map<String, INDArray> m = withFrozen.getParamTable();
Map<String, INDArray> pCopy = new HashMap<>(); Map<String, INDArray> pCopy = new HashMap<>();
for (Map.Entry<String, INDArray> entry : m.entrySet()) { for (Map.Entry<String, INDArray> entry : m.entrySet()) {
pCopy.put(entry.getKey(), entry.getValue().dup()); pCopy.put(entry.getKey(), entry.getValue().dup());
@ -187,7 +187,7 @@ public class TestFrozenLayers extends BaseSparkTest {
ComputationGraph fitted = sNet.getNetwork(); ComputationGraph fitted = sNet.getNetwork();
Map<String, INDArray> fittedParams = fitted.paramTable(); Map<String, INDArray> fittedParams = fitted.getParamTable();
for (Map.Entry<String, INDArray> entry : fittedParams.entrySet()) { for (Map.Entry<String, INDArray> entry : fittedParams.entrySet()) {
INDArray orig = pCopy.get(entry.getKey()); INDArray orig = pCopy.get(entry.getKey());

View File

@ -200,8 +200,8 @@ public class GAN {
Layer[] disLayers = ganDiscriminator.getLayers(); Layer[] disLayers = ganDiscriminator.getLayers();
Layer[] layers = ArrayUtils.addAll(genLayers, disLayers); Layer[] layers = ArrayUtils.addAll(genLayers, disLayers);
NeuralNetConfiguration genConf = generator.getConfiguration(); NeuralNetConfiguration genConf = generator.getNetConfiguration();
NeuralNetConfiguration disConf = ganDiscriminator.getConfiguration(); NeuralNetConfiguration disConf = ganDiscriminator.getNetConfiguration();
LayerConfiguration[] confLayers = new LayerConfiguration[layers.length]; LayerConfiguration[] confLayers = new LayerConfiguration[layers.length];
Map<Integer, InputPreProcessor> preProcessors = new HashMap<>(); Map<Integer, InputPreProcessor> preProcessors = new HashMap<>();

View File

@ -190,7 +190,7 @@ public class IntegrationTestRunner {
m = mln; m = mln;
MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true); MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true);
assertEquals(loaded.getConfiguration(), mln.getConfiguration(), "Configs not equal"); assertEquals(loaded.getNetConfiguration(), mln.getNetConfiguration(), "Configs not equal");
assertEquals( loaded.params(), mln.params(), "Params not equal"); assertEquals( loaded.params(), mln.params(), "Params not equal");
assertEquals( loaded.getParamTable(), mln.getParamTable(), "Param table not equal"); assertEquals( loaded.getParamTable(), mln.getParamTable(), "Param table not equal");
} else if(config instanceof ComputationGraphConfiguration ){ } else if(config instanceof ComputationGraphConfiguration ){
@ -202,7 +202,7 @@ public class IntegrationTestRunner {
ComputationGraph loaded = ComputationGraph.load(savedModel, true); ComputationGraph loaded = ComputationGraph.load(savedModel, true);
assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" ); assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" );
assertEquals( loaded.params(), cg.params(), "Params not equal"); assertEquals( loaded.params(), cg.params(), "Params not equal");
assertEquals(loaded.paramTable(), cg.paramTable(), "Param table not equal"); assertEquals(loaded.getParamTable(), cg.getParamTable(), "Param table not equal");
} else if(config instanceof SameDiff){ } else if(config instanceof SameDiff){
sd = (SameDiff)config; sd = (SameDiff)config;
SameDiff loaded = SameDiff.load(savedModel, true); SameDiff loaded = SameDiff.load(savedModel, true);
@ -426,8 +426,8 @@ public class IntegrationTestRunner {
boolean isTbptt; boolean isTbptt;
int tbpttLength; int tbpttLength;
if(modelType == ModelType.MLN){ if(modelType == ModelType.MLN){
isTbptt = mln.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; isTbptt = mln.getNetConfiguration().getBackpropType() == BackpropType.TruncatedBPTT;
tbpttLength = mln.getConfiguration().getTbpttFwdLength(); tbpttLength = mln.getNetConfiguration().getTbpttFwdLength();
} else if(modelType == ModelType.CG) { } else if(modelType == ModelType.CG) {
isTbptt = cg.getComputationGraphConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; isTbptt = cg.getComputationGraphConfiguration().getBackpropType() == BackpropType.TruncatedBPTT;
tbpttLength = cg.getComputationGraphConfiguration().getTbpttFwdLength(); tbpttLength = cg.getComputationGraphConfiguration().getTbpttFwdLength();
@ -606,7 +606,7 @@ public class IntegrationTestRunner {
if (modelType == ModelType.MLN) { if (modelType == ModelType.MLN) {
ModelSerializer.writeModel(m, f, true); ModelSerializer.writeModel(m, f, true);
MultiLayerNetwork restored = MultiLayerNetwork.load(f, true); MultiLayerNetwork restored = MultiLayerNetwork.load(f, true);
assertEquals(mln.getConfiguration(), restored.getConfiguration()); assertEquals(mln.getNetConfiguration(), restored.getNetConfiguration());
assertEquals(mln.params(), restored.params()); assertEquals(mln.params(), restored.params());
} else if(modelType == ModelType.CG){ } else if(modelType == ModelType.CG){
ModelSerializer.writeModel(m, f, true); ModelSerializer.writeModel(m, f, true);
@ -742,7 +742,7 @@ public class IntegrationTestRunner {
//Collect preprocessor coverage information: //Collect preprocessor coverage information:
Collection<InputPreProcessor> preProcessors; Collection<InputPreProcessor> preProcessors;
if (isMLN) { if (isMLN) {
preProcessors = mln.getConfiguration().getInputPreProcessors().values(); preProcessors = mln.getNetConfiguration().getInputPreProcessors().values();
} else { } else {
preProcessors = new ArrayList<>(); preProcessors = new ArrayList<>();
for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getComputationGraphConfiguration().getVertices().values()) { for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getComputationGraphConfiguration().getVertices().values()) {
@ -834,7 +834,7 @@ public class IntegrationTestRunner {
} else { } else {
paramPrefix = l.getLayerConfiguration().getLayerName() + "_"; paramPrefix = l.getLayerConfiguration().getLayerName() + "_";
} }
Map<String,INDArray> paramTable = l.paramTable(); Map<String,INDArray> paramTable = l.getParamTable();
for(Map.Entry<String,INDArray> e : paramTable.entrySet()){ for(Map.Entry<String,INDArray> e : paramTable.entrySet()){
out.put(paramPrefix + e.getKey(), e.getValue().dup()); out.put(paramPrefix + e.getKey(), e.getValue().dup());
} }
@ -1088,7 +1088,7 @@ public class IntegrationTestRunner {
if(pSoFar + n < i){ if(pSoFar + n < i){
pSoFar += n; pSoFar += n;
} else { } else {
for(Map.Entry<String,INDArray> e : l.paramTable().entrySet()){ for(Map.Entry<String,INDArray> e : l.getParamTable().entrySet()){
pSoFar += e.getValue().length(); pSoFar += e.getValue().length();
if(pSoFar >= i){ if(pSoFar >= i){
pName = e.getKey(); pName = e.getKey();

View File

@ -48,7 +48,7 @@ public class TestUtils {
ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
assertEquals(net.params(), restored.params()); assertEquals(net.params(), restored.params());
} catch (IOException e){ } catch (IOException e){
//Should never happen //Should never happen
@ -56,7 +56,7 @@ public class TestUtils {
} }
//Also check the NeuralNetConfiguration is serializable (required by Spark etc) //Also check the NeuralNetConfiguration is serializable (required by Spark etc)
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
serializeDeserializeJava(conf); serializeDeserializeJava(conf);
return restored; return restored;

View File

@ -109,12 +109,12 @@ public class LayerHelperValidationUtil {
} }
MultiLayerNetwork net1NoHelper = new MultiLayerNetwork(netOrig.getConfiguration().clone()); MultiLayerNetwork net1NoHelper = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
net1NoHelper.init(); net1NoHelper.init();
log.info("Removing all layer helpers from network copy 1"); log.info("Removing all layer helpers from network copy 1");
removeHelpers(net1NoHelper.getLayers(), null); removeHelpers(net1NoHelper.getLayers(), null);
MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getConfiguration().clone()); MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
net2With.init(); net2With.init();
net2With.params().assign(netOrig.params()); net2With.params().assign(netOrig.params());
log.info("Removing all except for specified helpers from network copy 2: " + t.getAllowHelpersForClasses()); log.info("Removing all except for specified helpers from network copy 2: " + t.getAllowHelpersForClasses());
@ -253,7 +253,7 @@ public class LayerHelperValidationUtil {
Preconditions.checkNotNull(t.getData(), "DataSetIterator is not set (null)"); Preconditions.checkNotNull(t.getData(), "DataSetIterator is not set (null)");
log.info("Testing run-to-run consistency of training with layer helper"); log.info("Testing run-to-run consistency of training with layer helper");
net2With = new MultiLayerNetwork(netOrig.getConfiguration().clone()); net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
net2With.init(); net2With.init();
net2With.params().assign(netOrig.params()); net2With.params().assign(netOrig.params());
log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses()); log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());
@ -265,7 +265,7 @@ public class LayerHelperValidationUtil {
for( int i=0; i<2; i++ ) { for( int i=0; i<2; i++ ) {
net2With = new MultiLayerNetwork(netOrig.getConfiguration().clone()); net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone());
net2With.init(); net2With.init();
net2With.params().assign(netOrig.params()); net2With.params().assign(netOrig.params());
log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses()); log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses());

View File

@ -66,7 +66,7 @@ public class TestUtils {
ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
assertEquals(net.params(), restored.params()); assertEquals(net.params(), restored.params());
} catch (IOException e){ } catch (IOException e){
//Should never happen //Should never happen
@ -74,7 +74,7 @@ public class TestUtils {
} }
//Also check the NeuralNetConfiguration is serializable (required by Spark etc) //Also check the NeuralNetConfiguration is serializable (required by Spark etc)
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
serializeDeserializeJava(conf); serializeDeserializeJava(conf);
return restored; return restored;

View File

@ -622,7 +622,7 @@ public class EvalTest extends BaseDL4JTest {
//Disable validation, and check same thing: //Disable validation, and check same thing:
net.getConfiguration().setValidateOutputLayerConfig(false); net.getNetConfiguration().setValidateOutputLayerConfig(false);
net.evaluate(iter); net.evaluate(iter);
net.evaluateROCMultiClass(iter, 0); net.evaluateROCMultiClass(iter, 0);

View File

@ -511,7 +511,7 @@ public class GradientCheckTests extends BaseDL4JTest {
ComputationGraph netGraph = new ComputationGraph(conf); ComputationGraph netGraph = new ComputationGraph(conf);
netGraph.init(); netGraph.init();
log.info("params before learning: " + netGraph.getLayer(1).paramTable()); log.info("params before learning: " + netGraph.getLayer(1).getParamTable());
//Run a number of iterations of learning manually make some pseudo data //Run a number of iterations of learning manually make some pseudo data
//the ides is simple: since we do a element wise multiplication layer (just a scaling), we want the cos sim //the ides is simple: since we do a element wise multiplication layer (just a scaling), we want the cos sim
@ -538,7 +538,7 @@ public class GradientCheckTests extends BaseDL4JTest {
assertTrue( scoreAfter < 0.8 * scoreBefore, msg); assertTrue( scoreAfter < 0.8 * scoreBefore, msg);
// expectation in case linear regression(with only element wise multiplication layer): large weight for the fourth weight // expectation in case linear regression(with only element wise multiplication layer): large weight for the fourth weight
log.info("params after learning: " + netGraph.getLayer(1).paramTable()); log.info("params after learning: " + netGraph.getLayer(1).getParamTable());
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(netGraph).inputs(new INDArray[]{features}) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(netGraph).inputs(new INDArray[]{features})
.labels(new INDArray[]{labels})); .labels(new INDArray[]{labels}));

View File

@ -100,14 +100,14 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
@Test @Test
public void testClone() { public void testClone() {
NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true); NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true);
BaseLayer bl = (BaseLayer) conf.getFirstLayer(); BaseLayer bl = (BaseLayer) conf.getFlattenedLayerConfigurations().get(0);
conf.setStepFunction(new DefaultStepFunction()); conf.setStepFunction(new DefaultStepFunction());
NeuralNetConfiguration conf2 = conf.clone(); NeuralNetConfiguration conf2 = conf.clone();
assertEquals(conf, conf2); assertEquals(conf, conf2);
assertNotSame(conf, conf2); assertNotSame(conf, conf2);
assertNotSame(conf.getFirstLayer(), conf2.getFirstLayer()); assertNotSame(conf.getFlattenedLayerConfigurations().get(0), conf2.getFlattenedLayerConfigurations().get(0));
assertNotSame(conf.getStepFunction(), conf2.getStepFunction()); assertNotSame(conf.getStepFunction(), conf2.getStepFunction());
} }
@ -119,9 +119,9 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer).build(); .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer).build();
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
Layer model = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); Layer model = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
@ -130,9 +130,9 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(123) NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(123)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer2).build(); .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer2).build();
long numParams2 = conf2.getFirstLayer().initializer().numParams(conf); long numParams2 = conf2.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params2 = Nd4j.create(1, numParams); INDArray params2 = Nd4j.create(1, numParams);
Layer model2 = conf2.getFirstLayer().instantiate(conf2, null, 0, params2, true, params.dataType()); Layer model2 = conf2.getFlattenedLayerConfigurations().get(0).instantiate(conf2, null, 0, params2, true, params.dataType());
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2); assertEquals(modelWeights, modelWeights2);
@ -208,9 +208,9 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
private static Layer getLayer(int nIn, int nOut, IWeightInit weightInit, boolean preTrain) { private static Layer getLayer(int nIn, int nOut, IWeightInit weightInit, boolean preTrain) {
NeuralNetConfiguration conf = getConfig(nIn, nOut, weightInit, preTrain); NeuralNetConfiguration conf = getConfig(nIn, nOut, weightInit, preTrain);
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
return conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
} }
@ -235,7 +235,7 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
ConvexOptimizer opt = new StochasticGradientDescent(net.getConfiguration(), ConvexOptimizer opt = new StochasticGradientDescent(net.getNetConfiguration(),
new NegativeDefaultStepFunction(), null, net); new NegativeDefaultStepFunction(), null, net);
assertEquals(lr, ((Sgd)net.getLayer(0).getLayerConfiguration().getUpdaterByParam("W")).getLearningRate(), 1e-4); assertEquals(lr, ((Sgd)net.getLayer(0).getLayerConfiguration().getUpdaterByParam("W")).getLearningRate(), 1e-4);
assertEquals(biasLr, ((Sgd)net.getLayer(0).getLayerConfiguration().getUpdaterByParam("b")).getLearningRate(), 1e-4); assertEquals(biasLr, ((Sgd)net.getLayer(0).getLayerConfiguration().getUpdaterByParam("b")).getLearningRate(), 1e-4);
@ -295,7 +295,7 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
ConvexOptimizer opt = new StochasticGradientDescent(net.getConfiguration(), ConvexOptimizer opt = new StochasticGradientDescent(net.getNetConfiguration(),
new NegativeDefaultStepFunction(), null, net); new NegativeDefaultStepFunction(), null, net);
assertEquals(l1, TestUtils.getL1(net.getLayer(0).getLayerConfiguration().getRegularizationByParam("W")), 1e-4); assertEquals(l1, TestUtils.getL1(net.getLayer(0).getLayerConfiguration().getRegularizationByParam("W")), 1e-4);
List<Regularization> r = net.getLayer(0).getLayerConfiguration().getRegularizationByParam("b"); List<Regularization> r = net.getLayer(0).getLayerConfiguration().getRegularizationByParam("b");

View File

@ -456,7 +456,7 @@ public class TestConstraints extends BaseDL4JTest {
INDArray label = Nd4j.rand(1, 1); INDArray label = Nd4j.rand(1, 1);
g.fit(new INDArray[]{in1, in2}, new INDArray[]{label}); g.fit(new INDArray[]{in1, in2}, new INDArray[]{label});
for(Map.Entry<String,INDArray> e : g.paramTable().entrySet()){ for(Map.Entry<String,INDArray> e : g.getParamTable().entrySet()){
if(!e.getKey().contains("W")){ if(!e.getKey().contains("W")){
continue; continue;
} }

View File

@ -82,9 +82,9 @@ public class TestDropout extends BaseDL4JTest {
.setOutputs("2") .setOutputs("2")
.build(); .build();
assertEquals(new Dropout(0.6), ((LayerVertex)conf2.getVertices().get("0")).getNetConfiguration().getFirstLayer().getIDropout()); assertEquals(new Dropout(0.6), ((LayerVertex)conf2.getVertices().get("0")).getLayerConfiguration().getIDropout());
assertEquals(new Dropout(0.7), ((LayerVertex)conf2.getVertices().get("1")).getNetConfiguration().getFirstLayer().getIDropout()); assertEquals(new Dropout(0.7), ((LayerVertex)conf2.getVertices().get("1")).getLayerConfiguration().getIDropout());
assertEquals(new AlphaDropout(0.5), ((LayerVertex)conf2.getVertices().get("2")).getNetConfiguration().getFirstLayer().getIDropout()); assertEquals(new AlphaDropout(0.5), ((LayerVertex)conf2.getVertices().get("2")).getLayerConfiguration().getIDropout());
} }
@Test @Test

View File

@ -232,7 +232,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
cg.computeGradientAndScore(); cg.computeGradientAndScore();
// Let's figure out what our params are now. // Let's figure out what our params are now.
Map<String, INDArray> params = cg.paramTable(); Map<String, INDArray> params = cg.getParamTable();
INDArray dense1_W = nullsafe(params.get("dense1_W")); INDArray dense1_W = nullsafe(params.get("dense1_W"));
INDArray dense1_b = nullsafe(params.get("dense1_b")); INDArray dense1_b = nullsafe(params.get("dense1_b"));
INDArray dense2_W = nullsafe(params.get("dense2_W")); INDArray dense2_W = nullsafe(params.get("dense2_W"));
@ -408,7 +408,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
cg.computeGradientAndScore(); cg.computeGradientAndScore();
// Let's figure out what our params are now. // Let's figure out what our params are now.
Map<String, INDArray> params = cg.paramTable(); Map<String, INDArray> params = cg.getParamTable();
INDArray dense1_W = nullsafe(params.get("dense1_W")); INDArray dense1_W = nullsafe(params.get("dense1_W"));
INDArray dense1_b = nullsafe(params.get("dense1_b")); INDArray dense1_b = nullsafe(params.get("dense1_b"));
INDArray dense2_W = nullsafe(params.get("dense2_W")); INDArray dense2_W = nullsafe(params.get("dense2_W"));
@ -578,7 +578,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
cg.computeGradientAndScore(); cg.computeGradientAndScore();
// Let's figure out what our params are now. // Let's figure out what our params are now.
Map<String, INDArray> params = cg.paramTable(); Map<String, INDArray> params = cg.getParamTable();
INDArray dense1_W = nullsafe(params.get("dense1_W")); INDArray dense1_W = nullsafe(params.get("dense1_W"));
INDArray dense1_b = nullsafe(params.get("dense1_b")); INDArray dense1_b = nullsafe(params.get("dense1_b"));
INDArray dense2_W = nullsafe(params.get("dense2_W")); INDArray dense2_W = nullsafe(params.get("dense2_W"));

View File

@ -159,7 +159,7 @@ public class ShiftVertexTest extends BaseDL4JTest {
cg.setLabel(0, target); cg.setLabel(0, target);
cg.computeGradientAndScore(); cg.computeGradientAndScore();
double score_dl4j = cg.score(); double score_dl4j = cg.score();
Map<String, INDArray> weights = cg.paramTable(); Map<String, INDArray> weights = cg.getParamTable();
Gradient g = cg.gradient(); Gradient g = cg.gradient();
Map<String, INDArray> gradients = g.gradientForVariable(); Map<String, INDArray> gradients = g.gradientForVariable();
Map<String, INDArray> manual_gradients = new TreeMap<String, INDArray>(); Map<String, INDArray> manual_gradients = new TreeMap<String, INDArray>();

View File

@ -212,21 +212,21 @@ public class LayerBuilderTest extends BaseDL4JTest {
try (ByteArrayInputStream bis = new ByteArrayInputStream(data); ObjectInput in = new ObjectInputStream(bis)) { try (ByteArrayInputStream bis = new ByteArrayInputStream(data); ObjectInput in = new ObjectInputStream(bis)) {
confActual = (NeuralNetConfiguration) in.readObject(); confActual = (NeuralNetConfiguration) in.readObject();
} }
assertEquals(confExpected.getFirstLayer(), confActual.getFirstLayer(), "unequal Java serialization"); assertEquals(confExpected.getFlattenedLayerConfigurations().get(0), confActual.getFlattenedLayerConfigurations().get(0), "unequal Java serialization");
// check JSON // check JSON
String json = confExpected.toJson(); String json = confExpected.toJson();
confActual = NeuralNetConfiguration.fromJson(json); confActual = NeuralNetConfiguration.fromJson(json);
assertEquals(confExpected.getFirstLayer(), confActual.getFirstLayer(), "unequal JSON serialization"); assertEquals(confExpected.getFlattenedLayerConfigurations().get(0), confActual.getFlattenedLayerConfigurations().get(0), "unequal JSON serialization");
// check YAML // check YAML
String yaml = confExpected.toYaml(); String yaml = confExpected.toYaml();
confActual = NeuralNetConfiguration.fromYaml(yaml); confActual = NeuralNetConfiguration.fromYaml(yaml);
assertEquals(confExpected.getFirstLayer(), confActual.getFirstLayer(), "unequal YAML serialization"); assertEquals(confExpected.getFlattenedLayerConfigurations().get(0), confActual.getFlattenedLayerConfigurations().get(0), "unequal YAML serialization");
// check the layer's use of callSuper on equals method // check the layer's use of callSuper on equals method
confActual.getFirstLayer().setIDropout(new Dropout(new java.util.Random().nextDouble())); confActual.getFlattenedLayerConfigurations().get(0).setIDropout(new Dropout(new java.util.Random().nextDouble()));
assertNotEquals( confExpected.getFirstLayer(), confActual.getFirstLayer(), "broken equals method (missing callSuper?)"); assertNotEquals( confExpected, confActual, "broken equals method (missing callSuper?)");
} }
} }

View File

@ -62,9 +62,9 @@ public class TestPreProcessors extends BaseDL4JTest {
.nOut(layerSize).build()) .nOut(layerSize).build())
.build(); .build();
long numParams = nnc.getFirstLayer().initializer().numParams(nnc); long numParams = nnc.getFlattenedLayerConfigurations().get(0).initializer().numParams(nnc);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
DenseLayer layer = (DenseLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); DenseLayer layer = (DenseLayer) nnc.getFlattenedLayerConfigurations().get(0).instantiate(nnc, null, 0, params, true, params.dataType());
layer.setInputMiniBatchSize(miniBatchSize); layer.setInputMiniBatchSize(miniBatchSize);
INDArray activations3dc = Nd4j.create(new int[] {miniBatchSize, layerSize, timeSeriesLength}, 'c'); INDArray activations3dc = Nd4j.create(new int[] {miniBatchSize, layerSize, timeSeriesLength}, 'c');
@ -147,9 +147,9 @@ public class TestPreProcessors extends BaseDL4JTest {
.nOut(layerSize).build()) .nOut(layerSize).build())
.build(); .build();
val numParams = nnc.getFirstLayer().initializer().numParams(nnc); val numParams = nnc.getFlattenedLayerConfigurations().get(0).initializer().numParams(nnc);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
DenseLayer layer = (DenseLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); DenseLayer layer = (DenseLayer) nnc.getFlattenedLayerConfigurations().get(0).instantiate(nnc, null, 0, params, true, params.dataType());
layer.setInputMiniBatchSize(miniBatchSize); layer.setInputMiniBatchSize(miniBatchSize);
INDArray rand = Nd4j.rand(miniBatchSize * timeSeriesLength, layerSize); INDArray rand = Nd4j.rand(miniBatchSize * timeSeriesLength, layerSize);
@ -232,10 +232,10 @@ public class TestPreProcessors extends BaseDL4JTest {
.nOut(nChannels).build()) .nOut(nChannels).build())
.build(); .build();
val numParams = nnc.getFirstLayer().initializer().numParams(nnc); val numParams = nnc.getFlattenedLayerConfigurations().get(0).initializer().numParams(nnc);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
ConvolutionLayer layer = ConvolutionLayer layer =
(ConvolutionLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); (ConvolutionLayer) nnc.getFlattenedLayerConfigurations().get(0).instantiate(nnc, null, 0, params, true, params.dataType());
layer.setInputMiniBatchSize(miniBatchSize); layer.setInputMiniBatchSize(miniBatchSize);
INDArray activationsCnn = Nd4j.rand(miniBatchSize * timeSeriesLength, nChannels, INDArray activationsCnn = Nd4j.rand(miniBatchSize * timeSeriesLength, nChannels,
@ -314,10 +314,10 @@ public class TestPreProcessors extends BaseDL4JTest {
.nOut(nChannels).build()) .nOut(nChannels).build())
.build(); .build();
val numParams = nnc.getFirstLayer().initializer().numParams(nnc); val numParams = nnc.getFlattenedLayerConfigurations().get(0).initializer().numParams(nnc);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
ConvolutionLayer layer = ConvolutionLayer layer =
(ConvolutionLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); (ConvolutionLayer) nnc.getFlattenedLayerConfigurations().get(0).instantiate(nnc, null, 0, params, true, params.dataType());
layer.setInputMiniBatchSize(miniBatchSize); layer.setInputMiniBatchSize(miniBatchSize);
val shape_rnn = new long[] {miniBatchSize, nChannels * inputHeight * inputWidth, val shape_rnn = new long[] {miniBatchSize, nChannels * inputHeight * inputWidth,

View File

@ -256,9 +256,9 @@ public class DTypeTests extends BaseDL4JTest {
} }
public static void logUsedClasses(MultiLayerNetwork net) { public static void logUsedClasses(MultiLayerNetwork net) {
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
for (NeuralNetConfiguration nnc : conf.getNetConfigurations()) { for (NeuralNetConfiguration nnc : conf.getNetConfigurations()) {
LayerConfiguration l = nnc.getFirstLayer(); LayerConfiguration l = nnc.getFlattenedLayerConfigurations().get(0);
seenLayers.add(l.getClass()); seenLayers.add(l.getClass());
if (l instanceof BaseWrapperLayer) { if (l instanceof BaseWrapperLayer) {
BaseWrapperLayer bwl = (BaseWrapperLayer) l; BaseWrapperLayer bwl = (BaseWrapperLayer) l;
@ -281,7 +281,7 @@ public class DTypeTests extends BaseDL4JTest {
for (GraphVertex gv : conf.getVertices().values()) { for (GraphVertex gv : conf.getVertices().values()) {
seenVertices.add(gv.getClass()); seenVertices.add(gv.getClass());
if (gv instanceof LayerVertex) { if (gv instanceof LayerVertex) {
seenLayers.add(((LayerVertex) gv).getNetConfiguration().getFirstLayer().getClass()); seenLayers.add(((LayerVertex) gv).getLayerConfiguration().getClass());
InputPreProcessor ipp = ((LayerVertex) gv).getPreProcessor(); InputPreProcessor ipp = ((LayerVertex) gv).getPreProcessor();
if (ipp != null) { if (ipp != null) {
seenPreprocs.add(ipp.getClass()); seenPreprocs.add(ipp.getClass());

View File

@ -96,11 +96,11 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
Map<String,INDArray> paramsBefore = new HashMap<>(); Map<String,INDArray> paramsBefore = new HashMap<>();
//Pretrain first layer //Pretrain first layer
for(Map.Entry<String,INDArray> e : cg.paramTable().entrySet()){ for(Map.Entry<String,INDArray> e : cg.getParamTable().entrySet()){
paramsBefore.put(e.getKey(), e.getValue().dup()); paramsBefore.put(e.getKey(), e.getValue().dup());
} }
cg.pretrainLayer("vae1", ds); cg.pretrainLayer("vae1", ds);
for(Map.Entry<String,INDArray> e : cg.paramTable().entrySet()){ for(Map.Entry<String,INDArray> e : cg.getParamTable().entrySet()){
if(e.getKey().startsWith("vae1")){ if(e.getKey().startsWith("vae1")){
assertNotEquals(paramsBefore.get(e.getKey()), e.getValue()); assertNotEquals(paramsBefore.get(e.getKey()), e.getValue());
} else { } else {
@ -113,11 +113,11 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
//Pretrain second layer //Pretrain second layer
for(Map.Entry<String,INDArray> e : cg.paramTable().entrySet()){ for(Map.Entry<String,INDArray> e : cg.getParamTable().entrySet()){
paramsBefore.put(e.getKey(), e.getValue().dup()); paramsBefore.put(e.getKey(), e.getValue().dup());
} }
cg.pretrainLayer("vae2", ds); cg.pretrainLayer("vae2", ds);
for(Map.Entry<String,INDArray> e : cg.paramTable().entrySet()){ for(Map.Entry<String,INDArray> e : cg.getParamTable().entrySet()){
if(e.getKey().startsWith("vae2")){ if(e.getKey().startsWith("vae2")){
assertNotEquals(paramsBefore.get(e.getKey()), e.getValue()); assertNotEquals(paramsBefore.get(e.getKey()), e.getValue());
} else { } else {

View File

@ -406,9 +406,9 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
.addLayer("rnn", new GravesLSTM.Builder().nOut(5).build(), "in") .addLayer("rnn", new GravesLSTM.Builder().nOut(5).build(), "in")
.addLayer("out", new RnnOutputLayer.Builder().nOut(5).activation(Activation.SOFTMAX).build(), "rnn").setOutputs("out").build(); .addLayer("out", new RnnOutputLayer.Builder().nOut(5).activation(Activation.SOFTMAX).build(), "rnn").setOutputs("out").build();
assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("rnn")).getNetConfiguration().getFirstLayer()) assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("rnn")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getNIn()); .getNIn());
assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("out")).getNetConfiguration().getFirstLayer()) assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("out")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getNIn()); .getNIn());
LayerVertex lv1 = (LayerVertex) conf1.getVertices().get("rnn"); LayerVertex lv1 = (LayerVertex) conf1.getVertices().get("rnn");
@ -423,9 +423,9 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
.addLayer("out", new RnnOutputLayer.Builder().nOut(5).activation(Activation.SOFTMAX).build(), "ff") .addLayer("out", new RnnOutputLayer.Builder().nOut(5).activation(Activation.SOFTMAX).build(), "ff")
.setOutputs("out").build(); .setOutputs("out").build();
assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("ff")).getNetConfiguration().getFirstLayer()) assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("ff")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getNIn()); .getNIn());
assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("out")).getNetConfiguration().getFirstLayer()) assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("out")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getNIn()); .getNIn());
lv1 = (LayerVertex) conf2.getVertices().get("ff"); lv1 = (LayerVertex) conf2.getVertices().get("ff");
@ -460,7 +460,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
LayerVertex lv4 = (LayerVertex) conf3.getVertices().get("out"); LayerVertex lv4 = (LayerVertex) conf3.getVertices().get("out");
assertNull(lv4.getPreProcessor()); assertNull(lv4.getPreProcessor());
//Check nIns: //Check nIns:
assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getNetConfiguration().getFirstLayer()).getNIn()); assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getNetConfiguration().getFlattenedLayerConfigurations().get(0)).getNIn());
//CNN->Dense, RNN->Dense, Dense->RNN //CNN->Dense, RNN->Dense, Dense->RNN
ComputationGraphConfiguration conf4 = ComputationGraphConfiguration conf4 =
@ -495,9 +495,9 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
LayerVertex lv5 = (LayerVertex) conf4.getVertices().get("out"); LayerVertex lv5 = (LayerVertex) conf4.getVertices().get("out");
assertTrue(lv5.getPreProcessor() instanceof FeedForwardToRnnPreProcessor); assertTrue(lv5.getPreProcessor() instanceof FeedForwardToRnnPreProcessor);
//Check nIns: //Check nIns:
assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getNetConfiguration().getFirstLayer()).getNIn()); assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getNetConfiguration().getFlattenedLayerConfigurations().get(0)).getNIn());
assertEquals(5, ((FeedForwardLayer) lv4.getNetConfiguration().getFirstLayer()).getNIn()); assertEquals(5, ((FeedForwardLayer) lv4.getNetConfiguration().getFlattenedLayerConfigurations().get(0)).getNIn());
assertEquals(20, ((FeedForwardLayer) lv5.getNetConfiguration().getFirstLayer()).getNIn()); //10+10 out of the merge vertex -> 20 in to output layer vertex assertEquals(20, ((FeedForwardLayer) lv5.getNetConfiguration().getFlattenedLayerConfigurations().get(0)).getNIn()); //10+10 out of the merge vertex -> 20 in to output layer vertex
//Input to 2 CNN layers: //Input to 2 CNN layers:
@ -903,7 +903,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
.build(); .build();
LayerVertex lv = (LayerVertex) conf.getVertices().get("layer"); LayerVertex lv = (LayerVertex) conf.getVertices().get("layer");
FeedForwardLayer l = ((FeedForwardLayer) (lv).getNetConfiguration().getFirstLayer()); FeedForwardLayer l = ((FeedForwardLayer) (lv).getNetConfiguration().getFlattenedLayerConfigurations().get(0));
assertEquals(3, l.getNIn()); assertEquals(3, l.getNIn());
assertNull(lv.getPreProcessor()); assertNull(lv.getPreProcessor());
@ -920,7 +920,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
.build(); .build();
lv = (LayerVertex) conf.getVertices().get("layer"); lv = (LayerVertex) conf.getVertices().get("layer");
l = ((FeedForwardLayer) (lv).getNetConfiguration().getFirstLayer()); l = ((FeedForwardLayer) (lv).getNetConfiguration().getFlattenedLayerConfigurations().get(0));
assertEquals(3, l.getNIn()); assertEquals(3, l.getNIn());
assertNotNull(lv.getPreProcessor()); assertNotNull(lv.getPreProcessor());
InputPreProcessor preProcessor = lv.getPreProcessor(); InputPreProcessor preProcessor = lv.getPreProcessor();
@ -945,7 +945,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
//Check subsampling layer: //Check subsampling layer:
lv = (LayerVertex) conf.getVertices().get("l0"); lv = (LayerVertex) conf.getVertices().get("l0");
SubsamplingLayer sl = ((SubsamplingLayer) (lv).getNetConfiguration().getFirstLayer()); SubsamplingLayer sl = ((SubsamplingLayer) (lv).getNetConfiguration().getFlattenedLayerConfigurations().get(0));
assertNotNull(lv.getPreProcessor()); assertNotNull(lv.getPreProcessor());
preProcessor = lv.getPreProcessor(); preProcessor = lv.getPreProcessor();
assertTrue(preProcessor instanceof FeedForwardToCnnPreProcessor); assertTrue(preProcessor instanceof FeedForwardToCnnPreProcessor);
@ -955,7 +955,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
assertEquals(3, preproc.getNumChannels()); assertEquals(3, preproc.getNumChannels());
//Check dense layer //Check dense layer
lv = (LayerVertex) conf.getVertices().get("layer"); lv = (LayerVertex) conf.getVertices().get("layer");
l = ((FeedForwardLayer) (lv).getNetConfiguration().getFirstLayer()); l = ((FeedForwardLayer) (lv).getNetConfiguration().getFlattenedLayerConfigurations().get(0));
assertEquals(3, l.getNIn()); assertEquals(3, l.getNIn());
assertNull(lv.getPreProcessor()); assertNull(lv.getPreProcessor());
@ -1673,7 +1673,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
ComputationGraph g = new ComputationGraph(conf2); ComputationGraph g = new ComputationGraph(conf2);
g.init(); g.init();
g.setParamTable(cg.paramTable()); g.setParamTable(cg.getParamTable());
int[] origOrder = g.topologicalSortOrder(); int[] origOrder = g.topologicalSortOrder();
INDArray[] out4 = g.output(in); INDArray[] out4 = g.output(in);

View File

@ -72,9 +72,9 @@ public class TestSetGetParameters extends BaseDL4JTest {
assertSame(params, net3.params()); //Same object due to clone assertSame(params, net3.params()); //Same object due to clone
Map<String, INDArray> paramsMap = net.paramTable(); Map<String, INDArray> paramsMap = net.getParamTable();
Map<String, INDArray> paramsMap2 = net2.paramTable(); Map<String, INDArray> paramsMap2 = net2.getParamTable();
Map<String, INDArray> paramsMap3 = net3.paramTable(); Map<String, INDArray> paramsMap3 = net3.getParamTable();
for (String s : paramsMap.keySet()) { for (String s : paramsMap.keySet()) {
assertEquals(paramsMap.get(s), paramsMap2.get(s)); assertEquals(paramsMap.get(s), paramsMap2.get(s));
assertEquals(paramsMap.get(s), paramsMap3.get(s)); assertEquals(paramsMap.get(s), paramsMap3.get(s));

View File

@ -57,10 +57,10 @@ public class BaseLayerTest extends BaseDL4JTest {
@Test @Test
public void testSetExistingParamsConvolutionSingleLayer() { public void testSetExistingParamsConvolutionSingleLayer() {
Layer layer = configureSingleLayer(); Layer layer = configureSingleLayer();
assertNotEquals(paramTable, layer.paramTable()); assertNotEquals(paramTable, layer.getParamTable());
layer.setParamTable(paramTable); layer.setParamTable(paramTable);
assertEquals(paramTable, layer.paramTable()); assertEquals(paramTable, layer.getParamTable());
} }
@ -69,9 +69,9 @@ public class BaseLayerTest extends BaseDL4JTest {
MultiLayerNetwork net = configureMultiLayer(); MultiLayerNetwork net = configureMultiLayer();
for (Layer layer : net.getLayers()) { for (Layer layer : net.getLayers()) {
assertNotEquals(paramTable, layer.paramTable()); assertNotEquals(paramTable, layer.getParamTable());
layer.setParamTable(paramTable); layer.setParamTable(paramTable);
assertEquals(paramTable, layer.paramTable()); assertEquals(paramTable, layer.getParamTable());
} }
} }
@ -83,9 +83,9 @@ public class BaseLayerTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.layer(new ConvolutionLayer.Builder().nIn(nIn).nOut(nOut).build()).build(); .layer(new ConvolutionLayer.Builder().nIn(nIn).nOut(nOut).build()).build();
val numParams = conf.getFirstLayer().initializer().numParams(conf); val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
return conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
} }

View File

@ -133,7 +133,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
MultiLayerNetwork clonedModel = modelNow.clone(); MultiLayerNetwork clonedModel = modelNow.clone();
//Check json //Check json
assertEquals(modelNow.getConfiguration().toJson(), clonedModel.getConfiguration().toJson()); assertEquals(modelNow.getNetConfiguration().toJson(), clonedModel.getNetConfiguration().toJson());
//Check params //Check params
assertEquals(modelNow.params(), clonedModel.params()); assertEquals(modelNow.params(), clonedModel.params());

View File

@ -64,9 +64,9 @@ public class OutputLayerTest extends BaseDL4JTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
OutputLayer l = (OutputLayer) conf.getFirstLayer().instantiate(conf, OutputLayer l = (OutputLayer) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf,
Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType()); Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType());
params = l.params(); params = l.params();
l.setParamsTable(params); l.setParamsTable(params);

View File

@ -43,7 +43,7 @@ public class RepeatVectorTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123)
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
.layer(new RepeatVector.Builder(REPEAT).build()).build(); .layer(new RepeatVector.Builder(REPEAT).build()).build();
return conf.getFirstLayer().instantiate(conf, null, 0, return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0,
null, false, DataType.DOUBLE); null, false, DataType.DOUBLE);
} }

View File

@ -52,9 +52,9 @@ public class SeedTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration conf =
NeuralNetConfiguration.builder().layer(layerType).seed(123).build(); NeuralNetConfiguration.builder().layer(layerType).seed(123).build();
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams)); layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams));
layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces());

View File

@ -90,9 +90,9 @@ public class Convolution3DTest extends BaseDL4JTest {
.dataFormat(Convolution3D.DataFormat.NCDHW).convolutionMode(mode).hasBias(false) .dataFormat(Convolution3D.DataFormat.NCDHW).convolutionMode(mode).hasBias(false)
.build()) .build())
.build(); .build();
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.ones(1, numParams); INDArray params = Nd4j.ones(1, numParams);
return conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
} }
public INDArray getData() throws Exception { public INDArray getData() throws Exception {

View File

@ -258,9 +258,9 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(cnn).build(); NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(cnn).build();
val numParams = conf.getFirstLayer().initializer().numParams(conf); val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
assertEquals(1, layer.getParam("b").size(0)); assertEquals(1, layer.getParam("b").size(0));
} }
@ -319,9 +319,9 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(layer).build(); NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(layer).build();
val numParams = conf.getFirstLayer().initializer().numParams(conf); val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
return conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
} }
public Layer getMNISTConfig() { public Layer getMNISTConfig() {

View File

@ -62,7 +62,7 @@ public class SpaceToDepthTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123)
.layer(new SpaceToDepthLayer.Builder(blockSize, dataFormat).build()).build(); .layer(new SpaceToDepthLayer.Builder(blockSize, dataFormat).build()).build();
return conf.getFirstLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType());
} }
@Test @Test

View File

@ -172,7 +172,7 @@ public class SubsamplingLayerTest extends BaseDL4JTest {
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123)
.layer(new SubsamplingLayer.Builder(pooling, new int[] {2, 2}).build()).build(); .layer(new SubsamplingLayer.Builder(pooling, new int[] {2, 2}).build()).build();
return conf.getFirstLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType());
} }
public INDArray getData() throws Exception { public INDArray getData() throws Exception {

View File

@ -287,28 +287,28 @@ public class TestConvolutionModes extends BaseDL4JTest {
.activation(Activation.SOFTMAX).nOut(3).build(), "7") .activation(Activation.SOFTMAX).nOut(3).build(), "7")
.setOutputs("8").build(); .setOutputs("8").build();
assertEquals(cm, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("0")).getNetConfiguration().getFirstLayer()) assertEquals(cm, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("0")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getConvolutionMode()); .getConvolutionMode());
assertEquals(ConvolutionMode.Strict, assertEquals(ConvolutionMode.Strict,
((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("1")).getNetConfiguration().getFirstLayer()) ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("1")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getConvolutionMode()); .getConvolutionMode());
assertEquals(ConvolutionMode.Truncate, assertEquals(ConvolutionMode.Truncate,
((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("2")).getNetConfiguration().getFirstLayer()) ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("2")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getConvolutionMode()); .getConvolutionMode());
assertEquals(ConvolutionMode.Same, assertEquals(ConvolutionMode.Same,
((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("3")).getNetConfiguration().getFirstLayer()) ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("3")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getConvolutionMode()); .getConvolutionMode());
assertEquals(cm, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("4")).getNetConfiguration().getFirstLayer()) assertEquals(cm, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("4")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getConvolutionMode()); .getConvolutionMode());
assertEquals(ConvolutionMode.Strict, assertEquals(ConvolutionMode.Strict,
((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("5")).getNetConfiguration().getFirstLayer()) ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("5")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getConvolutionMode()); .getConvolutionMode());
assertEquals(ConvolutionMode.Truncate, assertEquals(ConvolutionMode.Truncate,
((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("6")).getNetConfiguration().getFirstLayer()) ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("6")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getConvolutionMode()); .getConvolutionMode());
assertEquals(ConvolutionMode.Same, assertEquals(ConvolutionMode.Same,
((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("7")).getNetConfiguration().getFirstLayer()) ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("7")).getNetConfiguration().getFlattenedLayerConfigurations().get(0))
.getConvolutionMode()); .getConvolutionMode());
} }
} }

View File

@ -107,7 +107,7 @@ public class Upsampling1DTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123)
.layer(new Upsampling1D.Builder(size).build()).build(); .layer(new Upsampling1D.Builder(size).build()).build();
return conf.getFirstLayer().instantiate(conf, null, 0, return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0,
null, true, Nd4j.defaultFloatingPointType()); null, true, Nd4j.defaultFloatingPointType());
} }

View File

@ -111,7 +111,7 @@ public class Upsampling2DTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123)
.layer(new Upsampling2D.Builder(size).build()).build(); .layer(new Upsampling2D.Builder(size).build()).build();
return conf.getFirstLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType());
} }
public INDArray getData() throws Exception { public INDArray getData() throws Exception {

View File

@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
@ -53,13 +54,14 @@ public class CustomLayer extends FeedForwardLayer {
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
boolean initializeParams, DataType networkDataType) { boolean initializeParams, DataType networkDataType) {
CustomLayerImpl ret = new CustomLayerImpl(conf, networkDataType); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType);
ret.setListeners(trainingListeners); ret.setListeners(trainingListeners);
ret.setIndex(layerIndex); ret.setIndex(layerIndex);
ret.setParamsViewArray(layerParamsView); ret.setParamsViewArray(layerParamsView);
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams); Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
ret.setParamTable(paramTable); ret.setParamTable(paramTable);
ret.setLayerConfiguration(conf); ret.setLayerConfiguration(lconf);
return ret; return ret;
} }

View File

@ -21,11 +21,12 @@
package org.deeplearning4j.nn.layers.custom.testclasses; package org.deeplearning4j.nn.layers.custom.testclasses;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.layers.BaseLayer; import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
public class CustomLayerImpl extends BaseLayer<CustomLayer> { public class CustomLayerImpl extends BaseLayer<CustomLayer> {
public CustomLayerImpl(NeuralNetConfiguration conf, DataType dataType) { public CustomLayerImpl(LayerConfiguration conf, DataType dataType) {
super(conf, dataType); super(conf, dataType);
} }

View File

@ -29,6 +29,7 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -51,13 +52,14 @@ public class CustomOutputLayer extends BaseOutputLayer {
@Override @Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
CustomOutputLayerImpl ret = new CustomOutputLayerImpl(conf, networkDataType); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
CustomOutputLayerImpl ret = new CustomOutputLayerImpl(lconf, networkDataType);
ret.setListeners(trainingListeners); ret.setListeners(trainingListeners);
ret.setIndex(layerIndex); ret.setIndex(layerIndex);
ret.setParamsViewArray(layerParamsView); ret.setParamsViewArray(layerParamsView);
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams); Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
ret.setParamTable(paramTable); ret.setParamTable(paramTable);
ret.setLayerConfiguration(conf); ret.setLayerConfiguration(lconf);
return ret; return ret;
} }

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.layers.custom.testclasses; package org.deeplearning4j.nn.layers.custom.testclasses;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.layers.BaseOutputLayer; import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@ -28,7 +29,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
public class CustomOutputLayerImpl extends BaseOutputLayer<CustomOutputLayer> { public class CustomOutputLayerImpl extends BaseOutputLayer<CustomOutputLayer> {
public CustomOutputLayerImpl(NeuralNetConfiguration conf, DataType dataType) { public CustomOutputLayerImpl(LayerConfiguration conf, DataType dataType) {
super(conf, dataType); super(conf, dataType);
} }

View File

@ -53,9 +53,9 @@ public class DenseTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(build).build(); NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(build).build();
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, Nd4j.defaultFloatingPointType()); Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, Nd4j.defaultFloatingPointType());
assertEquals(1, layer.getParam("b").size(0)); assertEquals(1, layer.getParam("b").size(0));
} }

View File

@ -130,12 +130,12 @@ public class BatchNormalizationTest extends BaseDL4JTest {
BatchNormalization bN = b.build(); BatchNormalization bN = b.build();
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(bN).build(); NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(bN).build();
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = null; INDArray params = null;
if (numParams > 0) { if (numParams > 0) {
params = Nd4j.create(1, numParams); params = Nd4j.create(1, numParams);
} }
Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params == null ? Nd4j.defaultFloatingPointType() : params.dataType()); Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params == null ? Nd4j.defaultFloatingPointType() : params.dataType());
if (numParams > 0) { if (numParams > 0) {
layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams)); layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams));
} }

View File

@ -123,7 +123,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest {
DataSet filtered = next.filterBy(new int[]{0, 1}); DataSet filtered = next.filterBy(new int[]{0, 1});
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
network.setEpochCount(i); network.setEpochCount(i);
network.getConfiguration().setEpochCount(i); network.getNetConfiguration().setEpochCount(i);
network.fit(filtered); network.fit(filtered);
} }

View File

@ -68,10 +68,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build()) .nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build())
.build(); .build();
val numParams = conf.getFirstLayer().initializer().numParams(conf); val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
final GravesBidirectionalLSTM layer = final GravesBidirectionalLSTM layer =
(GravesBidirectionalLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); (GravesBidirectionalLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
//Data: has shape [miniBatchSize,nIn,timeSeriesLength]; //Data: has shape [miniBatchSize,nIn,timeSeriesLength];
//Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength];
@ -135,11 +135,11 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build())
.build(); .build();
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
GravesBidirectionalLSTM lstm = GravesBidirectionalLSTM lstm =
(GravesBidirectionalLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); (GravesBidirectionalLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getFirstLayer().initializer().numParams(conf))); lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf)));
//Set input, do a forward pass: //Set input, do a forward pass:
lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces());
assertNotNull(lstm.input()); assertNotNull(lstm.input());
@ -207,10 +207,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build())
.build(); .build();
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
final GravesBidirectionalLSTM lstm = final GravesBidirectionalLSTM lstm =
(GravesBidirectionalLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); (GravesBidirectionalLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
final INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); final INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces());
@ -266,9 +266,9 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.build(); .build();
long numParams = confBidirectional.getFirstLayer().initializer().numParams(confBidirectional); long numParams = confBidirectional.getFlattenedLayerConfigurations().get(0).initializer().numParams(confBidirectional);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getFirstLayer() final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getFlattenedLayerConfigurations().get(0)
.instantiate(confBidirectional, null, 0, params, true, params.dataType()); .instantiate(confBidirectional, null, 0, params, true, params.dataType());
@ -311,19 +311,19 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.weightInit(WeightInit.ZERO).activation(Activation.TANH).build()) .weightInit(WeightInit.ZERO).activation(Activation.TANH).build())
.build(); .build();
long numParams = confForwards.getFirstLayer().initializer().numParams(confForwards); long numParams = confForwards.getFlattenedLayerConfigurations().get(0).initializer().numParams(confForwards);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
long numParamsBD = confBidirectional.getFirstLayer().initializer().numParams(confBidirectional); long numParamsBD = confBidirectional.getFlattenedLayerConfigurations().get(0).initializer().numParams(confBidirectional);
INDArray paramsBD = Nd4j.create(1, numParamsBD); INDArray paramsBD = Nd4j.create(1, numParamsBD);
final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getFirstLayer() final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getFlattenedLayerConfigurations().get(0)
.instantiate(confBidirectional, null, 0, paramsBD, true, params.dataType()); .instantiate(confBidirectional, null, 0, paramsBD, true, params.dataType());
final GravesLSTM forwardsLSTM = final GravesLSTM forwardsLSTM =
(GravesLSTM) confForwards.getFirstLayer().instantiate(confForwards, null, 0, params, true, params.dataType()); (GravesLSTM) confForwards.getFlattenedLayerConfigurations().get(0).instantiate(confForwards, null, 0, params, true, params.dataType());
bidirectionalLSTM.setBackpropGradientsViewArray( bidirectionalLSTM.setBackpropGradientsViewArray(
Nd4j.create(1, confBidirectional.getFirstLayer().initializer().numParams(confBidirectional))); Nd4j.create(1, confBidirectional.getFlattenedLayerConfigurations().get(0).initializer().numParams(confBidirectional)));
forwardsLSTM.setBackpropGradientsViewArray( forwardsLSTM.setBackpropGradientsViewArray(
Nd4j.create(1, confForwards.getFirstLayer().initializer().numParams(confForwards))); Nd4j.create(1, confForwards.getFlattenedLayerConfigurations().get(0).initializer().numParams(confForwards)));
final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(miniBatchSize, nIn, timeSeriesLength): final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(miniBatchSize, nIn, timeSeriesLength):
@ -546,7 +546,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
net.init(); net.init();
assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) net.getLayer(0).getNetConfiguration() assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) net.getLayer(0).getNetConfiguration()
.getFirstLayer()).getGateActivationFn().toString()); .getFlattenedLayerConfigurations().get(0)).getGateActivationFn().toString());
INDArray in = Nd4j.rand(3, 2, 5); INDArray in = Nd4j.rand(3, 2, 5);
INDArray labels = Nd4j.rand(3, 2, 5); INDArray labels = Nd4j.rand(3, 2, 5);

View File

@ -63,9 +63,9 @@ public class GravesLSTMTest extends BaseDL4JTest {
.nOut(nHiddenUnits).activation(Activation.TANH).build()) .nOut(nHiddenUnits).activation(Activation.TANH).build())
.build(); .build();
val numParams = conf.getFirstLayer().initializer().numParams(conf); val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
GravesLSTM layer = (GravesLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); GravesLSTM layer = (GravesLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
//Data: has shape [miniBatchSize,nIn,timeSeriesLength]; //Data: has shape [miniBatchSize,nIn,timeSeriesLength];
//Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength];
@ -109,10 +109,10 @@ public class GravesLSTMTest extends BaseDL4JTest {
.dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build())
.build(); .build();
val numParams = conf.getFirstLayer().initializer().numParams(conf); val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
GravesLSTM lstm = (GravesLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); GravesLSTM lstm = (GravesLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getFirstLayer().initializer().numParams(conf))); lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf)));
//Set input, do a forward pass: //Set input, do a forward pass:
lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces());
assertNotNull(lstm.input()); assertNotNull(lstm.input());
@ -160,9 +160,9 @@ public class GravesLSTMTest extends BaseDL4JTest {
.activation(Activation.TANH).build()) .activation(Activation.TANH).build())
.build(); .build();
val numParams = conf.getFirstLayer().initializer().numParams(conf); val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
GravesLSTM lstm = (GravesLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); GravesLSTM lstm = (GravesLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces());

View File

@ -73,7 +73,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
Map<String, INDArray> pt1 = net.getLayer(0).paramTable(); Map<String, INDArray> pt1 = net.getLayer(0).getParamTable();
assertNotNull(pt1); assertNotNull(pt1);
assertEquals(2, pt1.size()); assertEquals(2, pt1.size());
assertNotNull(pt1.get(ConvolutionParamInitializer.WEIGHT_KEY)); assertNotNull(pt1.get(ConvolutionParamInitializer.WEIGHT_KEY));

View File

@ -71,7 +71,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
Map<String, INDArray> pt1 = net.getLayer(0).paramTable(); Map<String, INDArray> pt1 = net.getLayer(0).getParamTable();
assertNotNull(pt1); assertNotNull(pt1);
assertEquals(2, pt1.size()); assertEquals(2, pt1.size());
assertNotNull(pt1.get(DefaultParamInitializer.WEIGHT_KEY)); assertNotNull(pt1.get(DefaultParamInitializer.WEIGHT_KEY));

View File

@ -104,7 +104,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
//Check params: //Check params:
assertEquals(netStandard.params(), netSD.params()); assertEquals(netStandard.params(), netSD.params());
assertEquals(netStandard.paramTable(), netSD.paramTable()); assertEquals(netStandard.getParamTable(), netSD.getParamTable());
INDArray in = Nd4j.rand(minibatch, nIn); INDArray in = Nd4j.rand(minibatch, nIn);
INDArray l = TestUtils.randomOneHot(minibatch, nOut, 12345); INDArray l = TestUtils.randomOneHot(minibatch, nOut, 12345);
@ -159,7 +159,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
netSD.fit(ds); netSD.fit(ds);
netStandard.fit(ds); netStandard.fit(ds);
assertEquals(netStandard.paramTable(), netSD.paramTable()); assertEquals(netStandard.getParamTable(), netSD.getParamTable());
assertEquals(netStandard.params(), netSD.params()); assertEquals(netStandard.params(), netSD.params());
assertEquals(netStandard.getFlattenedGradients(), netSD.getFlattenedGradients()); assertEquals(netStandard.getFlattenedGradients(), netSD.getFlattenedGradients());
} }

View File

@ -63,7 +63,7 @@ public class TestVAE extends BaseDL4JTest {
.build()) .build())
.build(); .build();
LayerConfiguration c = mlc.getFirstLayer(); LayerConfiguration c = mlc.getFlattenedLayerConfigurations().get(0);
org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder vae = org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder vae =
(VariationalAutoencoder) c; (VariationalAutoencoder) c;
@ -78,7 +78,7 @@ public class TestVAE extends BaseDL4JTest {
System.out.println("Exp num params: " + expNumParams); System.out.println("Exp num params: " + expNumParams);
assertEquals(expNumParams, net.getLayer(0).params().length()); assertEquals(expNumParams, net.getLayer(0).params().length());
Map<String, INDArray> paramTable = net.getLayer(0).paramTable(); Map<String, INDArray> paramTable = net.getLayer(0).getParamTable();
int count = 0; int count = 0;
for (INDArray arr : paramTable.values()) { for (INDArray arr : paramTable.values()) {
count += arr.length(); count += arr.length();
@ -135,7 +135,7 @@ public class TestVAE extends BaseDL4JTest {
net.init(); net.init();
net.initGradientsView(); //TODO this should happen automatically net.initGradientsView(); //TODO this should happen automatically
Map<String, INDArray> paramTable = net.getLayer(0).paramTable(); Map<String, INDArray> paramTable = net.getLayer(0).getParamTable();
Map<String, INDArray> gradTable = Map<String, INDArray> gradTable =
((org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0)) ((org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0))
.getGradientViews(); .getGradientViews();
@ -175,7 +175,7 @@ public class TestVAE extends BaseDL4JTest {
org.deeplearning4j.nn.layers.variational.VariationalAutoencoder layer = org.deeplearning4j.nn.layers.variational.VariationalAutoencoder layer =
(org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0); (org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0);
Map<String, INDArray> layerParams = layer.paramTable(); Map<String, INDArray> layerParams = layer.getParamTable();
Map<String, INDArray> layerGradViews = layer.getGradientViews(); Map<String, INDArray> layerGradViews = layer.getGradientViews();
layer.setInput(Nd4j.rand(3, 10), LayerWorkspaceMgr.noWorkspaces()); layer.setInput(Nd4j.rand(3, 10), LayerWorkspaceMgr.noWorkspaces());
@ -239,7 +239,7 @@ public class TestVAE extends BaseDL4JTest {
net.pretrainLayer(0, input); net.pretrainLayer(0, input);
//Get a snapshot of the pretrain params after fitting: //Get a snapshot of the pretrain params after fitting:
Map<String, INDArray> layerParams = layer.paramTable(); Map<String, INDArray> layerParams = layer.getParamTable();
Map<String, INDArray> pretrainParamsBefore = new HashMap<>(); Map<String, INDArray> pretrainParamsBefore = new HashMap<>();
for (String s : layerParams.keySet()) { for (String s : layerParams.keySet()) {
if (layer.isPretrainParam(s)) { if (layer.isPretrainParam(s)) {
@ -255,7 +255,7 @@ public class TestVAE extends BaseDL4JTest {
net.fit(features, labels); net.fit(features, labels);
} }
Map<String, INDArray> layerParamsAfter = layer.paramTable(); Map<String, INDArray> layerParamsAfter = layer.getParamTable();
for (String s : pretrainParamsBefore.keySet()) { for (String s : pretrainParamsBefore.keySet()) {
INDArray before = pretrainParamsBefore.get(s); INDArray before = pretrainParamsBefore.get(s);

View File

@ -104,13 +104,13 @@ public class WorkspaceTests extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf.clone()); MultiLayerNetwork net = new MultiLayerNetwork(conf.clone());
net.init(); net.init();
net.getConfiguration().setInferenceWorkspaceMode(WorkspaceMode.ENABLED); net.getNetConfiguration().setInferenceWorkspaceMode(WorkspaceMode.ENABLED);
net.getConfiguration().setTrainingWorkspaceMode(WorkspaceMode.ENABLED); net.getNetConfiguration().setTrainingWorkspaceMode(WorkspaceMode.ENABLED);
MultiLayerNetwork net2 = new MultiLayerNetwork(conf.clone()); MultiLayerNetwork net2 = new MultiLayerNetwork(conf.clone());
net2.init(); net2.init();
net2.getConfiguration().setInferenceWorkspaceMode(WorkspaceMode.NONE); net2.getNetConfiguration().setInferenceWorkspaceMode(WorkspaceMode.NONE);
net2.getConfiguration().setTrainingWorkspaceMode(WorkspaceMode.NONE); net2.getNetConfiguration().setTrainingWorkspaceMode(WorkspaceMode.NONE);
INDArray in = Nd4j.rand(1, 2, 5, 5); INDArray in = Nd4j.rand(1, 2, 5, 5);

View File

@ -817,15 +817,15 @@ public class MultiLayerTest extends BaseDL4JTest {
DataSetIterator iter = new IrisDataSetIterator(50, 150); DataSetIterator iter = new IrisDataSetIterator(50, 150);
assertEquals(0, network.getConfiguration().getIterationCount()); assertEquals(0, network.getNetConfiguration().getIterationCount());
network.fit(iter); network.fit(iter);
assertEquals(3, network.getConfiguration().getIterationCount()); assertEquals(3, network.getNetConfiguration().getIterationCount());
iter.reset(); iter.reset();
network.fit(iter); network.fit(iter);
assertEquals(6, network.getConfiguration().getIterationCount()); assertEquals(6, network.getNetConfiguration().getIterationCount());
iter.reset(); iter.reset();
network.fit(iter.next()); network.fit(iter.next());
assertEquals(7, network.getConfiguration().getIterationCount()); assertEquals(7, network.getNetConfiguration().getIterationCount());
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(network, baos, true); ModelSerializer.writeModel(network, baos, true);
@ -833,7 +833,7 @@ public class MultiLayerTest extends BaseDL4JTest {
ByteArrayInputStream bais = new ByteArrayInputStream(asBytes); ByteArrayInputStream bais = new ByteArrayInputStream(asBytes);
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(bais, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(7, net.getConfiguration().getIterationCount()); assertEquals(7, net.getNetConfiguration().getIterationCount());
} }
@ -1072,20 +1072,20 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0, net.getConfiguration().getEpochCount()); assertEquals(0, net.getNetConfiguration().getEpochCount());
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
assertEquals(i, net.getConfiguration().getEpochCount()); assertEquals(i, net.getNetConfiguration().getEpochCount());
net.fit(iter); net.fit(iter);
assertEquals(i + 1, net.getConfiguration().getEpochCount()); assertEquals(i + 1, net.getNetConfiguration().getEpochCount());
} }
assertEquals(4, net.getConfiguration().getEpochCount()); assertEquals(4, net.getNetConfiguration().getEpochCount());
MultiLayerNetwork restored = TestUtils.testModelSerialization(net); MultiLayerNetwork restored = TestUtils.testModelSerialization(net);
assertEquals(4, restored.getConfiguration().getEpochCount()); assertEquals(4, restored.getNetConfiguration().getEpochCount());
} }
@Test @Test

View File

@ -86,7 +86,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
Layer layer = network.getLayer(0); Layer layer = network.getLayer(0);
assertTrue(layer instanceof GravesLSTM); assertTrue(layer instanceof GravesLSTM);
Map<String, INDArray> paramTable = layer.paramTable(); Map<String, INDArray> paramTable = layer.getParamTable();
assertEquals(3, paramTable.size()); //2 sets of weights, 1 set of biases assertEquals(3, paramTable.size()); //2 sets of weights, 1 set of biases
INDArray recurrentWeights = paramTable.get(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY); INDArray recurrentWeights = paramTable.get(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY);
@ -131,7 +131,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
Layer layer = network.getLayer(i); Layer layer = network.getLayer(i);
assertTrue(layer instanceof GravesLSTM); assertTrue(layer instanceof GravesLSTM);
Map<String, INDArray> paramTable = layer.paramTable(); Map<String, INDArray> paramTable = layer.getParamTable();
assertEquals(3, paramTable.size()); //2 sets of weights, 1 set of biases assertEquals(3, paramTable.size()); //2 sets of weights, 1 set of biases
int layerNIn = (i == 0 ? nIn : nHiddenUnits[i - 1]); int layerNIn = (i == 0 ? nIn : nHiddenUnits[i - 1]);
@ -458,9 +458,9 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
mlnTBPTT.clearTbpttState = false; mlnTBPTT.clearTbpttState = false;
assertEquals(BackpropType.TruncatedBPTT, mlnTBPTT.getConfiguration().getBackpropType()); assertEquals(BackpropType.TruncatedBPTT, mlnTBPTT.getNetConfiguration().getBackpropType());
assertEquals(timeSeriesLength, mlnTBPTT.getConfiguration().getTbpttFwdLength()); assertEquals(timeSeriesLength, mlnTBPTT.getNetConfiguration().getTbpttFwdLength());
assertEquals(timeSeriesLength, mlnTBPTT.getConfiguration().getTbpttBackLength()); assertEquals(timeSeriesLength, mlnTBPTT.getNetConfiguration().getTbpttBackLength());
INDArray inputData = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); INDArray inputData = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
INDArray labels = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength); INDArray labels = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength);

View File

@ -124,8 +124,8 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
net2GradUpd.getUpdater().getStateViewArray()); net2GradUpd.getUpdater().getStateViewArray());
//Remove the next 2 lines: fails - as net 1 is 1 iteration ahead //Remove the next 2 lines: fails - as net 1 is 1 iteration ahead
net1GradCalc.getConfiguration().setIterationCount(0); net1GradCalc.getNetConfiguration().setIterationCount(0);
net2GradUpd.getConfiguration().setIterationCount(0); net2GradUpd.getNetConfiguration().setIterationCount(0);
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
net1GradCalc.fit(f, l); net1GradCalc.fit(f, l);

View File

@ -127,7 +127,7 @@ public class TestFrozenLayers extends BaseDL4JTest {
} }
Map<String,INDArray> paramsBefore = new LinkedHashMap<>(); Map<String,INDArray> paramsBefore = new LinkedHashMap<>();
for(Map.Entry<String,INDArray> entry : transfer.paramTable().entrySet()){ for(Map.Entry<String,INDArray> entry : transfer.getParamTable().entrySet()){
paramsBefore.put(entry.getKey(), entry.getValue().dup()); paramsBefore.put(entry.getKey(), entry.getValue().dup());
} }
@ -137,7 +137,7 @@ public class TestFrozenLayers extends BaseDL4JTest {
transfer.fit(new INDArray[]{f},new INDArray[]{l}); transfer.fit(new INDArray[]{f},new INDArray[]{l});
} }
for(Map.Entry<String,INDArray> entry : transfer.paramTable().entrySet()){ for(Map.Entry<String,INDArray> entry : transfer.getParamTable().entrySet()){
String s = msg + " - " + entry.getKey(); String s = msg + " - " + entry.getKey();
if(entry.getKey().startsWith("5_")){ if(entry.getKey().startsWith("5_")){
//Non-frozen layer //Non-frozen layer

View File

@ -70,9 +70,9 @@ public class TestTransferLearningModelSerializer extends BaseDL4JTest {
assertTrue(withFrozen.getLayer(0) instanceof FrozenLayer); assertTrue(withFrozen.getLayer(0) instanceof FrozenLayer);
assertTrue(withFrozen.getLayer(1) instanceof FrozenLayer); assertTrue(withFrozen.getLayer(1) instanceof FrozenLayer);
assertTrue(withFrozen.getConfiguration().getConf(0) assertTrue(withFrozen.getNetConfiguration().getConf(0)
.getLayer() instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); .getLayer() instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer);
assertTrue(withFrozen.getConfiguration().getConf(1) assertTrue(withFrozen.getNetConfiguration().getConf(1)
.getLayer() instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); .getLayer() instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer);
MultiLayerNetwork restored = TestUtils.testModelSerialization(withFrozen); MultiLayerNetwork restored = TestUtils.testModelSerialization(withFrozen);
@ -120,8 +120,8 @@ public class TestTransferLearningModelSerializer extends BaseDL4JTest {
assertTrue(withFrozen.getLayer(1) instanceof FrozenLayer); assertTrue(withFrozen.getLayer(1) instanceof FrozenLayer);
Map<String, GraphVertex> m = withFrozen.getComputationGraphConfiguration().getVertices(); Map<String, GraphVertex> m = withFrozen.getComputationGraphConfiguration().getVertices();
LayerConfiguration l0 = ((LayerVertex) m.get("0")).getNetConfiguration().getFirstLayer(); LayerConfiguration l0 = ((LayerVertex) m.get("0")).getLayerConfiguration();
LayerConfiguration l1 = ((LayerVertex) m.get("1")).getNetConfiguration().getFirstLayer(); LayerConfiguration l1 = ((LayerVertex) m.get("1")).getLayerConfiguration();
assertTrue(l0 instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); assertTrue(l0 instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer);
assertTrue(l1 instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); assertTrue(l1 instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer);

View File

@ -605,13 +605,13 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
cg2.output(arr); cg2.output(arr);
Map<String,INDArray> m = new HashMap<>(cg.paramTable()); Map<String,INDArray> m = new HashMap<>(cg.getParamTable());
m.put("newOut_W", m.remove("out_W")); m.put("newOut_W", m.remove("out_W"));
m.put("newOut_b", m.remove("out_b")); m.put("newOut_b", m.remove("out_b"));
cg2.setParamTable(m); cg2.setParamTable(m);
Map<String,INDArray> p1 = cg.paramTable(); Map<String,INDArray> p1 = cg.getParamTable();
Map<String,INDArray> p2 = cg2.paramTable(); Map<String,INDArray> p2 = cg2.getParamTable();
for(String s : p1.keySet()){ for(String s : p1.keySet()){
INDArray i1 = p1.get(s); INDArray i1 = p1.get(s);
INDArray i2 = p2.get(s.replaceAll("out", "newOut")); INDArray i2 = p2.get(s.replaceAll("out", "newOut"));
@ -651,13 +651,13 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
cg2.output(arr); cg2.output(arr);
Map<String,INDArray> m = new HashMap<>(cg.paramTable()); Map<String,INDArray> m = new HashMap<>(cg.getParamTable());
m.put("newOut_W", m.remove("out_W")); m.put("newOut_W", m.remove("out_W"));
m.put("newOut_b", m.remove("out_b")); m.put("newOut_b", m.remove("out_b"));
cg2.setParamTable(m); cg2.setParamTable(m);
Map<String,INDArray> p1 = cg.paramTable(); Map<String,INDArray> p1 = cg.getParamTable();
Map<String,INDArray> p2 = cg2.paramTable(); Map<String,INDArray> p2 = cg2.getParamTable();
for(String s : p1.keySet()){ for(String s : p1.keySet()){
INDArray i1 = p1.get(s); INDArray i1 = p1.get(s);
INDArray i2 = p2.get(s.replaceAll("out", "newOut")); INDArray i2 = p2.get(s.replaceAll("out", "newOut"));

View File

@ -112,8 +112,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
assertEquals(expectedModel.params(), modelNow.params()); assertEquals(expectedModel.params(), modelNow.params());
//Check json //Check json
NeuralNetConfiguration expectedConf = expectedModel.getConfiguration(); NeuralNetConfiguration expectedConf = expectedModel.getNetConfiguration();
assertEquals(expectedConf.toJson(), modelNow.getConfiguration().toJson()); assertEquals(expectedConf.toJson(), modelNow.getNetConfiguration().toJson());
//Check params after fit //Check params after fit
modelNow.fit(randomData); modelNow.fit(randomData);
@ -160,9 +160,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
//Will fail - expected because of dist and weight init changes //Will fail - expected because of dist and weight init changes
//assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson()); //assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson());
BaseLayer bl0 = ((BaseLayer) modelNow.getConfiguration().getConf(0).getLayer()); BaseLayer bl0 = ((BaseLayer) modelNow.getNetConfiguration().getConf(0).getLayer());
BaseLayer bl1 = ((BaseLayer) modelNow.getConfiguration().getConf(1).getLayer()); BaseLayer bl1 = ((BaseLayer) modelNow.getNetConfiguration().getConf(1).getLayer());
BaseLayer bl3 = ((BaseLayer) modelNow.getConfiguration().getConf(3).getLayer()); BaseLayer bl3 = ((BaseLayer) modelNow.getNetConfiguration().getConf(3).getLayer());
assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class); assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class);
try { try {
assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()), assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()),
@ -357,18 +357,18 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
.setInputPreProcessor(4, new FeedForwardToRnnPreProcessor()).build(); .setInputPreProcessor(4, new FeedForwardToRnnPreProcessor()).build();
//modelNow should have the same architecture as modelExpectedArch //modelNow should have the same architecture as modelExpectedArch
assertEquals(modelExpectedArch.getConfiguration().getConf(0).toJson(), assertEquals(modelExpectedArch.getNetConfiguration().getConf(0).toJson(),
modelNow.getConfiguration().getConf(0).toJson()); modelNow.getNetConfiguration().getConf(0).toJson());
//some learning related info the subsampling layer will not be overwritten //some learning related info the subsampling layer will not be overwritten
//assertTrue(modelExpectedArch.getConfiguration().getConf(1).toJson().equals(modelNow.getConfiguration().getConf(1).toJson())); //assertTrue(modelExpectedArch.getConfiguration().getConf(1).toJson().equals(modelNow.getConfiguration().getConf(1).toJson()));
assertEquals(modelExpectedArch.getConfiguration().getConf(2).toJson(), assertEquals(modelExpectedArch.getNetConfiguration().getConf(2).toJson(),
modelNow.getConfiguration().getConf(2).toJson()); modelNow.getNetConfiguration().getConf(2).toJson());
assertEquals(modelExpectedArch.getConfiguration().getConf(3).toJson(), assertEquals(modelExpectedArch.getNetConfiguration().getConf(3).toJson(),
modelNow.getConfiguration().getConf(3).toJson()); modelNow.getNetConfiguration().getConf(3).toJson());
assertEquals(modelExpectedArch.getConfiguration().getConf(4).toJson(), assertEquals(modelExpectedArch.getNetConfiguration().getConf(4).toJson(),
modelNow.getConfiguration().getConf(4).toJson()); modelNow.getNetConfiguration().getConf(4).toJson());
assertEquals(modelExpectedArch.getConfiguration().getConf(5).toJson(), assertEquals(modelExpectedArch.getNetConfiguration().getConf(5).toJson(),
modelNow.getConfiguration().getConf(5).toJson()); modelNow.getNetConfiguration().getConf(5).toJson());
assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape());
@ -530,7 +530,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); assertEquals(new WeightInitRelu(), l1.getWeightInitFn());
assertEquals(0.2, TestUtils.getL2(l1), 1e-6); assertEquals(0.2, TestUtils.getL2(l1), 1e-6);
assertEquals(BackpropType.TruncatedBPTT, net2.getConfiguration().getBackpropType()); assertEquals(BackpropType.TruncatedBPTT, net2.getNetConfiguration().getBackpropType());
} }
@Test @Test

View File

@ -52,9 +52,9 @@ public class TestGradientNormalization extends BaseDL4JTest {
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).build()) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).build())
.build(); .build();
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
INDArray gradArray = Nd4j.rand(1, 220).muli(10).subi(5); INDArray gradArray = Nd4j.rand(1, 220).muli(10).subi(5);
layer.setBackpropGradientsViewArray(gradArray); layer.setBackpropGradientsViewArray(gradArray);
INDArray weightGrad = Shape.newShapeNoCopy(gradArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 200)), INDArray weightGrad = Shape.newShapeNoCopy(gradArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 200)),
@ -98,9 +98,9 @@ public class TestGradientNormalization extends BaseDL4JTest {
.gradientNormalization(GradientNormalization.RenormalizeL2PerParamType).build()) .gradientNormalization(GradientNormalization.RenormalizeL2PerParamType).build())
.build(); .build();
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(Nd4j.create(params.shape())); layer.setBackpropGradientsViewArray(Nd4j.create(params.shape()));
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
INDArray weightGrad = Nd4j.rand(10, 20); INDArray weightGrad = Nd4j.rand(10, 20);
@ -131,9 +131,9 @@ public class TestGradientNormalization extends BaseDL4JTest {
.gradientNormalizationThreshold(threshold).build()) .gradientNormalizationThreshold(threshold).build())
.build(); .build();
long numParams = conf.getFirstLayer().initializer().numParams(conf); long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
INDArray gradArray = Nd4j.rand(1, 220).muli(10).subi(5); INDArray gradArray = Nd4j.rand(1, 220).muli(10).subi(5);
layer.setBackpropGradientsViewArray(gradArray); layer.setBackpropGradientsViewArray(gradArray);
INDArray weightGrad = Shape.newShapeNoCopy(gradArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 200)), INDArray weightGrad = Shape.newShapeNoCopy(gradArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 200)),
@ -187,9 +187,9 @@ public class TestGradientNormalization extends BaseDL4JTest {
.gradientNormalizationThreshold(threshold).build()) .gradientNormalizationThreshold(threshold).build())
.build(); .build();
val numParams = conf.getFirstLayer().initializer().numParams(conf); val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
INDArray gradArray = Nd4j.rand(1, 220).muli(t == 0 ? 0.05 : 10).subi(t == 0 ? 0 : 5); INDArray gradArray = Nd4j.rand(1, 220).muli(t == 0 ? 0.05 : 10).subi(t == 0 ? 0 : 5);
layer.setBackpropGradientsViewArray(gradArray); layer.setBackpropGradientsViewArray(gradArray);
INDArray weightGrad = INDArray weightGrad =
@ -242,9 +242,9 @@ public class TestGradientNormalization extends BaseDL4JTest {
.gradientNormalizationThreshold(threshold).build()) .gradientNormalizationThreshold(threshold).build())
.build(); .build();
val numParams = conf.getFirstLayer().initializer().numParams(conf); val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(Nd4j.create(params.shape())); layer.setBackpropGradientsViewArray(Nd4j.create(params.shape()));
Updater updater = UpdaterCreator.getUpdater(layer); Updater updater = UpdaterCreator.getUpdater(layer);
INDArray weightGrad = Nd4j.rand(10, 20).muli(0.05); INDArray weightGrad = Nd4j.rand(10, 20).muli(0.05);

View File

@ -20,6 +20,7 @@
package org.deeplearning4j.optimize.solver; package org.deeplearning4j.optimize.solver;
import lombok.NonNull;
import lombok.val; import lombok.val;
import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -44,6 +45,7 @@ import org.deeplearning4j.optimize.solvers.LineGradientDescent;
import org.deeplearning4j.optimize.solvers.StochasticGradientDescent; import org.deeplearning4j.optimize.solvers.StochasticGradientDescent;
import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction; import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -52,7 +54,9 @@ import org.nd4j.linalg.api.ops.impl.transforms.strict.Sin;
import org.nd4j.linalg.api.rng.DefaultRandom; import org.nd4j.linalg.api.rng.DefaultRandom;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Condition; import org.nd4j.linalg.indexing.conditions.Condition;
@ -317,6 +321,90 @@ public class TestOptimizers extends BaseDL4JTest {
} }
/**
* This method returns updater state (if applicable), null otherwise
*
* @return
*/
@Override
public INDArray updaterState() {
return null;
}
/**
* This method fits model with a given DataSet
*
* @param dataSet
*/
@Override
public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet) {
}
/**
* This method fits model with a given MultiDataSet
*
* @param dataSet
*/
@Override
public void fit(MultiDataSet dataSet) {
}
/**
* This method fits model with a given DataSetIterator
*
* @param iterator
*/
@Override
public void fit(DataSetIterator iterator) {
}
/**
* This method fits model with a given MultiDataSetIterator
*
* @param iterator
*/
@Override
public void fit(MultiDataSetIterator iterator) {
}
/**
* This method executes evaluation of the model against given iterator and evaluation
* implementations
*
* @param iterator
* @param evaluations
*/
@Override
public <T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator,
T... evaluations) {
return null;
}
/**
* This method executes evaluation of the model against given iterator and evaluation
* implementations
*
* @param iterator
* @param evaluations
*/
@Override
public <T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator,
T... evaluations) {
return null;
}
/**
* @param netConfiguration
*/
@Override
public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) {
}
@Override @Override
public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
// Gradients: d(x^2)/dx = 2x // Gradients: d(x^2)/dx = 2x
@ -464,6 +552,90 @@ public class TestOptimizers extends BaseDL4JTest {
} }
/**
* This method returns updater state (if applicable), null otherwise
*
* @return
*/
@Override
public INDArray updaterState() {
return null;
}
/**
* This method fits model with a given DataSet
*
* @param dataSet
*/
@Override
public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet) {
}
/**
* This method fits model with a given MultiDataSet
*
* @param dataSet
*/
@Override
public void fit(MultiDataSet dataSet) {
}
/**
* This method fits model with a given DataSetIterator
*
* @param iterator
*/
@Override
public void fit(DataSetIterator iterator) {
}
/**
* This method fits model with a given MultiDataSetIterator
*
* @param iterator
*/
@Override
public void fit(MultiDataSetIterator iterator) {
}
/**
* This method executes evaluation of the model against given iterator and evaluation
* implementations
*
* @param iterator
* @param evaluations
*/
@Override
public <T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator,
T... evaluations) {
return null;
}
/**
* This method executes evaluation of the model against given iterator and evaluation
* implementations
*
* @param iterator
* @param evaluations
*/
@Override
public <T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator,
T... evaluations) {
return null;
}
/**
* @param netConfiguration
*/
@Override
public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) {
}
@Override @Override
public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
//Gradient decomposes due to sum, so: //Gradient decomposes due to sum, so:
@ -649,6 +821,90 @@ public class TestOptimizers extends BaseDL4JTest {
return dist.sample(new int[] {1, nDimensions}); return dist.sample(new int[] {1, nDimensions});
} }
/**
* This method returns updater state (if applicable), null otherwise
*
* @return
*/
@Override
public INDArray updaterState() {
return null;
}
/**
* This method fits model with a given DataSet
*
* @param dataSet
*/
@Override
public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet) {
}
/**
* This method fits model with a given MultiDataSet
*
* @param dataSet
*/
@Override
public void fit(MultiDataSet dataSet) {
}
/**
* This method fits model with a given DataSetIterator
*
* @param iterator
*/
@Override
public void fit(DataSetIterator iterator) {
}
/**
* This method fits model with a given MultiDataSetIterator
*
* @param iterator
*/
@Override
public void fit(MultiDataSetIterator iterator) {
}
/**
* This method executes evaluation of the model against given iterator and evaluation
* implementations
*
* @param iterator
* @param evaluations
*/
@Override
public <T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator,
T... evaluations) {
return null;
}
/**
* This method executes evaluation of the model against given iterator and evaluation
* implementations
*
* @param iterator
* @param evaluations
*/
@Override
public <T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator,
T... evaluations) {
return null;
}
/**
* @param netConfiguration
*/
@Override
public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) {
}
@Override @Override
public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
val nDims = parameters.length(); val nDims = parameters.length();
@ -912,7 +1168,7 @@ public class TestOptimizers extends BaseDL4JTest {
} }
@Override @Override
public void setLayerConfiguration(NeuralNetConfiguration layerConfiguration) { public void setLayerConfiguration(LayerConfiguration layerConfiguration) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@ -934,13 +1190,13 @@ public class TestOptimizers extends BaseDL4JTest {
} }
@Override @Override
public Map<String, INDArray> paramTable() { public Map<String, INDArray> getParamTable() {
return Collections.singletonMap("W", getParam("W")); return Collections.singletonMap("W", getParam("W"));
} }
@Override @Override
public Map<String, INDArray> paramTable(boolean backpropParamsOnly) { public Map<String, INDArray> getParamTable(boolean backpropParamsOnly) {
return paramTable(); return getParamTable();
} }
@Override @Override

View File

@ -65,7 +65,7 @@ public class RegressionTest050 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(2, conf.getNetConfigurations().size()); assertEquals(2, conf.getNetConfigurations().size());
DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
@ -99,7 +99,7 @@ public class RegressionTest050 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(2, conf.getNetConfigurations().size()); assertEquals(2, conf.getNetConfigurations().size());
DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
@ -138,7 +138,7 @@ public class RegressionTest050 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(3, conf.getNetConfigurations().size()); assertEquals(3, conf.getNetConfigurations().size());
ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer();

View File

@ -67,7 +67,7 @@ public class RegressionTest060 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(2, conf.getNetConfigurations().size()); assertEquals(2, conf.getNetConfigurations().size());
DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
@ -101,7 +101,7 @@ public class RegressionTest060 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(2, conf.getNetConfigurations().size()); assertEquals(2, conf.getNetConfigurations().size());
DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
@ -144,7 +144,7 @@ public class RegressionTest060 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(3, conf.getNetConfigurations().size()); assertEquals(3, conf.getNetConfigurations().size());
ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer();
@ -190,7 +190,7 @@ public class RegressionTest060 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(3, conf.getNetConfigurations().size()); assertEquals(3, conf.getNetConfigurations().size());
GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer(); GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer();

View File

@ -68,7 +68,7 @@ public class RegressionTest071 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(2, conf.getNetConfigurations().size()); assertEquals(2, conf.getNetConfigurations().size());
DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
@ -102,7 +102,7 @@ public class RegressionTest071 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(2, conf.getNetConfigurations().size()); assertEquals(2, conf.getNetConfigurations().size());
DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
@ -145,7 +145,7 @@ public class RegressionTest071 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(3, conf.getNetConfigurations().size()); assertEquals(3, conf.getNetConfigurations().size());
ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer();
@ -191,7 +191,7 @@ public class RegressionTest071 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(3, conf.getNetConfigurations().size()); assertEquals(3, conf.getNetConfigurations().size());
GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer(); GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer();

View File

@ -67,7 +67,7 @@ public class RegressionTest080 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(2, conf.getNetConfigurations().size()); assertEquals(2, conf.getNetConfigurations().size());
DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
@ -106,7 +106,7 @@ public class RegressionTest080 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(2, conf.getNetConfigurations().size()); assertEquals(2, conf.getNetConfigurations().size());
DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer();
@ -155,7 +155,7 @@ public class RegressionTest080 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(3, conf.getNetConfigurations().size()); assertEquals(3, conf.getNetConfigurations().size());
ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer();
@ -206,7 +206,7 @@ public class RegressionTest080 extends BaseDL4JTest {
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true);
NeuralNetConfiguration conf = net.getConfiguration(); NeuralNetConfiguration conf = net.getNetConfiguration();
assertEquals(3, conf.getNetConfigurations().size()); assertEquals(3, conf.getNetConfigurations().size());
GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer(); GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer();

View File

@ -107,9 +107,9 @@ public class RegressionTest100a extends BaseDL4JTest {
assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new RmsProp(0.1), l0.getIUpdater()); assertEquals(new RmsProp(0.1), l0.getIUpdater());
assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
assertEquals(50, net.getConfiguration().getTbpttBackLength()); assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
assertEquals(50, net.getConfiguration().getTbpttFwdLength()); assertEquals(50, net.getNetConfiguration().getTbpttFwdLength());
INDArray outExp; INDArray outExp;
File f2 = Resources.asFile("regression_testing/100a/GravesLSTMCharModelingExample_Output_100a.bin"); File f2 = Resources.asFile("regression_testing/100a/GravesLSTMCharModelingExample_Output_100a.bin");

View File

@ -108,7 +108,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
List<INDArray> activations = net.feedForward(in); List<INDArray> activations = net.feedForward(in);
assertEquals(dt, net.getConfiguration().getDataType()); assertEquals(dt, net.getNetConfiguration().getDataType());
assertEquals(dt, net.params().dataType()); assertEquals(dt, net.params().dataType());
assertEquals( outExp, outAct, dtype); assertEquals( outExp, outAct, dtype);
} }
@ -142,9 +142,9 @@ public class RegressionTest100b3 extends BaseDL4JTest {
assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0));
assertEquals(new Adam(0.005), l0.getIUpdater()); assertEquals(new Adam(0.005), l0.getIUpdater());
assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
assertEquals(50, net.getConfiguration().getTbpttBackLength()); assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
assertEquals(50, net.getConfiguration().getTbpttFwdLength()); assertEquals(50, net.getNetConfiguration().getTbpttFwdLength());
INDArray outExp; INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b3/GravesLSTMCharModelingExample_Output_100b3.bin"); File f2 = Resources.asFile("regression_testing/100b3/GravesLSTMCharModelingExample_Output_100b3.bin");

View File

@ -125,7 +125,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
INDArray outAct = net.output(in); INDArray outAct = net.output(in);
assertEquals(dtype, outAct.dataType()); assertEquals(dtype, outAct.dataType());
assertEquals(dtype, net.getConfiguration().getDataType()); assertEquals(dtype, net.getNetConfiguration().getDataType());
assertEquals(dtype, net.params().dataType()); assertEquals(dtype, net.params().dataType());
boolean eq = outExp.equalsWithEps(outAct, 0.01); boolean eq = outExp.equalsWithEps(outAct, 0.01);
assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct); assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct);
@ -160,9 +160,9 @@ public class RegressionTest100b4 extends BaseDL4JTest {
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
assertEquals(new Adam(0.005), l2.getIUpdater()); assertEquals(new Adam(0.005), l2.getIUpdater());
assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
assertEquals(50, net.getConfiguration().getTbpttBackLength()); assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
assertEquals(50, net.getConfiguration().getTbpttFwdLength()); assertEquals(50, net.getNetConfiguration().getTbpttFwdLength());
INDArray outExp; INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b4/GravesLSTMCharModelingExample_Output_100b4.bin"); File f2 = Resources.asFile("regression_testing/100b4/GravesLSTMCharModelingExample_Output_100b4.bin");

View File

@ -107,7 +107,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
INDArray outAct = net.output(in); INDArray outAct = net.output(in);
assertEquals(dtype, outAct.dataType()); assertEquals(dtype, outAct.dataType());
assertEquals(dtype, net.getConfiguration().getDataType()); assertEquals(dtype, net.getNetConfiguration().getDataType());
assertEquals(dtype, net.params().dataType()); assertEquals(dtype, net.params().dataType());
boolean eq = outExp.equalsWithEps(outAct, 0.01); boolean eq = outExp.equalsWithEps(outAct, 0.01);
assertTrue( eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct); assertTrue( eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct);
@ -142,9 +142,9 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
assertEquals(new Adam(0.005), l2.getIUpdater()); assertEquals(new Adam(0.005), l2.getIUpdater());
assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
assertEquals(50, net.getConfiguration().getTbpttBackLength()); assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
assertEquals(50, net.getConfiguration().getTbpttFwdLength()); assertEquals(50, net.getNetConfiguration().getTbpttFwdLength());
INDArray outExp; INDArray outExp;
File f2 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Output_100b6.bin"); File f2 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Output_100b6.bin");

View File

@ -28,6 +28,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.DefaultParamInitializer;
@ -68,11 +69,13 @@ public class CustomLayer extends FeedForwardLayer {
@Override @Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> iterationListeners, public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> iterationListeners,
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(0);
//The instantiate method is how we go from the configuration class (i.e., this class) to the implementation class //The instantiate method is how we go from the configuration class (i.e., this class) to the implementation class
// (i.e., a CustomLayerImpl instance) // (i.e., a CustomLayerImpl instance)
//For the most part, it's the same for each type of layer //For the most part, it's the same for each type of layer
CustomLayerImpl myCustomLayer = new CustomLayerImpl(conf, networkDataType); CustomLayerImpl myCustomLayer = new CustomLayerImpl(lconf, networkDataType);
myCustomLayer.setListeners(iterationListeners); //Set the iteration listeners, if any myCustomLayer.setListeners(iterationListeners); //Set the iteration listeners, if any
myCustomLayer.setIndex(layerIndex); //Integer index of the layer myCustomLayer.setIndex(layerIndex); //Integer index of the layer
@ -87,7 +90,7 @@ public class CustomLayer extends FeedForwardLayer {
// are in turn a view of the 'layerParamsView' array. // are in turn a view of the 'layerParamsView' array.
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams); Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
myCustomLayer.setParamTable(paramTable); myCustomLayer.setParamTable(paramTable);
myCustomLayer.setLayerConfiguration(conf); myCustomLayer.setLayerConfiguration(lconf);
return myCustomLayer; return myCustomLayer;
} }

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.regressiontest.customlayer100a; package org.deeplearning4j.regressiontest.customlayer100a;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer; import org.deeplearning4j.nn.layers.BaseLayer;
@ -35,7 +36,7 @@ import org.nd4j.common.primitives.Pair;
public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic parameter here: the configuration class type public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic parameter here: the configuration class type
public CustomLayerImpl(NeuralNetConfiguration conf, DataType dataType) { public CustomLayerImpl(LayerConfiguration conf, DataType dataType) {
super(conf, dataType); super(conf, dataType);
} }
@ -56,7 +57,7 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
INDArray secondHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns)); INDArray secondHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
IActivation activation1 = layerConf().getActivationFn(); IActivation activation1 = layerConf().getActivationFn();
IActivation activation2 = ((CustomLayer) layerConfiguration.getFirstLayer()).getSecondActivationFunction(); IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
//IActivation function instances modify the activation functions in-place //IActivation function instances modify the activation functions in-place
activation1.getActivation(firstHalf, training); activation1.getActivation(firstHalf, training);
@ -105,7 +106,7 @@ public class CustomLayerImpl extends BaseLayer<CustomLayer> { //Generic paramete
INDArray epsilonSecondHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns)); INDArray epsilonSecondHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns));
IActivation activation1 = layerConf().getActivationFn(); IActivation activation1 = layerConf().getActivationFn();
IActivation activation2 = ((CustomLayer) layerConfiguration.getFirstLayer()).getSecondActivationFunction(); IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction();
//IActivation backprop method modifies the 'firstHalf' and 'secondHalf' arrays in-place, to contain dL/dz //IActivation backprop method modifies the 'firstHalf' and 'secondHalf' arrays in-place, to contain dL/dz
activation1.backprop(firstHalf, epsilonFirstHalf); activation1.backprop(firstHalf, epsilonFirstHalf);

View File

@ -155,7 +155,7 @@ public class ModelGuesserTest extends BaseDL4JTest {
ModelSerializer.writeModel(net, tempFile, true); ModelSerializer.writeModel(net, tempFile, true);
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
assertEquals(net.params(), network.params()); assertEquals(net.params(), network.params());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
@ -172,7 +172,7 @@ public class ModelGuesserTest extends BaseDL4JTest {
try (InputStream inputStream = new FileInputStream(tempFile)) { try (InputStream inputStream = new FileInputStream(tempFile)) {
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream); MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream);
Assertions.assertNotNull(network); Assertions.assertNotNull(network);
assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
assertEquals(net.params(), network.params()); assertEquals(net.params(), network.params());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
} }

View File

@ -80,7 +80,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile); MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile);
assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
assertEquals(net.params(), network.params()); assertEquals(net.params(), network.params());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
} }
@ -124,7 +124,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis); MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis);
assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson());
assertEquals(net.params(), network.params()); assertEquals(net.params(), network.params());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
} }

View File

@ -24,7 +24,6 @@ import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
@ -57,7 +56,7 @@ public class KerasModelImportTest extends BaseDL4JTest {
@Test @Test
public void testNCHWNWHCChangeImport() { public void testNCHWNWHCChangeImport() {
MultiLayerNetwork model = loadModel("modelimport/keras/weights/conv2dnchw/simpleconv2d.hdf5"); MultiLayerNetwork model = loadModel("modelimport/keras/weights/conv2dnchw/simpleconv2d.hdf5");
List<LayerConfiguration> layerConfigs = model.getConfiguration().getFlattenedLayerConfigurations(); List<LayerConfiguration> layerConfigs = model.getNetConfiguration().getFlattenedLayerConfigurations();
ConvolutionLayer convolutionLayer = (ConvolutionLayer) layerConfigs.get(0); ConvolutionLayer convolutionLayer = (ConvolutionLayer) layerConfigs.get(0);
assertEquals(CNN2DFormat.NCHW,convolutionLayer.getCnn2dDataFormat()); assertEquals(CNN2DFormat.NCHW,convolutionLayer.getCnn2dDataFormat());
SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layerConfigs.get(1); SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layerConfigs.get(1);

View File

@ -208,7 +208,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
assertTrue(net.params().equalsWithEps(restored.params(), 2e-3)); assertTrue(net.params().equalsWithEps(restored.params(), 2e-3));
} }
} }

View File

@ -46,6 +46,15 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
public interface IModel { public interface IModel {
/**
* The param table
*
* @return
*/
Map<String, INDArray> getParamTable();
Map<String, INDArray> getParamTable(boolean backpropOnly);
/** /**
* This method returns updater state (if applicable), null otherwise * This method returns updater state (if applicable), null otherwise
* *
@ -273,6 +282,7 @@ public interface IModel {
* @param listeners new listeners * @param listeners new listeners
*/ */
void setListeners(TrainingListener... listeners); void setListeners(TrainingListener... listeners);
void setListeners(Collection<TrainingListener> listeners);
/** /**
* Add TrainingListeners to the model * Add TrainingListeners to the model

View File

@ -1126,6 +1126,17 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
return getFlattenedLayerConfigurations().get(index); return getFlattenedLayerConfigurations().get(index);
} }
/**
* Deprecated, do not use. Workaround for old tests
* and getFlattenedLayerConfigurations().get(0);
* @return
*/
@Deprecated
public LayerConfiguration getFirstLayer() {
log.warn("This getFirstLayer method is an ugly workaround and will be removed.");
return getFlattenedLayerConfigurations().get(0);
}
public static abstract class NeuralNetConfigurationBuilder<C extends NeuralNetConfiguration, public static abstract class NeuralNetConfigurationBuilder<C extends NeuralNetConfiguration,
B extends NeuralNetConfiguration.NeuralNetConfigurationBuilder<C, B>> extends B extends NeuralNetConfiguration.NeuralNetConfigurationBuilder<C, B>> extends
NeuralNetBaseBuilderConfigurationBuilder<C, B> { NeuralNetBaseBuilderConfigurationBuilder<C, B> {

View File

@ -258,6 +258,9 @@ public abstract class LayerConfiguration implements TrainingConfig, Serializable
"Not supported: all layers with parameters should override this method"); "Not supported: all layers with parameters should override this method");
} }
@Getter
private IUpdater iUpdater;
@Override @Override
public void setDataType(DataType dataType) { public void setDataType(DataType dataType) {
//No-op for most layers //No-op for most layers

View File

@ -2443,6 +2443,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial
} }
} }
/**
* @param listeners
*/
@Override
public void setListeners(Collection<TrainingListener> listeners) {
setListeners(listeners.toArray(new TrainingListener[]{}));
}
/** /**
* @deprecated Use {@link #getListeners()} * @deprecated Use {@link #getListeners()}
*/ */
@ -4525,4 +4533,5 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial
public String toString() { public String toString() {
return getNetConfiguration().toString(); return getNetConfiguration().toString();
} }
} }

View File

@ -94,13 +94,13 @@ public class EarlyStoppingParallelTrainer<T extends IModel> implements IEarlySto
Collection<TrainingListener> listeners = ((MultiLayerNetwork) model).getListeners(); Collection<TrainingListener> listeners = ((MultiLayerNetwork) model).getListeners();
Collection<TrainingListener> newListeners = new LinkedList<>(listeners); Collection<TrainingListener> newListeners = new LinkedList<>(listeners);
newListeners.add(trainerListener); newListeners.add(trainerListener);
model.setListeners(newListeners); model.setListeners(newListeners.toArray(new TrainingListener[]{}));
} else if (model instanceof ComputationGraph) { } else if (model instanceof ComputationGraph) {
Collection<TrainingListener> listeners = ((ComputationGraph) model).getListeners(); Collection<TrainingListener> listeners = ((ComputationGraph) model).getListeners();
Collection<TrainingListener> newListeners = new LinkedList<>(listeners); Collection<TrainingListener> newListeners = new LinkedList<>(listeners);
newListeners.add(trainerListener); newListeners.add(trainerListener);
model.setListeners(newListeners); model.setListeners(newListeners.toArray(new TrainingListener[]{}));
} }
this.wrapper = new ParallelWrapper.Builder<>(model).workers(workers).prefetchBuffer(prefetchBuffer) this.wrapper = new ParallelWrapper.Builder<>(model).workers(workers).prefetchBuffer(prefetchBuffer)

View File

@ -204,7 +204,7 @@ public class InplaceParallelInference extends ParallelInference {
if (loadBalanceMode == LoadBalanceMode.FIFO) if (loadBalanceMode == LoadBalanceMode.FIFO)
queue.add(model); queue.add(model);
} else if (sourceModel instanceof MultiLayerNetwork) { } else if (sourceModel instanceof MultiLayerNetwork) {
val model = new MultiLayerNetwork(NeuralNetConfiguration.fromJson(((MultiLayerNetwork) sourceModel).getConfiguration().toJson())); val model = new MultiLayerNetwork(NeuralNetConfiguration.fromJson(((MultiLayerNetwork) sourceModel).getNetConfiguration().toJson()));
model.init(params, false); model.init(params, false);
Nd4j.getExecutioner().commit(); Nd4j.getExecutioner().commit();

View File

@ -472,7 +472,7 @@ public class ParallelInference {
} else if (protoModel instanceof MultiLayerNetwork) { } else if (protoModel instanceof MultiLayerNetwork) {
if (!rootDevice) { if (!rootDevice) {
this.replicatedModel = new MultiLayerNetwork(NeuralNetConfiguration.fromJson( this.replicatedModel = new MultiLayerNetwork(NeuralNetConfiguration.fromJson(
((MultiLayerNetwork) protoModel).getConfiguration().toJson())); ((MultiLayerNetwork) protoModel).getNetConfiguration().toJson()));
this.replicatedModel.init(); this.replicatedModel.init();
synchronized (locker) { synchronized (locker) {

View File

@ -957,10 +957,10 @@ public class ParallelWrapper implements AutoCloseable {
List<TrainingListener> modelListeners = null; List<TrainingListener> modelListeners = null;
if (model instanceof MultiLayerNetwork) { if (model instanceof MultiLayerNetwork) {
modelListeners = new ArrayList<>(((MultiLayerNetwork) model).getListeners()); modelListeners = new ArrayList<>(((MultiLayerNetwork) model).getListeners());
model.setListeners(Collections.emptyList()); model.setListeners(new TrainingListener[]{});
} else if (model instanceof ComputationGraph) { } else if (model instanceof ComputationGraph) {
modelListeners = new ArrayList<>(((ComputationGraph) model).getListeners()); modelListeners = new ArrayList<>(((ComputationGraph) model).getListeners());
model.setListeners(Collections.emptyList()); model.setListeners(new TrainingListener[]{});
} }
if (modelListeners != null && !modelListeners.isEmpty()) { if (modelListeners != null && !modelListeners.isEmpty()) {

View File

@ -278,7 +278,7 @@ public class DefaultTrainer extends Thread implements Trainer {
} }
configureListeners(uuid, oldListeners, replicatedListeners); configureListeners(uuid, oldListeners, replicatedListeners);
this.replicatedModel.setListeners(replicatedListeners); this.replicatedModel.setListeners(replicatedListeners.toArray(new TrainingListener[]{}));
} }
@Override @Override
@ -296,7 +296,7 @@ public class DefaultTrainer extends Thread implements Trainer {
if (originalModel instanceof MultiLayerNetwork) { if (originalModel instanceof MultiLayerNetwork) {
if (!onRootModel) { if (!onRootModel) {
NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson( NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson(
((MultiLayerNetwork) originalModel).getConfiguration().toJson()); ((MultiLayerNetwork) originalModel).getNetConfiguration().toJson());
conf.setTrainingWorkspaceMode(workspaceMode); conf.setTrainingWorkspaceMode(workspaceMode);
this.replicatedModel = new MultiLayerNetwork(conf); this.replicatedModel = new MultiLayerNetwork(conf);
@ -323,7 +323,7 @@ public class DefaultTrainer extends Thread implements Trainer {
if (!((MultiLayerNetwork) replicatedModel).isInitCalled()) if (!((MultiLayerNetwork) replicatedModel).isInitCalled())
this.replicatedModel.init(); this.replicatedModel.init();
((MultiLayerNetwork) replicatedModel).getConfiguration() ((MultiLayerNetwork) replicatedModel).getNetConfiguration()
.setTrainingWorkspaceMode(workspaceMode); .setTrainingWorkspaceMode(workspaceMode);
} }
} else if (originalModel instanceof ComputationGraph) { } else if (originalModel instanceof ComputationGraph) {

View File

@ -122,7 +122,7 @@ public class SparkDl4jMultiLayer extends SparkListenable {
public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerNetwork network, public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerNetwork network,
TrainingMaster<?, ?> trainingMaster) { TrainingMaster<?, ?> trainingMaster) {
sc = javaSparkContext; sc = javaSparkContext;
this.conf = network.getConfiguration().clone(); this.conf = network.getNetConfiguration().clone();
this.network = network; this.network = network;
if (!network.isInitCalled()) if (!network.isInitCalled())
network.init(); network.init();
@ -315,8 +315,8 @@ public class SparkDl4jMultiLayer extends SparkListenable {
* @return the multi layer network that was fitDataSet * @return the multi layer network that was fitDataSet
*/ */
public MultiLayerNetwork fitLabeledPoint(JavaRDD<LabeledPoint> rdd) { public MultiLayerNetwork fitLabeledPoint(JavaRDD<LabeledPoint> rdd) {
int nLayers = network.getConfiguration().getFlattenedLayerConfigurations().size(); int nLayers = network.getNetConfiguration().getFlattenedLayerConfigurations().size();
FeedForwardLayer ffl = (FeedForwardLayer) network.getConfiguration().getFlattenedLayerConfigurations().get(nLayers - 1); FeedForwardLayer ffl = (FeedForwardLayer) network.getNetConfiguration().getFlattenedLayerConfigurations().get(nLayers - 1);
JavaRDD<DataSet> ds = MLLibUtil.fromLabeledPoint(sc, rdd, ffl.getNOut()); JavaRDD<DataSet> ds = MLLibUtil.fromLabeledPoint(sc, rdd, ffl.getNOut());
return fit(ds); return fit(ds);
} }

View File

@ -275,7 +275,7 @@ public class ParameterAveragingTrainingMaster
@Override @Override
public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network) { public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network) {
NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getConfiguration(), NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getNetConfiguration(),
network.getNetwork().params(), network.getNetwork().getUpdater().getStateViewArray()); network.getNetwork().params(), network.getNetwork().getUpdater().getStateViewArray());
if (collectTrainingStats) if (collectTrainingStats)
@ -727,7 +727,7 @@ public class ParameterAveragingTrainingMaster
if (params != null) { if (params != null) {
//Params may be null for edge case (empty RDD) //Params may be null for edge case (empty RDD)
if (network != null) { if (network != null) {
NeuralNetConfiguration conf = network.getNetwork().getConfiguration(); NeuralNetConfiguration conf = network.getNetwork().getNetConfiguration();
int numUpdates = averagingFrequency; int numUpdates = averagingFrequency;
conf.setIterationCount(conf.getIterationCount() + numUpdates); conf.setIterationCount(conf.getIterationCount() + numUpdates);
} else { } else {

View File

@ -172,9 +172,9 @@ public class ParameterAveragingTrainingWorker extends BaseTrainingWorker<Paramet
list.add(l); //Don't need to clone listeners: not from broadcast, so deserialization handles list.add(l); //Don't need to clone listeners: not from broadcast, so deserialization handles
} }
if (m instanceof MultiLayerNetwork) if (m instanceof MultiLayerNetwork)
m.setListeners(list); m.setListeners(list.toArray(new TrainingListener[]{}));
else else
m.setListeners(list); m.setListeners(list.toArray(new TrainingListener[]{}));
} }
} }

View File

@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
@ -53,13 +54,14 @@ public class CustomLayer extends FeedForwardLayer {
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
boolean initializeParams, DataType networkDataType) { boolean initializeParams, DataType networkDataType) {
CustomLayerImpl ret = new CustomLayerImpl(conf, networkDataType); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType);
ret.setListeners(trainingListeners); ret.setListeners(trainingListeners);
ret.setIndex(layerIndex); ret.setIndex(layerIndex);
ret.setParamsViewArray(layerParamsView); ret.setParamsViewArray(layerParamsView);
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams); Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
ret.setParamTable(paramTable); ret.setParamTable(paramTable);
ret.setLayerConfiguration(conf); ret.setLayerConfiguration(lconf);
return ret; return ret;
} }

View File

@ -21,11 +21,12 @@
package org.deeplearning4j.spark.impl.customlayer.layer; package org.deeplearning4j.spark.impl.customlayer.layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.layers.BaseLayer; import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
public class CustomLayerImpl extends BaseLayer<CustomLayer> { public class CustomLayerImpl extends BaseLayer<CustomLayer> {
public CustomLayerImpl(NeuralNetConfiguration conf, DataType dataType) { public CustomLayerImpl(LayerConfiguration conf, DataType dataType) {
super(conf, dataType); super(conf, dataType);
} }

View File

@ -154,7 +154,7 @@ public class TestFrozenLayers extends BaseSparkTest {
ComputationGraph withFrozen = new TransferLearning.GraphBuilder(origModel).fineTuneConfiguration(finetune) ComputationGraph withFrozen = new TransferLearning.GraphBuilder(origModel).fineTuneConfiguration(finetune)
.setFeatureExtractor("1").build(); .setFeatureExtractor("1").build();
Map<String, INDArray> m = withFrozen.paramTable(); Map<String, INDArray> m = withFrozen.getParamTable();
Map<String, INDArray> pCopy = new HashMap<>(); Map<String, INDArray> pCopy = new HashMap<>();
for (Map.Entry<String, INDArray> entry : m.entrySet()) { for (Map.Entry<String, INDArray> entry : m.entrySet()) {
pCopy.put(entry.getKey(), entry.getValue().dup()); pCopy.put(entry.getKey(), entry.getValue().dup());
@ -190,7 +190,7 @@ public class TestFrozenLayers extends BaseSparkTest {
ComputationGraph fitted = sNet.getNetwork(); ComputationGraph fitted = sNet.getNetwork();
Map<String, INDArray> fittedParams = fitted.paramTable(); Map<String, INDArray> fittedParams = fitted.getParamTable();
for (Map.Entry<String, INDArray> entry : fittedParams.entrySet()) { for (Map.Entry<String, INDArray> entry : fittedParams.entrySet()) {
INDArray orig = pCopy.get(entry.getKey()); INDArray orig = pCopy.get(entry.getKey());

View File

@ -784,13 +784,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
JavaRDD<DataSet> rdd = sc.parallelize(list); JavaRDD<DataSet> rdd = sc.parallelize(list);
assertEquals(0, sparkNet.getNetwork().getConfiguration().getIterationCount()); assertEquals(0, sparkNet.getNetwork().getNetConfiguration().getIterationCount());
sparkNet.fit(rdd); sparkNet.fit(rdd);
assertEquals(minibatchesPerWorkerPerEpoch, assertEquals(minibatchesPerWorkerPerEpoch,
sparkNet.getNetwork().getConfiguration().getIterationCount()); sparkNet.getNetwork().getNetConfiguration().getIterationCount());
sparkNet.fit(rdd); sparkNet.fit(rdd);
assertEquals(2 * minibatchesPerWorkerPerEpoch, assertEquals(2 * minibatchesPerWorkerPerEpoch,
sparkNet.getNetwork().getConfiguration().getIterationCount()); sparkNet.getNetwork().getNetConfiguration().getIterationCount());
sparkNet.getTrainingMaster().deleteTempFiles(sc); sparkNet.getTrainingMaster().deleteTempFiles(sc);
} }
@ -1074,11 +1074,11 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
for(int i=0; i<3; i++ ){ for(int i=0; i<3; i++ ){
assertEquals(i, sn1.getNetwork().getConfiguration().getEpochCount()); assertEquals(i, sn1.getNetwork().getNetConfiguration().getEpochCount());
assertEquals(i, sn2.getNetwork().getComputationGraphConfiguration().getEpochCount()); assertEquals(i, sn2.getNetwork().getComputationGraphConfiguration().getEpochCount());
sn1.fit(rdd); sn1.fit(rdd);
sn2.fit(rdd); sn2.fit(rdd);
assertEquals(i+1, sn1.getNetwork().getConfiguration().getEpochCount()); assertEquals(i+1, sn1.getNetwork().getNetConfiguration().getEpochCount());
assertEquals(i+1, sn2.getNetwork().getComputationGraphConfiguration().getEpochCount()); assertEquals(i+1, sn2.getNetwork().getComputationGraphConfiguration().getEpochCount());
} }
} }

View File

@ -239,7 +239,7 @@ public class SharedTrainingWrapper {
List<TrainingListener> listeners = worker.getListeners(); List<TrainingListener> listeners = worker.getListeners();
if(listeners != null){ if(listeners != null){
model.setListeners(listeners); model.setListeners(listeners.toArray(new TrainingListener[]{}));
StatsStorageRouter r = worker.getRouter(); StatsStorageRouter r = worker.getRouter();
if(r != null){ if(r != null){
for(TrainingListener l : listeners){ for(TrainingListener l : listeners){
@ -425,7 +425,7 @@ public class SharedTrainingWrapper {
.setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode()); .setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode());
((ComputationGraph) originalModel).setGradientsAccumulator(accumulator); ((ComputationGraph) originalModel).setGradientsAccumulator(accumulator);
} else if (model instanceof MultiLayerNetwork) { } else if (model instanceof MultiLayerNetwork) {
((MultiLayerNetwork) originalModel).getConfiguration() ((MultiLayerNetwork) originalModel).getNetConfiguration()
.setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode()); .setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode());
((MultiLayerNetwork) originalModel).setGradientsAccumulator(accumulator); ((MultiLayerNetwork) originalModel).setGradientsAccumulator(accumulator);
} }

View File

@ -262,7 +262,7 @@ public class SharedTrainingMaster extends BaseTrainingMaster<SharedTrainingResul
/* /*
Here we're going create our worker, which will be passed into corresponding FlatMapFunction Here we're going create our worker, which will be passed into corresponding FlatMapFunction
*/ */
NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getConfiguration(), NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getNetConfiguration(),
network.getNetwork().params(), network.getNetwork().getUpdater().getStateViewArray()); network.getNetwork().params(), network.getNetwork().getUpdater().getStateViewArray());
voidConfiguration.setUnicastControllerPort(voidConfiguration.getPortSupplier().getPort()); voidConfiguration.setUnicastControllerPort(voidConfiguration.getPortSupplier().getPort());

View File

@ -20,6 +20,7 @@ package org.deeplearning4j.plot;
import com.google.common.util.concurrent.AtomicDouble; import com.google.common.util.concurrent.AtomicDouble;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NonNull;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.api.IModel;
@ -29,15 +30,21 @@ import org.deeplearning4j.clustering.sptree.SpTree;
import org.deeplearning4j.clustering.vptree.VPTree; import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.*; import org.nd4j.linalg.api.memory.enums.*;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.conditions.Conditions;
@ -302,32 +309,104 @@ public class BarnesHutTsne implements IModel {
return x; return x;
} }
/**
* This method returns updater state (if applicable), null otherwise
*
* @return
*/
@Override
public INDArray updaterState() {
return null;
}
@Override @Override
public ConvexOptimizer getOptimizer() { public ConvexOptimizer getOptimizer() {
return null; return null;
} }
/**
* This method fits model with a given DataSet
*
* @param dataSet
*/
@Override
public void fit(DataSet dataSet) {
}
/**
* This method fits model with a given MultiDataSet
*
* @param dataSet
*/
@Override
public void fit(MultiDataSet dataSet) {
}
/**
* This method fits model with a given DataSetIterator
*
* @param iterator
*/
@Override
public void fit(DataSetIterator iterator) {
}
/**
* This method fits model with a given MultiDataSetIterator
*
* @param iterator
*/
@Override
public void fit(MultiDataSetIterator iterator) {
}
/**
* This method executes evaluation of the model against given iterator and evaluation
* implementations
*
* @param iterator
* @param evaluations
*/
@Override
public <T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T... evaluations) {
return null;
}
/**
* This method executes evaluation of the model against given iterator and evaluation
* implementations
*
* @param iterator
* @param evaluations
*/
@Override
public <T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator,
T... evaluations) {
return null;
}
@Override @Override
public INDArray getParam(String param) { public INDArray getParam(String param) {
return null; return null;
} }
@Override @Override
public void addListeners(TrainingListener... listener) { public void addListeners(TrainingListener... listener) {//no op
// no-op
} }
@Override public Map<String, INDArray> getParamTable() {
public Map<String, INDArray> paramTable() {
return null; return null;
} }
@Override public Map<String, INDArray> getParamTable(boolean backprapParamsOnly) {
public Map<String, INDArray> paramTable(boolean backprapParamsOnly) {
return null; return null;
} }
@Override
public void setParamTable(Map<String, INDArray> paramTable) { public void setParamTable(Map<String, INDArray> paramTable) {
} }
@ -490,7 +569,7 @@ public class BarnesHutTsne implements IModel {
* *
* @param listeners * @param listeners
*/ */
@Override
public void setListeners(Collection<org.deeplearning4j.optimize.api.TrainingListener> listeners) { public void setListeners(Collection<org.deeplearning4j.optimize.api.TrainingListener> listeners) {
} }
@ -901,8 +980,15 @@ public class BarnesHutTsne implements IModel {
return null; return null;
} }
/**
* @param netConfiguration
*/
@Override @Override
public void setLayerConfiguration(NeuralNetConfiguration layerConfiguration) { public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) {
}
public void setLayerConfiguration(LayerConfiguration layerConfiguration) {
} }
@ -1060,4 +1146,14 @@ public class BarnesHutTsne implements IModel {
public void close(){ public void close(){
//No-op //No-op
} }
/**
* Get the TrainingListeners
*
* @return training listener
*/
@Override
public Collection<TrainingListener> getListeners() {
return null;
}
} }

View File

@ -30,6 +30,7 @@ import org.deeplearning4j.core.storage.StorageMetaData;
import org.deeplearning4j.core.storage.listener.RoutingIterationListener; import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@ -426,10 +427,10 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
//Need to append "0_", "1_" etc to param names from layers... //Need to append "0_", "1_" etc to param names from layers...
int layerIdx = 0; int layerIdx = 0;
for (Layer l : ((MultiLayerNetwork) model).getLayers()) { for (Layer l : ((MultiLayerNetwork) model).getLayers()) {
NeuralNetConfiguration conf = l.getNetConfiguration(); LayerConfiguration conf = l.getLayerConfiguration();
List<String> paramkeys = l.getLayerConfiguration().initializer().paramKeys(l.getLayerConfiguration()); List<String> paramkeys = l.getLayerConfiguration().initializer().paramKeys(l.getLayerConfiguration());
for (String s : paramkeys) { for (String s : paramkeys) {
double lr = conf.getFirstLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); double lr = conf.getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount());
if (Double.isNaN(lr)) { if (Double.isNaN(lr)) {
//Edge case: No-Op updater, AdaDelta etc - don't have a LR hence return NaN for IUpdater.getLearningRate //Edge case: No-Op updater, AdaDelta etc - don't have a LR hence return NaN for IUpdater.getLearningRate
lr = 0.0; lr = 0.0;
@ -440,11 +441,11 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
} }
} else if (model instanceof ComputationGraph) { } else if (model instanceof ComputationGraph) {
for (Layer l : ((ComputationGraph) model).getLayers()) { for (Layer l : ((ComputationGraph) model).getLayers()) {
NeuralNetConfiguration conf = l.getNetConfiguration(); LayerConfiguration conf = l.getLayerConfiguration();
String layerName = conf.getFirstLayer().getLayerName(); String layerName = conf.getLayerName();
List<String> paramkeys = l.getLayerConfiguration().initializer().paramKeys(l.getLayerConfiguration()); List<String> paramkeys = l.getLayerConfiguration().initializer().paramKeys(l.getLayerConfiguration());
for (String s : paramkeys) { for (String s : paramkeys) {
double lr = conf.getFirstLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); double lr = conf.getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount());
if (Double.isNaN(lr)) { if (Double.isNaN(lr)) {
//Edge case: No-Op updater, AdaDelta etc - don't have a LR hence return NaN for IUpdater.getLearningRate //Edge case: No-Op updater, AdaDelta etc - don't have a LR hence return NaN for IUpdater.getLearningRate
lr = 0.0; lr = 0.0;
@ -467,7 +468,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
//--- Histograms --- //--- Histograms ---
if (updateConfig.collectHistograms(StatsType.Parameters)) { if (updateConfig.collectHistograms(StatsType.Parameters)) {
Map<String, Histogram> paramHistograms = getHistograms(model.paramTable(backpropParamsOnly), Map<String, Histogram> paramHistograms = getHistograms(model.getParamTable(backpropParamsOnly),
updateConfig.numHistogramBins(StatsType.Parameters)); updateConfig.numHistogramBins(StatsType.Parameters));
report.reportHistograms(StatsType.Parameters, paramHistograms); report.reportHistograms(StatsType.Parameters, paramHistograms);
} }
@ -490,7 +491,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
//--- Summary Stats: Mean, Variance, Mean Magnitudes --- //--- Summary Stats: Mean, Variance, Mean Magnitudes ---
if (updateConfig.collectMean(StatsType.Parameters)) { if (updateConfig.collectMean(StatsType.Parameters)) {
Map<String, Double> meanParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Mean); Map<String, Double> meanParams = calculateSummaryStats(model.getParamTable(backpropParamsOnly), StatType.Mean);
report.reportMean(StatsType.Parameters, meanParams); report.reportMean(StatsType.Parameters, meanParams);
} }
@ -511,7 +512,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
if (updateConfig.collectStdev(StatsType.Parameters)) { if (updateConfig.collectStdev(StatsType.Parameters)) {
Map<String, Double> stdevParams = Map<String, Double> stdevParams =
calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Stdev); calculateSummaryStats(model.getParamTable(backpropParamsOnly), StatType.Stdev);
report.reportStdev(StatsType.Parameters, stdevParams); report.reportStdev(StatsType.Parameters, stdevParams);
} }
@ -532,7 +533,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
if (updateConfig.collectMeanMagnitudes(StatsType.Parameters)) { if (updateConfig.collectMeanMagnitudes(StatsType.Parameters)) {
Map<String, Double> meanMagParams = Map<String, Double> meanMagParams =
calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.MeanMagnitude); calculateSummaryStats(model.getParamTable(backpropParamsOnly), StatType.MeanMagnitude);
report.reportMeanMagnitudes(StatsType.Parameters, meanMagParams); report.reportMeanMagnitudes(StatsType.Parameters, meanMagParams);
} }
@ -652,7 +653,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
long numParams; long numParams;
if (model instanceof MultiLayerNetwork) { if (model instanceof MultiLayerNetwork) {
MultiLayerNetwork net = ((MultiLayerNetwork) model); MultiLayerNetwork net = ((MultiLayerNetwork) model);
jsonConf = net.getConfiguration().toJson(); jsonConf = net.getNetConfiguration().toJson();
numLayers = net.getnLayers(); numLayers = net.getnLayers();
numParams = net.numParams(); numParams = net.numParams();
} else if (model instanceof ComputationGraph) { } else if (model instanceof ComputationGraph) {
@ -670,7 +671,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
+ (model == null ? null : model.getClass())); + (model == null ? null : model.getClass()));
} }
Map<String, INDArray> paramMap = model.paramTable(backpropParamsOnly); Map<String, INDArray> paramMap = model.getParamTable(backpropParamsOnly);
String[] paramNames = new String[paramMap.size()]; String[] paramNames = new String[paramMap.size()];
int i = 0; int i = 0;
for (String s : paramMap.keySet()) { //Assuming sensible iteration order - LinkedHashMaps are used in MLN/CG for example for (String s : paramMap.keySet()) { //Assuming sensible iteration order - LinkedHashMaps are used in MLN/CG for example

View File

@ -1129,8 +1129,8 @@ public class TrainModule implements UIModule {
NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson(configJson); NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson(configJson);
int confIdx = layerIdx - 1; //-1 because of input int confIdx = layerIdx - 1; //-1 because of input
if (confIdx >= 0) { if (confIdx >= 0) {
nnc = conf.getNetConfigurations().get(confIdx); layer = conf.getFlattenedLayerConfigurations().get(confIdx);
layer = nnc.getFirstLayer(); nnc = layer.getNetConfiguration();
} else { } else {
//Input layer //Input layer
layerType = "Input"; layerType = "Input";
@ -1144,7 +1144,7 @@ public class TrainModule implements UIModule {
if (vertices.containsKey(vertexName) && vertices.get(vertexName) instanceof LayerVertex) { if (vertices.containsKey(vertexName) && vertices.get(vertexName) instanceof LayerVertex) {
LayerVertex lv = (LayerVertex) vertices.get(vertexName); LayerVertex lv = (LayerVertex) vertices.get(vertexName);
nnc = lv.getNetConfiguration(); nnc = lv.getNetConfiguration();
layer = nnc.getFirstLayer(); layer = lv.getLayerConfiguration();
} else if (conf.getNetworkInputs().contains(vertexName)) { } else if (conf.getNetworkInputs().contains(vertexName)) {
layerType = "Input"; layerType = "Input";
} else { } else {
@ -1177,7 +1177,7 @@ public class TrainModule implements UIModule {
if (layer instanceof BaseLayer) { if (layer instanceof BaseLayer) {
BaseLayer bl = (BaseLayer) layer; BaseLayer bl = (BaseLayer) layer;
activationFn = bl.getActivationFn().toString(); activationFn = bl.getActivationFn().toString();
long nParams = layer.initializer().numParams(nnc.getFirstLayer()); long nParams = layer.initializer().numParams(bl.getLayer());
layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerNParams"), layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerNParams"),
String.valueOf(nParams)}); String.valueOf(nParams)});
if (nParams > 0) { if (nParams > 0) {

View File

@ -62,24 +62,24 @@ public class TrainModuleUtils {
layerInfo.add(Collections.emptyMap()); layerInfo.add(Collections.emptyMap());
List<NeuralNetConfiguration> list = config.getNetConfigurations(); List<LayerConfiguration> list = config.getFlattenedLayerConfigurations();
int layerIdx = 1; int layerIdx = 1;
for (NeuralNetConfiguration c : list) { for (LayerConfiguration c : list) {
LayerConfiguration layer = c.getFirstLayer(); LayerConfiguration layer = c;
String layerName = layer.getLayerName(); String layerName = layer.getLayerName();
if (layerName == null) if (layerName == null)
layerName = "layer" + layerIdx; layerName = "layer" + layerIdx;
vertexNames.add(layerName); vertexNames.add(layerName);
originalVertexName.add(String.valueOf(layerIdx - 1)); originalVertexName.add(String.valueOf(layerIdx - 1));
String layerType = c.getFirstLayer().getClass().getSimpleName().replaceAll("Layer$", ""); String layerType = c.getClass().getSimpleName().replaceAll("Layer$", "");
layerTypes.add(layerType); layerTypes.add(layerType);
layerInputs.add(Collections.singletonList(layerIdx - 1)); layerInputs.add(Collections.singletonList(layerIdx - 1));
layerIdx++; layerIdx++;
//Extract layer info //Extract layer info
Map<String, String> map = getLayerInfo(c, layer); Map<String, String> map = getLayerInfo(c.getNetConfiguration(), layer);
layerInfo.add(map); layerInfo.add(map);
} }

View File

@ -143,7 +143,7 @@
<label class="btn btn-secondary" onclick="setLayout('dagre')"> <label class="btn btn-secondary" onclick="setLayout('dagre')">
<input type="radio" name="options" id="option3" autocomplete="off">Alt</label> <input type="radio" name="options" id="option3" autocomplete="off">Alt</label>
<label class="btn btn-secondary" onclick="setLayout('cose-bilkent')"> <label class="btn btn-secondary" onclick="setLayout('cose-bilkent')">
<input type="radio" name="options" id="option3" autocomplete="off">Spread</label> <input type="radio" name="options" id="option4" autocomplete="off">Spread</label>
</div> </div>
<br> <br>
<br> <br>

View File

@ -45,7 +45,7 @@ public class TestUtils {
ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.getNetConfiguration(), restored.getNetConfiguration());
assertEquals(net.params(), restored.params()); assertEquals(net.params(), restored.params());
return restored; return restored;