Playing with some new code

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-03-23 17:39:00 +01:00
parent 4665c5a10a
commit fec570ff98
18 changed files with 2160 additions and 735 deletions

View File

@ -21,15 +21,10 @@
package net.brutex.gan; package net.brutex.gan;
import java.util.List;
import java.util.Random; import java.util.Random;
import javax.ws.rs.client.ClientBuilder;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.datavec.api.Writable;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.recordreader.ImageRecordReader;
@ -37,34 +32,29 @@ import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.ImageTransform; import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform; import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform; import org.datavec.image.transform.ResizeImageTransform;
import org.datavec.image.transform.ScaleImageTransform;
import org.datavec.image.transform.ShowImageTransform; import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreToChartListener; import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
import org.glassfish.jersey.client.JerseyClient;
import org.glassfish.jersey.client.JerseyClientBuilder;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*; import javax.swing.*;
@ -106,6 +96,8 @@ public class App {
new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH) new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH)
.build() .build()
}; };
} }
/** /**
@ -114,7 +106,7 @@ public class App {
* @return config * @return config
*/ */
private static MultiLayerConfiguration generator() { private static MultiLayerConfiguration generator() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration confxx = new NeuralNetConfiguration.Builder()
.seed(42) .seed(42)
.updater(UPDATER) .updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
@ -123,9 +115,25 @@ public class App {
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
.list(genLayers()) .list(genLayers())
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
.build();
log.debug("Generator network: \n{}", confxx.toJson());
NeuralNetworkConfiguration conf2 = NeuralNetworkConfiguration.builder().build();
NeuralNetworkConfiguration confx = NeuralNetworkConfiguration.builder()
.cacheMode(CacheMode.HOST)
.layer( new DenseLayer.Builder().build())
.layer( new DenseLayer.Builder().build())
.layer( BuildingBlockLayer.builder().build())
.layers( List.of(genLayers()))
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
.build(); .build();
return conf;
return confx;
} }
private static Layer[] disLayers() { private static Layer[] disLayers() {

View File

@ -0,0 +1,27 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
//apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
dependencies {
implementation platform(projects.cavisCommonPlatform)
implementation projects.cavisDnn.cavisDnnApi
implementation projects.cavisDnn.cavisDnnNn
}

View File

@ -0,0 +1,40 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.api;
/**
* This is an "executable" Layer, that is based on a {@link LayerConfiguration}
*/
public interface Layer {
/**
* Get the underlying configuration for this Layer
* @return configuration
*/
LayerConfiguration getLayerConfiguration();
/**
* Set the underlying layer configuration
* @param conf The new configuration
*/
void setLayerConfiguration(LayerConfiguration conf);
}

View File

@ -0,0 +1,42 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.api;
public interface LayerConfiguration {
/**
* Create and return an instance of a LayerConfiguration.
*
* @param network the "holding" network for the instance
* @return the new layer instance
*/
Layer instantiate(NeuralNetwork network);
/**
* Defines the valid input type for this Layer
*
* @return InputType
*/
org.deeplearning4j.nn.conf.inputs.InputType.Type getInputType();
}

View File

@ -0,0 +1,69 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.api;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
/**
* A Neural Network is an instance of a {@link NeuralNetworkConfiguration}, that can be trained,
* evaluated, saved, exported, etc. Its configuration state is defined with the
* {@link #setConfiguration(NeuralNetworkConfiguration)} and {@link #getConfiguration()} methods.
*
*/
public interface NeuralNetwork {
/**
* The configuration that defines this Neural Network
*
* @param conf the configuration to use for this network
*/
void setConfiguration(NeuralNetworkConfiguration conf);
NeuralNetworkConfiguration getConfiguration();
/**
* This method fits model with a given DataSet
*
* @param dataSet the dataset to use for training
*/
void fit(DataSet dataSet);
/**
* This method fits model with a given MultiDataSet
*
* @param dataSet the multi dataset to use for training
*/
void fit(MultiDataSet dataSet);
/**
* The name of the Neural Network
* @return the name
*/
String getName();
/**
* Set the name for this Neural Network
* @param name the name
*/
void setName(String name);
}

View File

@ -0,0 +1,43 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.api;
import java.util.List;
public interface NeuralNetworkConfiguration {
/**
* Provides a flat list of all embedded layer configurations, this
* can only be called after the layer is initialized or {@link #getLayerConfigurations()} is
* called.
*
* @return unstacked layer configurations
*/
List<LayerConfiguration> getLayerConfigurations();
/**
* This uncollables any stacked layer configurations within building blocks like
* @link BuildingBlockLayer}
*/
void calculateInnerLayerConfigurations();
}

View File

@ -22,7 +22,7 @@ apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
dependencies { dependencies {
implementation platform(projects.cavisCommonPlatform) implementation platform(projects.cavisCommonPlatform)
implementation projects.cavisDnn.cavisDnnNnApi
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
implementation 'org.lucee:oswego-concurrent:1.3.4' implementation 'org.lucee:oswego-concurrent:1.3.4'
implementation projects.cavisDnn.cavisDnnCommon implementation projects.cavisDnn.cavisDnnCommon

View File

@ -0,0 +1,143 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.conf;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.Singular;
import lombok.extern.jackson.Jacksonized;
import lombok.extern.slf4j.Slf4j;
import net.brutex.ai.dnn.api.LayerConfiguration;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer;
/**
* The NeuralNetworkConfiguration is a sequential container for the different layers in your
* network (or other NeuralNetworkConfigurations). That said, NeuralNetworkConfigurations can be
* stacked.<br/><br/>
* It then chains outputs to inputs sequentially for each NeuralNetworkConfiguration,
* finally returning the output of the "top" configuration. Any settings made, are inherited and can
* be overridden on a "deeper" level. For this use case, you need to wrap the NeuralNetworkConfiguration
* into a BuildingBlockLayer
*
*/
@Jacksonized
@JsonIgnoreProperties(ignoreUnknown = true)
@lombok.Builder
@Slf4j
public class NeuralNetworkConfiguration implements net.brutex.ai.dnn.api.NeuralNetworkConfiguration, Serializable, Cloneable {
/**
* The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified otherwise.
* Valid values are<br/>
* CacheMode.NONE,<br/>
* CacheMode.HOST or<br/>
* CacheMode.DEVICE<br/>
*/
@NonNull
@lombok.Builder.Default private CacheMode cacheMode = CacheMode.NONE;
@Getter @Setter @NonNull
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
@Getter @Setter @NonNull
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
@Getter @Setter @NonNull
protected BackpropType backpropType = BackpropType.Standard;
@Getter
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
@Getter @Setter protected int tbpttFwdLength = 20;
@Getter @Setter protected int tbpttBackLength = 20;
/**
* The list of layer configurations in this configuration. They will be indexed automatically
* as the layers get added starting with index 0.
*/
@Singular @Getter
private List<LayerConfiguration> layerConfigurations;
/**
* The name for this configuration. Defaults to "Anonymous NeuralNetworkConfiguration" if
* it is not specified.
*/
@lombok.Builder.Default @Getter
private String name = "Anonymous NeuralNetworkConfiguration";
/**
* The {@link InputType} of the data for this network configuration
*/
private InputType inputType;
/**
* hidden list of layers, that "flattens" all the layers of this network and applies
* inheritance.
*/
@lombok.Builder.ObtainVia(method = "calculateInnerLayers")
private final List<LayerConfiguration> innerLayerConfigurations;
@Override
public void calculateInnerLayerConfigurations() {
List<LayerConfiguration> list = new ArrayList<>();
for( LayerConfiguration layer : this.layerConfigurations) {
if(layer instanceof BuildingBlockLayer) {
BuildingBlockLayer blayer = (BuildingBlockLayer) layer;
blayer.getConf().calculateInnerLayerConfigurations();
list.addAll(blayer.getConf().getLayerConfigurations());
} else {
list.add(layer);
}
}
this.layerConfigurations = list;
}
/**
* Creates and returns a copy of this object.
*
* @return a clone of this instance.
* @throws CloneNotSupportedException if the object's class does not support the {@code Cloneable}
* interface. Subclasses that override the {@code clone} method
* can also throw this exception to indicate that an instance
* cannot be cloned.
* @see Cloneable
*/
@Override
protected Object clone() throws CloneNotSupportedException {
return super.clone();
}
}

View File

@ -0,0 +1,35 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.conf.layer;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import net.brutex.ai.dnn.api.LayerConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
public abstract class AbstractLayerConfiguration implements LayerConfiguration {
@Getter @Setter @NonNull
private InputType.Type inputType;
}

View File

@ -0,0 +1,52 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.conf.layer;
import lombok.extern.slf4j.Slf4j;
import net.brutex.ai.dnn.api.Layer;
import net.brutex.ai.dnn.api.NeuralNetwork;
import net.brutex.ai.dnn.conf.layer.AbstractLayerConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.Type;
@Slf4j
public class FFLayer extends AbstractLayerConfiguration {
/**
* Create and return an instance of a LayerConfiguration.
*
* @param network the "holding" network for the instance
* @return the new layer instance
*/
@Override
public Layer instantiate(NeuralNetwork network) {
//Let's do some verifications first
if(getInputType() != Type.FF) {
log.error("The {} layer configuration must use an InputType of {}, but found {}",
this.getClass().getSimpleName(),
Type.FF.name(),
getInputType().name());
}
return null;
}
}

View File

@ -0,0 +1,28 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.conf.layer;
public abstract class LayerConfiguration {
}

View File

@ -0,0 +1,72 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.impl.network;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import net.brutex.ai.dnn.api.Layer;
import net.brutex.ai.dnn.api.NeuralNetwork;
import net.brutex.ai.dnn.api.LayerConfiguration;
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.dataset.api.MultiDataSet;
public abstract class AbstractNeuralNetwork implements NeuralNetwork {
@Getter @Setter @NonNull
private String name;
@Getter @NonNull
private NeuralNetworkConfiguration configuration;
@Getter
private final Collection<TrainingListener> trainingListeners = new HashSet<>();
/**
* The neural network holds an instantiation of its configured
* layers.
* @return the actual runtime layers
*/
@Getter
private final List<Layer> runtimeLayers = new ArrayList<>();
/**
* Sets the configuration to be used. Each time a configuration is set, the runtime layers
* of this NeuralNetwork are updated from the configuration.
*
* @param conf the configuration to use for this network
*/
public void setConfiguration(net.brutex.ai.dnn.api.NeuralNetworkConfiguration conf) {
List<LayerConfiguration> layers = conf.getLayerConfigurations();
for(LayerConfiguration layer : layers) {
Layer initializedLayer = layer.instantiate(this);
this.getRuntimeLayers().add(initializedLayer);
}
this.configuration = configuration;
}
}

View File

@ -0,0 +1,692 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.impl.network;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.CrashReportingUtil;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.NDArrayIndex;
@Slf4j
public class NeuralNetwork extends AbstractNeuralNetwork {
//the hidden neural network layers (including output layer)
protected Layer[] layers;
protected transient ThreadLocal<Long> lastEtlTime = new ThreadLocal<>();
//Current training data: input features and labels
@Getter @Setter @NonNull
protected INDArray input;
@Getter @Setter
protected INDArray labels;
//Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers
@Getter
protected transient Map<String, Pointer> helperWorkspaces = new HashMap<>();
/**
* Used to call optimizers during backprop
*/
@NonNull
protected transient Solver solver = new Solver.Builder().configure(getConfiguration()).
listeners(getTrainingListeners()).model(this).build();
/**
* Create a new NeuralNetwork from the given configuration
* @param conf
*/
public NeuralNetwork(NeuralNetworkConfiguration conf) {
if(! validateConfiguration() ) {
log.error("Configuration '{}' has failed validation.", conf.getName());
throw new RuntimeException();
}
log.info("Configuration '{}' has been validated successfully.", conf.getName());
this.conf = conf;
}
private boolean validateConfiguration() {
return true;
}
private void logNotImplemented( ) {
// getStackTrace() method return
// current method name at 0th index
String method = new Throwable()
.getStackTrace()[1]
.getMethodName();
log.trace("Method '{}}' is not implemented for {}", method, this.getClass().getSimpleName());
}
/**
* This method does initialization of model
* <p>
* PLEASE NOTE: All implementations should track own state, to avoid double spending
*/
@Override
public void init() {
logNotImplemented();
}
/**
* This method returns model parameters as single INDArray
*
* @return
*/
@Override
public INDArray params() {
logNotImplemented();
return null;
}
/**
* This method returns updater state (if applicable), null otherwise
*
* @return
*/
@Override
public INDArray updaterState() {
return getUpdater(true) != null ? getUpdater(true).getStateViewArray() : null;
}
/**
* This method returns Optimizer used for training
*
* @return the optimizer
*/
@Override
public ConvexOptimizer getOptimizer() {
return solver.getOptimizer();
}
/** Get the updater for this NeuralNetwork from the Solver
* @return Updater for NeuralNetwork
*/
private Updater getUpdater(boolean initializeIfReq) {
if (solver == null && initializeIfReq) {
synchronized(this){
if(solver == null) { //May have been created while waiting for lock
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this).build();
solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
}
}
}
if(solver != null) {
return solver.getOptimizer().getUpdater(initializeIfReq);
}
return null;
}
/**
* Set the updater for the NeuralNetwork in the Solver
* */
public void setUpdater(@NonNull Updater updater) {
solver.getOptimizer().setUpdater(updater);
}
@Override
public void fit(MultiDataSet dataSet) {
if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) {
INDArray features = dataSet.getFeatures(0);
INDArray labels = dataSet.getLabels(0);
INDArray fMask = null;
INDArray lMask = null;
if (dataSet.getFeaturesMaskArrays() != null)
fMask = dataSet.getFeaturesMaskArrays()[0];
if (dataSet.getFeaturesMaskArrays() != null)
lMask = dataSet.getLabelsMaskArrays()[0];
DataSet ds = new DataSet(features, labels, fMask, lMask);
fit(ds);
} else {
throw new DL4JInvalidInputException(
"MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." +
"Please consider use of ComputationGraph");
}
}
/**
* Perform minibatch training on all minibatches in the MultiDataSetIterator, for the specified number of epochs.
* Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a loop
*
* @param iterator Training data (DataSetIterator). Iterator must support resetting
* @param numEpochs Number of training epochs, >= 1
*/
public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs){
Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs);
Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" +
"iterator has does not support resetting (iterator.resetSupported() returned false)");
for(int i = 0; i < numEpochs; i++) {
fit(iterator);
}
}
/**
* Perform minibatch training on all minibatches in the MultiDataSetIterator.<br>
* Note: The MultiDataSets in the MultiDataSetIterator must have exactly 1 input and output array (as
* MultiLayerNetwork only supports 1 input and 1 output)
*
* @param iterator Training data (DataSetIterator). Iterator must support resetting
*/
@Override
public void fit(MultiDataSetIterator iterator) {
fit(new MultiDataSetWrapperIterator(iterator));
}
/**
* Perform minibatch training on all minibatches in the DataSetIterator for 1 epoch.<br>
* Note that this method does not do layerwise pretraining.<br>
* For pretraining use method pretrain.. #pretrain(DataSetIterator)<br>
* @param iterator Training data (DataSetIterator)
*/
@Override
public void fit(DataSetIterator iterator) {
try{
fitHelper(iterator);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
private synchronized void fitHelper(DataSetIterator iterator){
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
DataSetIterator iter;
boolean destructable = false;
if (iterator.asyncSupported()) {
iter = new AsyncDataSetIterator(iterator, Math.min(
Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true);
destructable = true;
} else {
iter = iterator;
}
for (TrainingListener tl : trainingListeners) {
tl.onEpochStart(this);
}
LayerWorkspaceMgr workspaceMgr;
if(conf.getTrainingWorkspaceMode() == WorkspaceMode.NONE){
workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
} else {
workspaceMgr = LayerWorkspaceMgr.builder()
.with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
//Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM
// as these should be closed by the time updaters are executed
//Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this
.with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.build();
}
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
update(TaskUtils.buildTask(iter));
if (!iter.hasNext() && iter.resetSupported()) {
iter.reset();
}
long time1 = System.currentTimeMillis();
while (iter.hasNext()) {
DataSet next = iter.next();
long time2 = System.currentTimeMillis();
lastEtlTime.set((time2 - time1));
if (next.getFeatures() == null || next.getLabels() == null)
break;
// TODO: basically we want to wrap internals of this loop into workspace
boolean hasMaskArrays = next.hasMaskArrays();
if (conf.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArray(),
next.getLabelsMaskArray(), workspaceMgr);
} else {
if (hasMaskArrays)
setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
setInput(next.getFeatures());
setLabels(next.getLabels());
if (solver == null) {
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this)
.build();
}
}
//TODO CACHE
solver.optimize(workspaceMgr);
}
if (hasMaskArrays)
clearLayerMaskArrays();
time1 = System.currentTimeMillis();
synchronizeIterEpochCounts();
}
if (!trainingListeners.isEmpty()) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochEnd(this);
}
}
clearLayersStates();
if (destructable)
((AsyncDataSetIterator) iter).shutdown();
incrementEpochCount();
}
/**
* Workspace for working memory for a single layer: forward pass and backward pass
* Note that this is opened/closed once per op (activate/backpropGradient call)
*/
protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM";
/**
* Workspace for storing all layers' activations - used only to store activations (layer inputs) as part of backprop
* Not used for inference
*/
protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT";
/**
* Next 2 workspaces: used for:
* (a) Inference: holds activations for one layer only
* (b) Backprop: holds activation gradients for one layer only
* In both cases, they are opened and closed on every second layer
*/
protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1";
protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2";
/**
* Workspace for output methods that use OutputAdapter
*/
protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM";
/**
* Workspace for working memory in RNNs - opened and closed once per RNN time step
*/
protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM";
protected WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG;
protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder()
.initialSize(0)
.overallocationLimit(0.05)
.policyLearning(LearningPolicy.FIRST_LOOP)
.policyReset(ResetPolicy.BLOCK_LEFT)
.policySpill(SpillPolicy.REALLOCATE)
.policyAllocation(AllocationPolicy.OVERALLOCATE)
.build();
protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG;
protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder()
.initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT)
.policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE)
.policyLearning(LearningPolicy.FIRST_LOOP).build();
boolean initDone;
protected void update(Task task) {
if (!initDone) {
initDone = true;
Heartbeat heartbeat = Heartbeat.getInstance();
task = ModelSerializer.taskByModel(this);
Environment env = EnvironmentUtils.buildEnvironment();
heartbeat.reportEvent(Event.STANDALONE, env, task);
}
}
protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray,
INDArray labelsMaskArray, LayerWorkspaceMgr workspaceMgr) {
if (input.rank() != 3 || labels.rank() != 3) {
log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got "
+ Arrays.toString(input.shape()) + "\tand labels with shape "
+ Arrays.toString(labels.shape()));
return;
}
if (input.size(2) != labels.size(2)) {
log.warn("Input and label time series have different lengths: {} input length, {} label length",
input.size(2), labels.size(2));
return;
}
int fwdLen = conf.getTbpttFwdLength();
update(TaskUtils.buildTask(input, labels));
val timeSeriesLength = input.size(2);
long nSubsets = timeSeriesLength / fwdLen;
if (timeSeriesLength % fwdLen != 0)
nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20)
rnnClearPreviousState();
for (int i = 0; i < nSubsets; i++) {
long startTimeIdx = (long) i * fwdLen;
long endTimeIdx = startTimeIdx + fwdLen;
if (endTimeIdx > timeSeriesLength)
endTimeIdx = timeSeriesLength;
if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels,
featuresMaskArray, labelsMaskArray);
setInput(subsets[0]);
setLabels(subsets[1]);
setLayerMaskArrays(subsets[2], subsets[3]);
if (solver == null) {
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this)
.build();
}
}
solver.optimize(workspaceMgr);
//Finally, update the state of the RNN layers:
updateRnnStateWithTBPTTState();
}
rnnClearPreviousState();
clearLayerMaskArrays();
}
private INDArray[] getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray input, INDArray labels,
INDArray fMask, INDArray lMask ){
INDArray[] out = new INDArray[4];
out[0] = input.get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
out[1] = labels.get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
if (fMask != null) {
out[2] = fMask.get(NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
if (lMask != null) {
out[3] = lMask.get(NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
return out;
}
/**
* Intended for internal/developer use
*/
public void updateRnnStateWithTBPTTState() {
Layer[] layers = conf.calculateInnerLayers().toArray(new Layer[]{});
for (int i = 0; i < layers.length; i++) {
if (layers[i] instanceof RecurrentLayer) {
RecurrentLayer l = ((RecurrentLayer) layers[i]);
l.rnnSetPreviousState(l.rnnGetTBPTTState());
} else if (layers[i] instanceof MultiLayerNetwork) {
((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState();
}
}
}
/** Clear the previous state of the RNN layers (if any).
*/
public void rnnClearPreviousState() {
Layer[] layers = conf.getLayers().toArray(new Layer[]{});
if (layers == null)
return;
for (int i = 0; i < layers.length; i++) {
if (layers[i] instanceof RecurrentLayer)
((RecurrentLayer) layers[i]).rnnClearPreviousState();
else if (layers[i] instanceof MultiLayerNetwork) {
((MultiLayerNetwork) layers[i]).rnnClearPreviousState();
} else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){
((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying()).rnnClearPreviousState();
}
}
}
/** Remove the mask arrays from all layers.<br>
* See {@link #setLayerMaskArrays(INDArray, INDArray)} for details on mask arrays.
*/
public void clearLayerMaskArrays() {
Layer[] layers = conf.getLayers().toArray(new Layer[]{});
for (Layer layer : layers) {
layer.setMaskArray(null);
}
}
/**
* Increment the epoch count (in the underlying {@link MultiLayerConfiguration} by 1).
* Note that this is done <i>automatically</i> when using iterator-based fitting methods, such as
* {@link #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, INDArray/INDArray etc),
* the network has no way to know when one epoch ends and another starts. In such situations, this method
* can be used to increment the epoch counter.<br>
* Note that the epoch counter is used for situations such as some learning rate schedules, and the like.
*
* The current epoch count can be obtained using {@code MultiLayerConfiguration.getLayerwiseConfiguration().getEpochCount()}
*/
public void incrementEpochCount(){
conf.setEpochCount(conf.getEpochCount() + 1);
synchronizeIterEpochCounts();
}
protected void synchronizeIterEpochCounts() {
//TODO: this is necessary for some schedules - but the redundant values are a little ugly...
int currIter = conf.getIterationCount();
int currEpoch = conf.getEpochCount();
log.error("Something went wrong here. Code incomplete");
/*for(Layer l : conf.getLayers()) {
l.setIterationCount(currIter);
l.setEpochCount(currEpoch);
}
*/
}
/**
* This method just makes sure there's no state preserved within layers
*/
public void clearLayersStates() {
for (Layer layer : layers) {
layer.clear();
layer.clearNoiseWeightParams();
}
}
/**Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many
* and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths
* within the same minibatch.<br>
* For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape
* [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength]
* and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding -
* at a given time step).<br>
* <b>NOTE</b>: This method is not usually used directly. Instead, methods such as @link #feedForward(INDArray, INDArray, INDArray)}
* and @link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking internally.
* @param featuresMaskArray Mask array for features (input)
* @param labelsMaskArray Mask array for labels (output)
* @see #clearLayerMaskArrays()
*/
public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) {
if (featuresMaskArray != null) {
if (featuresMaskArray.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
//New approach: use feedForwardMaskArray method
feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0));
/*
//feedforward layers below a RNN layer: need the input (features) mask array
//Reason: even if the time series input is zero padded, the output from the dense layers are
// non-zero (i.e., activationFunction(0*weights + bias) != 0 in general)
//This assumes that the time series input is masked - i.e., values are 0 at the padded time steps,
// so we don't need to do anything for the recurrent layer
//Now, if mask array is 2d -> need to reshape to 1d (column vector) in the exact same order
// as is done for 3d -> 2d time series reshaping
INDArray reshapedFeaturesMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(featuresMaskArray);
for( int i=0; i<layers.length-1; i++ ){
Type t = layers[i].type();
if( t == Type.CONVOLUTIONAL || t == Type.FEED_FORWARD ){
layers[i].setMaskArray(reshapedFeaturesMask);
} else if( t == Type.RECURRENT ) break;
}
*/
}
if (labelsMaskArray != null) {
if (!(getOutputLayer() instanceof IOutputLayer))
return;
layers[layers.length - 1].setMaskArray(labelsMaskArray);
}
}
/**
* Get the output layer - i.e., the last layer in the netwok
*
* @return
*/
public Layer getOutputLayer() {
Layer ret = layers[layers.length - 1];
if (ret instanceof FrozenLayerWithBackprop) {
ret = ((FrozenLayerWithBackprop) ret).getInsideLayer();
}
return ret;
}
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
int minibatchSize) {
if (maskArray == null) {
for (int i = 0; i < layers.length; i++) {
layers[i].feedForwardMaskArray(null, null, minibatchSize);
}
} else {
//Do a forward pass through each preprocessor and layer
for (int i = 0; i < layers.length; i++) {
InputPreProcessor preProcessor = conf.getInputPreProcessors().get(i);
if (preProcessor != null) {
Pair<INDArray, MaskState> p =
preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
if (p != null) {
maskArray = p.getFirst();
currentMaskState = p.getSecond();
} else {
maskArray = null;
currentMaskState = null;
}
}
Pair<INDArray, MaskState> p =
layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
if (p != null) {
maskArray = p.getFirst();
currentMaskState = p.getSecond();
} else {
maskArray = null;
currentMaskState = null;
}
}
}
return new Pair<>(maskArray, currentMaskState);
}
}

View File

@ -35,7 +35,7 @@ public interface NeuralNetwork {
/** /**
* This method does initialization of model * This method does initialization of model
* * <p>
* PLEASE NOTE: All implementations should track own state, to avoid double spending * PLEASE NOTE: All implementations should track own state, to avoid double spending
*/ */
void init(); void init();
@ -49,6 +49,7 @@ public interface NeuralNetwork {
/** /**
* This method returns updater state (if applicable), null otherwise * This method returns updater state (if applicable), null otherwise
*
* @return * @return
*/ */
INDArray updaterState(); INDArray updaterState();
@ -89,14 +90,16 @@ public interface NeuralNetwork {
void fit(MultiDataSetIterator iterator); void fit(MultiDataSetIterator iterator);
/** /**
* This method executes evaluation of the model against given iterator and evaluation implementations * This method executes evaluation of the model against given iterator and evaluation
* implementations
* *
* @param iterator * @param iterator
*/ */
<T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T... evaluations); <T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T... evaluations);
/** /**
* This method executes evaluation of the model against given iterator and evaluation implementations * This method executes evaluation of the model against given iterator and evaluation
* implementations
* *
* @param iterator * @param iterator
*/ */

View File

@ -52,6 +52,33 @@ import java.io.IOException;
import java.io.Serializable; import java.io.Serializable;
import java.util.*; import java.util.*;
/**
* Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
* multiple layers. Everything starts with a MultiLayerConfiguration, which organizes those layers
* and their hyperparameters. Hyperparameters are variables that determine how a neural network
* learns. They include how many times to update the weights of the model, how to initialize those
* weights, which activation function to attach to the nodes, which optimization algorithm to use,
* and how fast the model should learn. This is what one configuration would look like:
* <br/><br/>
*
* MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()<br/>
* .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)<br/>
* .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)<br/>
* .updater(new Sgd(0.05)) //... other hyperparameters <br/>
* .list() .backprop(true)<br/>
* .build();<br/><br/>
*
* With Deeplearning4j, you add a layer
* by calling layer on the NeuralNetConfiguration.Builder(), specifying its place in the order of
* layers (the zero-indexed layer below is the input layer), the number of input and output nodes,
* nIn and nOut, as well as the type: DenseLayer.<br/><br/>
*
* .layer(0, new DenseLayer.Builder().nIn(784).nOut(250)<br/>
* .build())<br/><br/>
*
* Once you've configured your net, you train the
* model with model.fit.
*/
@Data @Data
@AllArgsConstructor(access = AccessLevel.PRIVATE) @AllArgsConstructor(access = AccessLevel.PRIVATE)
@NoArgsConstructor @NoArgsConstructor
@ -89,6 +116,252 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
//Counter for the number of epochs completed so far. Used for per-epoch schedules //Counter for the number of epochs completed so far. Used for per-epoch schedules
protected int epochCount = 0; protected int epochCount = 0;
/**
* Create a neural net configuration from json
*
* @param json the neural net configuration from json
* @return {@link MultiLayerConfiguration}
*/
public static MultiLayerConfiguration fromYaml(String json) {
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
try {
return mapper.readValue(json, MultiLayerConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Create a neural net configuration from json
*
* @param json the neural net configuration from json
* @return {@link MultiLayerConfiguration}
*/
public static MultiLayerConfiguration fromJson(String json) {
MultiLayerConfiguration conf;
ObjectMapper mapper = NeuralNetConfiguration.mapper();
try {
conf = mapper.readValue(json, MultiLayerConfiguration.class);
} catch (InvalidTypeIdException e) {
if (e.getMessage().contains("@class")) {
try {
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
return JsonMappers.getLegacyMapper().readValue(json, MultiLayerConfiguration.class);
} catch (InvalidTypeIdException e2) {
//Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
//1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
String msg = e2.getMessage();
if (msg != null && msg.contains("Could not resolve type id")) {
throw new RuntimeException(
"Error deserializing MultiLayerConfiguration - configuration may have a custom " +
"layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom"
+
" layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J",
e);
}
throw new RuntimeException(e2);
} catch (IOException e2) {
throw new RuntimeException(e2);
}
}
throw new RuntimeException(e);
} catch (IOException e) {
//Check if this exception came from legacy deserializer...
String msg = e.getMessage();
if (msg != null && msg.contains("legacy")) {
throw new RuntimeException(
"Error deserializing MultiLayerConfiguration - configuration may have a custom " +
"layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be "
+
"deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)",
e);
}
throw new RuntimeException(e);
}
//To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
// Previously: enumeration used for loss functions. Now: use classes
// IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
int layerCount = 0;
JsonNode confs = null;
for (NeuralNetConfiguration nnc : conf.getConfs()) {
Layer l = nnc.getLayer();
if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) {
//lossFn field null -> may be an old config format, with lossFunction field being for the enum
//if so, try walking the JSON graph to extract out the appropriate enum value
BaseOutputLayer ol = (BaseOutputLayer) l;
try {
JsonNode jsonNode = mapper.readTree(json);
if (confs == null) {
confs = jsonNode.get("confs");
}
if (confs instanceof ArrayNode) {
ArrayNode layerConfs = (ArrayNode) confs;
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
if (outputLayerNNCNode == null) {
return conf; //Should never happen...
}
JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
JsonNode lossFunctionNode = null;
if (outputLayerNode.has("output")) {
lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
} else if (outputLayerNode.has("rnnoutput")) {
lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
}
if (lossFunctionNode != null) {
String lossFunctionEnumStr = lossFunctionNode.asText();
LossFunctions.LossFunction lossFunction = null;
try {
lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
} catch (Exception e) {
log.warn(
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
e);
}
if (lossFunction != null) {
switch (lossFunction) {
case MSE:
ol.setLossFn(new LossMSE());
break;
case XENT:
ol.setLossFn(new LossBinaryXENT());
break;
case NEGATIVELOGLIKELIHOOD:
ol.setLossFn(new LossNegativeLogLikelihood());
break;
case MCXENT:
ol.setLossFn(new LossMCXENT());
break;
//Remaining: TODO
case SQUARED_LOSS:
case RECONSTRUCTION_CROSSENTROPY:
default:
log.warn(
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}",
lossFunction);
break;
}
}
}
} else {
log.warn(
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})",
(confs != null ? confs.getClass() : null));
}
} catch (IOException e) {
log.warn(
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
e);
break;
}
}
//Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn")
//Try to load the old format if necessary, and create the appropriate IActivation instance
if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) {
try {
JsonNode jsonNode = mapper.readTree(json);
if (confs == null) {
confs = jsonNode.get("confs");
}
if (confs instanceof ArrayNode) {
ArrayNode layerConfs = (ArrayNode) confs;
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
if (outputLayerNNCNode == null) {
return conf; //Should never happen...
}
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
continue;
}
JsonNode layerNode = layerWrapperNode.elements().next();
JsonNode activationFunction = layerNode.get(
"activationFunction"); //Should only have 1 element: "dense", "output", etc
if (activationFunction != null) {
IActivation ia = Activation.fromString(activationFunction.asText())
.getActivationFunction();
((BaseLayer) l).setActivationFn(ia);
}
}
} catch (IOException e) {
log.warn(
"Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
e);
}
}
if (!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) {
return conf;
}
layerCount++;
}
return conf;
}
/**
* Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied
* from handling of {@link Activation} above.
*
* @return True if all is well and layer iteration shall continue. False else-wise.
*/
private static boolean handleLegacyWeightInitFromJson(String json, Layer l, ObjectMapper mapper,
JsonNode confs, int layerCount) {
if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) {
try {
JsonNode jsonNode = mapper.readTree(json);
if (confs == null) {
confs = jsonNode.get("confs");
}
if (confs instanceof ArrayNode) {
ArrayNode layerConfs = (ArrayNode) confs;
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
if (outputLayerNNCNode == null) {
return false; //Should never happen...
}
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
return true;
}
JsonNode layerNode = layerWrapperNode.elements().next();
JsonNode weightInit = layerNode.get(
"weightInit"); //Should only have 1 element: "dense", "output", etc
JsonNode distribution = layerNode.get("dist");
Distribution dist = null;
if (distribution != null) {
dist = mapper.treeToValue(distribution, Distribution.class);
}
if (weightInit != null) {
final IWeightInit wi = WeightInit.valueOf(weightInit.asText())
.getWeightInitFunction(dist);
((BaseLayer) l).setWeightInitFn(wi);
}
}
} catch (IOException e) {
log.warn(
"Layer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON",
e);
}
}
return true;
}
public int getEpochCount() { public int getEpochCount() {
return epochCount; return epochCount;
} }
@ -114,22 +387,6 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
} }
/**
* Create a neural net configuration from json
*
* @param json the neural net configuration from json
* @return {@link MultiLayerConfiguration}
*/
public static MultiLayerConfiguration fromYaml(String json) {
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
try {
return mapper.readValue(json, MultiLayerConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/** /**
* @return JSON representation of NN configuration * @return JSON representation of NN configuration
*/ */
@ -146,217 +403,6 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
} }
/**
* Create a neural net configuration from json
*
* @param json the neural net configuration from json
* @return {@link MultiLayerConfiguration}
*/
public static MultiLayerConfiguration fromJson(String json) {
MultiLayerConfiguration conf;
ObjectMapper mapper = NeuralNetConfiguration.mapper();
try {
conf = mapper.readValue(json, MultiLayerConfiguration.class);
} catch (InvalidTypeIdException e){
if(e.getMessage().contains("@class")){
try {
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
return JsonMappers.getLegacyMapper().readValue(json, MultiLayerConfiguration.class);
} catch (InvalidTypeIdException e2){
//Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
//1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
String msg = e2.getMessage();
if(msg != null && msg.contains("Could not resolve type id")){
throw new RuntimeException("Error deserializing MultiLayerConfiguration - configuration may have a custom " +
"layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" +
" layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", e);
}
throw new RuntimeException(e2);
} catch (IOException e2){
throw new RuntimeException(e2);
}
}
throw new RuntimeException(e);
} catch (IOException e) {
//Check if this exception came from legacy deserializer...
String msg = e.getMessage();
if (msg != null && msg.contains("legacy")) {
throw new RuntimeException("Error deserializing MultiLayerConfiguration - configuration may have a custom " +
"layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be " +
"deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", e);
}
throw new RuntimeException(e);
}
//To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
// Previously: enumeration used for loss functions. Now: use classes
// IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
int layerCount = 0;
JsonNode confs = null;
for (NeuralNetConfiguration nnc : conf.getConfs()) {
Layer l = nnc.getLayer();
if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) {
//lossFn field null -> may be an old config format, with lossFunction field being for the enum
//if so, try walking the JSON graph to extract out the appropriate enum value
BaseOutputLayer ol = (BaseOutputLayer) l;
try {
JsonNode jsonNode = mapper.readTree(json);
if (confs == null) {
confs = jsonNode.get("confs");
}
if (confs instanceof ArrayNode) {
ArrayNode layerConfs = (ArrayNode) confs;
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
if (outputLayerNNCNode == null)
return conf; //Should never happen...
JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
JsonNode lossFunctionNode = null;
if (outputLayerNode.has("output")) {
lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
} else if (outputLayerNode.has("rnnoutput")) {
lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
}
if (lossFunctionNode != null) {
String lossFunctionEnumStr = lossFunctionNode.asText();
LossFunctions.LossFunction lossFunction = null;
try {
lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
} catch (Exception e) {
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
e);
}
if (lossFunction != null) {
switch (lossFunction) {
case MSE:
ol.setLossFn(new LossMSE());
break;
case XENT:
ol.setLossFn(new LossBinaryXENT());
break;
case NEGATIVELOGLIKELIHOOD:
ol.setLossFn(new LossNegativeLogLikelihood());
break;
case MCXENT:
ol.setLossFn(new LossMCXENT());
break;
//Remaining: TODO
case SQUARED_LOSS:
case RECONSTRUCTION_CROSSENTROPY:
default:
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}",
lossFunction);
break;
}
}
}
} else {
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})",
(confs != null ? confs.getClass() : null));
}
} catch (IOException e) {
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
e);
break;
}
}
//Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn")
//Try to load the old format if necessary, and create the appropriate IActivation instance
if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) {
try {
JsonNode jsonNode = mapper.readTree(json);
if (confs == null) {
confs = jsonNode.get("confs");
}
if (confs instanceof ArrayNode) {
ArrayNode layerConfs = (ArrayNode) confs;
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
if (outputLayerNNCNode == null)
return conf; //Should never happen...
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
continue;
}
JsonNode layerNode = layerWrapperNode.elements().next();
JsonNode activationFunction = layerNode.get("activationFunction"); //Should only have 1 element: "dense", "output", etc
if (activationFunction != null) {
IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
((BaseLayer) l).setActivationFn(ia);
}
}
} catch (IOException e) {
log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
e);
}
}
if(!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) {
return conf;
}
layerCount++;
}
return conf;
}
/**
* Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied from handling of {@link Activation}
* above.
* @return True if all is well and layer iteration shall continue. False else-wise.
*/
private static boolean handleLegacyWeightInitFromJson(String json, Layer l, ObjectMapper mapper, JsonNode confs, int layerCount) {
if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) {
try {
JsonNode jsonNode = mapper.readTree(json);
if (confs == null) {
confs = jsonNode.get("confs");
}
if (confs instanceof ArrayNode) {
ArrayNode layerConfs = (ArrayNode) confs;
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
if (outputLayerNNCNode == null)
return false; //Should never happen...
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
return true;
}
JsonNode layerNode = layerWrapperNode.elements().next();
JsonNode weightInit = layerNode.get("weightInit"); //Should only have 1 element: "dense", "output", etc
JsonNode distribution = layerNode.get("dist");
Distribution dist = null;
if(distribution != null) {
dist = mapper.treeToValue(distribution, Distribution.class);
}
if (weightInit != null) {
final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist);
((BaseLayer) l).setWeightInitFn(wi);
}
}
} catch (IOException e) {
log.warn("Layer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON",
e);
}
}
return true;
}
@Override @Override
public String toString() { public String toString() {
return toJson(); return toJson();
@ -434,12 +480,13 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
inputType = confs.get(i).getLayer().getOutputType(i, inputType); inputType = confs.get(i).getLayer().getOutputType(i, inputType);
} }
return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class, "MultiLayerNetwork", inputType); return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class,
"MultiLayerNetwork", inputType);
} }
/** /**
* For the given input shape/type for the network, return a list of activation sizes for each layer in the network.<br> * For the given input shape/type for the network, return a list of activation sizes for each
* i.e., list.get(i) is the output activation sizes for layer i * layer in the network.<br> i.e., list.get(i) is the output activation sizes for layer i
* *
* @param inputType Input type for the network * @param inputType Input type for the network
* @return A lits of activation types for the network, indexed by layer number * @return A lits of activation types for the network, indexed by layer number
@ -482,11 +529,10 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
/** /**
* Whether to over ride the nIn * Whether to over ride the nIn configuration forcibly upon construction. Default value is true
* configuration forcibly upon construction. *
* Default value is true * @param overrideNinUponBuild Whether to over ride the nIn configuration forcibly upon
* @param overrideNinUponBuild Whether to over ride the nIn * construction.
* configuration forcibly upon construction.
* @return builder pattern * @return builder pattern
*/ */
public Builder overrideNinUponBuild(boolean overrideNinUponBuild) { public Builder overrideNinUponBuild(boolean overrideNinUponBuild) {
@ -495,8 +541,7 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
/** /**
* Specify the processors. * Specify the processors. These are used at each layer for doing things like normalization and
* These are used at each layer for doing things like normalization and
* shaping of input. * shaping of input.
* *
* @param processor what to use to preProcess the data. * @param processor what to use to preProcess the data.
@ -507,6 +552,23 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
return this; return this;
} }
public Builder inputPreProcessor(String layer, InputPreProcessor processor) {
int i = 0;
for (NeuralNetConfiguration conf : this.confs) {
if (conf.getLayer().getLayerName().equals(layer)) {
inputPreProcessors.put(i, processor);
log.trace("Assigned preProcessor to layer with name {} at index {}", layer, i);
break;
}
i++;
}
if (i >= this.confs.size()) {
log.warn("Could not assign preprocessor to layer with name {} as layer was not found.",
layer);
}
return this;
}
public Builder inputPreProcessors(Map<Integer, InputPreProcessor> processors) { public Builder inputPreProcessors(Map<Integer, InputPreProcessor> processors) {
this.inputPreProcessors = processors; this.inputPreProcessors = processors;
return this; return this;
@ -531,10 +593,9 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
/** /**
* This method defines how/if preOutput cache is handled: * This method defines how/if preOutput cache is handled: NONE: cache disabled (default value)
* NONE: cache disabled (default value) * HOST: Host memory will be used DEVICE: GPU memory will be used (on CPU backends effect will
* HOST: Host memory will be used * be the same as for HOST)
* DEVICE: GPU memory will be used (on CPU backends effect will be the same as for HOST)
* *
* @param cacheMode * @param cacheMode
* @return * @return
@ -545,9 +606,9 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
/** /**
* The type of backprop. Default setting is used for most networks (MLP, CNN etc), * The type of backprop. Default setting is used for most networks (MLP, CNN etc), but
* but optionally truncated BPTT can be used for training recurrent neural networks. * optionally truncated BPTT can be used for training recurrent neural networks. If using
* If using TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength() * TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength()
*/ */
public Builder backpropType(@NonNull BackpropType type) { public Builder backpropType(@NonNull BackpropType type) {
this.backpropType = type; this.backpropType = type;
@ -555,9 +616,9 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
/** /**
* When doing truncated BPTT: how many steps should we do?<br> * When doing truncated BPTT: how many steps should we do?<br> Only applicable when doing
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)<br> * backpropType(BackpropType.TruncatedBPTT)<br> See: <a
* See: <a href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a> * href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
* *
* @param bpttLength length > 0 * @param bpttLength length > 0
*/ */
@ -567,14 +628,14 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
/** /**
* When doing truncated BPTT: how many steps of forward pass should we do * When doing truncated BPTT: how many steps of forward pass should we do before doing
* before doing (truncated) backprop?<br> * (truncated) backprop?<br> Only applicable when doing
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)<br> * backpropType(BackpropType.TruncatedBPTT)<br> Typically tBPTTForwardLength parameter is same
* Typically tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter, * as the tBPTTBackwardLength parameter, but may be larger than it in some circumstances (but
* but may be larger than it in some circumstances (but never smaller)<br> * never smaller)<br> Ideally your training data time series length should be divisible by this
* Ideally your training data time series length should be divisible by this
* This is the k1 parameter on pg23 of * This is the k1 parameter on pg23 of
* <a href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a> * <a
* href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
* *
* @param forwardLength Forward length > 0, >= backwardLength * @param forwardLength Forward length > 0, >= backwardLength
*/ */
@ -584,10 +645,10 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
/** /**
* When doing truncated BPTT: how many steps of backward should we do?<br> * When doing truncated BPTT: how many steps of backward should we do?<br> Only applicable when
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)<br> * doing backpropType(BackpropType.TruncatedBPTT)<br> This is the k2 parameter on pg23 of
* This is the k2 parameter on pg23 of * <a
* <a href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a> * href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
* *
* @param backwardLength <= forwardLength * @param backwardLength <= forwardLength
*/ */
@ -607,12 +668,12 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
/** /**
* Enabled by default. If enabled, the output layer configuration will be validated, to throw an exception on * Enabled by default. If enabled, the output layer configuration will be validated, to throw an
* likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.<br> * exception on likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.<br> If
* If disabled (false) no output layer validation will be performed.<br> * disabled (false) no output layer validation will be performed.<br> Disabling this validation
* Disabling this validation is not recommended, as the configurations that fail validation usually will * is not recommended, as the configurations that fail validation usually will not be able to
* not be able to learn correctly. However, the option to disable this validation is provided for advanced users * learn correctly. However, the option to disable this validation is provided for advanced
* when creating non-standard architectures. * users when creating non-standard architectures.
* *
* @param validate If true: validate output layer configuration. False: don't validate * @param validate If true: validate output layer configuration. False: don't validate
*/ */
@ -622,10 +683,11 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
/** /**
* Enabled by default. If enabled, an exception will be throw when using the (invalid) combination of truncated * Enabled by default. If enabled, an exception will be throw when using the (invalid)
* backpropagation through time (TBPTT) with either a GlobalPoolingLayer or LastTimeStepLayer.<br> * combination of truncated backpropagation through time (TBPTT) with either a
* It is possible to disable this validation to allow what is almost certainly an invalid configuration to be used, * GlobalPoolingLayer or LastTimeStepLayer.<br> It is possible to disable this validation to
* however this is not recommended. * allow what is almost certainly an invalid configuration to be used, however this is not
* recommended.
* *
* @param validate Whether TBPTT validation should be performed * @param validate Whether TBPTT validation should be performed
*/ */
@ -635,7 +697,9 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
/** /**
* Set the DataType for the network parameters and activations for all layers in the network. Default: Float * Set the DataType for the network parameters and activations for all layers in the network.
* Default: Float
*
* @param dataType Datatype to use for parameters and activations * @param dataType Datatype to use for parameters and activations
*/ */
public Builder dataType(@NonNull DataType dataType) { public Builder dataType(@NonNull DataType dataType) {
@ -646,9 +710,12 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
public MultiLayerConfiguration build() { public MultiLayerConfiguration build() {
//Validate BackpropType setting //Validate BackpropType setting
if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH) && backpropType != BackpropType.TruncatedBPTT) { if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH)
log.warn("Truncated backpropagation through time lengths have been configured with values " + tbpttFwdLength && backpropType != BackpropType.TruncatedBPTT) {
+ " and " + tbpttBackLength + " but backprop type is set to " + backpropType + ". TBPTT configuration" + log.warn("Truncated backpropagation through time lengths have been configured with values "
+ tbpttFwdLength
+ " and " + tbpttBackLength + " but backprop type is set to " + backpropType
+ ". TBPTT configuration" +
" settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT"); " settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
} }
@ -657,15 +724,18 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
for (int i = 0; i < confs.size(); i++) { for (int i = 0; i < confs.size(); i++) {
Layer l = confs.get(i).getLayer(); Layer l = confs.get(i).getLayer();
if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) { if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) {
throw new IllegalStateException("Invalid network configuration detected: Truncated backpropagation through time (TBPTT)" + throw new IllegalStateException(
" cannot be used with layer " + i + " of type " + l.getClass().getName() + ": TBPTT is incompatible with this layer type (which is designed " + "Invalid network configuration detected: Truncated backpropagation through time (TBPTT)"
"to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n" + +
" cannot be used with layer " + i + " of type " + l.getClass().getName()
+ ": TBPTT is incompatible with this layer type (which is designed " +
"to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n"
+
"This check can be disabled using validateTbpttConfig(false) but this is not recommended."); "This check can be disabled using validateTbpttConfig(false) but this is not recommended.");
} }
} }
} }
if (inputType == null && inputPreProcessors.get(0) == null) { if (inputType == null && inputPreProcessors.get(0) == null) {
//User hasn't set the InputType. Sometimes we can infer it... //User hasn't set the InputType. Sometimes we can infer it...
// For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to feed in // For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to feed in
@ -690,7 +760,6 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
} }
} }
//Add preprocessors and set nIns, if InputType has been set //Add preprocessors and set nIns, if InputType has been set
// Builder.inputType field can be set in 1 of 4 ways: // Builder.inputType field can be set in 1 of 4 ways:
// 1. User calls setInputType directly // 1. User calls setInputType directly
@ -723,17 +792,19 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType; InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType;
feedForwardLayer.setNIn(recurrent.getTimeSeriesLength()); feedForwardLayer.setNIn(recurrent.getTimeSeriesLength());
} }
} else {
l.setNIn(currentInputType,
overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user
} }
else } else {
l.setNIn(currentInputType, overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user l.setNIn(currentInputType,
overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user
} }
else
l.setNIn(currentInputType, overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user
} else {
l.setNIn(currentInputType,
overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user
} }
else
l.setNIn(currentInputType, overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user
currentInputType = l.getOutputType(i, currentInputType); currentInputType = l.getOutputType(i, currentInputType);
} }
@ -758,7 +829,8 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
//Validate output layer configurations... //Validate output layer configurations...
for (NeuralNetConfiguration n : conf.getConfs()) { for (NeuralNetConfiguration n : conf.getConfs()) {
Layer l = n.getLayer(); Layer l = n.getLayer();
OutputLayerUtil.validateOutputLayer(l.getLayerName(), l); //No-op for non output/loss layers OutputLayerUtil.validateOutputLayer(l.getLayerName(),
l); //No-op for non output/loss layers
} }
} }

View File

@ -0,0 +1,97 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package org.deeplearning4j.nn.conf.layers.wrapper;
import java.util.Collection;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import net.brutex.ai.dnn.api.LayerConfiguration;
import net.brutex.ai.dnn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@Builder(builderClassName = "Builder", access = AccessLevel.PUBLIC)
public class BuildingBlockLayer extends BaseLayer implements LayerConfiguration {
@NonNull
@Getter
private NeuralNetworkConfiguration conf;
@Override
public Layer instantiate(NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
boolean initializeParams, DataType networkDataType) {
return null;
}
@Override
public ParamInitializer initializer() {
return null;
}
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
return null;
}
@Override
public void setNIn(InputType inputType, boolean override) {
}
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return null;
}
@Override
public boolean isPretrainParam(String paramName) {
return false;
}
@Override
public LayerMemoryReport getMemoryReport(InputType inputType) {
return null;
}
/**
* Create and return an instance of a LayerConfiguration.
*
* @param network the "holding" network for the instance
* @return the new layer instance
*/
@Override
public net.brutex.ai.dnn.api.Layer instantiate(NeuralNetwork network) {
return null;
}
}

View File

@ -101,7 +101,7 @@ import java.util.*;
@Slf4j @Slf4j
public class MultiLayerNetwork implements Serializable, Classifier, Layer, NeuralNetwork { public class MultiLayerNetwork implements Serializable, Classifier, Layer, org.deeplearning4j.nn.api.NeuralNetwork {
//the hidden neural network layers (including output layer) //the hidden neural network layers (including output layer)
protected Layer[] layers; protected Layer[] layers;

View File

@ -100,6 +100,7 @@ include ':cavis-dnn:cavis-dnn-data:cavis-dnn-data-utility-iterators'
include ':cavis-dnn:cavis-dnn-modelimport' include ':cavis-dnn:cavis-dnn-modelimport'
include ':cavis-dnn:cavis-dnn-nlp' include ':cavis-dnn:cavis-dnn-nlp'
include ':cavis-dnn:cavis-dnn-nn' include ':cavis-dnn:cavis-dnn-nn'
include ':cavis-dnn:cavis-dnn-nn-api'
include ':cavis-dnn:cavis-dnn-nn-parent' include ':cavis-dnn:cavis-dnn-nn-parent'
include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-server' include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-server'
include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-client' include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-client'
@ -154,3 +155,4 @@ include ':brutex-extended-tests'
include ':cavis-full' include ':cavis-full'