parent
fec570ff98
commit
9af4f9f23a
|
@ -21,8 +21,19 @@
|
||||||
|
|
||||||
package net.brutex.gan;
|
package net.brutex.gan;
|
||||||
|
|
||||||
import java.util.List;
|
import java.awt.BorderLayout;
|
||||||
|
import java.awt.Dimension;
|
||||||
|
import java.awt.GridLayout;
|
||||||
|
import java.awt.Image;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
import javax.swing.ImageIcon;
|
||||||
|
import javax.swing.JFrame;
|
||||||
|
import javax.swing.JLabel;
|
||||||
|
import javax.swing.JPanel;
|
||||||
|
import javax.swing.WindowConstants;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
|
@ -34,20 +45,23 @@ import org.datavec.image.transform.PipelineImageTransform;
|
||||||
import org.datavec.image.transform.ResizeImageTransform;
|
import org.datavec.image.transform.ResizeImageTransform;
|
||||||
import org.datavec.image.transform.ShowImageTransform;
|
import org.datavec.image.transform.ShowImageTransform;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.CacheMode;
|
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||||
import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer;
|
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
@ -55,13 +69,6 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
|
||||||
|
|
||||||
import javax.swing.*;
|
|
||||||
import java.awt.*;
|
|
||||||
import java.awt.image.BufferedImage;
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@ -106,7 +113,7 @@ public class App {
|
||||||
* @return config
|
* @return config
|
||||||
*/
|
*/
|
||||||
private static MultiLayerConfiguration generator() {
|
private static MultiLayerConfiguration generator() {
|
||||||
MultiLayerConfiguration confxx = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.seed(42)
|
.seed(42)
|
||||||
.updater(UPDATER)
|
.updater(UPDATER)
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||||
|
@ -117,23 +124,8 @@ public class App {
|
||||||
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||||
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
|
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
|
||||||
.build();
|
.build();
|
||||||
log.debug("Generator network: \n{}", confxx.toJson());
|
|
||||||
|
|
||||||
NeuralNetworkConfiguration conf2 = NeuralNetworkConfiguration.builder().build();
|
return conf;
|
||||||
|
|
||||||
NeuralNetworkConfiguration confx = NeuralNetworkConfiguration.builder()
|
|
||||||
.cacheMode(CacheMode.HOST)
|
|
||||||
.layer( new DenseLayer.Builder().build())
|
|
||||||
.layer( new DenseLayer.Builder().build())
|
|
||||||
.layer( BuildingBlockLayer.builder().build())
|
|
||||||
.layers( List.of(genLayers()))
|
|
||||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
|
||||||
.build();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return confx;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Layer[] disLayers() {
|
private static Layer[] disLayers() {
|
||||||
|
@ -155,6 +147,7 @@ public class App {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static MultiLayerConfiguration discriminator() {
|
private static MultiLayerConfiguration discriminator() {
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.seed(42)
|
.seed(42)
|
||||||
.updater(UPDATER)
|
.updater(UPDATER)
|
||||||
|
@ -183,11 +176,11 @@ public class App {
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.seed(42)
|
.seed(42)
|
||||||
.updater(UPDATER)
|
.updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() )
|
||||||
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
|
.gradientNormalization( GradientNormalization.RenormalizeL2PerLayer)
|
||||||
.gradientNormalizationThreshold(GRADIENT_THRESHOLD)
|
.gradientNormalizationThreshold( 100 )
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit( new WeightInitXavier() )
|
||||||
.activation(Activation.IDENTITY)
|
.activation( new ActivationIdentity())
|
||||||
.list( layers )
|
.list( layers )
|
||||||
.setInputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
.setInputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -295,7 +295,7 @@ public class BrianTest extends BaseSparkSessionTest {
|
||||||
.activation(Activation.RELU).l2(0.001).build())
|
.activation(Activation.RELU).l2(0.001).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER)
|
.layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER)
|
||||||
.activation(Activation.RELU).build())
|
.activation(Activation.RELU).build())
|
||||||
//.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
|
//.layer(2, new DenseLayerConfiguration.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
|
||||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4)
|
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4)
|
||||||
.weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
|
.weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -301,7 +301,7 @@ public class BrianTest2 /*extends BaseDL4JTest*/ {
|
||||||
.list()
|
.list()
|
||||||
.layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).l2(0.001).build())
|
.layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).l2(0.001).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
|
.layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
|
||||||
//.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
|
//.layer(2, new DenseLayerConfiguration.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
|
||||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4).weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
|
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4).weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -95,7 +95,7 @@ public class TestServer {
|
||||||
.list()
|
.list()
|
||||||
//.layer(0, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 5).stride(1,1).padding(0,2).nOut(1).name("1st Filter").updater(new Adam.Builder().learningRate(0.2).build()).build())
|
//.layer(0, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 5).stride(1,1).padding(0,2).nOut(1).name("1st Filter").updater(new Adam.Builder().learningRate(0.2).build()).build())
|
||||||
//.layer(1, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 2).stride(1,2).padding(0,0).nOut(1).name("2nd Filter").updater(new Adam.Builder().learningRate(0.1).build()).build())
|
//.layer(1, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 2).stride(1,2).padding(0,0).nOut(1).name("2nd Filter").updater(new Adam.Builder().learningRate(0.1).build()).build())
|
||||||
// .layer(1, new DenseLayer.Builder().nIn(10).nOut(64).activation(Activation.RELU).build())
|
// .layer(1, new DenseLayerConfiguration.Builder().nIn(10).nOut(64).activation(Activation.RELU).build())
|
||||||
.layer(0, new DenseLayer.Builder().nIn(10).nOut(100).activation(Activation.RELU).l2(0.003).build())
|
.layer(0, new DenseLayer.Builder().nIn(10).nOut(100).activation(Activation.RELU).l2(0.003).build())
|
||||||
.layer(1, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build())
|
.layer(1, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build())
|
||||||
.layer(2, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build())
|
.layer(2, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build())
|
||||||
|
|
|
@ -131,7 +131,7 @@ public class TestServer2 {
|
||||||
.list()
|
.list()
|
||||||
//.layer(0, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 5).stride(1,1).padding(0,2).nOut(1).name("1st Filter").updater(new Adam.Builder().learningRate(0.2).build()).build())
|
//.layer(0, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 5).stride(1,1).padding(0,2).nOut(1).name("1st Filter").updater(new Adam.Builder().learningRate(0.2).build()).build())
|
||||||
//.layer(1, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 2).stride(1,2).padding(0,0).nOut(1).name("2nd Filter").updater(new Adam.Builder().learningRate(0.1).build()).build())
|
//.layer(1, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 2).stride(1,2).padding(0,0).nOut(1).name("2nd Filter").updater(new Adam.Builder().learningRate(0.1).build()).build())
|
||||||
// .layer(1, new DenseLayer.Builder().nIn(10).nOut(64).activation(Activation.RELU).build())
|
// .layer(1, new DenseLayerConfiguration.Builder().nIn(10).nOut(64).activation(Activation.RELU).build())
|
||||||
.layer(0, new DenseLayer.Builder().nIn(10).nOut(100).activation(Activation.RELU).l2(0.003).build())
|
.layer(0, new DenseLayer.Builder().nIn(10).nOut(100).activation(Activation.RELU).l2(0.003).build())
|
||||||
.layer(1, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build())
|
.layer(1, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build())
|
||||||
.layer(2, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build())
|
.layer(2, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build())
|
||||||
|
|
|
@ -284,7 +284,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
INDArray paramsPostTraining;
|
INDArray paramsPostTraining;
|
||||||
if (modelType == ModelType.MLN) {
|
if (modelType == ModelType.MLN) {
|
||||||
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
Preconditions.checkState(layersToTrain != null, "ILayer indices must not be null");
|
||||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||||
|
|
||||||
for (int i : layersToTrain) {
|
for (int i : layersToTrain) {
|
||||||
|
@ -293,7 +293,7 @@ public class IntegrationTestBaselineGenerator {
|
||||||
paramsPostTraining = mln.params();
|
paramsPostTraining = mln.params();
|
||||||
} else if (modelType == ModelType.CG) {
|
} else if (modelType == ModelType.CG) {
|
||||||
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
Preconditions.checkState(layersToTrain != null, "ILayer names must not be null");
|
||||||
|
|
||||||
for (String i : layersToTrain) {
|
for (String i : layersToTrain) {
|
||||||
cg.pretrainLayer(i, iter);
|
cg.pretrainLayer(i, iter);
|
||||||
|
|
|
@ -200,7 +200,7 @@ public class IntegrationTestRunner {
|
||||||
m = cg;
|
m = cg;
|
||||||
|
|
||||||
ComputationGraph loaded = ComputationGraph.load(savedModel, true);
|
ComputationGraph loaded = ComputationGraph.load(savedModel, true);
|
||||||
assertEquals(loaded.getConfiguration(), cg.getConfiguration(), "Configs not equal" );
|
assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" );
|
||||||
assertEquals( loaded.params(), cg.params(), "Params not equal");
|
assertEquals( loaded.params(), cg.params(), "Params not equal");
|
||||||
assertEquals(loaded.paramTable(), cg.paramTable(), "Param table not equal");
|
assertEquals(loaded.paramTable(), cg.paramTable(), "Param table not equal");
|
||||||
} else if(config instanceof SameDiff){
|
} else if(config instanceof SameDiff){
|
||||||
|
@ -383,7 +383,7 @@ public class IntegrationTestRunner {
|
||||||
org.deeplearning4j.nn.api.Layer[] layers;
|
org.deeplearning4j.nn.api.Layer[] layers;
|
||||||
if(modelType == ModelType.MLN){
|
if(modelType == ModelType.MLN){
|
||||||
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer indices must not be null");
|
Preconditions.checkState(layersToTrain != null, "ILayer indices must not be null");
|
||||||
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
DataSetIterator dsi = new MultiDataSetWrapperIterator(iter);
|
||||||
|
|
||||||
for( int i : layersToTrain){
|
for( int i : layersToTrain){
|
||||||
|
@ -393,7 +393,7 @@ public class IntegrationTestRunner {
|
||||||
layers = mln.getLayers();
|
layers = mln.getLayers();
|
||||||
} else if(modelType == ModelType.CG) {
|
} else if(modelType == ModelType.CG) {
|
||||||
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
String[] layersToTrain = tc.getUnsupervisedTrainLayersCG();
|
||||||
Preconditions.checkState(layersToTrain != null, "Layer names must not be null");
|
Preconditions.checkState(layersToTrain != null, "ILayer names must not be null");
|
||||||
|
|
||||||
for( String i : layersToTrain){
|
for( String i : layersToTrain){
|
||||||
cg.pretrainLayer(i, iter);
|
cg.pretrainLayer(i, iter);
|
||||||
|
@ -429,8 +429,8 @@ public class IntegrationTestRunner {
|
||||||
isTbptt = mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT;
|
isTbptt = mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT;
|
||||||
tbpttLength = mln.getLayerWiseConfigurations().getTbpttFwdLength();
|
tbpttLength = mln.getLayerWiseConfigurations().getTbpttFwdLength();
|
||||||
} else if(modelType == ModelType.CG) {
|
} else if(modelType == ModelType.CG) {
|
||||||
isTbptt = cg.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT;
|
isTbptt = cg.getComputationGraphConfiguration().getBackpropType() == BackpropType.TruncatedBPTT;
|
||||||
tbpttLength = cg.getConfiguration().getTbpttFwdLength();
|
tbpttLength = cg.getComputationGraphConfiguration().getTbpttFwdLength();
|
||||||
} else {
|
} else {
|
||||||
isTbptt = false;
|
isTbptt = false;
|
||||||
tbpttLength = 0;
|
tbpttLength = 0;
|
||||||
|
@ -458,11 +458,11 @@ public class IntegrationTestRunner {
|
||||||
epochAfter = mln.getEpochCount();
|
epochAfter = mln.getEpochCount();
|
||||||
layers = mln.getLayers();
|
layers = mln.getLayers();
|
||||||
} else if(modelType == ModelType.CG){
|
} else if(modelType == ModelType.CG){
|
||||||
iterBefore = cg.getConfiguration().getIterationCount();
|
iterBefore = cg.getComputationGraphConfiguration().getIterationCount();
|
||||||
epochBefore = cg.getConfiguration().getEpochCount();
|
epochBefore = cg.getComputationGraphConfiguration().getEpochCount();
|
||||||
cg.fit(countingIter);
|
cg.fit(countingIter);
|
||||||
iterAfter = cg.getConfiguration().getIterationCount();
|
iterAfter = cg.getComputationGraphConfiguration().getIterationCount();
|
||||||
epochAfter = cg.getConfiguration().getEpochCount();
|
epochAfter = cg.getComputationGraphConfiguration().getEpochCount();
|
||||||
layers = cg.getLayers();
|
layers = cg.getLayers();
|
||||||
} else {
|
} else {
|
||||||
iterBefore = sd.getTrainingConfig().getIterationCount();
|
iterBefore = sd.getTrainingConfig().getIterationCount();
|
||||||
|
@ -611,7 +611,7 @@ public class IntegrationTestRunner {
|
||||||
} else if(modelType == ModelType.CG){
|
} else if(modelType == ModelType.CG){
|
||||||
ModelSerializer.writeModel(m, f, true);
|
ModelSerializer.writeModel(m, f, true);
|
||||||
ComputationGraph restored = ComputationGraph.load(f, true);
|
ComputationGraph restored = ComputationGraph.load(f, true);
|
||||||
assertEquals(cg.getConfiguration(), restored.getConfiguration());
|
assertEquals(cg.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
|
||||||
assertEquals(cg.params(), restored.params());
|
assertEquals(cg.params(), restored.params());
|
||||||
} else {
|
} else {
|
||||||
sd.save(f, true);
|
sd.save(f, true);
|
||||||
|
@ -745,7 +745,7 @@ public class IntegrationTestRunner {
|
||||||
preProcessors = mln.getLayerWiseConfigurations().getInputPreProcessors().values();
|
preProcessors = mln.getLayerWiseConfigurations().getInputPreProcessors().values();
|
||||||
} else {
|
} else {
|
||||||
preProcessors = new ArrayList<>();
|
preProcessors = new ArrayList<>();
|
||||||
for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getConfiguration().getVertices().values()) {
|
for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getComputationGraphConfiguration().getVertices().values()) {
|
||||||
if (gv instanceof LayerVertex) {
|
if (gv instanceof LayerVertex) {
|
||||||
InputPreProcessor pp = ((LayerVertex) gv).getPreProcessor();
|
InputPreProcessor pp = ((LayerVertex) gv).getPreProcessor();
|
||||||
if (pp != null) {
|
if (pp != null) {
|
||||||
|
@ -760,7 +760,7 @@ public class IntegrationTestRunner {
|
||||||
|
|
||||||
//Collect vertex coverage information
|
//Collect vertex coverage information
|
||||||
if (!isMLN) {
|
if (!isMLN) {
|
||||||
for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getConfiguration().getVertices().values()) {
|
for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getComputationGraphConfiguration().getVertices().values()) {
|
||||||
vertexConfClassesSeen.put(gv.getClass(), vertexConfClassesSeen.getOrDefault(gv.getClass(), 0) + 1);
|
vertexConfClassesSeen.put(gv.getClass(), vertexConfClassesSeen.getOrDefault(gv.getClass(), 0) + 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -872,14 +872,14 @@ public class IntegrationTestRunner {
|
||||||
|
|
||||||
log.info("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||");
|
log.info("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||");
|
||||||
|
|
||||||
log.info("Layer coverage - classes seen:");
|
log.info("ILayer coverage - classes seen:");
|
||||||
for (Class<?> c : layerClasses) {
|
for (Class<?> c : layerClasses) {
|
||||||
if (layerConfClassesSeen.containsKey(c)) {
|
if (layerConfClassesSeen.containsKey(c)) {
|
||||||
log.info("Class seen {} times in tests: {}", layerConfClassesSeen.get(c), c.getName());
|
log.info("Class seen {} times in tests: {}", layerConfClassesSeen.get(c), c.getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.info("Layer classes NOT seen in any tests:");
|
log.info("ILayer classes NOT seen in any tests:");
|
||||||
for (Class<?> c : layerClasses) {
|
for (Class<?> c : layerClasses) {
|
||||||
if (!layerConfClassesSeen.containsKey(c)) {
|
if (!layerConfClassesSeen.containsKey(c)) {
|
||||||
log.info("Class NOT seen in any tests: {}", c.getName());
|
log.info("Class NOT seen in any tests: {}", c.getName());
|
||||||
|
|
|
@ -73,7 +73,7 @@ public class TestUtils {
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
||||||
restored = ModelSerializer.restoreComputationGraph(bais, true);
|
restored = ModelSerializer.restoreComputationGraph(bais, true);
|
||||||
|
|
||||||
assertEquals(net.getConfiguration(), restored.getConfiguration());
|
assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
|
||||||
assertEquals(net.params(), restored.params());
|
assertEquals(net.params(), restored.params());
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
//Should never happen
|
//Should never happen
|
||||||
|
@ -81,7 +81,7 @@ public class TestUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
//Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
|
//Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
|
||||||
ComputationGraphConfiguration conf = net.getConfiguration();
|
ComputationGraphConfiguration conf = net.getComputationGraphConfiguration();
|
||||||
serializeDeserializeJava(conf);
|
serializeDeserializeJava(conf);
|
||||||
|
|
||||||
return restored;
|
return restored;
|
||||||
|
|
|
@ -90,7 +90,7 @@ public class TestUtils {
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
||||||
restored = ModelSerializer.restoreComputationGraph(bais, true);
|
restored = ModelSerializer.restoreComputationGraph(bais, true);
|
||||||
|
|
||||||
assertEquals(net.getConfiguration(), restored.getConfiguration());
|
assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration());
|
||||||
assertEquals(net.params(), restored.params());
|
assertEquals(net.params(), restored.params());
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
//Should never happen
|
//Should never happen
|
||||||
|
@ -98,7 +98,7 @@ public class TestUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
//Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
|
//Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
|
||||||
ComputationGraphConfiguration conf = net.getConfiguration();
|
ComputationGraphConfiguration conf = net.getComputationGraphConfiguration();
|
||||||
serializeDeserializeJava(conf);
|
serializeDeserializeJava(conf);
|
||||||
|
|
||||||
return restored;
|
return restored;
|
||||||
|
|
|
@ -626,7 +626,7 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
net.evaluate(iter);
|
net.evaluate(iter);
|
||||||
net.evaluateROCMultiClass(iter, 0);
|
net.evaluateROCMultiClass(iter, 0);
|
||||||
|
|
||||||
cg.getConfiguration().setValidateOutputLayerConfig(false);
|
cg.getComputationGraphConfiguration().setValidateOutputLayerConfig(false);
|
||||||
cg.evaluate(iter);
|
cg.evaluate(iter);
|
||||||
cg.evaluateROCMultiClass(iter, 0);
|
cg.evaluateROCMultiClass(iter, 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,7 +90,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
mln.init();
|
mln.init();
|
||||||
|
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
|
|
||||||
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
||||||
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
||||||
|
@ -135,7 +135,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
mln.init();
|
mln.init();
|
||||||
|
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
|
|
||||||
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
||||||
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
||||||
|
@ -237,7 +237,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
|
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
|
||||||
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
|
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
|
||||||
// for (int k = 0; k < mln.getnLayers(); k++)
|
// for (int k = 0; k < mln.getnLayers(); k++)
|
||||||
// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
|
// System.out.println("ILayer " + k + " # params: " + mln.getLayer(k).numParams());
|
||||||
|
|
||||||
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
||||||
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
||||||
|
@ -341,7 +341,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
|
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
|
||||||
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
|
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
|
||||||
// for (int k = 0; k < mln.getnLayers(); k++)
|
// for (int k = 0; k < mln.getnLayers(); k++)
|
||||||
// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
|
// System.out.println("ILayer " + k + " # params: " + mln.getLayer(k).numParams());
|
||||||
|
|
||||||
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
||||||
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
||||||
|
@ -385,7 +385,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
mln.init();
|
mln.init();
|
||||||
|
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
|
|
||||||
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
||||||
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
||||||
|
@ -430,7 +430,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
mln.init();
|
mln.init();
|
||||||
|
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
|
|
||||||
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
||||||
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
||||||
|
@ -572,7 +572,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
|
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
|
||||||
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
|
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
|
||||||
// for (int k = 0; k < net.getNumLayers(); k++)
|
// for (int k = 0; k < net.getNumLayers(); k++)
|
||||||
// System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams());
|
// System.out.println("ILayer " + k + " # params: " + net.getLayer(k).numParams());
|
||||||
|
|
||||||
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
|
||||||
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
|
||||||
|
|
|
@ -118,7 +118,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -198,7 +198,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -282,7 +282,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -359,7 +359,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
|
|
@ -149,7 +149,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
log.info(msg);
|
log.info(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++) {
|
// for (int j = 0; j < net.getnLayers(); j++) {
|
||||||
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// log.info("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -252,7 +252,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
log.info(msg);
|
log.info(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++) {
|
// for (int j = 0; j < net.getnLayers(); j++) {
|
||||||
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// log.info("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -431,7 +431,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
log.info(msg);
|
log.info(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++) {
|
// for (int j = 0; j < net.getnLayers(); j++) {
|
||||||
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// log.info("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -530,7 +530,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
log.info(msg);
|
log.info(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++) {
|
// for (int j = 0; j < net.getnLayers(); j++) {
|
||||||
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// log.info("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -137,7 +137,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation="
|
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation="
|
||||||
+ outputActivation + ", doLearningFirst=" + doLearningFirst);
|
+ outputActivation + ", doLearningFirst=" + doLearningFirst);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -231,7 +231,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
|
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
|
||||||
+ doLearningFirst);
|
+ doLearningFirst);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -293,7 +293,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||||
|
@ -361,7 +361,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||||
|
@ -427,7 +427,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -500,7 +500,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -920,7 +920,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
|
|
@ -95,7 +95,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
|
||||||
System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = "
|
System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = "
|
||||||
+ miniBatchSize);
|
+ miniBatchSize);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -156,7 +156,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize + " - " + (nchw ? "NCHW" : "NHWC"));
|
System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize + " - " + (nchw ? "NCHW" : "NHWC"));
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -216,7 +216,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize);
|
System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
|
||||||
|
@ -299,7 +299,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
|
||||||
System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = "
|
System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = "
|
||||||
+ miniBatchSize);
|
+ miniBatchSize);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
|
||||||
|
|
|
@ -123,7 +123,7 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
+ lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
|
+ lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
|
||||||
+ doLearningFirst);
|
+ doLearningFirst);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -203,7 +203,7 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
+ lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
|
+ lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
|
||||||
+ doLearningFirst);
|
+ doLearningFirst);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -297,7 +297,7 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
+ ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
+ ", lossFn=" + lf + ", outputActivation=" + outputActivation
|
||||||
+ ", doLearningFirst=" + doLearningFirst + ", l2=" + l2 + ", l1=" + l1);
|
+ ", doLearningFirst=" + doLearningFirst + ", l2=" + l2 + ", l1=" + l1);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -342,7 +342,7 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testEmbeddingLayerSimple");
|
System.out.println("testEmbeddingLayerSimple");
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -382,7 +382,7 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testEmbeddingLayerSimple");
|
System.out.println("testEmbeddingLayerSimple");
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -472,7 +472,7 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -714,7 +714,7 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
// (a) activation function
|
// (a) activation function
|
||||||
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
|
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
|
||||||
// (c) Loss function (with specified output activations)
|
// (c) Loss function (with specified output activations)
|
||||||
// (d) Layer Normalization enabled / disabled
|
// (d) ILayer Normalization enabled / disabled
|
||||||
Activation[] activFns = {Activation.SIGMOID, Activation.TANH};
|
Activation[] activFns = {Activation.SIGMOID, Activation.TANH};
|
||||||
boolean[] characteristic = {true, false}; //If true: run some backprop steps first
|
boolean[] characteristic = {true, false}; //If true: run some backprop steps first
|
||||||
|
|
||||||
|
@ -776,7 +776,7 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
+ lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
|
+ lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
|
||||||
+ doLearningFirst + ", layerNorm=" + layerNorm);
|
+ doLearningFirst + ", layerNorm=" + layerNorm);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
|
|
@ -106,7 +106,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testBasicIris()");
|
System.out.println("testBasicIris()");
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
@ -157,7 +157,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testBasicIrisWithMerging()");
|
System.out.println("testBasicIrisWithMerging()");
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
@ -214,7 +214,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")");
|
System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")");
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
@ -274,7 +274,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")");
|
System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")");
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
@ -376,7 +376,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
@ -439,7 +439,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
@ -478,7 +478,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testLSTMWithSubset()");
|
System.out.println("testLSTMWithSubset()");
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
@ -515,7 +515,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testLSTMWithLastTimeStepVertex()");
|
System.out.println("testLSTMWithLastTimeStepVertex()");
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
//First: test with no input mask array
|
//First: test with no input mask array
|
||||||
|
@ -579,7 +579,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testLSTMWithDuplicateToTimeSeries()");
|
System.out.println("testLSTMWithDuplicateToTimeSeries()");
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input1, input2})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input1, input2})
|
||||||
|
@ -628,7 +628,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testLSTMWithReverseTimeSeriesVertex()");
|
System.out.println("testLSTMWithReverseTimeSeriesVertex()");
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
@ -683,7 +683,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(inputs)
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(inputs)
|
||||||
|
@ -723,7 +723,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
@ -769,7 +769,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(input)
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(input)
|
||||||
|
@ -820,7 +820,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
@ -888,7 +888,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testBasicIrisTripletStackingL2Loss()");
|
System.out.println("testBasicIrisTripletStackingL2Loss()");
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{pos, anc, neg})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{pos, anc, neg})
|
||||||
|
@ -949,7 +949,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{example})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{example})
|
||||||
|
@ -1014,7 +1014,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -1063,7 +1063,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
|
||||||
|
@ -1121,7 +1121,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
|
||||||
|
@ -1179,7 +1179,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
|
||||||
|
@ -1242,7 +1242,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
graph.setLayerMaskArrays(new INDArray[] {inMask1, inMask2}, null);
|
graph.setLayerMaskArrays(new INDArray[] {inMask1, inMask2}, null);
|
||||||
|
@ -1301,7 +1301,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
|
||||||
|
@ -1347,7 +1347,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1})
|
||||||
|
@ -1398,7 +1398,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1})
|
||||||
|
@ -1436,7 +1436,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testGraphEmbeddingLayerSimple");
|
System.out.println("testGraphEmbeddingLayerSimple");
|
||||||
// for (int j = 0; j < cg.getNumLayers(); j++)
|
// for (int j = 0; j < cg.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + cg.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + cg.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(cg).inputs(new INDArray[]{input})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(cg).inputs(new INDArray[]{input})
|
||||||
|
|
|
@ -84,7 +84,7 @@ public class LRNGradientCheckTests extends BaseDL4JTest {
|
||||||
|
|
||||||
// if (PRINT_RESULTS) {
|
// if (PRINT_RESULTS) {
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
// }
|
// }
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
|
|
@ -126,7 +126,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -215,7 +215,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
|
||||||
|
@ -343,7 +343,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
|
||||||
+ ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", l2=" + l2
|
+ ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", l2=" + l2
|
||||||
+ ", l1=" + l1);
|
+ ", l1=" + l1);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
|
|
@ -78,7 +78,7 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest {
|
||||||
|
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
.hasBias(true) //Layer 0: Always have a bias
|
.hasBias(true) //ILayer 0: Always have a bias
|
||||||
.build())
|
.build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(layerSize).nOut(layerSize)
|
.layer(1, new DenseLayer.Builder().nIn(layerSize).nOut(layerSize)
|
||||||
|
|
||||||
|
|
|
@ -137,7 +137,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Starting test: " + testName);
|
System.out.println("Starting test: " + testName);
|
||||||
|
@ -244,7 +244,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Starting test: " + testName);
|
System.out.println("Starting test: " + testName);
|
||||||
|
@ -393,7 +393,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(testName);
|
System.out.println(testName);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Starting test: " + testName);
|
System.out.println("Starting test: " + testName);
|
||||||
|
|
|
@ -124,7 +124,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -195,7 +195,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int l = 0; l < mln.getnLayers(); l++)
|
// for (int l = 0; l < mln.getnLayers(); l++)
|
||||||
// System.out.println("Layer " + l + " # params: " + mln.getLayer(l).numParams());
|
// System.out.println("ILayer " + l + " # params: " + mln.getLayer(l).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS,
|
boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS,
|
||||||
|
@ -283,7 +283,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS,
|
boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS,
|
||||||
|
@ -325,7 +325,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS,
|
boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS,
|
||||||
|
|
|
@ -133,8 +133,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
|
|
||||||
//Learning rate without layerwise override:
|
//Learning rate without layerwise override:
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(0.3).list()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(0.3).list()
|
||||||
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
|
.layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
|
.layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
|
@ -143,8 +143,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
|
|
||||||
//With:
|
//With:
|
||||||
conf = new NeuralNetConfiguration.Builder().learningRate(0.3).list()
|
conf = new NeuralNetConfiguration.Builder().learningRate(0.3).list()
|
||||||
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
|
.layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).learningRate(0.2).build()).build();
|
.layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).learningRate(0.2).build()).build();
|
||||||
|
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
@ -154,8 +154,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
|
|
||||||
//L1 and L2 without layerwise override:
|
//L1 and L2 without layerwise override:
|
||||||
conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.2).list()
|
conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.2).list()
|
||||||
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
|
.layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
|
.layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build();
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
|
@ -166,8 +166,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
|
|
||||||
//L1 and L2 with layerwise override:
|
//L1 and L2 with layerwise override:
|
||||||
conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.2).list()
|
conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.2).list()
|
||||||
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l1(0.9).build())
|
.layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).l1(0.9).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.8).build()).build();
|
.layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).l2(0.8).build()).build();
|
||||||
net = new MultiLayerNetwork(conf);
|
net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
|
@ -326,8 +326,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr)
|
||||||
.updater(Updater.SGD)
|
.updater(Updater.SGD)
|
||||||
.learningRateDecayPolicy(LearningRatePolicy.Exponential).lrPolicyDecayRate(lrDecayRate).list()
|
.learningRateDecayPolicy(LearningRatePolicy.Exponential).lrPolicyDecayRate(lrDecayRate).list()
|
||||||
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
|
.layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
|
.layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
|
@ -345,8 +345,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
int iterations = 1;
|
int iterations = 1;
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
|
||||||
.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(lrDecayRate)
|
.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(lrDecayRate)
|
||||||
.lrPolicyPower(power).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
|
.lrPolicyPower(power).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
|
.layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
|
@ -367,8 +367,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
int iterations = 1;
|
int iterations = 1;
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
|
||||||
.learningRateDecayPolicy(LearningRatePolicy.Step).lrPolicyDecayRate(lrDecayRate)
|
.learningRateDecayPolicy(LearningRatePolicy.Step).lrPolicyDecayRate(lrDecayRate)
|
||||||
.lrPolicySteps(steps).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
|
.lrPolicySteps(steps).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
|
.layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
|
@ -388,8 +388,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
int iterations = 1;
|
int iterations = 1;
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
|
||||||
.learningRateDecayPolicy(LearningRatePolicy.Poly).lrPolicyDecayRate(lrDecayRate)
|
.learningRateDecayPolicy(LearningRatePolicy.Poly).lrPolicyDecayRate(lrDecayRate)
|
||||||
.lrPolicyPower(power).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
|
.lrPolicyPower(power).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
|
.layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
|
@ -409,8 +409,8 @@ public class LayerConfigTest extends BaseDL4JTest {
|
||||||
int iterations = 1;
|
int iterations = 1;
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
|
||||||
.learningRateDecayPolicy(LearningRatePolicy.Sigmoid).lrPolicyDecayRate(lrDecayRate)
|
.learningRateDecayPolicy(LearningRatePolicy.Sigmoid).lrPolicyDecayRate(lrDecayRate)
|
||||||
.lrPolicySteps(steps).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
|
.lrPolicySteps(steps).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
|
.layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
|
|
|
@ -229,7 +229,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
if (seenLayers.size() < layerClasses.size()) {
|
if (seenLayers.size() < layerClasses.size()) {
|
||||||
for (Class<?> c : layerClasses) {
|
for (Class<?> c : layerClasses) {
|
||||||
if (!seenLayers.contains(c) && !ignoreClasses.contains(c)) {
|
if (!seenLayers.contains(c) && !ignoreClasses.contains(c)) {
|
||||||
log.warn("Layer class not tested for global vs. network datatypes: {}", c);
|
log.warn("ILayer class not tested for global vs. network datatypes: {}", c);
|
||||||
fail = true;
|
fail = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -279,7 +279,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void logUsedClasses(ComputationGraph net) {
|
public static void logUsedClasses(ComputationGraph net) {
|
||||||
ComputationGraphConfiguration conf = net.getConfiguration();
|
ComputationGraphConfiguration conf = net.getComputationGraphConfiguration();
|
||||||
for (GraphVertex gv : conf.getVertices().values()) {
|
for (GraphVertex gv : conf.getVertices().values()) {
|
||||||
seenVertices.add(gv.getClass());
|
seenVertices.add(gv.getClass());
|
||||||
if (gv instanceof LayerVertex) {
|
if (gv instanceof LayerVertex) {
|
||||||
|
|
|
@ -65,7 +65,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int timeSeriesLength = 12;
|
int timeSeriesLength = 12;
|
||||||
|
|
||||||
//4 layer network: 2 GravesLSTM + DenseLayer + RnnOutputLayer. Hence also tests preprocessors.
|
//4 layer network: 2 GravesLSTM + DenseLayerConfiguration + RnnOutputLayer. Hence also tests preprocessors.
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(5).nOut(7)
|
.addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(5).nOut(7)
|
||||||
|
@ -208,7 +208,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int timeSeriesLength = 12;
|
int timeSeriesLength = 12;
|
||||||
|
|
||||||
//4 layer network: 2 GravesLSTM + DenseLayer + RnnOutputLayer. Hence also tests preprocessors.
|
//4 layer network: 2 GravesLSTM + DenseLayerConfiguration + RnnOutputLayer. Hence also tests preprocessors.
|
||||||
//Network architecture: lstm0 -> Dense -> RnnOutputLayer0
|
//Network architecture: lstm0 -> Dense -> RnnOutputLayer0
|
||||||
// and lstm1 -> Dense -> RnnOutputLayer1
|
// and lstm1 -> Dense -> RnnOutputLayer1
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
|
||||||
|
@ -391,9 +391,9 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
|
||||||
graphTBPTT.init();
|
graphTBPTT.init();
|
||||||
graphTBPTT.clearTbpttState = false;
|
graphTBPTT.clearTbpttState = false;
|
||||||
|
|
||||||
assertEquals(BackpropType.TruncatedBPTT, graphTBPTT.getConfiguration().getBackpropType());
|
assertEquals(BackpropType.TruncatedBPTT, graphTBPTT.getComputationGraphConfiguration().getBackpropType());
|
||||||
assertEquals(timeSeriesLength, graphTBPTT.getConfiguration().getTbpttFwdLength());
|
assertEquals(timeSeriesLength, graphTBPTT.getComputationGraphConfiguration().getTbpttFwdLength());
|
||||||
assertEquals(timeSeriesLength, graphTBPTT.getConfiguration().getTbpttBackLength());
|
assertEquals(timeSeriesLength, graphTBPTT.getComputationGraphConfiguration().getTbpttBackLength());
|
||||||
|
|
||||||
INDArray inputData = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
|
INDArray inputData = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength);
|
||||||
INDArray labels = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength);
|
INDArray labels = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength);
|
||||||
|
|
|
@ -42,7 +42,6 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@ -168,8 +167,8 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
ComputationGraph cg = net.toComputationGraph();
|
ComputationGraph cg = net.toComputationGraph();
|
||||||
cg.getConfiguration().setInferenceWorkspaceMode(wsm);
|
cg.getComputationGraphConfiguration().setInferenceWorkspaceMode(wsm);
|
||||||
cg.getConfiguration().setTrainingWorkspaceMode(wsm);
|
cg.getComputationGraphConfiguration().setTrainingWorkspaceMode(wsm);
|
||||||
DataSetIterator ds = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(1, true, 12345), 1);
|
DataSetIterator ds = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(1, true, 12345), 1);
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
net.pretrainLayer(0, ds);
|
net.pretrainLayer(0, ds);
|
||||||
|
|
|
@ -1033,15 +1033,15 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(50, 150);
|
DataSetIterator iter = new IrisDataSetIterator(50, 150);
|
||||||
|
|
||||||
assertEquals(0, network.getConfiguration().getIterationCount());
|
assertEquals(0, network.getComputationGraphConfiguration().getIterationCount());
|
||||||
network.fit(iter);
|
network.fit(iter);
|
||||||
assertEquals(3, network.getConfiguration().getIterationCount());
|
assertEquals(3, network.getComputationGraphConfiguration().getIterationCount());
|
||||||
iter.reset();
|
iter.reset();
|
||||||
network.fit(iter);
|
network.fit(iter);
|
||||||
assertEquals(6, network.getConfiguration().getIterationCount());
|
assertEquals(6, network.getComputationGraphConfiguration().getIterationCount());
|
||||||
iter.reset();
|
iter.reset();
|
||||||
network.fit(iter.next());
|
network.fit(iter.next());
|
||||||
assertEquals(7, network.getConfiguration().getIterationCount());
|
assertEquals(7, network.getComputationGraphConfiguration().getIterationCount());
|
||||||
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
ModelSerializer.writeModel(network, baos, true);
|
ModelSerializer.writeModel(network, baos, true);
|
||||||
|
@ -1049,7 +1049,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
|
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(asBytes);
|
ByteArrayInputStream bais = new ByteArrayInputStream(asBytes);
|
||||||
ComputationGraph net = ModelSerializer.restoreComputationGraph(bais, true);
|
ComputationGraph net = ModelSerializer.restoreComputationGraph(bais, true);
|
||||||
assertEquals(7, net.getConfiguration().getIterationCount());
|
assertEquals(7, net.getComputationGraphConfiguration().getIterationCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -1272,18 +1272,18 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
assertEquals(0, net.getConfiguration().getEpochCount());
|
assertEquals(0, net.getComputationGraphConfiguration().getEpochCount());
|
||||||
|
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
for( int i=0; i<4; i++ ){
|
for( int i=0; i<4; i++ ){
|
||||||
assertEquals(i, net.getConfiguration().getEpochCount());
|
assertEquals(i, net.getComputationGraphConfiguration().getEpochCount());
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
assertEquals(i+1, net.getConfiguration().getEpochCount());
|
assertEquals(i+1, net.getComputationGraphConfiguration().getEpochCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(4, net.getConfiguration().getEpochCount());
|
assertEquals(4, net.getComputationGraphConfiguration().getEpochCount());
|
||||||
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
|
|
||||||
|
@ -1293,7 +1293,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
||||||
|
|
||||||
ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true);
|
ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true);
|
||||||
assertEquals(4, restored.getConfiguration().getEpochCount());
|
assertEquals(4, restored.getComputationGraphConfiguration().getEpochCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -1619,13 +1619,13 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
GraphIndices indices = cg.calculateIndices();
|
GraphIndices indices = cg.calculateIndices();
|
||||||
|
|
||||||
int[] order = cg.topologicalSortOrder();
|
int[] order = cg.topologicalSortOrder();
|
||||||
List<String> strOrder = cg.getConfiguration().getTopologicalOrderStr();
|
List<String> strOrder = cg.getComputationGraphConfiguration().getTopologicalOrderStr();
|
||||||
INDArray[] out1 = cg.output(in);
|
INDArray[] out1 = cg.output(in);
|
||||||
|
|
||||||
//Check it's the same after loading:
|
//Check it's the same after loading:
|
||||||
ComputationGraph cg2 = TestUtils.testModelSerialization(cg);
|
ComputationGraph cg2 = TestUtils.testModelSerialization(cg);
|
||||||
int[] order2 = cg2.topologicalSortOrder();
|
int[] order2 = cg2.topologicalSortOrder();
|
||||||
List<String> strOrder2 = cg.getConfiguration().getTopologicalOrderStr();
|
List<String> strOrder2 = cg.getComputationGraphConfiguration().getTopologicalOrderStr();
|
||||||
assertArrayEquals(order, order2);
|
assertArrayEquals(order, order2);
|
||||||
assertEquals(strOrder, strOrder2);
|
assertEquals(strOrder, strOrder2);
|
||||||
|
|
||||||
|
@ -1633,7 +1633,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
assertArrayEquals(out1, out2);
|
assertArrayEquals(out1, out2);
|
||||||
|
|
||||||
//Delete the topological order, ensure it gets recreated properly:
|
//Delete the topological order, ensure it gets recreated properly:
|
||||||
ComputationGraphConfiguration conf3 = cg2.getConfiguration().clone();
|
ComputationGraphConfiguration conf3 = cg2.getComputationGraphConfiguration().clone();
|
||||||
conf3.setTopologicalOrder(null);
|
conf3.setTopologicalOrder(null);
|
||||||
conf3.setTopologicalOrderStr(null);
|
conf3.setTopologicalOrderStr(null);
|
||||||
ComputationGraph cg3 = new ComputationGraph(conf3);
|
ComputationGraph cg3 = new ComputationGraph(conf3);
|
||||||
|
@ -1641,7 +1641,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
cg3.setParams(cg2.params());
|
cg3.setParams(cg2.params());
|
||||||
|
|
||||||
int[] order3 = cg3.topologicalSortOrder();
|
int[] order3 = cg3.topologicalSortOrder();
|
||||||
List<String> strOrder3 = cg.getConfiguration().getTopologicalOrderStr();
|
List<String> strOrder3 = cg.getComputationGraphConfiguration().getTopologicalOrderStr();
|
||||||
INDArray[] out3 = cg3.output(in);
|
INDArray[] out3 = cg3.output(in);
|
||||||
assertArrayEquals(order, order3);
|
assertArrayEquals(order, order3);
|
||||||
assertEquals(strOrder, strOrder3);
|
assertEquals(strOrder, strOrder3);
|
||||||
|
|
|
@ -235,7 +235,7 @@ public class FrozenLayerTest extends BaseDL4JTest {
|
||||||
ComputationGraph clonedModel = modelNow.clone();
|
ComputationGraph clonedModel = modelNow.clone();
|
||||||
|
|
||||||
//Check json
|
//Check json
|
||||||
assertEquals(clonedModel.getConfiguration().toJson(), modelNow.getConfiguration().toJson());
|
assertEquals(clonedModel.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
|
||||||
|
|
||||||
//Check params
|
//Check params
|
||||||
assertEquals(modelNow.params(), clonedModel.params());
|
assertEquals(modelNow.params(), clonedModel.params());
|
||||||
|
|
|
@ -50,7 +50,7 @@ public class TestDropout extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testDropoutSimple() throws Exception {
|
public void testDropoutSimple() throws Exception {
|
||||||
//Testing dropout with a single layer
|
//Testing dropout with a single layer
|
||||||
//Layer input: values should be set to either 0.0 or 2.0x original value
|
//ILayer input: values should be set to either 0.0 or 2.0x original value
|
||||||
|
|
||||||
int nIn = 8;
|
int nIn = 8;
|
||||||
int nOut = 8;
|
int nOut = 8;
|
||||||
|
|
|
@ -200,7 +200,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testEmbeddingForwardPass() {
|
public void testEmbeddingForwardPass() {
|
||||||
//With the same parameters, embedding layer should have same activations as the equivalent one-hot representation
|
//With the same parameters, embedding layer should have same activations as the equivalent one-hot representation
|
||||||
// input with a DenseLayer
|
// input with a DenseLayerConfiguration
|
||||||
|
|
||||||
int nClassesIn = 10;
|
int nClassesIn = 10;
|
||||||
|
|
||||||
|
@ -243,7 +243,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testEmbeddingBackwardPass() {
|
public void testEmbeddingBackwardPass() {
|
||||||
//With the same parameters, embedding layer should have same activations as the equivalent one-hot representation
|
//With the same parameters, embedding layer should have same activations as the equivalent one-hot representation
|
||||||
// input with a DenseLayer
|
// input with a DenseLayerConfiguration
|
||||||
|
|
||||||
int nClassesIn = 10;
|
int nClassesIn = 10;
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest {
|
||||||
+ "ocnn" + "sigmoid" + ", doLearningFirst="
|
+ "ocnn" + "sigmoid" + ", doLearningFirst="
|
||||||
+ doLearningFirst);
|
+ doLearningFirst);
|
||||||
for (int j = 0; j < network.getnLayers(); j++)
|
for (int j = 0; j < network.getnLayers(); j++)
|
||||||
System.out.println("Layer " + j + " # params: " + network.getLayer(j).numParams());
|
System.out.println("ILayer " + j + " # params: " + network.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(network, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(network, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
|
|
@ -98,7 +98,7 @@ public class SameDiffDense extends SameDiffLayer {
|
||||||
if(DefaultParamInitializer.BIAS_KEY.equals(e.getKey())){
|
if(DefaultParamInitializer.BIAS_KEY.equals(e.getKey())){
|
||||||
e.getValue().assign(0.0);
|
e.getValue().assign(0.0);
|
||||||
} else {
|
} else {
|
||||||
//Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayer
|
//Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayerConfiguration
|
||||||
WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, weightInit, null, 'f', e.getValue());
|
WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, weightInit, null, 'f', e.getValue());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,14 +72,14 @@ public class SameDiffDenseVertex extends SameDiffVertex {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initializeParameters(Map<String, INDArray> params) {
|
public void initializeParameters(Map<String, INDArray> params) {
|
||||||
//Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayer
|
//Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayerConfiguration
|
||||||
WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, weightInit, null, 'f', params.get("W"));
|
WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, weightInit, null, 'f', params.get("W"));
|
||||||
params.get("b").assign(0.0);
|
params.get("b").assign(0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char paramReshapeOrder(String paramName){
|
public char paramReshapeOrder(String paramName){
|
||||||
return 'f'; //To match DL4J DenseLayer - for easy comparison
|
return 'f'; //To match DL4J DenseLayerConfiguration - for easy comparison
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -73,8 +73,8 @@ public class WorkspaceTests extends BaseDL4JTest {
|
||||||
ComputationGraph c = createNet();
|
ComputationGraph c = createNet();
|
||||||
for (WorkspaceMode wm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
|
for (WorkspaceMode wm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
|
||||||
log.info("Starting test: {}", wm);
|
log.info("Starting test: {}", wm);
|
||||||
c.getConfiguration().setTrainingWorkspaceMode(wm);
|
c.getComputationGraphConfiguration().setTrainingWorkspaceMode(wm);
|
||||||
c.getConfiguration().setInferenceWorkspaceMode(wm);
|
c.getComputationGraphConfiguration().setInferenceWorkspaceMode(wm);
|
||||||
|
|
||||||
INDArray f = Nd4j.rand(8, 1, 28, 28);
|
INDArray f = Nd4j.rand(8, 1, 28, 28);
|
||||||
INDArray l = Nd4j.rand(8, 10);
|
INDArray l = Nd4j.rand(8, 10);
|
||||||
|
@ -666,8 +666,8 @@ public class WorkspaceTests extends BaseDL4JTest {
|
||||||
ComputationGraph c = createNet();
|
ComputationGraph c = createNet();
|
||||||
for (WorkspaceMode wm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
|
for (WorkspaceMode wm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
|
||||||
log.info("Starting test: {}", wm);
|
log.info("Starting test: {}", wm);
|
||||||
c.getConfiguration().setTrainingWorkspaceMode(wm);
|
c.getComputationGraphConfiguration().setTrainingWorkspaceMode(wm);
|
||||||
c.getConfiguration().setInferenceWorkspaceMode(wm);
|
c.getComputationGraphConfiguration().setInferenceWorkspaceMode(wm);
|
||||||
|
|
||||||
INDArray f = Nd4j.rand(8, 1, 28, 28);
|
INDArray f = Nd4j.rand(8, 1, 28, 28);
|
||||||
INDArray l = Nd4j.rand(8, 10);
|
INDArray l = Nd4j.rand(8, 10);
|
||||||
|
|
|
@ -995,7 +995,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCompareLayerMethods(){
|
public void testCompareLayerMethods(){
|
||||||
//Simple test: compare .layer(int, Layer) and .layer(Layer) are identical
|
//Simple test: compare .layer(int, ILayer) and .layer(ILayer) are identical
|
||||||
|
|
||||||
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(123).list()
|
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(123).list()
|
||||||
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER)
|
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER)
|
||||||
|
|
|
@ -261,7 +261,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int timeSeriesLength = 12;
|
int timeSeriesLength = 12;
|
||||||
|
|
||||||
//4 layer network: 2 GravesLSTM + DenseLayer + RnnOutputLayer. Hence also tests preprocessors.
|
//4 layer network: 2 GravesLSTM + DenseLayerConfiguration + RnnOutputLayer. Hence also tests preprocessors.
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list()
|
||||||
.layer(0, l0)
|
.layer(0, l0)
|
||||||
.layer(1, l1)
|
.layer(1, l1)
|
||||||
|
|
|
@ -216,8 +216,8 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest {
|
||||||
net2GradUpd.getUpdater().getStateViewArray());
|
net2GradUpd.getUpdater().getStateViewArray());
|
||||||
|
|
||||||
//Remove the next 2 lines: fails - as net 1 is 1 iteration ahead
|
//Remove the next 2 lines: fails - as net 1 is 1 iteration ahead
|
||||||
net1GradCalc.getConfiguration().setIterationCount(0);
|
net1GradCalc.getComputationGraphConfiguration().setIterationCount(0);
|
||||||
net2GradUpd.getConfiguration().setIterationCount(0);
|
net2GradUpd.getComputationGraphConfiguration().setIterationCount(0);
|
||||||
|
|
||||||
|
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
|
|
|
@ -120,7 +120,7 @@ public class TestTransferLearningModelSerializer extends BaseDL4JTest {
|
||||||
assertTrue(withFrozen.getLayer(0) instanceof FrozenLayer);
|
assertTrue(withFrozen.getLayer(0) instanceof FrozenLayer);
|
||||||
assertTrue(withFrozen.getLayer(1) instanceof FrozenLayer);
|
assertTrue(withFrozen.getLayer(1) instanceof FrozenLayer);
|
||||||
|
|
||||||
Map<String, GraphVertex> m = withFrozen.getConfiguration().getVertices();
|
Map<String, GraphVertex> m = withFrozen.getComputationGraphConfiguration().getVertices();
|
||||||
Layer l0 = ((LayerVertex) m.get("0")).getLayerConf().getLayer();
|
Layer l0 = ((LayerVertex) m.get("0")).getLayerConf().getLayer();
|
||||||
Layer l1 = ((LayerVertex) m.get("1")).getLayerConf().getLayer();
|
Layer l1 = ((LayerVertex) m.get("1")).getLayerConf().getLayer();
|
||||||
assertTrue(l0 instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer);
|
assertTrue(l0 instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer);
|
||||||
|
|
|
@ -102,7 +102,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
//Check json
|
//Check json
|
||||||
assertEquals(expectedConf.toJson(), modelNow.getConfiguration().toJson());
|
assertEquals(expectedConf.toJson(), modelNow.getComputationGraphConfiguration().toJson());
|
||||||
|
|
||||||
//Check params after fit
|
//Check params after fit
|
||||||
modelNow.fit(randomData);
|
modelNow.fit(randomData);
|
||||||
|
@ -382,7 +382,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
modelExpectedArch.getVertex("layer0").setLayerAsFrozen();
|
modelExpectedArch.getVertex("layer0").setLayerAsFrozen();
|
||||||
modelExpectedArch.getVertex("layer1").setLayerAsFrozen();
|
modelExpectedArch.getVertex("layer1").setLayerAsFrozen();
|
||||||
|
|
||||||
assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson());
|
assertEquals(modelExpectedArch.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson());
|
||||||
|
|
||||||
modelNow.setParams(modelExpectedArch.params());
|
modelNow.setParams(modelExpectedArch.params());
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -445,7 +445,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
||||||
// assertEquals(confExpected, graph.getConfiguration());
|
// assertEquals(confExpected, graph.getConfiguration());
|
||||||
assertEquals(confExpected.toJson(), graph.getConfiguration().toJson());
|
assertEquals(confExpected.toJson(), graph.getComputationGraphConfiguration().toJson());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,7 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
|
||||||
.setOutputs("outLeft", "outCentre", "outRight").build();
|
.setOutputs("outLeft", "outCentre", "outRight").build();
|
||||||
ComputationGraph expectedModel = new ComputationGraph(expectedConf);
|
ComputationGraph expectedModel = new ComputationGraph(expectedConf);
|
||||||
expectedModel.init();
|
expectedModel.init();
|
||||||
assertEquals(expectedConf.toJson(), modelSubset.getConfiguration().toJson());
|
assertEquals(expectedConf.toJson(), modelSubset.getComputationGraphConfiguration().toJson());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -764,7 +764,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/** Simple abstract class to deal with the fact that we don't care about the majority of the Model/Layer
|
/** Simple abstract class to deal with the fact that we don't care about the majority of the Model/ILayer
|
||||||
* methods here. Classes extending this model for optimizer tests need only implement the score() and
|
* methods here. Classes extending this model for optimizer tests need only implement the score() and
|
||||||
* gradient() methods.
|
* gradient() methods.
|
||||||
*/
|
*/
|
||||||
|
@ -907,7 +907,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray input() {
|
public INDArray input() {
|
||||||
//Work-around for BaseUpdater.postApply(): Uses Layer.input().size(0)
|
//Work-around for BaseUpdater.postApply(): Uses ILayer.input().size(0)
|
||||||
//in order to get mini-batch size. i.e., divide by 1 here.
|
//in order to get mini-batch size. i.e., divide by 1 here.
|
||||||
return Nd4j.zeros(1);
|
return Nd4j.zeros(1);
|
||||||
}
|
}
|
||||||
|
|
|
@ -221,7 +221,7 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);
|
ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);
|
||||||
|
|
||||||
ComputationGraphConfiguration conf = net.getConfiguration();
|
ComputationGraphConfiguration conf = net.getComputationGraphConfiguration();
|
||||||
assertEquals(3, conf.getVertices().size());
|
assertEquals(3, conf.getVertices().size());
|
||||||
|
|
||||||
GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
|
GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
|
||||||
|
|
|
@ -221,7 +221,7 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);
|
ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);
|
||||||
|
|
||||||
ComputationGraphConfiguration conf = net.getConfiguration();
|
ComputationGraphConfiguration conf = net.getComputationGraphConfiguration();
|
||||||
assertEquals(3, conf.getVertices().size());
|
assertEquals(3, conf.getVertices().size());
|
||||||
|
|
||||||
GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
|
GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
|
||||||
|
|
|
@ -237,7 +237,7 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);
|
ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);
|
||||||
|
|
||||||
ComputationGraphConfiguration conf = net.getConfiguration();
|
ComputationGraphConfiguration conf = net.getComputationGraphConfiguration();
|
||||||
assertEquals(3, conf.getVertices().size());
|
assertEquals(3, conf.getVertices().size());
|
||||||
|
|
||||||
GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
|
GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
|
||||||
|
|
|
@ -171,7 +171,7 @@ public class RegressionTest100a extends BaseDL4JTest {
|
||||||
int nBoxes = 5;
|
int nBoxes = 5;
|
||||||
int nClasses = 10;
|
int nClasses = 10;
|
||||||
|
|
||||||
ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getConfiguration().getVertices().get("convolution2d_9")).getLayerConf().getLayer();
|
ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getComputationGraphConfiguration().getVertices().get("convolution2d_9")).getLayerConf().getLayer();
|
||||||
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
||||||
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
||||||
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
||||||
|
|
|
@ -206,7 +206,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
int nBoxes = 5;
|
int nBoxes = 5;
|
||||||
int nClasses = 10;
|
int nClasses = 10;
|
||||||
|
|
||||||
ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getConfiguration().getVertices().get("convolution2d_9")).getLayerConf().getLayer();
|
ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getComputationGraphConfiguration().getVertices().get("convolution2d_9")).getLayerConf().getLayer();
|
||||||
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
||||||
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
||||||
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
||||||
|
|
|
@ -224,7 +224,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
int nBoxes = 5;
|
int nBoxes = 5;
|
||||||
int nClasses = 10;
|
int nClasses = 10;
|
||||||
|
|
||||||
ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getConfiguration().getVertices()
|
ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getComputationGraphConfiguration().getVertices()
|
||||||
.get("convolution2d_9")).getLayerConf().getLayer();
|
.get("convolution2d_9")).getLayerConf().getLayer();
|
||||||
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
||||||
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
||||||
|
|
|
@ -205,7 +205,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
int nBoxes = 5;
|
int nBoxes = 5;
|
||||||
int nClasses = 10;
|
int nClasses = 10;
|
||||||
|
|
||||||
ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getConfiguration().getVertices()
|
ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getComputationGraphConfiguration().getVertices()
|
||||||
.get("convolution2d_9")).getLayerConf().getLayer();
|
.get("convolution2d_9")).getLayerConf().getLayer();
|
||||||
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
||||||
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
||||||
|
|
|
@ -94,7 +94,7 @@ public class CustomLayer extends FeedForwardLayer {
|
||||||
@Override
|
@Override
|
||||||
public ParamInitializer initializer() {
|
public ParamInitializer initializer() {
|
||||||
//This method returns the parameter initializer for this type of layer
|
//This method returns the parameter initializer for this type of layer
|
||||||
//In this case, we can use the DefaultParamInitializer, which is the same one used for DenseLayer
|
//In this case, we can use the DefaultParamInitializer, which is the same one used for DenseLayerConfiguration
|
||||||
//For more complex layers, you may need to implement a custom parameter initializer
|
//For more complex layers, you may need to implement a custom parameter initializer
|
||||||
//See the various parameter initializers here:
|
//See the various parameter initializers here:
|
||||||
//https://github.com/deeplearning4j/deeplearning4j/tree/master/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/params
|
//https://github.com/deeplearning4j/deeplearning4j/tree/master/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/params
|
||||||
|
@ -108,7 +108,7 @@ public class CustomLayer extends FeedForwardLayer {
|
||||||
//If you don't need this functionality for your custom layer, you can return a LayerMemoryReport
|
//If you don't need this functionality for your custom layer, you can return a LayerMemoryReport
|
||||||
// with all 0s, or
|
// with all 0s, or
|
||||||
|
|
||||||
//This implementation: based on DenseLayer implementation
|
//This implementation: based on DenseLayerConfiguration implementation
|
||||||
InputType outputType = getOutputType(-1, inputType);
|
InputType outputType = getOutputType(-1, inputType);
|
||||||
|
|
||||||
val numParams = initializer().numParams(this);
|
val numParams = initializer().numParams(this);
|
||||||
|
@ -131,7 +131,7 @@ public class CustomLayer extends FeedForwardLayer {
|
||||||
.workingMemory(0, 0, trainSizeFixed,
|
.workingMemory(0, 0, trainSizeFixed,
|
||||||
trainSizeVariable) //No additional memory (beyond activations) for inference
|
trainSizeVariable) //No additional memory (beyond activations) for inference
|
||||||
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS,
|
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS,
|
||||||
MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer
|
MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayerConfiguration
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -117,7 +117,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
|
||||||
String str = FileUtils.readFileToString(list[0]);
|
String str = FileUtils.readFileToString(list[0]);
|
||||||
// System.out.println(str);
|
// System.out.println(str);
|
||||||
assertTrue(str.contains("Network Information"));
|
assertTrue(str.contains("Network Information"));
|
||||||
assertTrue(str.contains("Layer Helpers"));
|
assertTrue(str.contains("ILayer Helpers"));
|
||||||
assertTrue(str.contains("JavaCPP"));
|
assertTrue(str.contains("JavaCPP"));
|
||||||
assertTrue(str.contains("ScoreIterationListener"));
|
assertTrue(str.contains("ScoreIterationListener"));
|
||||||
|
|
||||||
|
@ -134,7 +134,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
|
||||||
assertEquals(1, list.length);
|
assertEquals(1, list.length);
|
||||||
str = FileUtils.readFileToString(list[0]);
|
str = FileUtils.readFileToString(list[0]);
|
||||||
assertTrue(str.contains("Network Information"));
|
assertTrue(str.contains("Network Information"));
|
||||||
assertTrue(str.contains("Layer Helpers"));
|
assertTrue(str.contains("ILayer Helpers"));
|
||||||
assertTrue(str.contains("JavaCPP"));
|
assertTrue(str.contains("JavaCPP"));
|
||||||
assertTrue(str.contains("ScoreIterationListener(1)"));
|
assertTrue(str.contains("ScoreIterationListener(1)"));
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
|
||||||
// System.out.println("///////////////////////////////////////////////////////////");
|
// System.out.println("///////////////////////////////////////////////////////////");
|
||||||
|
|
||||||
assertTrue(mlnMemoryInfo.contains("Network Information"));
|
assertTrue(mlnMemoryInfo.contains("Network Information"));
|
||||||
assertTrue(mlnMemoryInfo.contains("Layer Helpers"));
|
assertTrue(mlnMemoryInfo.contains("ILayer Helpers"));
|
||||||
assertTrue(mlnMemoryInfo.contains("JavaCPP"));
|
assertTrue(mlnMemoryInfo.contains("JavaCPP"));
|
||||||
assertTrue(mlnMemoryInfo.contains("ScoreIterationListener(1)"));
|
assertTrue(mlnMemoryInfo.contains("ScoreIterationListener(1)"));
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
|
||||||
assertEquals(1, list.length);
|
assertEquals(1, list.length);
|
||||||
str = FileUtils.readFileToString(list[0]);
|
str = FileUtils.readFileToString(list[0]);
|
||||||
assertTrue(str.contains("Network Information"));
|
assertTrue(str.contains("Network Information"));
|
||||||
assertTrue(str.contains("Layer Helpers"));
|
assertTrue(str.contains("ILayer Helpers"));
|
||||||
assertTrue(str.contains("JavaCPP"));
|
assertTrue(str.contains("JavaCPP"));
|
||||||
assertTrue(str.contains("ScoreIterationListener(1)"));
|
assertTrue(str.contains("ScoreIterationListener(1)"));
|
||||||
|
|
||||||
|
@ -187,7 +187,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
|
||||||
assertEquals(1, list.length);
|
assertEquals(1, list.length);
|
||||||
str = FileUtils.readFileToString(list[0]);
|
str = FileUtils.readFileToString(list[0]);
|
||||||
assertTrue(str.contains("Network Information"));
|
assertTrue(str.contains("Network Information"));
|
||||||
assertTrue(str.contains("Layer Helpers"));
|
assertTrue(str.contains("ILayer Helpers"));
|
||||||
assertTrue(str.contains("JavaCPP"));
|
assertTrue(str.contains("JavaCPP"));
|
||||||
assertTrue(str.contains("ScoreIterationListener(1)"));
|
assertTrue(str.contains("ScoreIterationListener(1)"));
|
||||||
|
|
||||||
|
@ -203,7 +203,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest {
|
||||||
// System.out.println("///////////////////////////////////////////////////////////");
|
// System.out.println("///////////////////////////////////////////////////////////");
|
||||||
|
|
||||||
assertTrue(cgMemoryInfo.contains("Network Information"));
|
assertTrue(cgMemoryInfo.contains("Network Information"));
|
||||||
assertTrue(cgMemoryInfo.contains("Layer Helpers"));
|
assertTrue(cgMemoryInfo.contains("ILayer Helpers"));
|
||||||
assertTrue(cgMemoryInfo.contains("JavaCPP"));
|
assertTrue(cgMemoryInfo.contains("JavaCPP"));
|
||||||
assertTrue(cgMemoryInfo.contains("ScoreIterationListener(1)"));
|
assertTrue(cgMemoryInfo.contains("ScoreIterationListener(1)"));
|
||||||
|
|
||||||
|
|
|
@ -151,7 +151,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile);
|
ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile);
|
||||||
|
|
||||||
assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson());
|
assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
|
||||||
assertEquals(cg.params(), network.params());
|
assertEquals(cg.params(), network.params());
|
||||||
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
@ -177,7 +177,7 @@ public class ModelSerializerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);
|
ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);
|
||||||
|
|
||||||
assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson());
|
assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson());
|
||||||
assertEquals(cg.params(), network.params());
|
assertEquals(cg.params(), network.params());
|
||||||
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
|
@ -198,7 +198,7 @@ public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper {
|
||||||
}
|
}
|
||||||
if (!(activationFn instanceof ActivationTanH)) {
|
if (!(activationFn instanceof ActivationTanH)) {
|
||||||
supported = false;
|
supported = false;
|
||||||
log.warn("Not supported: Layer activation functions != ActivationTanH");
|
log.warn("Not supported: ILayer activation functions != ActivationTanH");
|
||||||
}
|
}
|
||||||
if (hasPeepholeConnections) {
|
if (hasPeepholeConnections) {
|
||||||
supported = false;
|
supported = false;
|
||||||
|
|
|
@ -295,7 +295,7 @@ public class KerasLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Copy Keras layer weights to DL4J Layer.
|
* Copy Keras layer weights to DL4J ILayer.
|
||||||
*
|
*
|
||||||
* @param layer DL4J layer
|
* @param layer DL4J layer
|
||||||
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
||||||
|
@ -358,7 +358,7 @@ public class KerasLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Whether this Keras layer maps to a DL4J Layer.
|
* Whether this Keras layer maps to a DL4J ILayer.
|
||||||
*
|
*
|
||||||
* @return true or false
|
* @return true or false
|
||||||
*/
|
*/
|
||||||
|
@ -367,9 +367,9 @@ public class KerasLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets corresponding DL4J Layer, if any.
|
* Gets corresponding DL4J ILayer, if any.
|
||||||
*
|
*
|
||||||
* @return DL4J Layer
|
* @return DL4J ILayer
|
||||||
* @see org.deeplearning4j.nn.api.Layer
|
* @see org.deeplearning4j.nn.api.Layer
|
||||||
*/
|
*/
|
||||||
public Layer getLayer() {
|
public Layer getLayer() {
|
||||||
|
|
|
@ -583,8 +583,8 @@ public class KerasModel {
|
||||||
graphBuilder.addVertex(layer.getLayerName(), layer.getVertex(), inboundLayerNamesArray);
|
graphBuilder.addVertex(layer.getLayerName(), layer.getVertex(), inboundLayerNamesArray);
|
||||||
} else if (layer.isInputPreProcessor()) {
|
} else if (layer.isInputPreProcessor()) {
|
||||||
if (preprocessor == null)
|
if (preprocessor == null)
|
||||||
throw new UnsupportedKerasConfigurationException("Layer " + layer.getLayerName()
|
throw new UnsupportedKerasConfigurationException("ILayer " + layer.getLayerName()
|
||||||
+ " could not be mapped to Layer, Vertex, or InputPreProcessor");
|
+ " could not be mapped to ILayer, Vertex, or InputPreProcessor");
|
||||||
graphBuilder.addVertex(layer.getLayerName(), new PreprocessorVertex(preprocessor),
|
graphBuilder.addVertex(layer.getLayerName(), new PreprocessorVertex(preprocessor),
|
||||||
inboundLayerNamesArray);
|
inboundLayerNamesArray);
|
||||||
}
|
}
|
||||||
|
|
|
@ -246,7 +246,7 @@ public class KerasLayerConfiguration {
|
||||||
private final String LAYER_FIELD_RATE = "rate";
|
private final String LAYER_FIELD_RATE = "rate";
|
||||||
private final String LAYER_FIELD_GAUSSIAN_VARIANCE = ""; // 1: sigma, 2: stddev
|
private final String LAYER_FIELD_GAUSSIAN_VARIANCE = ""; // 1: sigma, 2: stddev
|
||||||
|
|
||||||
/* Layer wrappers */
|
/* ILayer wrappers */
|
||||||
// Missing: TimeDistributed
|
// Missing: TimeDistributed
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -115,9 +115,9 @@ public class KerasDense extends KerasLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get DL4J DenseLayer.
|
* Get DL4J DenseLayerConfiguration.
|
||||||
*
|
*
|
||||||
* @return DenseLayer
|
* @return DenseLayerConfiguration
|
||||||
*/
|
*/
|
||||||
public DenseLayer getDenseLayer() {
|
public DenseLayer getDenseLayer() {
|
||||||
return (DenseLayer) this.layer;
|
return (DenseLayer) this.layer;
|
||||||
|
|
|
@ -211,10 +211,10 @@ public class KerasLSTM extends KerasLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get DL4J Layer. If returnSequences is true, this can be casted to an "LSTM" layer, otherwise it can be casted
|
* Get DL4J ILayer. If returnSequences is true, this can be casted to an "LSTM" layer, otherwise it can be casted
|
||||||
* to a "LastTimeStep" layer.
|
* to a "LastTimeStep" layer.
|
||||||
*
|
*
|
||||||
* @return LSTM Layer
|
* @return LSTM ILayer
|
||||||
*/
|
*/
|
||||||
public Layer getLSTMLayer() {
|
public Layer getLSTMLayer() {
|
||||||
return layer;
|
return layer;
|
||||||
|
|
|
@ -184,7 +184,7 @@ public class KerasSimpleRnn extends KerasLayer {
|
||||||
/**
|
/**
|
||||||
* Get DL4J SimpleRnn layer.
|
* Get DL4J SimpleRnn layer.
|
||||||
*
|
*
|
||||||
* @return SimpleRnn Layer
|
* @return SimpleRnn ILayer
|
||||||
*/
|
*/
|
||||||
public Layer getSimpleRnnLayer() {
|
public Layer getSimpleRnnLayer() {
|
||||||
return this.layer;
|
return this.layer;
|
||||||
|
|
|
@ -160,7 +160,7 @@ public class KerasBidirectional extends KerasLayer {
|
||||||
/**
|
/**
|
||||||
* Return the underlying recurrent layer of this bidirectional layer
|
* Return the underlying recurrent layer of this bidirectional layer
|
||||||
*
|
*
|
||||||
* @return Layer, recurrent layer
|
* @return ILayer, recurrent layer
|
||||||
*/
|
*/
|
||||||
public Layer getUnderlyingRecurrentLayer() {
|
public Layer getUnderlyingRecurrentLayer() {
|
||||||
return kerasRnnlayer.getLayer();
|
return kerasRnnlayer.getLayer();
|
||||||
|
@ -169,7 +169,7 @@ public class KerasBidirectional extends KerasLayer {
|
||||||
/**
|
/**
|
||||||
* Get DL4J Bidirectional layer.
|
* Get DL4J Bidirectional layer.
|
||||||
*
|
*
|
||||||
* @return Bidirectional Layer
|
* @return Bidirectional ILayer
|
||||||
*/
|
*/
|
||||||
public Bidirectional getBidirectionalLayer() {
|
public Bidirectional getBidirectionalLayer() {
|
||||||
return (Bidirectional) this.layer;
|
return (Bidirectional) this.layer;
|
||||||
|
|
|
@ -85,7 +85,7 @@ public class FullModelComparisons extends BaseDL4JTest {
|
||||||
|
|
||||||
System.out.println(model.summary());
|
System.out.println(model.summary());
|
||||||
|
|
||||||
// 1. Layer
|
// 1. ILayer
|
||||||
LSTM firstLstm = (LSTM) model.getLayer(0);
|
LSTM firstLstm = (LSTM) model.getLayer(0);
|
||||||
org.deeplearning4j.nn.conf.layers.LSTM firstConf =
|
org.deeplearning4j.nn.conf.layers.LSTM firstConf =
|
||||||
(org.deeplearning4j.nn.conf.layers.LSTM) firstLstm.conf().getLayer();
|
(org.deeplearning4j.nn.conf.layers.LSTM) firstLstm.conf().getLayer();
|
||||||
|
@ -123,7 +123,7 @@ public class FullModelComparisons extends BaseDL4JTest {
|
||||||
Assertions.assertEquals(b.getDouble(0, 192), -0.13569744, 1e-7); // Keras O
|
Assertions.assertEquals(b.getDouble(0, 192), -0.13569744, 1e-7); // Keras O
|
||||||
Assertions.assertEquals(b.getDouble(0, 0), -0.2587392, 1e-7); // Keras C
|
Assertions.assertEquals(b.getDouble(0, 0), -0.2587392, 1e-7); // Keras C
|
||||||
|
|
||||||
// 2. Layer
|
// 2. ILayer
|
||||||
LSTM secondLstm = (LSTM) ((LastTimeStepLayer) model.getLayer(1)).getUnderlying();
|
LSTM secondLstm = (LSTM) ((LastTimeStepLayer) model.getLayer(1)).getUnderlying();
|
||||||
org.deeplearning4j.nn.conf.layers.LSTM secondConf =
|
org.deeplearning4j.nn.conf.layers.LSTM secondConf =
|
||||||
(org.deeplearning4j.nn.conf.layers.LSTM) secondLstm.conf().getLayer();
|
(org.deeplearning4j.nn.conf.layers.LSTM) secondLstm.conf().getLayer();
|
||||||
|
|
|
@ -39,4 +39,13 @@ public interface LayerConfiguration {
|
||||||
*/
|
*/
|
||||||
org.deeplearning4j.nn.conf.inputs.InputType.Type getInputType();
|
org.deeplearning4j.nn.conf.inputs.InputType.Type getInputType();
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defines the valid input type for this Layer
|
||||||
|
*
|
||||||
|
* @return InputType
|
||||||
|
*/
|
||||||
|
org.deeplearning4j.nn.conf.inputs.InputType.Type getOutputType();
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation platform(projects.cavisCommonPlatform)
|
implementation platform(projects.cavisCommonPlatform)
|
||||||
implementation projects.cavisDnn.cavisDnnNnApi
|
// implementation projects.cavisDnn.cavisDnnNnApi
|
||||||
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
|
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
|
||||||
implementation 'org.lucee:oswego-concurrent:1.3.4'
|
implementation 'org.lucee:oswego-concurrent:1.3.4'
|
||||||
implementation projects.cavisDnn.cavisDnnCommon
|
implementation projects.cavisDnn.cavisDnnCommon
|
||||||
|
@ -58,3 +58,5 @@ dependencies {
|
||||||
implementation "com.squareup.okhttp3:okhttp"
|
implementation "com.squareup.okhttp3:okhttp"
|
||||||
implementation "com.squareup.okhttp3:logging-interceptor"
|
implementation "com.squareup.okhttp3:logging-interceptor"
|
||||||
}
|
}
|
||||||
|
sourceCompatibility = JavaVersion.VERSION_11
|
||||||
|
targetCompatibility = JavaVersion.VERSION_11
|
||||||
|
|
|
@ -19,10 +19,28 @@
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package net.brutex.ai.dnn.conf.layer;
|
package net.brutex.ai.dnn.api;
|
||||||
|
|
||||||
public abstract class LayerConfiguration {
|
/**
|
||||||
|
* This is an "executable" ILayer, that is based on a {@link ILayerConfiguration}
|
||||||
|
*/
|
||||||
|
public interface ILayer {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the underlying configuration for this ILayer
|
||||||
|
* @return configuration
|
||||||
|
*/
|
||||||
|
ILayerConfiguration getLayerConfiguration();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the underlying layer configuration
|
||||||
|
* @param conf The new configuration
|
||||||
|
*/
|
||||||
|
void setLayerConfiguration(ILayerConfiguration conf);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An implementation should provide a method to validate the network
|
||||||
|
* @return true if no errors found; false otherwise
|
||||||
|
*/
|
||||||
|
boolean isValid();
|
||||||
}
|
}
|
|
@ -19,34 +19,45 @@
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package net.brutex.ai.dnn.conf.layer;
|
package net.brutex.ai.dnn.api;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import net.brutex.ai.dnn.api.Layer;
|
|
||||||
import net.brutex.ai.dnn.api.NeuralNetwork;
|
|
||||||
import net.brutex.ai.dnn.conf.layer.AbstractLayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType.Type;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class FFLayer extends AbstractLayerConfiguration {
|
|
||||||
|
|
||||||
|
public interface ILayerConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create and return an instance of a LayerConfiguration.
|
* Create and return an instance of a ILayerConfiguration.
|
||||||
*
|
*
|
||||||
* @param network the "holding" network for the instance
|
* @param network the "holding" network for the instance
|
||||||
* @return the new layer instance
|
* @return the new layer instance
|
||||||
*/
|
*/
|
||||||
@Override
|
ILayer instantiate(IModel network);
|
||||||
public Layer instantiate(NeuralNetwork network) {
|
|
||||||
//Let's do some verifications first
|
|
||||||
if(getInputType() != Type.FF) {
|
/**
|
||||||
log.error("The {} layer configuration must use an InputType of {}, but found {}",
|
* Defines the valid input type for this ILayer
|
||||||
this.getClass().getSimpleName(),
|
*
|
||||||
Type.FF.name(),
|
* @return InputType
|
||||||
getInputType().name());
|
*/
|
||||||
}
|
org.deeplearning4j.nn.conf.inputs.InputType.Type getInputType();
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
/**
|
||||||
|
* Defines the valid input type for this ILayer
|
||||||
|
*
|
||||||
|
* @return InputType
|
||||||
|
*/
|
||||||
|
org.deeplearning4j.nn.conf.inputs.InputType.Type getOutputType();
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of trainable parameter in this layer
|
||||||
|
* @return number of parameter
|
||||||
|
*/
|
||||||
|
long numParameters();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An implementation should provide a method to validate the network
|
||||||
|
* @return true if no errors found; false otherwise
|
||||||
|
*/
|
||||||
|
boolean isValid();
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,86 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.ai.dnn.api;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Neural Network is an instance of a {@link INeuralNetworkConfiguration}, that can be trained,
|
||||||
|
* evaluated, saved, exported, etc. Its configuration state is defined with the
|
||||||
|
* {@link #setConfiguration(INeuralNetworkConfiguration)} and {@link #getConfiguration()} methods.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public interface IModel {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The configuration that defines this Neural Network
|
||||||
|
*
|
||||||
|
* @param conf the configuration to use for this network
|
||||||
|
*/
|
||||||
|
void setConfiguration(INeuralNetworkConfiguration conf);
|
||||||
|
INeuralNetworkConfiguration getConfiguration();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fit the model for one iteration on the provided data
|
||||||
|
*
|
||||||
|
* @param features the examples to classify (one example in each row)
|
||||||
|
* @param labels the example labels(a binary outcome matrix)
|
||||||
|
* @param featuresMask The mask array for the features (used for variable length time series, etc). May be null.
|
||||||
|
* @param labelsMask The mask array for the labels (used for variable length time series, etc). May be null.
|
||||||
|
*/
|
||||||
|
void fit(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method fits model with a given DataSet
|
||||||
|
*
|
||||||
|
* @param dataSet the dataset to use for training
|
||||||
|
*/
|
||||||
|
void fit(DataSet dataSet);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method fits model with a given MultiDataSet
|
||||||
|
*
|
||||||
|
* @param dataSet the multi dataset to use for training
|
||||||
|
*/
|
||||||
|
void fit(MultiDataSet dataSet);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The name of the Neural Network
|
||||||
|
* @return the name
|
||||||
|
*/
|
||||||
|
String getName();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the name for this Neural Network
|
||||||
|
* @param name the name
|
||||||
|
*/
|
||||||
|
void setName(String name);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An implementation should provide a method to validate the network
|
||||||
|
* @return true if no errors found; false otherwise
|
||||||
|
*/
|
||||||
|
boolean isValid();
|
||||||
|
|
||||||
|
}
|
|
@ -1,7 +1,7 @@
|
||||||
/*
|
/*
|
||||||
|
*
|
||||||
* ******************************************************************************
|
* ******************************************************************************
|
||||||
* *
|
* *
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
* * This program and the accompanying materials are made available under the
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
@ -16,10 +16,12 @@
|
||||||
* *
|
* *
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package net.brutex.ai.dnn.api;
|
||||||
|
|
||||||
|
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
||||||
import org.deeplearning4j.optimize.api.ConvexOptimizer;
|
import org.deeplearning4j.optimize.api.ConvexOptimizer;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -31,7 +33,7 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
/**
|
/**
|
||||||
* @author raver119
|
* @author raver119
|
||||||
*/
|
*/
|
||||||
public interface NeuralNetwork {
|
public interface INeuralNetwork {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method does initialization of model
|
* This method does initialization of model
|
||||||
|
@ -104,4 +106,17 @@ public interface NeuralNetwork {
|
||||||
* @param iterator
|
* @param iterator
|
||||||
*/
|
*/
|
||||||
<T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations);
|
<T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A neural network is created from a configuration.
|
||||||
|
* @param conf the configuration to create the network from
|
||||||
|
*/
|
||||||
|
void setConfiguration(NeuralNetworkConfiguration conf);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the configuration for this configuration
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
NeuralNetworkConfiguration getConfiguration();
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.ai.dnn.api;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface INeuralNetworkConfiguration {
|
||||||
|
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
/**
|
||||||
|
* Provides a flat list of all embedded layer configurations, this
|
||||||
|
* can only be called after the layer is initialized or {@link #getLayerConfigurations()} is
|
||||||
|
* called.
|
||||||
|
*
|
||||||
|
* @return unstacked layer configurations
|
||||||
|
|
||||||
|
List<ILayerConfiguration> getLayerConfigurations();
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This uncollables any stacked layer configurations within building blocks like
|
||||||
|
* @link BuildingBlockLayer}
|
||||||
|
|
||||||
|
void calculateInnerLayerConfigurations();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An implementation should provide a method to validate the network
|
||||||
|
* @return true if no errors found; false otherwise
|
||||||
|
|
||||||
|
boolean isValid();
|
||||||
|
}
|
||||||
|
**/
|
|
@ -22,32 +22,61 @@
|
||||||
package net.brutex.ai.dnn.conf;
|
package net.brutex.ai.dnn.conf;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
|
||||||
|
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||||
|
import java.io.IOException;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Random;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.Singular;
|
import lombok.Singular;
|
||||||
import lombok.extern.jackson.Jacksonized;
|
import lombok.extern.jackson.Jacksonized;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.brutex.ai.dnn.api.LayerConfiguration;
|
import net.brutex.ai.dnn.api.ILayerConfiguration;
|
||||||
|
import net.brutex.ai.dnn.api.INeuralNetworkConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.BackpropType;
|
import org.deeplearning4j.nn.conf.BackpropType;
|
||||||
import org.deeplearning4j.nn.conf.CacheMode;
|
import org.deeplearning4j.nn.conf.CacheMode;
|
||||||
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||||
|
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
|
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
|
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The NeuralNetworkConfiguration is a sequential container for the different layers in your
|
* The INeuralNetworkConfiguration is a sequential container for the different layers in your
|
||||||
* network (or other NeuralNetworkConfigurations). That said, NeuralNetworkConfigurations can be
|
* network (or other NeuralNetworkConfigurations). That said, NeuralNetworkConfigurations can be
|
||||||
* stacked.<br/><br/>
|
* stacked.<br/><br/>
|
||||||
* It then “chains” outputs to inputs sequentially for each NeuralNetworkConfiguration,
|
* It then “chains” outputs to inputs sequentially for each INeuralNetworkConfiguration,
|
||||||
* finally returning the output of the "top" configuration. Any settings made, are inherited and can
|
* finally returning the output of the "top" configuration. Any settings made, are inherited and can
|
||||||
* be overridden on a "deeper" level. For this use case, you need to wrap the NeuralNetworkConfiguration
|
* be overridden on a "deeper" level. For this use case, you need to wrap the INeuralNetworkConfiguration
|
||||||
* into a BuildingBlockLayer
|
* into a BuildingBlockLayer
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@ -55,77 +84,54 @@ import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer;
|
||||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||||
@lombok.Builder
|
@lombok.Builder
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class NeuralNetworkConfiguration implements net.brutex.ai.dnn.api.NeuralNetworkConfiguration, Serializable, Cloneable {
|
public class NeuralNetworkConfiguration extends NeuralNetConfiguration implements
|
||||||
|
INeuralNetworkConfiguration, Serializable, Cloneable {
|
||||||
/**
|
|
||||||
* The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified otherwise.
|
|
||||||
* Valid values are<br/>
|
|
||||||
* CacheMode.NONE,<br/>
|
|
||||||
* CacheMode.HOST or<br/>
|
|
||||||
* CacheMode.DEVICE<br/>
|
|
||||||
*/
|
|
||||||
@NonNull
|
|
||||||
@lombok.Builder.Default private CacheMode cacheMode = CacheMode.NONE;
|
|
||||||
|
|
||||||
@Getter @Setter @NonNull
|
|
||||||
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
|
|
||||||
|
|
||||||
@Getter @Setter @NonNull
|
|
||||||
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
|
|
||||||
|
|
||||||
@Getter @Setter @NonNull
|
|
||||||
protected BackpropType backpropType = BackpropType.Standard;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
|
|
||||||
|
|
||||||
|
|
||||||
@Getter @Setter protected int tbpttFwdLength = 20;
|
|
||||||
@Getter @Setter protected int tbpttBackLength = 20;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The list of layer configurations in this configuration. They will be indexed automatically
|
|
||||||
* as the layers get added starting with index 0.
|
|
||||||
*/
|
|
||||||
@Singular @Getter
|
|
||||||
private List<LayerConfiguration> layerConfigurations;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The name for this configuration. Defaults to "Anonymous NeuralNetworkConfiguration" if
|
|
||||||
* it is not specified.
|
|
||||||
*/
|
|
||||||
@lombok.Builder.Default @Getter
|
|
||||||
private String name = "Anonymous NeuralNetworkConfiguration";
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The {@link InputType} of the data for this network configuration
|
|
||||||
*/
|
|
||||||
private InputType inputType;
|
|
||||||
|
|
||||||
|
private static final int DEFAULT_TBPTT_LENGTH = 20;
|
||||||
|
@Getter protected final List<NeuralNetworkConfiguration> confs = new ArrayList<>();
|
||||||
/**
|
/**
|
||||||
* hidden list of layers, that "flattens" all the layers of this network and applies
|
* hidden list of layers, that "flattens" all the layers of this network and applies
|
||||||
* inheritance.
|
* inheritance.
|
||||||
*/
|
*/
|
||||||
@lombok.Builder.ObtainVia(method = "calculateInnerLayers")
|
@lombok.Builder.ObtainVia(method = "calculateInnerLayers")
|
||||||
private final List<LayerConfiguration> innerLayerConfigurations;
|
private final List<ILayerConfiguration> innerLayerConfigurations;
|
||||||
|
@Getter @Setter @NonNull @Singular
|
||||||
@Override
|
protected List<Layer> layers = new ArrayList<>();
|
||||||
public void calculateInnerLayerConfigurations() {
|
@Getter @Setter @NonNull @lombok.Builder.Default @Deprecated
|
||||||
List<LayerConfiguration> list = new ArrayList<>();
|
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
|
||||||
for( LayerConfiguration layer : this.layerConfigurations) {
|
@Getter @Setter @NonNull @lombok.Builder.Default @Deprecated
|
||||||
if(layer instanceof BuildingBlockLayer) {
|
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
|
||||||
BuildingBlockLayer blayer = (BuildingBlockLayer) layer;
|
/**
|
||||||
blayer.getConf().calculateInnerLayerConfigurations();
|
* The type of backprop. Default setting is used for most networks (MLP, CNN etc), but
|
||||||
list.addAll(blayer.getConf().getLayerConfigurations());
|
* optionally truncated BPTT can be used for training recurrent neural networks. If using
|
||||||
} else {
|
* TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength()
|
||||||
list.add(layer);
|
*/
|
||||||
}
|
@Getter @Setter @NonNull @lombok.Builder.Default
|
||||||
}
|
protected BackpropType backpropType = BackpropType.Standard;
|
||||||
this.layerConfigurations = list;
|
@Getter
|
||||||
}
|
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
|
||||||
|
/**
|
||||||
|
* When doing truncated BPTT: how many steps of forward pass should we do before doing
|
||||||
|
* (truncated) backprop?<br> Only applicable when doing
|
||||||
|
* backpropType(BackpropType.TruncatedBPTT)<br> Typically tBPTTForwardLength parameter is same
|
||||||
|
* as the tBPTTBackwardLength parameter, but may be larger than it in some circumstances (but
|
||||||
|
* never smaller)<br> Ideally your training data time series length should be divisible by this
|
||||||
|
* This is the k1 parameter on pg23 of
|
||||||
|
* <a
|
||||||
|
* href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
|
||||||
|
*
|
||||||
|
* @param forwardLength Forward length > 0, >= backwardLength
|
||||||
|
*/
|
||||||
|
@Getter @Setter protected int tbpttFwdLength = 20;
|
||||||
|
/**
|
||||||
|
* When doing truncated BPTT: how many steps of backward should we do?<br> Only applicable when
|
||||||
|
* doing backpropType(BackpropType.TruncatedBPTT)<br> This is the k2 parameter on pg23 of
|
||||||
|
* <a
|
||||||
|
* href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
|
||||||
|
*
|
||||||
|
* @param backwardLength <= forwardLength
|
||||||
|
*/
|
||||||
|
@Getter @Setter protected int tbpttBackLength = 20;
|
||||||
/**
|
/**
|
||||||
* Creates and returns a copy of this object.
|
* Creates and returns a copy of this object.
|
||||||
*
|
*
|
||||||
|
@ -136,8 +142,564 @@ public class NeuralNetworkConfiguration implements net.brutex.ai.dnn.api.NeuralN
|
||||||
* cannot be cloned.
|
* cannot be cloned.
|
||||||
* @see Cloneable
|
* @see Cloneable
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
//Nd4j.getRandom().setSeed(getConf(0).getSeed()); //TODO
|
||||||
|
//Counter for the number of parameter updates so far
|
||||||
|
// This is important for learning rate schedules, for example, and is stored here to ensure it is persisted
|
||||||
|
// for Spark and model serialization
|
||||||
|
@Getter @Setter
|
||||||
|
protected int iterationCount = 0;
|
||||||
|
//Counter for the number of epochs completed so far. Used for per-epoch schedules
|
||||||
|
@Getter @Setter
|
||||||
|
protected int epochCount = 0;
|
||||||
|
protected double dampingFactor = 100;
|
||||||
|
@Getter @Setter //todo why?
|
||||||
|
private Layer layer;
|
||||||
|
/**
|
||||||
|
* A seed for this network, will be random if not specified.
|
||||||
|
*/
|
||||||
|
@Getter @Setter @NonNull @lombok.Builder.Default
|
||||||
|
private long seed = new Random().nextLong();
|
||||||
|
/**
|
||||||
|
* The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified otherwise.
|
||||||
|
* This method defines how/if preOutput cache is handled: NONE: cache disabled (default value)
|
||||||
|
* HOST: Host memory will be used DEVICE: GPU memory will be used (on CPU backends effect will
|
||||||
|
* be the same as for HOST)
|
||||||
|
*
|
||||||
|
* Valid values are<br/>
|
||||||
|
* CacheMode.NONE,<br/>
|
||||||
|
* CacheMode.HOST or<br/>
|
||||||
|
* CacheMode.DEVICE<br/>
|
||||||
|
* @param cacheMode
|
||||||
|
*/
|
||||||
|
@NonNull @Getter @Setter
|
||||||
|
@lombok.Builder.Default private CacheMode cacheMode = CacheMode.NONE;
|
||||||
|
/**
|
||||||
|
* The list of layer configurations in this configuration. They will be indexed automatically
|
||||||
|
* as the layers get added starting with index 0.
|
||||||
|
*/
|
||||||
|
@Singular @Getter
|
||||||
|
private List<ILayerConfiguration> layerConfigurations;
|
||||||
|
/**
|
||||||
|
* The name for this configuration. Defaults to "Anonymous INeuralNetworkConfiguration" if
|
||||||
|
* it is not specified.
|
||||||
|
*/
|
||||||
|
@lombok.Builder.Default @Getter
|
||||||
|
private String name = "Anonymous INeuralNetworkConfiguration";
|
||||||
|
/**
|
||||||
|
* The {@link InputType} of the data for this network configuration
|
||||||
|
*/
|
||||||
|
private InputType inputType;
|
||||||
|
/**
|
||||||
|
* Set the DataType for the network parameters and activations for all layers in the network.
|
||||||
|
* Default: Float
|
||||||
|
*
|
||||||
|
* @param dataType Datatype to use for parameters and activations
|
||||||
|
*/
|
||||||
|
@Getter @Setter @lombok.Builder.Default @NonNull
|
||||||
|
private DataType dataType = DataType.FLOAT;
|
||||||
|
/**
|
||||||
|
* Whether to override the nIn configuration forcibly upon construction. Default value is true.
|
||||||
|
* @return builder pattern
|
||||||
|
*/
|
||||||
|
@Getter @Setter
|
||||||
|
@lombok.Builder.Default
|
||||||
|
private boolean overrideNinUponBuild = true;
|
||||||
|
/**
|
||||||
|
* Enabled by default. If enabled, the output layer configuration will be validated, to throw an
|
||||||
|
* exception on likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.<br> If
|
||||||
|
* disabled (false) no output layer validation will be performed.<br> Disabling this validation
|
||||||
|
* is not recommended, as the configurations that fail validation usually will not be able to
|
||||||
|
* learn correctly. However, the option to disable this validation is provided for advanced
|
||||||
|
* users when creating non-standard architectures.
|
||||||
|
*
|
||||||
|
* @param validate If true: validate output layer configuration. False: don't validate
|
||||||
|
*/
|
||||||
|
@Getter @Setter @lombok.Builder.Default
|
||||||
|
private boolean validateOutputLayerConfig=true;
|
||||||
|
/**
|
||||||
|
* Enabled by default. If enabled, an exception will be throw when using the (invalid)
|
||||||
|
* combination of truncated backpropagation through time (TBPTT) with either a
|
||||||
|
* GlobalPoolingLayer or LastTimeStepLayer.<br> It is possible to disable this validation to
|
||||||
|
* allow what is almost certainly an invalid configuration to be used, however this is not
|
||||||
|
* recommended.
|
||||||
|
*
|
||||||
|
* @param validate Whether TBPTT validation should be performed
|
||||||
|
*/
|
||||||
|
@Getter @Setter @lombok.Builder.Default
|
||||||
|
private boolean validateTbpttConfig=true;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam}
|
||||||
|
* or {@link org.nd4j.linalg.learning.config.Nesterovs}<br>
|
||||||
|
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
||||||
|
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
||||||
|
* value, and can be overridden on a per-layer basis.
|
||||||
|
*
|
||||||
|
* @param updater Updater to use
|
||||||
|
*/
|
||||||
|
@Getter @Setter @NonNull
|
||||||
|
private IUpdater updater;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc.
|
||||||
|
* See {@link GradientNormalization} for details<br>
|
||||||
|
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
||||||
|
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
||||||
|
* value, and can be overridden on a per-layer basis.
|
||||||
|
*
|
||||||
|
* @param gradientNormalization Type of normalization to use. Defaults to None.
|
||||||
|
* @see GradientNormalization
|
||||||
|
*/
|
||||||
|
@Getter @Setter @NonNull @lombok.Builder.Default
|
||||||
|
private GradientNormalization gradientNormalization = GradientNormalization.None;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer,
|
||||||
|
* GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue<br>
|
||||||
|
* Not used otherwise.<br>
|
||||||
|
* L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping.<br>
|
||||||
|
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
||||||
|
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
||||||
|
* value, and can be overridden on a per-layer basis.
|
||||||
|
*/
|
||||||
|
@Getter @Setter
|
||||||
|
private double gradientNormalizationThreshold;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Weight initialization scheme to use, for initial weight values
|
||||||
|
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
||||||
|
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
||||||
|
* value, and can be overridden on a per-layer basis.
|
||||||
|
*/
|
||||||
|
@Getter @Setter
|
||||||
|
private IWeightInit weightInit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Activation function / neuron non-linearity<br>
|
||||||
|
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
||||||
|
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
||||||
|
* value, and can be overridden on a per-layer basis.
|
||||||
|
*/
|
||||||
|
@Getter @Setter
|
||||||
|
private IActivation activation;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a neural net configuration from json
|
||||||
|
*
|
||||||
|
* @param json the neural net configuration from json
|
||||||
|
* @return {@link NeuralNetworkConfiguration}
|
||||||
|
*/
|
||||||
|
public static NeuralNetworkConfiguration fromJson(String json) {
|
||||||
|
NeuralNetworkConfiguration conf;
|
||||||
|
ObjectMapper mapper = NeuralNetworkConfiguration.mapper();
|
||||||
|
try {
|
||||||
|
conf = mapper.readValue(json, NeuralNetworkConfiguration.class);
|
||||||
|
} catch (InvalidTypeIdException e) {
|
||||||
|
if (e.getMessage().contains("@class")) {
|
||||||
|
try {
|
||||||
|
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
||||||
|
return JsonMappers.getLegacyMapper().readValue(json, NeuralNetworkConfiguration.class);
|
||||||
|
} catch (InvalidTypeIdException e2) {
|
||||||
|
//Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.ILayer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
|
||||||
|
//1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
|
||||||
|
String msg = e2.getMessage();
|
||||||
|
if (msg != null && msg.contains("Could not resolve type id")) {
|
||||||
|
throw new RuntimeException(
|
||||||
|
"Error deserializing MultiLayerConfiguration - configuration may have a custom " +
|
||||||
|
"layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom"
|
||||||
|
+
|
||||||
|
" layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J",
|
||||||
|
e);
|
||||||
|
}
|
||||||
|
throw new RuntimeException(e2);
|
||||||
|
} catch (IOException e2) {
|
||||||
|
throw new RuntimeException(e2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
} catch (IOException e) {
|
||||||
|
//Check if this exception came from legacy deserializer...
|
||||||
|
String msg = e.getMessage();
|
||||||
|
if (msg != null && msg.contains("legacy")) {
|
||||||
|
throw new RuntimeException(
|
||||||
|
"Error deserializing MultiLayerConfiguration - configuration may have a custom " +
|
||||||
|
"layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be "
|
||||||
|
+
|
||||||
|
"deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)",
|
||||||
|
e);
|
||||||
|
}
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
//To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
|
||||||
|
// Previously: enumeration used for loss functions. Now: use classes
|
||||||
|
// IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
|
||||||
|
int layerCount = 0;
|
||||||
|
JsonNode confs = null;
|
||||||
|
for (NeuralNetworkConfiguration nnc : conf.getConfs()) {
|
||||||
|
Layer l = nnc.getLayer();
|
||||||
|
if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) {
|
||||||
|
//lossFn field null -> may be an old config format, with lossFunction field being for the enum
|
||||||
|
//if so, try walking the JSON graph to extract out the appropriate enum value
|
||||||
|
|
||||||
|
BaseOutputLayer ol = (BaseOutputLayer) l;
|
||||||
|
try {
|
||||||
|
JsonNode jsonNode = mapper.readTree(json);
|
||||||
|
if (confs == null) {
|
||||||
|
confs = jsonNode.get("confs");
|
||||||
|
}
|
||||||
|
if (confs instanceof ArrayNode) {
|
||||||
|
ArrayNode layerConfs = (ArrayNode) confs;
|
||||||
|
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
||||||
|
if (outputLayerNNCNode == null) {
|
||||||
|
throw new RuntimeException("should never happen"); //return conf; //Should never happen...
|
||||||
|
}
|
||||||
|
JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
|
||||||
|
|
||||||
|
JsonNode lossFunctionNode = null;
|
||||||
|
if (outputLayerNode.has("output")) {
|
||||||
|
lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
|
||||||
|
} else if (outputLayerNode.has("rnnoutput")) {
|
||||||
|
lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lossFunctionNode != null) {
|
||||||
|
String lossFunctionEnumStr = lossFunctionNode.asText();
|
||||||
|
LossFunctions.LossFunction lossFunction = null;
|
||||||
|
try {
|
||||||
|
lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn(
|
||||||
|
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
|
||||||
|
e);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lossFunction != null) {
|
||||||
|
switch (lossFunction) {
|
||||||
|
case MSE:
|
||||||
|
ol.setLossFn(new LossMSE());
|
||||||
|
break;
|
||||||
|
case XENT:
|
||||||
|
ol.setLossFn(new LossBinaryXENT());
|
||||||
|
break;
|
||||||
|
case NEGATIVELOGLIKELIHOOD:
|
||||||
|
ol.setLossFn(new LossNegativeLogLikelihood());
|
||||||
|
break;
|
||||||
|
case MCXENT:
|
||||||
|
ol.setLossFn(new LossMCXENT());
|
||||||
|
break;
|
||||||
|
|
||||||
|
//Remaining: TODO
|
||||||
|
case SQUARED_LOSS:
|
||||||
|
case RECONSTRUCTION_CROSSENTROPY:
|
||||||
|
default:
|
||||||
|
log.warn(
|
||||||
|
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}",
|
||||||
|
lossFunction);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
log.warn(
|
||||||
|
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})",
|
||||||
|
(confs != null ? confs.getClass() : null));
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.warn(
|
||||||
|
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
|
||||||
|
e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn")
|
||||||
|
//Try to load the old format if necessary, and create the appropriate IActivation instance
|
||||||
|
if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) {
|
||||||
|
try {
|
||||||
|
JsonNode jsonNode = mapper.readTree(json);
|
||||||
|
if (confs == null) {
|
||||||
|
confs = jsonNode.get("confs");
|
||||||
|
}
|
||||||
|
if (confs instanceof ArrayNode) {
|
||||||
|
ArrayNode layerConfs = (ArrayNode) confs;
|
||||||
|
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
||||||
|
if (outputLayerNNCNode == null) {
|
||||||
|
throw new RuntimeException("Should never happen"); //return conf; //Should never happen...
|
||||||
|
}
|
||||||
|
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
||||||
|
|
||||||
|
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
JsonNode layerNode = layerWrapperNode.elements().next();
|
||||||
|
JsonNode activationFunction = layerNode.get(
|
||||||
|
"activationFunction"); //Should only have 1 element: "dense", "output", etc
|
||||||
|
|
||||||
|
if (activationFunction != null) {
|
||||||
|
IActivation ia = Activation.fromString(activationFunction.asText())
|
||||||
|
.getActivationFunction();
|
||||||
|
((BaseLayer) l).setActivationFn(ia);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.warn(
|
||||||
|
"ILayer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
|
||||||
|
e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) {
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
layerCount++;
|
||||||
|
}
|
||||||
|
return conf;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied
|
||||||
|
* from handling of {@link Activation} above.
|
||||||
|
*
|
||||||
|
* @return True if all is well and layer iteration shall continue. False else-wise.
|
||||||
|
*/
|
||||||
|
private static boolean handleLegacyWeightInitFromJson(String json, Layer l, ObjectMapper mapper,
|
||||||
|
JsonNode confs, int layerCount) {
|
||||||
|
if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) {
|
||||||
|
try {
|
||||||
|
JsonNode jsonNode = mapper.readTree(json);
|
||||||
|
if (confs == null) {
|
||||||
|
confs = jsonNode.get("confs");
|
||||||
|
}
|
||||||
|
if (confs instanceof ArrayNode) {
|
||||||
|
ArrayNode layerConfs = (ArrayNode) confs;
|
||||||
|
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
||||||
|
if (outputLayerNNCNode == null) {
|
||||||
|
return false; //Should never happen...
|
||||||
|
}
|
||||||
|
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
||||||
|
|
||||||
|
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
JsonNode layerNode = layerWrapperNode.elements().next();
|
||||||
|
JsonNode weightInit = layerNode.get(
|
||||||
|
"weightInit"); //Should only have 1 element: "dense", "output", etc
|
||||||
|
JsonNode distribution = layerNode.get("dist");
|
||||||
|
|
||||||
|
Distribution dist = null;
|
||||||
|
if (distribution != null) {
|
||||||
|
dist = mapper.treeToValue(distribution, Distribution.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weightInit != null) {
|
||||||
|
final IWeightInit wi = WeightInit.valueOf(weightInit.asText())
|
||||||
|
.getWeightInitFunction(dist);
|
||||||
|
((BaseLayer) l).setWeightInitFn(wi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.warn(
|
||||||
|
"ILayer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON",
|
||||||
|
e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Object mapper for serialization of configurations
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static ObjectMapper mapperYaml() {
|
||||||
|
return JsonMappers.getMapperYaml();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Object mapper for serialization of configurations
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static ObjectMapper mapper() {
|
||||||
|
return JsonMappers.getMapper();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return JSON representation of NN configuration
|
||||||
|
*/
|
||||||
|
public String toYaml() {
|
||||||
|
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
|
||||||
|
synchronized (mapper) {
|
||||||
|
try {
|
||||||
|
return mapper.writeValueAsString(this);
|
||||||
|
} catch (com.fasterxml.jackson.core.JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return JSON representation of NN configuration
|
||||||
|
*/
|
||||||
|
public String toJson() {
|
||||||
|
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
||||||
|
synchronized (mapper) {
|
||||||
|
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
|
||||||
|
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
|
||||||
|
try {
|
||||||
|
return mapper.writeValueAsString(this);
|
||||||
|
} catch (com.fasterxml.jackson.core.JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Object clone() throws CloneNotSupportedException {
|
public String toString() {
|
||||||
return super.clone();
|
return toJson();
|
||||||
|
}
|
||||||
|
|
||||||
|
public NeuralNetworkConfiguration getConf(int i) {
|
||||||
|
return confs.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNetworkConfiguration clone() {
|
||||||
|
|
||||||
|
NeuralNetworkConfiguration clone = (NeuralNetworkConfiguration) super.clone();
|
||||||
|
List<NeuralNetworkConfiguration> confList = clone.getConfs();
|
||||||
|
if (confList != null) {
|
||||||
|
List<NeuralNetworkConfiguration> list = new ArrayList<>();
|
||||||
|
for (NeuralNetworkConfiguration conf : confList) {
|
||||||
|
list.add(conf.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (clone.getInputPreProcessors() != null) {
|
||||||
|
Map<Integer, InputPreProcessor> map = new HashMap<>();
|
||||||
|
for (Map.Entry<Integer, InputPreProcessor> entry : clone.getInputPreProcessors().entrySet()) {
|
||||||
|
map.put(entry.getKey(), entry.getValue().clone());
|
||||||
|
}
|
||||||
|
clone.getInputPreProcessors().clear();
|
||||||
|
clone.getInputPreProcessors().putAll(map);
|
||||||
|
}
|
||||||
|
|
||||||
|
clone.setInferenceWorkspaceMode(this.inferenceWorkspaceMode);
|
||||||
|
clone.setTrainingWorkspaceMode(this.trainingWorkspaceMode);
|
||||||
|
clone.setCacheMode(this.cacheMode);
|
||||||
|
clone.setValidateOutputLayerConfig(this.validateOutputLayerConfig);
|
||||||
|
clone.setDataType(this.dataType);
|
||||||
|
|
||||||
|
return clone;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public InputPreProcessor getInputPreProcess(int curr) {
|
||||||
|
return inputPreProcessors.get(curr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a {@link MemoryReport} for the given MultiLayerConfiguration. This is used to estimate the
|
||||||
|
* memory requirements for the given network configuration and input
|
||||||
|
*
|
||||||
|
* @param inputType Input types for the network
|
||||||
|
* @return Memory report for the network
|
||||||
|
*/
|
||||||
|
public NetworkMemoryReport getMemoryReport(InputType inputType) {
|
||||||
|
|
||||||
|
Map<String, MemoryReport> memoryReportMap = new LinkedHashMap<>();
|
||||||
|
int nLayers = confs.size();
|
||||||
|
for (int i = 0; i < nLayers; i++) {
|
||||||
|
String layerName = confs.get(i).getLayer().getLayerName();
|
||||||
|
if (layerName == null) {
|
||||||
|
layerName = String.valueOf(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Pass input type through preprocessor, if necessary
|
||||||
|
InputPreProcessor preproc = getInputPreProcess(i);
|
||||||
|
//TODO memory requirements for preprocessor
|
||||||
|
if (preproc != null) {
|
||||||
|
inputType = preproc.getOutputType(inputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
LayerMemoryReport report = confs.get(i).getLayer().getMemoryReport(inputType);
|
||||||
|
memoryReportMap.put(layerName, report);
|
||||||
|
|
||||||
|
inputType = confs.get(i).getLayer().getOutputType(i, inputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class,
|
||||||
|
"MultiLayerNetwork", inputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* For the given input shape/type for the network, return a list of activation sizes for each
|
||||||
|
* layer in the network.<br> i.e., list.get(i) is the output activation sizes for layer i
|
||||||
|
*
|
||||||
|
* @param inputType Input type for the network
|
||||||
|
* @return A lits of activation types for the network, indexed by layer number
|
||||||
|
*/
|
||||||
|
public List<InputType> getLayerActivationTypes(@NonNull InputType inputType) {
|
||||||
|
List<InputType> out = new ArrayList<>();
|
||||||
|
int nLayers = confs.size();
|
||||||
|
for (int i = 0; i < nLayers; i++) {
|
||||||
|
InputPreProcessor preproc = getInputPreProcess(i);
|
||||||
|
if (preproc != null) {
|
||||||
|
inputType = preproc.getOutputType(inputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
inputType = confs.get(i).getLayer().getOutputType(i, inputType);
|
||||||
|
out.add(inputType);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defines some additional handy methods. Other than that,
|
||||||
|
* the builder is generated by lombok.
|
||||||
|
*/
|
||||||
|
public static class NeuralNetworkConfigurationBuilder {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specify the processors. These are used at each layer for doing things like normalization and
|
||||||
|
* shaping of input.
|
||||||
|
*
|
||||||
|
* @param processor what to use to preProcess the data.
|
||||||
|
* @return builder pattern
|
||||||
|
*/
|
||||||
|
public NeuralNetworkConfigurationBuilder inputPreProcessor(Integer layer,
|
||||||
|
InputPreProcessor processor) {
|
||||||
|
inputPreProcessors.put(layer, processor);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specify additional layer configurations
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public NeuralNetworkConfigurationBuilder layersFromArray(Layer[] arrLayers) {
|
||||||
|
for(Layer l : arrLayers) {
|
||||||
|
layers.add( l );
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
@ -24,12 +24,12 @@ package net.brutex.ai.dnn.conf.layer;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import net.brutex.ai.dnn.api.LayerConfiguration;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import net.brutex.ai.dnn.api.ILayerConfiguration;
|
||||||
|
|
||||||
public abstract class AbstractLayerConfiguration implements LayerConfiguration {
|
@SuperBuilder
|
||||||
|
public abstract class AbstractLayerConfiguration implements ILayerConfiguration {
|
||||||
|
|
||||||
@Getter @Setter @NonNull
|
@Getter @Setter @NonNull
|
||||||
private InputType.Type inputType;
|
private String name;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.ai.dnn.conf.layer;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.LayerValidation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The dense layer is a neural network layer that is connected deeply, which means each neuron in
|
||||||
|
* the dense layer receives input from all neurons of its previous layer. The dense layer is found
|
||||||
|
* to be the most commonly used layer in the models.
|
||||||
|
* <p>
|
||||||
|
* In the background, the dense layer performs a matrix-vector multiplication. The values used in
|
||||||
|
* the matrix are actually parameters that can be trained and updated with the help of
|
||||||
|
* backpropagation.
|
||||||
|
* <p>
|
||||||
|
* The output generated by the dense layer is an ‘m’ dimensional vector. Thus, dense layer is
|
||||||
|
* basically used for changing the dimensions of the vector. Dense layers also applies operations
|
||||||
|
* like rotation, scaling, translation on the vector.
|
||||||
|
*/
|
||||||
|
@SuperBuilder
|
||||||
|
public class DenseLayerConfiguration extends FeedForwardLayerConfiguration {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Decides whether we should include a bias vector for calculation purposes or not.
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
boolean bias = true;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An implementation to validate the network
|
||||||
|
*
|
||||||
|
* @return true if no errors found; false otherwise
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public boolean isValid() {
|
||||||
|
LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getName(), -99, getIn(), getOut());
|
||||||
|
return super.isValid();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.ai.dnn.conf.layer;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.brutex.ai.dnn.api.ILayer;
|
||||||
|
import net.brutex.ai.dnn.api.ILayerConfiguration;
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType.Type;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Feed Forward Layer Configuration
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@SuperBuilder
|
||||||
|
public class FeedForwardLayerConfiguration extends AbstractLayerConfiguration implements ILayerConfiguration {
|
||||||
|
|
||||||
|
@Getter private int in;
|
||||||
|
@Getter private int out;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This Fast Forward ILayer will always output data as
|
||||||
|
* FF type.
|
||||||
|
* @return InputType for FF
|
||||||
|
**/
|
||||||
|
@Getter
|
||||||
|
final InputType.Type outputType = InputType.Type.FF;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
final InputType.Type inputType = InputType.Type.FF;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create and return an instance of a ILayerConfiguration.
|
||||||
|
*
|
||||||
|
* @param network the "holding" network for the instance
|
||||||
|
* @return the new layer instance
|
||||||
|
*/
|
||||||
|
//@Override
|
||||||
|
public ILayer instantiate(IModel network) {
|
||||||
|
//Let's do some verifications first
|
||||||
|
if(getInputType() != Type.FF) {
|
||||||
|
log.error("The {} layer configuration must use an InputType of {}, but found {}",
|
||||||
|
this.getClass().getSimpleName(),
|
||||||
|
Type.FF.name(),
|
||||||
|
getInputType().name());
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of trainable parameter in this layer
|
||||||
|
*
|
||||||
|
* @return number of parameter
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public long numParameters() {
|
||||||
|
return in * out + out; //add one extra out for the bias
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An implementation should provide a method to validate the network
|
||||||
|
*
|
||||||
|
* @return true if no errors found; false otherwise
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public boolean isValid() {
|
||||||
|
boolean result = true;
|
||||||
|
if(getInputType() != Type.FF) {
|
||||||
|
log.error("The {} layer configuration must use an InputType of {}, but found {}",
|
||||||
|
this.getClass().getSimpleName(),
|
||||||
|
Type.FF.name(),
|
||||||
|
getInputType().name());
|
||||||
|
result = false;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,72 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package net.brutex.ai.dnn.impl.network;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.NonNull;
|
|
||||||
import lombok.Setter;
|
|
||||||
import net.brutex.ai.dnn.api.Layer;
|
|
||||||
import net.brutex.ai.dnn.api.NeuralNetwork;
|
|
||||||
import net.brutex.ai.dnn.api.LayerConfiguration;
|
|
||||||
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
|
||||||
|
|
||||||
public abstract class AbstractNeuralNetwork implements NeuralNetwork {
|
|
||||||
|
|
||||||
@Getter @Setter @NonNull
|
|
||||||
private String name;
|
|
||||||
|
|
||||||
@Getter @NonNull
|
|
||||||
private NeuralNetworkConfiguration configuration;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
private final Collection<TrainingListener> trainingListeners = new HashSet<>();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The neural network holds an instantiation of its configured
|
|
||||||
* layers.
|
|
||||||
* @return the actual runtime layers
|
|
||||||
*/
|
|
||||||
@Getter
|
|
||||||
private final List<Layer> runtimeLayers = new ArrayList<>();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets the configuration to be used. Each time a configuration is set, the runtime layers
|
|
||||||
* of this NeuralNetwork are updated from the configuration.
|
|
||||||
*
|
|
||||||
* @param conf the configuration to use for this network
|
|
||||||
*/
|
|
||||||
public void setConfiguration(net.brutex.ai.dnn.api.NeuralNetworkConfiguration conf) {
|
|
||||||
List<LayerConfiguration> layers = conf.getLayerConfigurations();
|
|
||||||
for(LayerConfiguration layer : layers) {
|
|
||||||
Layer initializedLayer = layer.instantiate(this);
|
|
||||||
this.getRuntimeLayers().add(initializedLayer);
|
|
||||||
}
|
|
||||||
this.configuration = configuration;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,692 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package net.brutex.ai.dnn.impl.network;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.Map;
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.NonNull;
|
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import lombok.val;
|
|
||||||
import org.bytedeco.javacpp.Pointer;
|
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
|
||||||
import org.deeplearning4j.nn.api.Classifier;
|
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
|
||||||
import org.deeplearning4j.nn.api.Updater;
|
|
||||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
|
||||||
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.BackpropType;
|
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
||||||
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
|
||||||
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
|
|
||||||
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
|
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
||||||
import org.deeplearning4j.nn.updater.UpdaterCreator;
|
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
||||||
import org.deeplearning4j.optimize.Solver;
|
|
||||||
import org.deeplearning4j.optimize.api.ConvexOptimizer;
|
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
|
||||||
import org.deeplearning4j.util.CrashReportingUtil;
|
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
|
||||||
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
|
||||||
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
|
||||||
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
|
||||||
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
|
||||||
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.heartbeat.Heartbeat;
|
|
||||||
import org.nd4j.linalg.heartbeat.reports.Environment;
|
|
||||||
import org.nd4j.linalg.heartbeat.reports.Event;
|
|
||||||
import org.nd4j.linalg.heartbeat.reports.Task;
|
|
||||||
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
|
|
||||||
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class NeuralNetwork extends AbstractNeuralNetwork {
|
|
||||||
|
|
||||||
|
|
||||||
//the hidden neural network layers (including output layer)
|
|
||||||
protected Layer[] layers;
|
|
||||||
|
|
||||||
protected transient ThreadLocal<Long> lastEtlTime = new ThreadLocal<>();
|
|
||||||
|
|
||||||
//Current training data: input features and labels
|
|
||||||
@Getter @Setter @NonNull
|
|
||||||
protected INDArray input;
|
|
||||||
@Getter @Setter
|
|
||||||
protected INDArray labels;
|
|
||||||
|
|
||||||
//Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers
|
|
||||||
@Getter
|
|
||||||
protected transient Map<String, Pointer> helperWorkspaces = new HashMap<>();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Used to call optimizers during backprop
|
|
||||||
*/
|
|
||||||
@NonNull
|
|
||||||
protected transient Solver solver = new Solver.Builder().configure(getConfiguration()).
|
|
||||||
listeners(getTrainingListeners()).model(this).build();
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a new NeuralNetwork from the given configuration
|
|
||||||
* @param conf
|
|
||||||
*/
|
|
||||||
public NeuralNetwork(NeuralNetworkConfiguration conf) {
|
|
||||||
if(! validateConfiguration() ) {
|
|
||||||
log.error("Configuration '{}' has failed validation.", conf.getName());
|
|
||||||
throw new RuntimeException();
|
|
||||||
}
|
|
||||||
log.info("Configuration '{}' has been validated successfully.", conf.getName());
|
|
||||||
this.conf = conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean validateConfiguration() {
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void logNotImplemented( ) {
|
|
||||||
// getStackTrace() method return
|
|
||||||
// current method name at 0th index
|
|
||||||
String method = new Throwable()
|
|
||||||
.getStackTrace()[1]
|
|
||||||
.getMethodName();
|
|
||||||
log.trace("Method '{}}' is not implemented for {}", method, this.getClass().getSimpleName());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method does initialization of model
|
|
||||||
* <p>
|
|
||||||
* PLEASE NOTE: All implementations should track own state, to avoid double spending
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void init() {
|
|
||||||
logNotImplemented();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method returns model parameters as single INDArray
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public INDArray params() {
|
|
||||||
logNotImplemented();
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method returns updater state (if applicable), null otherwise
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public INDArray updaterState() {
|
|
||||||
return getUpdater(true) != null ? getUpdater(true).getStateViewArray() : null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method returns Optimizer used for training
|
|
||||||
*
|
|
||||||
* @return the optimizer
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public ConvexOptimizer getOptimizer() {
|
|
||||||
return solver.getOptimizer();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/** Get the updater for this NeuralNetwork from the Solver
|
|
||||||
* @return Updater for NeuralNetwork
|
|
||||||
*/
|
|
||||||
private Updater getUpdater(boolean initializeIfReq) {
|
|
||||||
if (solver == null && initializeIfReq) {
|
|
||||||
synchronized(this){
|
|
||||||
if(solver == null) { //May have been created while waiting for lock
|
|
||||||
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this).build();
|
|
||||||
solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if(solver != null) {
|
|
||||||
return solver.getOptimizer().getUpdater(initializeIfReq);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the updater for the NeuralNetwork in the Solver
|
|
||||||
* */
|
|
||||||
public void setUpdater(@NonNull Updater updater) {
|
|
||||||
solver.getOptimizer().setUpdater(updater);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void fit(MultiDataSet dataSet) {
|
|
||||||
if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) {
|
|
||||||
INDArray features = dataSet.getFeatures(0);
|
|
||||||
INDArray labels = dataSet.getLabels(0);
|
|
||||||
INDArray fMask = null;
|
|
||||||
INDArray lMask = null;
|
|
||||||
|
|
||||||
if (dataSet.getFeaturesMaskArrays() != null)
|
|
||||||
fMask = dataSet.getFeaturesMaskArrays()[0];
|
|
||||||
|
|
||||||
if (dataSet.getFeaturesMaskArrays() != null)
|
|
||||||
lMask = dataSet.getLabelsMaskArrays()[0];
|
|
||||||
|
|
||||||
DataSet ds = new DataSet(features, labels, fMask, lMask);
|
|
||||||
fit(ds);
|
|
||||||
} else {
|
|
||||||
throw new DL4JInvalidInputException(
|
|
||||||
"MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." +
|
|
||||||
"Please consider use of ComputationGraph");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform minibatch training on all minibatches in the MultiDataSetIterator, for the specified number of epochs.
|
|
||||||
* Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a loop
|
|
||||||
*
|
|
||||||
* @param iterator Training data (DataSetIterator). Iterator must support resetting
|
|
||||||
* @param numEpochs Number of training epochs, >= 1
|
|
||||||
*/
|
|
||||||
public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs){
|
|
||||||
Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs);
|
|
||||||
Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" +
|
|
||||||
"iterator has does not support resetting (iterator.resetSupported() returned false)");
|
|
||||||
|
|
||||||
for(int i = 0; i < numEpochs; i++) {
|
|
||||||
fit(iterator);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform minibatch training on all minibatches in the MultiDataSetIterator.<br>
|
|
||||||
* Note: The MultiDataSets in the MultiDataSetIterator must have exactly 1 input and output array (as
|
|
||||||
* MultiLayerNetwork only supports 1 input and 1 output)
|
|
||||||
*
|
|
||||||
* @param iterator Training data (DataSetIterator). Iterator must support resetting
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void fit(MultiDataSetIterator iterator) {
|
|
||||||
fit(new MultiDataSetWrapperIterator(iterator));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform minibatch training on all minibatches in the DataSetIterator for 1 epoch.<br>
|
|
||||||
* Note that this method does not do layerwise pretraining.<br>
|
|
||||||
* For pretraining use method pretrain.. #pretrain(DataSetIterator)<br>
|
|
||||||
* @param iterator Training data (DataSetIterator)
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void fit(DataSetIterator iterator) {
|
|
||||||
try{
|
|
||||||
fitHelper(iterator);
|
|
||||||
} catch (OutOfMemoryError e){
|
|
||||||
CrashReportingUtil.writeMemoryCrashDump(this, e);
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private synchronized void fitHelper(DataSetIterator iterator){
|
|
||||||
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
|
|
||||||
DataSetIterator iter;
|
|
||||||
boolean destructable = false;
|
|
||||||
if (iterator.asyncSupported()) {
|
|
||||||
iter = new AsyncDataSetIterator(iterator, Math.min(
|
|
||||||
Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true);
|
|
||||||
destructable = true;
|
|
||||||
} else {
|
|
||||||
iter = iterator;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (TrainingListener tl : trainingListeners) {
|
|
||||||
tl.onEpochStart(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
LayerWorkspaceMgr workspaceMgr;
|
|
||||||
if(conf.getTrainingWorkspaceMode() == WorkspaceMode.NONE){
|
|
||||||
workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
|
|
||||||
} else {
|
|
||||||
workspaceMgr = LayerWorkspaceMgr.builder()
|
|
||||||
.with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
|
|
||||||
.with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
|
|
||||||
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
|
|
||||||
.with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
|
|
||||||
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
|
|
||||||
.with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
|
|
||||||
//Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM
|
|
||||||
// as these should be closed by the time updaters are executed
|
|
||||||
//Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this
|
|
||||||
.with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
|
|
||||||
|
|
||||||
update(TaskUtils.buildTask(iter));
|
|
||||||
if (!iter.hasNext() && iter.resetSupported()) {
|
|
||||||
iter.reset();
|
|
||||||
}
|
|
||||||
long time1 = System.currentTimeMillis();
|
|
||||||
while (iter.hasNext()) {
|
|
||||||
|
|
||||||
DataSet next = iter.next();
|
|
||||||
long time2 = System.currentTimeMillis();
|
|
||||||
|
|
||||||
lastEtlTime.set((time2 - time1));
|
|
||||||
|
|
||||||
if (next.getFeatures() == null || next.getLabels() == null)
|
|
||||||
break;
|
|
||||||
|
|
||||||
// TODO: basically we want to wrap internals of this loop into workspace
|
|
||||||
|
|
||||||
|
|
||||||
boolean hasMaskArrays = next.hasMaskArrays();
|
|
||||||
|
|
||||||
if (conf.getBackpropType() == BackpropType.TruncatedBPTT) {
|
|
||||||
doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArray(),
|
|
||||||
next.getLabelsMaskArray(), workspaceMgr);
|
|
||||||
} else {
|
|
||||||
if (hasMaskArrays)
|
|
||||||
setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
|
|
||||||
|
|
||||||
setInput(next.getFeatures());
|
|
||||||
setLabels(next.getLabels());
|
|
||||||
|
|
||||||
if (solver == null) {
|
|
||||||
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
|
||||||
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//TODO CACHE
|
|
||||||
solver.optimize(workspaceMgr);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hasMaskArrays)
|
|
||||||
clearLayerMaskArrays();
|
|
||||||
|
|
||||||
time1 = System.currentTimeMillis();
|
|
||||||
synchronizeIterEpochCounts();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!trainingListeners.isEmpty()) {
|
|
||||||
for (TrainingListener tl : trainingListeners) {
|
|
||||||
tl.onEpochEnd(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
clearLayersStates();
|
|
||||||
|
|
||||||
if (destructable)
|
|
||||||
((AsyncDataSetIterator) iter).shutdown();
|
|
||||||
|
|
||||||
incrementEpochCount();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Workspace for working memory for a single layer: forward pass and backward pass
|
|
||||||
* Note that this is opened/closed once per op (activate/backpropGradient call)
|
|
||||||
*/
|
|
||||||
protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM";
|
|
||||||
/**
|
|
||||||
* Workspace for storing all layers' activations - used only to store activations (layer inputs) as part of backprop
|
|
||||||
* Not used for inference
|
|
||||||
*/
|
|
||||||
protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT";
|
|
||||||
/**
|
|
||||||
* Next 2 workspaces: used for:
|
|
||||||
* (a) Inference: holds activations for one layer only
|
|
||||||
* (b) Backprop: holds activation gradients for one layer only
|
|
||||||
* In both cases, they are opened and closed on every second layer
|
|
||||||
*/
|
|
||||||
protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1";
|
|
||||||
protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Workspace for output methods that use OutputAdapter
|
|
||||||
*/
|
|
||||||
protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Workspace for working memory in RNNs - opened and closed once per RNN time step
|
|
||||||
*/
|
|
||||||
protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM";
|
|
||||||
|
|
||||||
|
|
||||||
protected WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG;
|
|
||||||
|
|
||||||
protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder()
|
|
||||||
.initialSize(0)
|
|
||||||
.overallocationLimit(0.05)
|
|
||||||
.policyLearning(LearningPolicy.FIRST_LOOP)
|
|
||||||
.policyReset(ResetPolicy.BLOCK_LEFT)
|
|
||||||
.policySpill(SpillPolicy.REALLOCATE)
|
|
||||||
.policyAllocation(AllocationPolicy.OVERALLOCATE)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG;
|
|
||||||
|
|
||||||
protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder()
|
|
||||||
.initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT)
|
|
||||||
.policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE)
|
|
||||||
.policyLearning(LearningPolicy.FIRST_LOOP).build();
|
|
||||||
|
|
||||||
|
|
||||||
boolean initDone;
|
|
||||||
protected void update(Task task) {
|
|
||||||
if (!initDone) {
|
|
||||||
initDone = true;
|
|
||||||
Heartbeat heartbeat = Heartbeat.getInstance();
|
|
||||||
task = ModelSerializer.taskByModel(this);
|
|
||||||
Environment env = EnvironmentUtils.buildEnvironment();
|
|
||||||
heartbeat.reportEvent(Event.STANDALONE, env, task);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray,
|
|
||||||
INDArray labelsMaskArray, LayerWorkspaceMgr workspaceMgr) {
|
|
||||||
if (input.rank() != 3 || labels.rank() != 3) {
|
|
||||||
log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got "
|
|
||||||
+ Arrays.toString(input.shape()) + "\tand labels with shape "
|
|
||||||
+ Arrays.toString(labels.shape()));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (input.size(2) != labels.size(2)) {
|
|
||||||
log.warn("Input and label time series have different lengths: {} input length, {} label length",
|
|
||||||
input.size(2), labels.size(2));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int fwdLen = conf.getTbpttFwdLength();
|
|
||||||
update(TaskUtils.buildTask(input, labels));
|
|
||||||
val timeSeriesLength = input.size(2);
|
|
||||||
long nSubsets = timeSeriesLength / fwdLen;
|
|
||||||
if (timeSeriesLength % fwdLen != 0)
|
|
||||||
nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20)
|
|
||||||
|
|
||||||
rnnClearPreviousState();
|
|
||||||
|
|
||||||
for (int i = 0; i < nSubsets; i++) {
|
|
||||||
long startTimeIdx = (long) i * fwdLen;
|
|
||||||
long endTimeIdx = startTimeIdx + fwdLen;
|
|
||||||
if (endTimeIdx > timeSeriesLength)
|
|
||||||
endTimeIdx = timeSeriesLength;
|
|
||||||
|
|
||||||
if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE)
|
|
||||||
throw new ND4JArraySizeException();
|
|
||||||
INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels,
|
|
||||||
featuresMaskArray, labelsMaskArray);
|
|
||||||
|
|
||||||
setInput(subsets[0]);
|
|
||||||
setLabels(subsets[1]);
|
|
||||||
setLayerMaskArrays(subsets[2], subsets[3]);
|
|
||||||
|
|
||||||
if (solver == null) {
|
|
||||||
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
|
||||||
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
solver.optimize(workspaceMgr);
|
|
||||||
|
|
||||||
//Finally, update the state of the RNN layers:
|
|
||||||
updateRnnStateWithTBPTTState();
|
|
||||||
}
|
|
||||||
|
|
||||||
rnnClearPreviousState();
|
|
||||||
clearLayerMaskArrays();
|
|
||||||
}
|
|
||||||
|
|
||||||
private INDArray[] getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray input, INDArray labels,
|
|
||||||
INDArray fMask, INDArray lMask ){
|
|
||||||
INDArray[] out = new INDArray[4];
|
|
||||||
out[0] = input.get(NDArrayIndex.all(), NDArrayIndex.all(),
|
|
||||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
|
||||||
out[1] = labels.get(NDArrayIndex.all(), NDArrayIndex.all(),
|
|
||||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
|
||||||
|
|
||||||
if (fMask != null) {
|
|
||||||
out[2] = fMask.get(NDArrayIndex.all(),
|
|
||||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
|
||||||
}
|
|
||||||
if (lMask != null) {
|
|
||||||
out[3] = lMask.get(NDArrayIndex.all(),
|
|
||||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Intended for internal/developer use
|
|
||||||
*/
|
|
||||||
public void updateRnnStateWithTBPTTState() {
|
|
||||||
Layer[] layers = conf.calculateInnerLayers().toArray(new Layer[]{});
|
|
||||||
for (int i = 0; i < layers.length; i++) {
|
|
||||||
if (layers[i] instanceof RecurrentLayer) {
|
|
||||||
RecurrentLayer l = ((RecurrentLayer) layers[i]);
|
|
||||||
l.rnnSetPreviousState(l.rnnGetTBPTTState());
|
|
||||||
} else if (layers[i] instanceof MultiLayerNetwork) {
|
|
||||||
((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Clear the previous state of the RNN layers (if any).
|
|
||||||
*/
|
|
||||||
public void rnnClearPreviousState() {
|
|
||||||
Layer[] layers = conf.getLayers().toArray(new Layer[]{});
|
|
||||||
if (layers == null)
|
|
||||||
return;
|
|
||||||
for (int i = 0; i < layers.length; i++) {
|
|
||||||
if (layers[i] instanceof RecurrentLayer)
|
|
||||||
((RecurrentLayer) layers[i]).rnnClearPreviousState();
|
|
||||||
else if (layers[i] instanceof MultiLayerNetwork) {
|
|
||||||
((MultiLayerNetwork) layers[i]).rnnClearPreviousState();
|
|
||||||
} else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){
|
|
||||||
((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying()).rnnClearPreviousState();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/** Remove the mask arrays from all layers.<br>
|
|
||||||
* See {@link #setLayerMaskArrays(INDArray, INDArray)} for details on mask arrays.
|
|
||||||
*/
|
|
||||||
public void clearLayerMaskArrays() {
|
|
||||||
Layer[] layers = conf.getLayers().toArray(new Layer[]{});
|
|
||||||
for (Layer layer : layers) {
|
|
||||||
layer.setMaskArray(null);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Increment the epoch count (in the underlying {@link MultiLayerConfiguration} by 1).
|
|
||||||
* Note that this is done <i>automatically</i> when using iterator-based fitting methods, such as
|
|
||||||
* {@link #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, INDArray/INDArray etc),
|
|
||||||
* the network has no way to know when one epoch ends and another starts. In such situations, this method
|
|
||||||
* can be used to increment the epoch counter.<br>
|
|
||||||
* Note that the epoch counter is used for situations such as some learning rate schedules, and the like.
|
|
||||||
*
|
|
||||||
* The current epoch count can be obtained using {@code MultiLayerConfiguration.getLayerwiseConfiguration().getEpochCount()}
|
|
||||||
*/
|
|
||||||
public void incrementEpochCount(){
|
|
||||||
conf.setEpochCount(conf.getEpochCount() + 1);
|
|
||||||
synchronizeIterEpochCounts();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void synchronizeIterEpochCounts() {
|
|
||||||
//TODO: this is necessary for some schedules - but the redundant values are a little ugly...
|
|
||||||
int currIter = conf.getIterationCount();
|
|
||||||
int currEpoch = conf.getEpochCount();
|
|
||||||
log.error("Something went wrong here. Code incomplete");
|
|
||||||
/*for(Layer l : conf.getLayers()) {
|
|
||||||
l.setIterationCount(currIter);
|
|
||||||
l.setEpochCount(currEpoch);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method just makes sure there's no state preserved within layers
|
|
||||||
*/
|
|
||||||
public void clearLayersStates() {
|
|
||||||
for (Layer layer : layers) {
|
|
||||||
layer.clear();
|
|
||||||
layer.clearNoiseWeightParams();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many
|
|
||||||
* and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths
|
|
||||||
* within the same minibatch.<br>
|
|
||||||
* For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape
|
|
||||||
* [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength]
|
|
||||||
* and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding -
|
|
||||||
* at a given time step).<br>
|
|
||||||
* <b>NOTE</b>: This method is not usually used directly. Instead, methods such as @link #feedForward(INDArray, INDArray, INDArray)}
|
|
||||||
* and @link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking internally.
|
|
||||||
* @param featuresMaskArray Mask array for features (input)
|
|
||||||
* @param labelsMaskArray Mask array for labels (output)
|
|
||||||
* @see #clearLayerMaskArrays()
|
|
||||||
*/
|
|
||||||
public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) {
|
|
||||||
if (featuresMaskArray != null) {
|
|
||||||
|
|
||||||
if (featuresMaskArray.size(0) > Integer.MAX_VALUE)
|
|
||||||
throw new ND4JArraySizeException();
|
|
||||||
//New approach: use feedForwardMaskArray method
|
|
||||||
feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0));
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
//feedforward layers below a RNN layer: need the input (features) mask array
|
|
||||||
//Reason: even if the time series input is zero padded, the output from the dense layers are
|
|
||||||
// non-zero (i.e., activationFunction(0*weights + bias) != 0 in general)
|
|
||||||
//This assumes that the time series input is masked - i.e., values are 0 at the padded time steps,
|
|
||||||
// so we don't need to do anything for the recurrent layer
|
|
||||||
|
|
||||||
//Now, if mask array is 2d -> need to reshape to 1d (column vector) in the exact same order
|
|
||||||
// as is done for 3d -> 2d time series reshaping
|
|
||||||
INDArray reshapedFeaturesMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(featuresMaskArray);
|
|
||||||
|
|
||||||
for( int i=0; i<layers.length-1; i++ ){
|
|
||||||
Type t = layers[i].type();
|
|
||||||
if( t == Type.CONVOLUTIONAL || t == Type.FEED_FORWARD ){
|
|
||||||
layers[i].setMaskArray(reshapedFeaturesMask);
|
|
||||||
} else if( t == Type.RECURRENT ) break;
|
|
||||||
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
if (labelsMaskArray != null) {
|
|
||||||
if (!(getOutputLayer() instanceof IOutputLayer))
|
|
||||||
return;
|
|
||||||
layers[layers.length - 1].setMaskArray(labelsMaskArray);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the output layer - i.e., the last layer in the netwok
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Layer getOutputLayer() {
|
|
||||||
Layer ret = layers[layers.length - 1];
|
|
||||||
if (ret instanceof FrozenLayerWithBackprop) {
|
|
||||||
ret = ((FrozenLayerWithBackprop) ret).getInsideLayer();
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
|
|
||||||
int minibatchSize) {
|
|
||||||
if (maskArray == null) {
|
|
||||||
for (int i = 0; i < layers.length; i++) {
|
|
||||||
layers[i].feedForwardMaskArray(null, null, minibatchSize);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
//Do a forward pass through each preprocessor and layer
|
|
||||||
for (int i = 0; i < layers.length; i++) {
|
|
||||||
InputPreProcessor preProcessor = conf.getInputPreProcessors().get(i);
|
|
||||||
|
|
||||||
if (preProcessor != null) {
|
|
||||||
Pair<INDArray, MaskState> p =
|
|
||||||
preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
|
|
||||||
if (p != null) {
|
|
||||||
maskArray = p.getFirst();
|
|
||||||
currentMaskState = p.getSecond();
|
|
||||||
} else {
|
|
||||||
maskArray = null;
|
|
||||||
currentMaskState = null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Pair<INDArray, MaskState> p =
|
|
||||||
layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
|
|
||||||
if (p != null) {
|
|
||||||
maskArray = p.getFirst();
|
|
||||||
currentMaskState = p.getSecond();
|
|
||||||
} else {
|
|
||||||
maskArray = null;
|
|
||||||
currentMaskState = null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Pair<>(maskArray, currentMaskState);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.ai.dnn.networks;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
||||||
|
import net.brutex.ai.dnn.api.INeuralNetwork;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Artificial Neural Network An artificial neural network (1) takes some input data, and (2)
|
||||||
|
* transforms this input data by calculating a weighted sum over the inputs and (3) applies a
|
||||||
|
* non-linear function to this transformation to calculate an intermediate state. The three steps
|
||||||
|
* above constitute what is known as a layer, and the transformative function is often referred to
|
||||||
|
* as a unit. The intermediate states—often termed features—are used as the input into another
|
||||||
|
* layer.
|
||||||
|
* <p>
|
||||||
|
* Through repetition of these steps, the artificial neural network learns multiple layers of
|
||||||
|
* non-linear features, which it then combines in a final layer to create a prediction.
|
||||||
|
* <p>
|
||||||
|
* The neural network learns by generating an error signal that measures the difference between the
|
||||||
|
* predictions of the network and the desired values and then using this error signal to change the
|
||||||
|
* weights (or parameters) so that predictions get more accurate.
|
||||||
|
*/
|
||||||
|
public abstract class ArtificialNeuralNetwork implements INeuralNetwork {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A neural network is created from a configuration.
|
||||||
|
* @param conf The (new net.brutex.ai) configuration for the network
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
@Setter //TODO make this also final and @NonNull
|
||||||
|
private NeuralNetworkConfiguration configuration;
|
||||||
|
}
|
|
@ -346,7 +346,7 @@ public abstract class BaseEarlyStoppingTrainer<T extends Model> implements IEarl
|
||||||
} else if(model instanceof ComputationGraph){
|
} else if(model instanceof ComputationGraph){
|
||||||
ComputationGraph cg = ((ComputationGraph) model);
|
ComputationGraph cg = ((ComputationGraph) model);
|
||||||
listeners = cg.getListeners();
|
listeners = cg.getListeners();
|
||||||
cg.getConfiguration().setEpochCount(epochNum);
|
cg.getComputationGraphConfiguration().setEpochCount(epochNum);
|
||||||
} else {
|
} else {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -431,7 +431,7 @@ public class GradientCheckUtil {
|
||||||
+ "DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
|
+ "DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
|
||||||
}
|
}
|
||||||
|
|
||||||
DataType netDataType = c.net.getConfiguration().getDataType();
|
DataType netDataType = c.net.getComputationGraphConfiguration().getDataType();
|
||||||
if (netDataType != DataType.DOUBLE) {
|
if (netDataType != DataType.DOUBLE) {
|
||||||
throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision ("
|
throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision ("
|
||||||
+ "is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil");
|
+ "is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil");
|
||||||
|
@ -444,8 +444,8 @@ public class GradientCheckUtil {
|
||||||
|
|
||||||
//Check configuration
|
//Check configuration
|
||||||
int layerCount = 0;
|
int layerCount = 0;
|
||||||
for (String vertexName : c.net.getConfiguration().getVertices().keySet()) {
|
for (String vertexName : c.net.getComputationGraphConfiguration().getVertices().keySet()) {
|
||||||
GraphVertex gv = c.net.getConfiguration().getVertices().get(vertexName);
|
GraphVertex gv = c.net.getComputationGraphConfiguration().getVertices().get(vertexName);
|
||||||
if (!(gv instanceof LayerVertex))
|
if (!(gv instanceof LayerVertex))
|
||||||
continue;
|
continue;
|
||||||
LayerVertex lv = (LayerVertex) gv;
|
LayerVertex lv = (LayerVertex) gv;
|
||||||
|
|
|
@ -32,17 +32,18 @@ import org.nd4j.common.primitives.Pair;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A layer is the highest-level building block in deep learning. A layer is a container that usually
|
||||||
|
* receives weighted input, transforms it with a set of mostly non-linear functions and then passes
|
||||||
|
* these values as output to the next layer. A layer is usually uniform, that is it only contains
|
||||||
|
* one type of activation function, pooling, convolution etc. so that it can be easily compared to
|
||||||
|
* other parts of the network. The first and last layers in a network are called input and output
|
||||||
|
* layers, respectively, and all layers in between are called hidden layers.
|
||||||
|
*
|
||||||
|
* @see <a href="https://developer.nvidia.com/blog/deep-learning-nutshell-core-concept">NVIDIA Deep Learning In A Nutshell</a>
|
||||||
|
*/
|
||||||
public interface Layer extends Serializable, Cloneable, Model, Trainable {
|
public interface Layer extends Serializable, Cloneable, Model, Trainable {
|
||||||
|
|
||||||
enum Type {
|
|
||||||
FEED_FORWARD, RECURRENT, CONVOLUTIONAL, CONVOLUTIONAL3D,
|
|
||||||
SUBSAMPLING, UPSAMPLING, RECURSIVE, MULTILAYER, NORMALIZATION
|
|
||||||
}
|
|
||||||
|
|
||||||
enum TrainingMode {
|
|
||||||
TRAIN, TEST
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method sets given CacheMode for current layer
|
* This method sets given CacheMode for current layer
|
||||||
*
|
*
|
||||||
|
@ -51,11 +52,12 @@ public interface Layer extends Serializable, Cloneable, Model, Trainable {
|
||||||
void setCacheMode(CacheMode mode);
|
void setCacheMode(CacheMode mode);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the regularization component of the score, for the parameters in this layer<br>
|
* Calculate the regularization component of the score, for the parameters in this layer<br> For
|
||||||
* For example, the L1, L2 and/or weight decay components of the loss function<br>
|
* example, the L1, L2 and/or weight decay components of the loss function<br>
|
||||||
*
|
*
|
||||||
* @param backpropOnlyParams If true: calculate regularization score based on backprop params only. If false: calculate
|
* @param backpropOnlyParams If true: calculate regularization score based on backprop params
|
||||||
* based on all params (including pretrain params, if any)
|
* only. If false: calculate based on all params (including pretrain
|
||||||
|
* params, if any)
|
||||||
* @return the regularization score of
|
* @return the regularization score of
|
||||||
*/
|
*/
|
||||||
double calcRegularizationScore(boolean backpropOnlyParams);
|
double calcRegularizationScore(boolean backpropOnlyParams);
|
||||||
|
@ -67,28 +69,29 @@ public interface Layer extends Serializable, Cloneable, Model, Trainable {
|
||||||
*/
|
*/
|
||||||
Type type();
|
Type type();
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the gradient relative to the error in the next layer
|
* Calculate the gradient relative to the error in the next layer
|
||||||
*
|
*
|
||||||
* @param epsilon w^(L+1)*delta^(L+1). Or, equiv: dC/da, i.e., (dC/dz)*(dz/da) = dC/da, where C
|
* @param epsilon w^(L+1)*delta^(L+1). Or, equiv: dC/da, i.e., (dC/dz)*(dz/da) = dC/da, where
|
||||||
* is cost function a=sigma(z) is activation.
|
* C is cost function a=sigma(z) is activation.
|
||||||
* @param workspaceMgr Workspace manager
|
* @param workspaceMgr Workspace manager
|
||||||
* @return Pair<Gradient , INDArray> where Gradient is gradient for this layer, INDArray is epsilon (activation gradient)
|
* @return Pair<Gradient, INDArray> where Gradient is gradient for this layer, INDArray is
|
||||||
* needed by next layer, but before element-wise multiply by sigmaPrime(z). So for standard feed-forward layer, if this layer is
|
* epsilon (activation gradient) needed by next layer, but before element-wise multiply by
|
||||||
* L, then return.getSecond() == dL/dIn = (w^(L)*(delta^(L))^T)^T. Note that the returned array should be placed in the
|
* sigmaPrime(z). So for standard feed-forward layer, if this layer is L, then return.getSecond()
|
||||||
* {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATION_GRAD} workspace via the workspace manager
|
* == dL/dIn = (w^(L)*(delta^(L))^T)^T. Note that the returned array should be placed in the
|
||||||
|
* {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATION_GRAD} workspace via the workspace
|
||||||
|
* manager
|
||||||
*/
|
*/
|
||||||
Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr);
|
Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform forward pass and return the activations array with the last set input
|
* Perform forward pass and return the activations array with the last set input
|
||||||
*
|
*
|
||||||
* @param training training or test mode
|
* @param training training or test mode
|
||||||
* @param workspaceMgr Workspace manager
|
* @param workspaceMgr Workspace manager
|
||||||
* @return the activation (layer output) of the last specified input. Note that the returned array should be placed
|
* @return the activation (layer output) of the last specified input. Note that the returned array
|
||||||
* in the {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATIONS} workspace via the workspace manager
|
* should be placed in the {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATIONS} workspace
|
||||||
|
* via the workspace manager
|
||||||
*/
|
*/
|
||||||
INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr);
|
INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr);
|
||||||
|
|
||||||
|
@ -99,7 +102,8 @@ public interface Layer extends Serializable, Cloneable, Model, Trainable {
|
||||||
* @param training train or test mode
|
* @param training train or test mode
|
||||||
* @param mgr Workspace manager.
|
* @param mgr Workspace manager.
|
||||||
* @return Activations array. Note that the returned array should be placed in the
|
* @return Activations array. Note that the returned array should be placed in the
|
||||||
* {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATIONS} workspace via the workspace manager
|
* {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATIONS} workspace via the workspace
|
||||||
|
* manager
|
||||||
*/
|
*/
|
||||||
INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr mgr);
|
INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr mgr);
|
||||||
|
|
||||||
|
@ -109,42 +113,42 @@ public interface Layer extends Serializable, Cloneable, Model, Trainable {
|
||||||
Collection<TrainingListener> getListeners();
|
Collection<TrainingListener> getListeners();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the {@link TrainingListener}s for this model. If any listeners have previously been set, they will be
|
* Set the {@link TrainingListener}s for this model. If any listeners have previously been set,
|
||||||
* replaced by this method
|
* they will be replaced by this method
|
||||||
*/
|
*/
|
||||||
void setListeners(TrainingListener... listeners);
|
void setListeners(TrainingListener... listeners);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the {@link TrainingListener}s for this model. If any listeners have previously been set, they will be
|
* Set the {@link TrainingListener}s for this model. If any listeners have previously been set,
|
||||||
* replaced by this method
|
* they will be replaced by this method
|
||||||
*/
|
*/
|
||||||
void setListeners(Collection<TrainingListener> listeners);
|
void setListeners(Collection<TrainingListener> listeners);
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the layer index.
|
|
||||||
*/
|
|
||||||
void setIndex(int index);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the layer index.
|
* Get the layer index.
|
||||||
*/
|
*/
|
||||||
int getIndex();
|
int getIndex();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the layer index.
|
||||||
|
*/
|
||||||
|
void setIndex(int index);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return The current iteration count (number of parameter updates) for the layer/network
|
* @return The current iteration count (number of parameter updates) for the layer/network
|
||||||
*/
|
*/
|
||||||
int getIterationCount();
|
int getIterationCount();
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The current epoch count (number of training epochs passed) for the layer/network
|
|
||||||
*/
|
|
||||||
int getEpochCount();
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the current iteration count (number of parameter updates) for the layer/network
|
* Set the current iteration count (number of parameter updates) for the layer/network
|
||||||
*/
|
*/
|
||||||
void setIterationCount(int iterationCount);
|
void setIterationCount(int iterationCount);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The current epoch count (number of training epochs passed) for the layer/network
|
||||||
|
*/
|
||||||
|
int getEpochCount();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the current epoch count (number of epochs passed ) for the layer/network
|
* Set the current epoch count (number of epochs passed ) for the layer/network
|
||||||
*/
|
*/
|
||||||
|
@ -155,14 +159,6 @@ public interface Layer extends Serializable, Cloneable, Model, Trainable {
|
||||||
*/
|
*/
|
||||||
void setInput(INDArray input, LayerWorkspaceMgr workspaceMgr);
|
void setInput(INDArray input, LayerWorkspaceMgr workspaceMgr);
|
||||||
|
|
||||||
/**
|
|
||||||
* Set current/last input mini-batch size.<br>
|
|
||||||
* Used for score and gradient calculations. Mini batch size may be different from
|
|
||||||
* getInput().size(0) due to reshaping operations - for example, when using RNNs with
|
|
||||||
* DenseLayer and OutputLayer. Called automatically during forward pass.
|
|
||||||
*/
|
|
||||||
void setInputMiniBatchSize(int size);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get current/last input mini-batch size, as set by setInputMiniBatchSize(int)
|
* Get current/last input mini-batch size, as set by setInputMiniBatchSize(int)
|
||||||
*
|
*
|
||||||
|
@ -171,16 +167,23 @@ public interface Layer extends Serializable, Cloneable, Model, Trainable {
|
||||||
int getInputMiniBatchSize();
|
int getInputMiniBatchSize();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the mask array. Note: In general, {@link #feedForwardMaskArray(INDArray, MaskState, int)} should be used in
|
* Set current/last input mini-batch size.<br> Used for score and gradient calculations. Mini
|
||||||
* preference to this.
|
* batch size may be different from getInput().size(0) due to reshaping operations - for example,
|
||||||
|
* when using RNNs with DenseLayerConfiguration and OutputLayer. Called automatically during
|
||||||
|
* forward pass.
|
||||||
|
*/
|
||||||
|
void setInputMiniBatchSize(int size);
|
||||||
|
|
||||||
|
INDArray getMaskArray();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the mask array. Note: In general, {@link #feedForwardMaskArray(INDArray, MaskState, int)}
|
||||||
|
* should be used in preference to this.
|
||||||
*
|
*
|
||||||
* @param maskArray Mask array to set
|
* @param maskArray Mask array to set
|
||||||
*/
|
*/
|
||||||
void setMaskArray(INDArray maskArray);
|
void setMaskArray(INDArray maskArray);
|
||||||
|
|
||||||
|
|
||||||
INDArray getMaskArray();
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns true if the layer can be trained in an unsupervised/pretrain manner (AE, VAE, etc)
|
* Returns true if the layer can be trained in an unsupervised/pretrain manner (AE, VAE, etc)
|
||||||
*
|
*
|
||||||
|
@ -188,38 +191,50 @@ public interface Layer extends Serializable, Cloneable, Model, Trainable {
|
||||||
*/
|
*/
|
||||||
boolean isPretrainLayer();
|
boolean isPretrainLayer();
|
||||||
|
|
||||||
|
|
||||||
void clearNoiseWeightParams();
|
void clearNoiseWeightParams();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A performance optimization: mark whether the layer is allowed to modify its input array in-place. In many cases,
|
* A performance optimization: mark whether the layer is allowed to modify its input array
|
||||||
* this is totally safe - in others, the input array will be shared by multiple layers, and hence it's not safe to
|
* in-place. In many cases, this is totally safe - in others, the input array will be shared by
|
||||||
* modify the input array.
|
* multiple layers, and hence it's not safe to modify the input array. This is usually used by ops
|
||||||
* This is usually used by ops such as dropout.
|
* such as dropout.
|
||||||
* @param allow If true: the input array is safe to modify. If false: the input array should be copied before it
|
*
|
||||||
* is modified (i.e., in-place modifications are un-safe)
|
* @param allow If true: the input array is safe to modify. If false: the input array should be
|
||||||
|
* copied before it is modified (i.e., in-place modifications are un-safe)
|
||||||
*/
|
*/
|
||||||
void allowInputModification(boolean allow);
|
void allowInputModification(boolean allow);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Feed forward the input mask array, setting in the layer as appropriate. This allows different layers to
|
* Feed forward the input mask array, setting in the layer as appropriate. This allows different
|
||||||
* handle masks differently - for example, bidirectional RNNs and normal RNNs operate differently with masks (the
|
* layers to handle masks differently - for example, bidirectional RNNs and normal RNNs operate
|
||||||
* former sets activations to 0 outside of the data present region (and keeps the mask active for future layers like
|
* differently with masks (the former sets activations to 0 outside of the data present region
|
||||||
* dense layers), whereas normal RNNs don't zero out the activations/errors )instead relying on backpropagated error
|
* (and keeps the mask active for future layers like dense layers), whereas normal RNNs don't zero
|
||||||
* arrays to handle the variable length case.<br>
|
* out the activations/errors )instead relying on backpropagated error arrays to handle the
|
||||||
* This is also used for example for networks that contain global pooling layers, arbitrary preprocessors, etc.
|
* variable length case.<br> This is also used for example for networks that contain global
|
||||||
|
* pooling layers, arbitrary preprocessors, etc.
|
||||||
*
|
*
|
||||||
* @param maskArray Mask array to set
|
* @param maskArray Mask array to set
|
||||||
* @param currentMaskState Current state of the mask - see {@link MaskState}
|
* @param currentMaskState Current state of the mask - see {@link MaskState}
|
||||||
* @param minibatchSize Current minibatch size. Needs to be known as it cannot always be inferred from the activations
|
* @param minibatchSize Current minibatch size. Needs to be known as it cannot always be
|
||||||
* array due to reshaping (such as a DenseLayer within a recurrent neural network)
|
* inferred from the activations array due to reshaping (such as a
|
||||||
|
* DenseLayerConfiguration within a recurrent neural network)
|
||||||
* @return New mask array after this layer, along with the new mask state.
|
* @return New mask array after this layer, along with the new mask state.
|
||||||
*/
|
*/
|
||||||
Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize);
|
Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
|
||||||
|
int minibatchSize);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return Get the layer helper, if any
|
* @return Get the layer helper, if any
|
||||||
*/
|
*/
|
||||||
LayerHelper getHelper();
|
LayerHelper getHelper();
|
||||||
|
|
||||||
|
|
||||||
|
enum Type {
|
||||||
|
FEED_FORWARD, RECURRENT, CONVOLUTIONAL, CONVOLUTIONAL3D,
|
||||||
|
SUBSAMPLING, UPSAMPLING, RECURSIVE, MULTILAYER, NORMALIZATION
|
||||||
|
}
|
||||||
|
|
||||||
|
enum TrainingMode {
|
||||||
|
TRAIN, TEST
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public interface ModelAdapter<T> extends OutputAdapter<T> {
|
public interface ModelAdapter<T> extends OutputAdapter<T> {
|
||||||
/**
|
/**
|
||||||
* This method invokes model internally, and does convertion to T
|
* This method invokes model internally, and does conversion to T
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
T apply(Model model, INDArray[] inputs, INDArray[] inputMasks, INDArray[] labelsMasks);
|
T apply(Model model, INDArray[] inputs, INDArray[] inputMasks, INDArray[] labelsMasks);
|
||||||
|
|
|
@ -41,7 +41,7 @@ public interface ParamInitializer {
|
||||||
/**
|
/**
|
||||||
* Get a list of all parameter keys given the layer configuration
|
* Get a list of all parameter keys given the layer configuration
|
||||||
*
|
*
|
||||||
* @param layer Layer
|
* @param layer ILayer
|
||||||
* @return All parameter keys
|
* @return All parameter keys
|
||||||
*/
|
*/
|
||||||
List<String> paramKeys(org.deeplearning4j.nn.conf.layers.Layer layer);
|
List<String> paramKeys(org.deeplearning4j.nn.conf.layers.Layer layer);
|
||||||
|
@ -49,7 +49,7 @@ public interface ParamInitializer {
|
||||||
/**
|
/**
|
||||||
* Weight parameter keys given the layer configuration
|
* Weight parameter keys given the layer configuration
|
||||||
*
|
*
|
||||||
* @param layer Layer
|
* @param layer ILayer
|
||||||
* @return Weight parameter keys
|
* @return Weight parameter keys
|
||||||
*/
|
*/
|
||||||
List<String> weightKeys(org.deeplearning4j.nn.conf.layers.Layer layer);
|
List<String> weightKeys(org.deeplearning4j.nn.conf.layers.Layer layer);
|
||||||
|
@ -57,7 +57,7 @@ public interface ParamInitializer {
|
||||||
/**
|
/**
|
||||||
* Bias parameter keys given the layer configuration
|
* Bias parameter keys given the layer configuration
|
||||||
*
|
*
|
||||||
* @param layer Layer
|
* @param layer ILayer
|
||||||
* @return Bias parameter keys
|
* @return Bias parameter keys
|
||||||
*/
|
*/
|
||||||
List<String> biasKeys(org.deeplearning4j.nn.conf.layers.Layer layer);
|
List<String> biasKeys(org.deeplearning4j.nn.conf.layers.Layer layer);
|
||||||
|
@ -65,7 +65,7 @@ public interface ParamInitializer {
|
||||||
/**
|
/**
|
||||||
* Is the specified parameter a weight?
|
* Is the specified parameter a weight?
|
||||||
*
|
*
|
||||||
* @param layer Layer
|
* @param layer ILayer
|
||||||
* @param key Key to check
|
* @param key Key to check
|
||||||
* @return True if parameter is a weight
|
* @return True if parameter is a weight
|
||||||
*/
|
*/
|
||||||
|
@ -74,7 +74,7 @@ public interface ParamInitializer {
|
||||||
/**
|
/**
|
||||||
* Is the specified parameter a bias?
|
* Is the specified parameter a bias?
|
||||||
*
|
*
|
||||||
* @param layer Layer
|
* @param layer ILayer
|
||||||
* @param key Key to check
|
* @param key Key to check
|
||||||
* @return True if parameter is a bias
|
* @return True if parameter is a bias
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -47,7 +47,7 @@ public interface TrainingConfig {
|
||||||
* Is the specified parameter a layerwise pretraining only parameter?<br>
|
* Is the specified parameter a layerwise pretraining only parameter?<br>
|
||||||
* For example, visible bias params in an autoencoder (or, decoder params in a variational autoencoder) aren't
|
* For example, visible bias params in an autoencoder (or, decoder params in a variational autoencoder) aren't
|
||||||
* used during supervised backprop.<br>
|
* used during supervised backprop.<br>
|
||||||
* Layers (like DenseLayer, etc) with no pretrainable parameters will return false for all (valid) inputs.
|
* Layers (like DenseLayerConfiguration, etc) with no pretrainable parameters will return false for all (valid) inputs.
|
||||||
*
|
*
|
||||||
* @param paramName Parameter name/key
|
* @param paramName Parameter name/key
|
||||||
* @return True if the parameter is for layerwise pretraining only, false otherwise
|
* @return True if the parameter is for layerwise pretraining only, false otherwise
|
||||||
|
|
|
@ -36,7 +36,7 @@ public interface Updater extends Serializable {
|
||||||
/**
|
/**
|
||||||
* Set the internal (historical) state view array for this updater
|
* Set the internal (historical) state view array for this updater
|
||||||
*
|
*
|
||||||
* @param layer Layer that this updater belongs to
|
* @param layer ILayer that this updater belongs to
|
||||||
* @param viewArray View array
|
* @param viewArray View array
|
||||||
* @param initialize Whether to initialize the array or not
|
* @param initialize Whether to initialize the array or not
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -33,7 +33,7 @@ public interface LayerConstraint extends Cloneable, Serializable {
|
||||||
* Apply a given constraint to a layer at each iteration
|
* Apply a given constraint to a layer at each iteration
|
||||||
* in the provided epoch, after parameters have been updated.
|
* in the provided epoch, after parameters have been updated.
|
||||||
*
|
*
|
||||||
* @param layer org.deeplearning4j.nn.api.Layer
|
* @param layer org.deeplearning4j.nn.api.ILayer
|
||||||
* @param iteration given iteration as integer
|
* @param iteration given iteration as integer
|
||||||
* @param epoch current epoch as integer
|
* @param epoch current epoch as integer
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -66,10 +66,10 @@ public interface RecurrentLayer extends Layer {
|
||||||
* (a) result in the same output<br>
|
* (a) result in the same output<br>
|
||||||
* (b) leave the state maps (both stateMap and tBpttStateMap) in an identical state
|
* (b) leave the state maps (both stateMap and tBpttStateMap) in an identical state
|
||||||
*
|
*
|
||||||
* @param input Layer input
|
* @param input ILayer input
|
||||||
* @param training if true: training. Otherwise: test
|
* @param training if true: training. Otherwise: test
|
||||||
* @param storeLastForTBPTT If true: store the final state in tBpttStateMap for use in truncated BPTT training
|
* @param storeLastForTBPTT If true: store the final state in tBpttStateMap for use in truncated BPTT training
|
||||||
* @return Layer activations
|
* @return ILayer activations
|
||||||
*/
|
*/
|
||||||
INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT, LayerWorkspaceMgr workspaceMg);
|
INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT, LayerWorkspaceMgr workspaceMg);
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ public interface RecurrentLayer extends Layer {
|
||||||
void rnnSetTBPTTState(Map<String, INDArray> state);
|
void rnnSetTBPTTState(Map<String, INDArray> state);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Truncated BPTT equivalent of Layer.backpropGradient().
|
* Truncated BPTT equivalent of ILayer.backpropGradient().
|
||||||
* Primary difference here is that forward pass in the context of BPTT is that we do
|
* Primary difference here is that forward pass in the context of BPTT is that we do
|
||||||
* forward pass using stored state for truncated BPTT vs. from zero initialization
|
* forward pass using stored state for truncated BPTT vs. from zero initialization
|
||||||
* for standard BPTT.
|
* for standard BPTT.
|
||||||
|
|
|
@ -25,6 +25,7 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.brutex.ai.dnn.api.INeuralNetworkConfiguration;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
|
@ -68,7 +69,9 @@ import java.util.*;
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@EqualsAndHashCode(exclude = {"iterationCount", "epochCount"})
|
@EqualsAndHashCode(exclude = {"iterationCount", "epochCount"})
|
||||||
public class NeuralNetConfiguration implements Serializable, Cloneable {
|
public class NeuralNetConfiguration implements Serializable, Cloneable,
|
||||||
|
INeuralNetworkConfiguration {
|
||||||
|
|
||||||
|
|
||||||
protected Layer layer;
|
protected Layer layer;
|
||||||
//batch size: primarily used for conv nets. Will be reinforced if set.
|
//batch size: primarily used for conv nets. Will be reinforced if set.
|
||||||
|
|
|
@ -43,7 +43,7 @@ public class MaxNormConstraint extends BaseConstraint {
|
||||||
/**
|
/**
|
||||||
* @param maxNorm Maximum L2 value
|
* @param maxNorm Maximum L2 value
|
||||||
* @param paramNames Which parameter names to apply constraint to
|
* @param paramNames Which parameter names to apply constraint to
|
||||||
* @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
* @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
||||||
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
||||||
* parameters which have order [depthOut, depthIn, kH, kW]
|
* parameters which have order [depthOut, depthIn, kH, kW]
|
||||||
*/
|
*/
|
||||||
|
@ -56,7 +56,7 @@ public class MaxNormConstraint extends BaseConstraint {
|
||||||
* Apply to weights but not biases by default
|
* Apply to weights but not biases by default
|
||||||
*
|
*
|
||||||
* @param maxNorm Maximum L2 value
|
* @param maxNorm Maximum L2 value
|
||||||
* @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
* @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
||||||
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
||||||
* parameters which have order [depthOut, depthIn, kH, kW]
|
* parameters which have order [depthOut, depthIn, kH, kW]
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -51,7 +51,7 @@ public class MinMaxNormConstraint extends BaseConstraint {
|
||||||
*
|
*
|
||||||
* @param max Maximum L2 value
|
* @param max Maximum L2 value
|
||||||
* @param min Minimum L2 value
|
* @param min Minimum L2 value
|
||||||
* @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
* @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
||||||
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
||||||
* parameters which have order [depthOut, depthIn, kH, kW]
|
* parameters which have order [depthOut, depthIn, kH, kW]
|
||||||
*/
|
*/
|
||||||
|
@ -65,7 +65,7 @@ public class MinMaxNormConstraint extends BaseConstraint {
|
||||||
* @param max Maximum L2 value
|
* @param max Maximum L2 value
|
||||||
* @param min Minimum L2 value
|
* @param min Minimum L2 value
|
||||||
* @param rate Constraint rate
|
* @param rate Constraint rate
|
||||||
* @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
* @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
||||||
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
||||||
* parameters which have order [depthOut, depthIn, kH, kW]
|
* parameters which have order [depthOut, depthIn, kH, kW]
|
||||||
*/
|
*/
|
||||||
|
@ -79,7 +79,7 @@ public class MinMaxNormConstraint extends BaseConstraint {
|
||||||
* @param min Minimum L2 value
|
* @param min Minimum L2 value
|
||||||
* @param rate Constraint rate
|
* @param rate Constraint rate
|
||||||
* @param paramNames Which parameter names to apply constraint to
|
* @param paramNames Which parameter names to apply constraint to
|
||||||
* @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
* @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
||||||
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
||||||
* parameters which have order [depthOut, depthIn, kH, kW]
|
* parameters which have order [depthOut, depthIn, kH, kW]
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -39,7 +39,7 @@ public class UnitNormConstraint extends BaseConstraint {
|
||||||
/**
|
/**
|
||||||
* Apply to weights but not biases by default
|
* Apply to weights but not biases by default
|
||||||
*
|
*
|
||||||
* @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
* @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
||||||
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
||||||
* parameters which have order [depthOut, depthIn, kH, kW]
|
* parameters which have order [depthOut, depthIn, kH, kW]
|
||||||
*/
|
*/
|
||||||
|
@ -49,7 +49,7 @@ public class UnitNormConstraint extends BaseConstraint {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
* @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should
|
||||||
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
* be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of
|
||||||
* parameters which have order [depthOut, depthIn, kH, kW]
|
* parameters which have order [depthOut, depthIn, kH, kW]
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -40,8 +39,8 @@ public class LayerVertex extends GraphVertex {
|
||||||
|
|
||||||
private NeuralNetConfiguration layerConf;
|
private NeuralNetConfiguration layerConf;
|
||||||
private InputPreProcessor preProcessor;
|
private InputPreProcessor preProcessor;
|
||||||
//Set outputVertex to true when Layer is an OutputLayer, OR For use in specialized situations like reinforcement learning
|
//Set outputVertex to true when ILayer is an OutputLayer, OR For use in specialized situations like reinforcement learning
|
||||||
// For RL situations, this Layer insn't an OutputLayer, but is the last layer in a graph, that gets its error/epsilon
|
// For RL situations, this ILayer insn't an OutputLayer, but is the last layer in a graph, that gets its error/epsilon
|
||||||
// passed in externally
|
// passed in externally
|
||||||
private boolean outputVertex;
|
private boolean outputVertex;
|
||||||
|
|
||||||
|
@ -99,7 +98,7 @@ public class LayerVertex extends GraphVertex {
|
||||||
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
|
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
|
||||||
INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
|
INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
|
||||||
//Now, we need to work out if this vertex is an output vertex or not...
|
//Now, we need to work out if this vertex is an output vertex or not...
|
||||||
boolean isOutput = graph.getConfiguration().getNetworkOutputs().contains(name);
|
boolean isOutput = graph.getComputationGraphConfiguration().getNetworkOutputs().contains(name);
|
||||||
|
|
||||||
org.deeplearning4j.nn.api.Layer layer =
|
org.deeplearning4j.nn.api.Layer layer =
|
||||||
layerConf.getLayer().instantiate(layerConf, null, idx, paramsView, initializeParams, networkDatatype);
|
layerConf.getLayer().instantiate(layerConf, null, idx, paramsView, initializeParams, networkDatatype);
|
||||||
|
|
|
@ -134,7 +134,7 @@ public class ActivationLayer extends NoParamLayer {
|
||||||
private IActivation activationFn = null;
|
private IActivation activationFn = null;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Layer activation function. Typical values include:<br> "relu" (rectified linear), "tanh", "sigmoid",
|
* ILayer activation function. Typical values include:<br> "relu" (rectified linear), "tanh", "sigmoid",
|
||||||
* "softmax", "hardtanh", "leakyrelu", "maxout", "softsign", "softplus"
|
* "softmax", "hardtanh", "leakyrelu", "maxout", "softsign", "softplus"
|
||||||
*
|
*
|
||||||
* @deprecated Use {@link #activation(Activation)} or {@link @activation(IActivation)}
|
* @deprecated Use {@link #activation(Activation)} or {@link @activation(IActivation)}
|
||||||
|
|
|
@ -176,7 +176,7 @@ public abstract class BaseLayer extends Layer implements Serializable, Cloneable
|
||||||
protected double biasInit = Double.NaN;
|
protected double biasInit = Double.NaN;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gain initialization value, for layers with Layer Normalization. Defaults to 1
|
* Gain initialization value, for layers with ILayer Normalization. Defaults to 1
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
protected double gainInit = Double.NaN;
|
protected double gainInit = Double.NaN;
|
||||||
|
@ -292,7 +292,7 @@ public abstract class BaseLayer extends Layer implements Serializable, Cloneable
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gain initialization value, for layers with Layer Normalization. Defaults to 1
|
* Gain initialization value, for layers with ILayer Normalization. Defaults to 1
|
||||||
*
|
*
|
||||||
* @param gainInit Value to use for initializing gain
|
* @param gainInit Value to use for initializing gain
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -63,14 +63,14 @@ public class CapsuleLayer extends SameDiffLayer {
|
||||||
this.routings = builder.routings;
|
this.routings = builder.routings;
|
||||||
|
|
||||||
if(capsules <= 0 || capsuleDimensions <= 0 || routings <= 0){
|
if(capsules <= 0 || capsuleDimensions <= 0 || routings <= 0){
|
||||||
throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \""
|
throw new IllegalArgumentException("Invalid configuration for Capsule ILayer (layer name = \""
|
||||||
+ layerName + "\"):"
|
+ layerName + "\"):"
|
||||||
+ " capsules, capsuleDimensions, and routings must be > 0. Got: "
|
+ " capsules, capsuleDimensions, and routings must be > 0. Got: "
|
||||||
+ capsules + ", " + capsuleDimensions + ", " + routings);
|
+ capsules + ", " + capsuleDimensions + ", " + routings);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(inputCapsules < 0 || inputCapsuleDimensions < 0){
|
if(inputCapsules < 0 || inputCapsuleDimensions < 0){
|
||||||
throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \""
|
throw new IllegalArgumentException("Invalid configuration for Capsule ILayer (layer name = \""
|
||||||
+ layerName + "\"):"
|
+ layerName + "\"):"
|
||||||
+ " inputCapsules and inputCapsuleDimensions must be >= 0 if set. Got: "
|
+ " inputCapsules and inputCapsuleDimensions must be >= 0 if set. Got: "
|
||||||
+ inputCapsules + ", " + inputCapsuleDimensions);
|
+ inputCapsules + ", " + inputCapsuleDimensions);
|
||||||
|
|
|
@ -55,7 +55,7 @@ public class DenseLayer extends FeedForwardLayer {
|
||||||
@Override
|
@Override
|
||||||
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
||||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
||||||
LayerValidation.assertNInNOutSet("DenseLayer", getLayerName(), layerIndex, getNIn(), getNOut());
|
LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getLayerName(), layerIndex, getNIn(), getNOut());
|
||||||
|
|
||||||
org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret =
|
org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret =
|
||||||
new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(conf, networkDataType);
|
new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(conf, networkDataType);
|
||||||
|
@ -101,7 +101,7 @@ public class DenseLayer extends FeedForwardLayer {
|
||||||
return new LayerMemoryReport.Builder(layerName, DenseLayer.class, inputType, outputType)
|
return new LayerMemoryReport.Builder(layerName, DenseLayer.class, inputType, outputType)
|
||||||
.standardMemory(numParams, updaterStateSize)
|
.standardMemory(numParams, updaterStateSize)
|
||||||
.workingMemory(0, 0, trainSizeFixed, trainSizeVariable) //No additional memory (beyond activations) for inference
|
.workingMemory(0, 0, trainSizeFixed, trainSizeVariable) //No additional memory (beyond activations) for inference
|
||||||
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer
|
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayerConfiguration
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -205,7 +205,7 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable {
|
||||||
/**
|
/**
|
||||||
* Is the specified parameter a layerwise pretraining only parameter?<br> For example, visible
|
* Is the specified parameter a layerwise pretraining only parameter?<br> For example, visible
|
||||||
* bias params in an autoencoder (or, decoder params in a variational autoencoder) aren't used
|
* bias params in an autoencoder (or, decoder params in a variational autoencoder) aren't used
|
||||||
* during supervised backprop.<br> Layers (like DenseLayer, etc) with no pretrainable parameters
|
* during supervised backprop.<br> Layers (like DenseLayerConfiguration, etc) with no pretrainable parameters
|
||||||
* will return false for all (valid) inputs.
|
* will return false for all (valid) inputs.
|
||||||
*
|
*
|
||||||
* @param paramName Parameter name/key
|
* @param paramName Parameter name/key
|
||||||
|
@ -255,7 +255,7 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable {
|
||||||
protected IDropout iDropout;
|
protected IDropout iDropout;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Layer name assigns layer string name. Allows easier differentiation between layers.
|
* ILayer name assigns layer string name. Allows easier differentiation between layers.
|
||||||
*/
|
*/
|
||||||
public T name(String layerName) {
|
public T name(String layerName) {
|
||||||
this.setLayerName(layerName);
|
this.setLayerName(layerName);
|
||||||
|
|
|
@ -42,7 +42,7 @@ public class LayerValidation {
|
||||||
/**
|
/**
|
||||||
* Asserts that the layer nIn and nOut values are set for the layer
|
* Asserts that the layer nIn and nOut values are set for the layer
|
||||||
*
|
*
|
||||||
* @param layerType Type of layer ("DenseLayer", etc)
|
* @param layerType Type of layer ("DenseLayerConfiguration", etc)
|
||||||
* @param layerName Name of the layer (may be null if not set)
|
* @param layerName Name of the layer (may be null if not set)
|
||||||
* @param layerIndex Index of the layer
|
* @param layerIndex Index of the layer
|
||||||
* @param nIn nIn value
|
* @param nIn nIn value
|
||||||
|
@ -60,7 +60,7 @@ public class LayerValidation {
|
||||||
/**
|
/**
|
||||||
* Asserts that the layer nOut value is set for the layer
|
* Asserts that the layer nOut value is set for the layer
|
||||||
*
|
*
|
||||||
* @param layerType Type of layer ("DenseLayer", etc)
|
* @param layerType Type of layer ("DenseLayerConfiguration", etc)
|
||||||
* @param layerName Name of the layer (may be null if not set)
|
* @param layerName Name of the layer (may be null if not set)
|
||||||
* @param layerIndex Index of the layer
|
* @param layerIndex Index of the layer
|
||||||
* @param nOut nOut value
|
* @param nOut nOut value
|
||||||
|
|
|
@ -147,7 +147,7 @@ public class LocalResponseNormalization extends Layer {
|
||||||
|
|
||||||
return new LayerMemoryReport.Builder(layerName, DenseLayer.class, inputType, inputType).standardMemory(0, 0)
|
return new LayerMemoryReport.Builder(layerName, DenseLayer.class, inputType, inputType).standardMemory(0, 0)
|
||||||
.workingMemory(0, 2 * actElementsPerEx, 0, 3 * actElementsPerEx)
|
.workingMemory(0, 2 * actElementsPerEx, 0, 3 * actElementsPerEx)
|
||||||
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer
|
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayerConfiguration
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ public class PrimaryCapsules extends SameDiffLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
if(capsules < 0){
|
if(capsules < 0){
|
||||||
throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \""
|
throw new IllegalArgumentException("Invalid configuration for Capsule ILayer (layer name = \""
|
||||||
+ layerName + "\"):"
|
+ layerName + "\"):"
|
||||||
+ " capsules must be >= 0 if set. Got: "
|
+ " capsules must be >= 0 if set. Got: "
|
||||||
+ capsules);
|
+ capsules);
|
||||||
|
|
|
@ -113,7 +113,7 @@ public class ElementWiseMultiplicationLayer extends org.deeplearning4j.nn.conf.l
|
||||||
return new LayerMemoryReport.Builder(layerName, ElementWiseMultiplicationLayer.class, inputType, outputType)
|
return new LayerMemoryReport.Builder(layerName, ElementWiseMultiplicationLayer.class, inputType, outputType)
|
||||||
.standardMemory(numParams, updaterStateSize)
|
.standardMemory(numParams, updaterStateSize)
|
||||||
.workingMemory(0, 0, trainSizeFixed, trainSizeVariable) //No additional memory (beyond activations) for inference
|
.workingMemory(0, 0, trainSizeFixed, trainSizeVariable) //No additional memory (beyond activations) for inference
|
||||||
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer
|
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayerConfiguration
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ public class TimeDistributed extends BaseWrapperLayer {
|
||||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param underlying Underlying (internal) layer - should be a feed forward type such as DenseLayer
|
* @param underlying Underlying (internal) layer - should be a feed forward type such as DenseLayerConfiguration
|
||||||
*/
|
*/
|
||||||
public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) {
|
public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) {
|
||||||
super(underlying);
|
super(underlying);
|
||||||
|
|
|
@ -33,7 +33,7 @@ public abstract class SameDiffLambdaLayer extends SameDiffLayer {
|
||||||
* The defineLayer method is used to define the forward pass for the layer
|
* The defineLayer method is used to define the forward pass for the layer
|
||||||
*
|
*
|
||||||
* @param sameDiff SameDiff instance to use to define the vertex
|
* @param sameDiff SameDiff instance to use to define the vertex
|
||||||
* @param layerInput Layer input variable
|
* @param layerInput ILayer input variable
|
||||||
* @return The output variable (corresponding to the output activations for the layer)
|
* @return The output variable (corresponding to the output activations for the layer)
|
||||||
*/
|
*/
|
||||||
public abstract SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput);
|
public abstract SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput);
|
||||||
|
|
|
@ -37,7 +37,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
|
||||||
* The defineVertex method is used to define the foward pass for the vertex
|
* The defineVertex method is used to define the foward pass for the vertex
|
||||||
*
|
*
|
||||||
* @param sameDiff SameDiff instance to use to define the vertex
|
* @param sameDiff SameDiff instance to use to define the vertex
|
||||||
* @param inputs Layer input variable
|
* @param inputs ILayer input variable
|
||||||
* @return The output variable (orresponding to the output activations for the vertex)
|
* @return The output variable (orresponding to the output activations for the vertex)
|
||||||
*/
|
*/
|
||||||
public abstract SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs);
|
public abstract SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs);
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue