Playing with some new code

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-03-23 17:39:00 +01:00
parent 4665c5a10a
commit fec570ff98
18 changed files with 2160 additions and 735 deletions

View File

@ -21,15 +21,10 @@
package net.brutex.gan; package net.brutex.gan;
import java.util.List;
import java.util.Random; import java.util.Random;
import javax.ws.rs.client.ClientBuilder;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.datavec.api.Writable;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.recordreader.ImageRecordReader;
@ -37,34 +32,29 @@ import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.ImageTransform; import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform; import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform; import org.datavec.image.transform.ResizeImageTransform;
import org.datavec.image.transform.ScaleImageTransform;
import org.datavec.image.transform.ShowImageTransform; import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreToChartListener; import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
import org.glassfish.jersey.client.JerseyClient;
import org.glassfish.jersey.client.JerseyClientBuilder;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.swing.*; import javax.swing.*;
@ -106,6 +96,8 @@ public class App {
new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH) new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH)
.build() .build()
}; };
} }
/** /**
@ -114,7 +106,7 @@ public class App {
* @return config * @return config
*/ */
private static MultiLayerConfiguration generator() { private static MultiLayerConfiguration generator() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration confxx = new NeuralNetConfiguration.Builder()
.seed(42) .seed(42)
.updater(UPDATER) .updater(UPDATER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
@ -123,9 +115,25 @@ public class App {
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
.list(genLayers()) .list(genLayers())
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
.build();
log.debug("Generator network: \n{}", confxx.toJson());
NeuralNetworkConfiguration conf2 = NeuralNetworkConfiguration.builder().build();
NeuralNetworkConfiguration confx = NeuralNetworkConfiguration.builder()
.cacheMode(CacheMode.HOST)
.layer( new DenseLayer.Builder().build())
.layer( new DenseLayer.Builder().build())
.layer( BuildingBlockLayer.builder().build())
.layers( List.of(genLayers()))
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
.build(); .build();
return conf;
return confx;
} }
private static Layer[] disLayers() { private static Layer[] disLayers() {

View File

@ -0,0 +1,27 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
//apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
dependencies {
implementation platform(projects.cavisCommonPlatform)
implementation projects.cavisDnn.cavisDnnApi
implementation projects.cavisDnn.cavisDnnNn
}

View File

@ -0,0 +1,40 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.api;
/**
* This is an "executable" Layer, that is based on a {@link LayerConfiguration}
*/
public interface Layer {
/**
* Get the underlying configuration for this Layer
* @return configuration
*/
LayerConfiguration getLayerConfiguration();
/**
* Set the underlying layer configuration
* @param conf The new configuration
*/
void setLayerConfiguration(LayerConfiguration conf);
}

View File

@ -0,0 +1,42 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.api;
public interface LayerConfiguration {
/**
* Create and return an instance of a LayerConfiguration.
*
* @param network the "holding" network for the instance
* @return the new layer instance
*/
Layer instantiate(NeuralNetwork network);
/**
* Defines the valid input type for this Layer
*
* @return InputType
*/
org.deeplearning4j.nn.conf.inputs.InputType.Type getInputType();
}

View File

@ -0,0 +1,69 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.api;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
/**
* A Neural Network is an instance of a {@link NeuralNetworkConfiguration}, that can be trained,
* evaluated, saved, exported, etc. Its configuration state is defined with the
* {@link #setConfiguration(NeuralNetworkConfiguration)} and {@link #getConfiguration()} methods.
*
*/
public interface NeuralNetwork {
/**
* The configuration that defines this Neural Network
*
* @param conf the configuration to use for this network
*/
void setConfiguration(NeuralNetworkConfiguration conf);
NeuralNetworkConfiguration getConfiguration();
/**
* This method fits model with a given DataSet
*
* @param dataSet the dataset to use for training
*/
void fit(DataSet dataSet);
/**
* This method fits model with a given MultiDataSet
*
* @param dataSet the multi dataset to use for training
*/
void fit(MultiDataSet dataSet);
/**
* The name of the Neural Network
* @return the name
*/
String getName();
/**
* Set the name for this Neural Network
* @param name the name
*/
void setName(String name);
}

View File

@ -0,0 +1,43 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.api;
import java.util.List;
public interface NeuralNetworkConfiguration {
/**
* Provides a flat list of all embedded layer configurations, this
* can only be called after the layer is initialized or {@link #getLayerConfigurations()} is
* called.
*
* @return unstacked layer configurations
*/
List<LayerConfiguration> getLayerConfigurations();
/**
* This uncollables any stacked layer configurations within building blocks like
* @link BuildingBlockLayer}
*/
void calculateInnerLayerConfigurations();
}

View File

@ -22,7 +22,7 @@ apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
dependencies { dependencies {
implementation platform(projects.cavisCommonPlatform) implementation platform(projects.cavisCommonPlatform)
implementation projects.cavisDnn.cavisDnnNnApi
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
implementation 'org.lucee:oswego-concurrent:1.3.4' implementation 'org.lucee:oswego-concurrent:1.3.4'
implementation projects.cavisDnn.cavisDnnCommon implementation projects.cavisDnn.cavisDnnCommon

View File

@ -0,0 +1,143 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.conf;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.Singular;
import lombok.extern.jackson.Jacksonized;
import lombok.extern.slf4j.Slf4j;
import net.brutex.ai.dnn.api.LayerConfiguration;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer;
/**
* The NeuralNetworkConfiguration is a sequential container for the different layers in your
* network (or other NeuralNetworkConfigurations). That said, NeuralNetworkConfigurations can be
* stacked.<br/><br/>
* It then chains outputs to inputs sequentially for each NeuralNetworkConfiguration,
* finally returning the output of the "top" configuration. Any settings made, are inherited and can
* be overridden on a "deeper" level. For this use case, you need to wrap the NeuralNetworkConfiguration
* into a BuildingBlockLayer
*
*/
@Jacksonized
@JsonIgnoreProperties(ignoreUnknown = true)
@lombok.Builder
@Slf4j
public class NeuralNetworkConfiguration implements net.brutex.ai.dnn.api.NeuralNetworkConfiguration, Serializable, Cloneable {
/**
* The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified otherwise.
* Valid values are<br/>
* CacheMode.NONE,<br/>
* CacheMode.HOST or<br/>
* CacheMode.DEVICE<br/>
*/
@NonNull
@lombok.Builder.Default private CacheMode cacheMode = CacheMode.NONE;
@Getter @Setter @NonNull
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
@Getter @Setter @NonNull
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
@Getter @Setter @NonNull
protected BackpropType backpropType = BackpropType.Standard;
@Getter
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
@Getter @Setter protected int tbpttFwdLength = 20;
@Getter @Setter protected int tbpttBackLength = 20;
/**
* The list of layer configurations in this configuration. They will be indexed automatically
* as the layers get added starting with index 0.
*/
@Singular @Getter
private List<LayerConfiguration> layerConfigurations;
/**
* The name for this configuration. Defaults to "Anonymous NeuralNetworkConfiguration" if
* it is not specified.
*/
@lombok.Builder.Default @Getter
private String name = "Anonymous NeuralNetworkConfiguration";
/**
* The {@link InputType} of the data for this network configuration
*/
private InputType inputType;
/**
* hidden list of layers, that "flattens" all the layers of this network and applies
* inheritance.
*/
@lombok.Builder.ObtainVia(method = "calculateInnerLayers")
private final List<LayerConfiguration> innerLayerConfigurations;
@Override
public void calculateInnerLayerConfigurations() {
List<LayerConfiguration> list = new ArrayList<>();
for( LayerConfiguration layer : this.layerConfigurations) {
if(layer instanceof BuildingBlockLayer) {
BuildingBlockLayer blayer = (BuildingBlockLayer) layer;
blayer.getConf().calculateInnerLayerConfigurations();
list.addAll(blayer.getConf().getLayerConfigurations());
} else {
list.add(layer);
}
}
this.layerConfigurations = list;
}
/**
* Creates and returns a copy of this object.
*
* @return a clone of this instance.
* @throws CloneNotSupportedException if the object's class does not support the {@code Cloneable}
* interface. Subclasses that override the {@code clone} method
* can also throw this exception to indicate that an instance
* cannot be cloned.
* @see Cloneable
*/
@Override
protected Object clone() throws CloneNotSupportedException {
return super.clone();
}
}

View File

@ -0,0 +1,35 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.conf.layer;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import net.brutex.ai.dnn.api.LayerConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
public abstract class AbstractLayerConfiguration implements LayerConfiguration {
@Getter @Setter @NonNull
private InputType.Type inputType;
}

View File

@ -0,0 +1,52 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.conf.layer;
import lombok.extern.slf4j.Slf4j;
import net.brutex.ai.dnn.api.Layer;
import net.brutex.ai.dnn.api.NeuralNetwork;
import net.brutex.ai.dnn.conf.layer.AbstractLayerConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.Type;
@Slf4j
public class FFLayer extends AbstractLayerConfiguration {
/**
* Create and return an instance of a LayerConfiguration.
*
* @param network the "holding" network for the instance
* @return the new layer instance
*/
@Override
public Layer instantiate(NeuralNetwork network) {
//Let's do some verifications first
if(getInputType() != Type.FF) {
log.error("The {} layer configuration must use an InputType of {}, but found {}",
this.getClass().getSimpleName(),
Type.FF.name(),
getInputType().name());
}
return null;
}
}

View File

@ -0,0 +1,28 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.conf.layer;
public abstract class LayerConfiguration {
}

View File

@ -0,0 +1,72 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.impl.network;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import net.brutex.ai.dnn.api.Layer;
import net.brutex.ai.dnn.api.NeuralNetwork;
import net.brutex.ai.dnn.api.LayerConfiguration;
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.dataset.api.MultiDataSet;
public abstract class AbstractNeuralNetwork implements NeuralNetwork {
@Getter @Setter @NonNull
private String name;
@Getter @NonNull
private NeuralNetworkConfiguration configuration;
@Getter
private final Collection<TrainingListener> trainingListeners = new HashSet<>();
/**
* The neural network holds an instantiation of its configured
* layers.
* @return the actual runtime layers
*/
@Getter
private final List<Layer> runtimeLayers = new ArrayList<>();
/**
* Sets the configuration to be used. Each time a configuration is set, the runtime layers
* of this NeuralNetwork are updated from the configuration.
*
* @param conf the configuration to use for this network
*/
public void setConfiguration(net.brutex.ai.dnn.api.NeuralNetworkConfiguration conf) {
List<LayerConfiguration> layers = conf.getLayerConfigurations();
for(LayerConfiguration layer : layers) {
Layer initializedLayer = layer.instantiate(this);
this.getRuntimeLayers().add(initializedLayer);
}
this.configuration = configuration;
}
}

View File

@ -0,0 +1,692 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.impl.network;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.CrashReportingUtil;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.NDArrayIndex;
@Slf4j
public class NeuralNetwork extends AbstractNeuralNetwork {
//the hidden neural network layers (including output layer)
protected Layer[] layers;
protected transient ThreadLocal<Long> lastEtlTime = new ThreadLocal<>();
//Current training data: input features and labels
@Getter @Setter @NonNull
protected INDArray input;
@Getter @Setter
protected INDArray labels;
//Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers
@Getter
protected transient Map<String, Pointer> helperWorkspaces = new HashMap<>();
/**
* Used to call optimizers during backprop
*/
@NonNull
protected transient Solver solver = new Solver.Builder().configure(getConfiguration()).
listeners(getTrainingListeners()).model(this).build();
/**
* Create a new NeuralNetwork from the given configuration
* @param conf
*/
public NeuralNetwork(NeuralNetworkConfiguration conf) {
if(! validateConfiguration() ) {
log.error("Configuration '{}' has failed validation.", conf.getName());
throw new RuntimeException();
}
log.info("Configuration '{}' has been validated successfully.", conf.getName());
this.conf = conf;
}
private boolean validateConfiguration() {
return true;
}
private void logNotImplemented( ) {
// getStackTrace() method return
// current method name at 0th index
String method = new Throwable()
.getStackTrace()[1]
.getMethodName();
log.trace("Method '{}}' is not implemented for {}", method, this.getClass().getSimpleName());
}
/**
* This method does initialization of model
* <p>
* PLEASE NOTE: All implementations should track own state, to avoid double spending
*/
@Override
public void init() {
logNotImplemented();
}
/**
* This method returns model parameters as single INDArray
*
* @return
*/
@Override
public INDArray params() {
logNotImplemented();
return null;
}
/**
* This method returns updater state (if applicable), null otherwise
*
* @return
*/
@Override
public INDArray updaterState() {
return getUpdater(true) != null ? getUpdater(true).getStateViewArray() : null;
}
/**
* This method returns Optimizer used for training
*
* @return the optimizer
*/
@Override
public ConvexOptimizer getOptimizer() {
return solver.getOptimizer();
}
/** Get the updater for this NeuralNetwork from the Solver
* @return Updater for NeuralNetwork
*/
private Updater getUpdater(boolean initializeIfReq) {
if (solver == null && initializeIfReq) {
synchronized(this){
if(solver == null) { //May have been created while waiting for lock
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this).build();
solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
}
}
}
if(solver != null) {
return solver.getOptimizer().getUpdater(initializeIfReq);
}
return null;
}
/**
* Set the updater for the NeuralNetwork in the Solver
* */
public void setUpdater(@NonNull Updater updater) {
solver.getOptimizer().setUpdater(updater);
}
@Override
public void fit(MultiDataSet dataSet) {
if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) {
INDArray features = dataSet.getFeatures(0);
INDArray labels = dataSet.getLabels(0);
INDArray fMask = null;
INDArray lMask = null;
if (dataSet.getFeaturesMaskArrays() != null)
fMask = dataSet.getFeaturesMaskArrays()[0];
if (dataSet.getFeaturesMaskArrays() != null)
lMask = dataSet.getLabelsMaskArrays()[0];
DataSet ds = new DataSet(features, labels, fMask, lMask);
fit(ds);
} else {
throw new DL4JInvalidInputException(
"MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." +
"Please consider use of ComputationGraph");
}
}
/**
* Perform minibatch training on all minibatches in the MultiDataSetIterator, for the specified number of epochs.
* Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a loop
*
* @param iterator Training data (DataSetIterator). Iterator must support resetting
* @param numEpochs Number of training epochs, >= 1
*/
public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs){
Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs);
Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" +
"iterator has does not support resetting (iterator.resetSupported() returned false)");
for(int i = 0; i < numEpochs; i++) {
fit(iterator);
}
}
/**
* Perform minibatch training on all minibatches in the MultiDataSetIterator.<br>
* Note: The MultiDataSets in the MultiDataSetIterator must have exactly 1 input and output array (as
* MultiLayerNetwork only supports 1 input and 1 output)
*
* @param iterator Training data (DataSetIterator). Iterator must support resetting
*/
@Override
public void fit(MultiDataSetIterator iterator) {
fit(new MultiDataSetWrapperIterator(iterator));
}
/**
* Perform minibatch training on all minibatches in the DataSetIterator for 1 epoch.<br>
* Note that this method does not do layerwise pretraining.<br>
* For pretraining use method pretrain.. #pretrain(DataSetIterator)<br>
* @param iterator Training data (DataSetIterator)
*/
@Override
public void fit(DataSetIterator iterator) {
try{
fitHelper(iterator);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
private synchronized void fitHelper(DataSetIterator iterator){
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
DataSetIterator iter;
boolean destructable = false;
if (iterator.asyncSupported()) {
iter = new AsyncDataSetIterator(iterator, Math.min(
Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true);
destructable = true;
} else {
iter = iterator;
}
for (TrainingListener tl : trainingListeners) {
tl.onEpochStart(this);
}
LayerWorkspaceMgr workspaceMgr;
if(conf.getTrainingWorkspaceMode() == WorkspaceMode.NONE){
workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
} else {
workspaceMgr = LayerWorkspaceMgr.builder()
.with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
//Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM
// as these should be closed by the time updaters are executed
//Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this
.with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.build();
}
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
update(TaskUtils.buildTask(iter));
if (!iter.hasNext() && iter.resetSupported()) {
iter.reset();
}
long time1 = System.currentTimeMillis();
while (iter.hasNext()) {
DataSet next = iter.next();
long time2 = System.currentTimeMillis();
lastEtlTime.set((time2 - time1));
if (next.getFeatures() == null || next.getLabels() == null)
break;
// TODO: basically we want to wrap internals of this loop into workspace
boolean hasMaskArrays = next.hasMaskArrays();
if (conf.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArray(),
next.getLabelsMaskArray(), workspaceMgr);
} else {
if (hasMaskArrays)
setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
setInput(next.getFeatures());
setLabels(next.getLabels());
if (solver == null) {
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this)
.build();
}
}
//TODO CACHE
solver.optimize(workspaceMgr);
}
if (hasMaskArrays)
clearLayerMaskArrays();
time1 = System.currentTimeMillis();
synchronizeIterEpochCounts();
}
if (!trainingListeners.isEmpty()) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochEnd(this);
}
}
clearLayersStates();
if (destructable)
((AsyncDataSetIterator) iter).shutdown();
incrementEpochCount();
}
/**
* Workspace for working memory for a single layer: forward pass and backward pass
* Note that this is opened/closed once per op (activate/backpropGradient call)
*/
protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM";
/**
* Workspace for storing all layers' activations - used only to store activations (layer inputs) as part of backprop
* Not used for inference
*/
protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT";
/**
* Next 2 workspaces: used for:
* (a) Inference: holds activations for one layer only
* (b) Backprop: holds activation gradients for one layer only
* In both cases, they are opened and closed on every second layer
*/
protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1";
protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2";
/**
* Workspace for output methods that use OutputAdapter
*/
protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM";
/**
* Workspace for working memory in RNNs - opened and closed once per RNN time step
*/
protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM";
protected WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG;
protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder()
.initialSize(0)
.overallocationLimit(0.05)
.policyLearning(LearningPolicy.FIRST_LOOP)
.policyReset(ResetPolicy.BLOCK_LEFT)
.policySpill(SpillPolicy.REALLOCATE)
.policyAllocation(AllocationPolicy.OVERALLOCATE)
.build();
protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG;
protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder()
.initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT)
.policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE)
.policyLearning(LearningPolicy.FIRST_LOOP).build();
boolean initDone;
protected void update(Task task) {
if (!initDone) {
initDone = true;
Heartbeat heartbeat = Heartbeat.getInstance();
task = ModelSerializer.taskByModel(this);
Environment env = EnvironmentUtils.buildEnvironment();
heartbeat.reportEvent(Event.STANDALONE, env, task);
}
}
protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray,
INDArray labelsMaskArray, LayerWorkspaceMgr workspaceMgr) {
if (input.rank() != 3 || labels.rank() != 3) {
log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got "
+ Arrays.toString(input.shape()) + "\tand labels with shape "
+ Arrays.toString(labels.shape()));
return;
}
if (input.size(2) != labels.size(2)) {
log.warn("Input and label time series have different lengths: {} input length, {} label length",
input.size(2), labels.size(2));
return;
}
int fwdLen = conf.getTbpttFwdLength();
update(TaskUtils.buildTask(input, labels));
val timeSeriesLength = input.size(2);
long nSubsets = timeSeriesLength / fwdLen;
if (timeSeriesLength % fwdLen != 0)
nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20)
rnnClearPreviousState();
for (int i = 0; i < nSubsets; i++) {
long startTimeIdx = (long) i * fwdLen;
long endTimeIdx = startTimeIdx + fwdLen;
if (endTimeIdx > timeSeriesLength)
endTimeIdx = timeSeriesLength;
if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels,
featuresMaskArray, labelsMaskArray);
setInput(subsets[0]);
setLabels(subsets[1]);
setLayerMaskArrays(subsets[2], subsets[3]);
if (solver == null) {
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this)
.build();
}
}
solver.optimize(workspaceMgr);
//Finally, update the state of the RNN layers:
updateRnnStateWithTBPTTState();
}
rnnClearPreviousState();
clearLayerMaskArrays();
}
private INDArray[] getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray input, INDArray labels,
INDArray fMask, INDArray lMask ){
INDArray[] out = new INDArray[4];
out[0] = input.get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
out[1] = labels.get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
if (fMask != null) {
out[2] = fMask.get(NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
if (lMask != null) {
out[3] = lMask.get(NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
return out;
}
/**
* Intended for internal/developer use
*/
public void updateRnnStateWithTBPTTState() {
Layer[] layers = conf.calculateInnerLayers().toArray(new Layer[]{});
for (int i = 0; i < layers.length; i++) {
if (layers[i] instanceof RecurrentLayer) {
RecurrentLayer l = ((RecurrentLayer) layers[i]);
l.rnnSetPreviousState(l.rnnGetTBPTTState());
} else if (layers[i] instanceof MultiLayerNetwork) {
((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState();
}
}
}
/** Clear the previous state of the RNN layers (if any).
*/
public void rnnClearPreviousState() {
Layer[] layers = conf.getLayers().toArray(new Layer[]{});
if (layers == null)
return;
for (int i = 0; i < layers.length; i++) {
if (layers[i] instanceof RecurrentLayer)
((RecurrentLayer) layers[i]).rnnClearPreviousState();
else if (layers[i] instanceof MultiLayerNetwork) {
((MultiLayerNetwork) layers[i]).rnnClearPreviousState();
} else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){
((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying()).rnnClearPreviousState();
}
}
}
/** Remove the mask arrays from all layers.<br>
* See {@link #setLayerMaskArrays(INDArray, INDArray)} for details on mask arrays.
*/
public void clearLayerMaskArrays() {
Layer[] layers = conf.getLayers().toArray(new Layer[]{});
for (Layer layer : layers) {
layer.setMaskArray(null);
}
}
/**
* Increment the epoch count (in the underlying {@link MultiLayerConfiguration} by 1).
* Note that this is done <i>automatically</i> when using iterator-based fitting methods, such as
* {@link #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, INDArray/INDArray etc),
* the network has no way to know when one epoch ends and another starts. In such situations, this method
* can be used to increment the epoch counter.<br>
* Note that the epoch counter is used for situations such as some learning rate schedules, and the like.
*
* The current epoch count can be obtained using {@code MultiLayerConfiguration.getLayerwiseConfiguration().getEpochCount()}
*/
public void incrementEpochCount(){
conf.setEpochCount(conf.getEpochCount() + 1);
synchronizeIterEpochCounts();
}
protected void synchronizeIterEpochCounts() {
//TODO: this is necessary for some schedules - but the redundant values are a little ugly...
int currIter = conf.getIterationCount();
int currEpoch = conf.getEpochCount();
log.error("Something went wrong here. Code incomplete");
/*for(Layer l : conf.getLayers()) {
l.setIterationCount(currIter);
l.setEpochCount(currEpoch);
}
*/
}
/**
* This method just makes sure there's no state preserved within layers
*/
public void clearLayersStates() {
for (Layer layer : layers) {
layer.clear();
layer.clearNoiseWeightParams();
}
}
/**Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many
* and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths
* within the same minibatch.<br>
* For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape
* [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength]
* and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding -
* at a given time step).<br>
* <b>NOTE</b>: This method is not usually used directly. Instead, methods such as @link #feedForward(INDArray, INDArray, INDArray)}
* and @link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking internally.
* @param featuresMaskArray Mask array for features (input)
* @param labelsMaskArray Mask array for labels (output)
* @see #clearLayerMaskArrays()
*/
public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) {
if (featuresMaskArray != null) {
if (featuresMaskArray.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
//New approach: use feedForwardMaskArray method
feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0));
/*
//feedforward layers below a RNN layer: need the input (features) mask array
//Reason: even if the time series input is zero padded, the output from the dense layers are
// non-zero (i.e., activationFunction(0*weights + bias) != 0 in general)
//This assumes that the time series input is masked - i.e., values are 0 at the padded time steps,
// so we don't need to do anything for the recurrent layer
//Now, if mask array is 2d -> need to reshape to 1d (column vector) in the exact same order
// as is done for 3d -> 2d time series reshaping
INDArray reshapedFeaturesMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(featuresMaskArray);
for( int i=0; i<layers.length-1; i++ ){
Type t = layers[i].type();
if( t == Type.CONVOLUTIONAL || t == Type.FEED_FORWARD ){
layers[i].setMaskArray(reshapedFeaturesMask);
} else if( t == Type.RECURRENT ) break;
}
*/
}
if (labelsMaskArray != null) {
if (!(getOutputLayer() instanceof IOutputLayer))
return;
layers[layers.length - 1].setMaskArray(labelsMaskArray);
}
}
/**
* Get the output layer - i.e., the last layer in the netwok
*
* @return
*/
public Layer getOutputLayer() {
Layer ret = layers[layers.length - 1];
if (ret instanceof FrozenLayerWithBackprop) {
ret = ((FrozenLayerWithBackprop) ret).getInsideLayer();
}
return ret;
}
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
int minibatchSize) {
if (maskArray == null) {
for (int i = 0; i < layers.length; i++) {
layers[i].feedForwardMaskArray(null, null, minibatchSize);
}
} else {
//Do a forward pass through each preprocessor and layer
for (int i = 0; i < layers.length; i++) {
InputPreProcessor preProcessor = conf.getInputPreProcessors().get(i);
if (preProcessor != null) {
Pair<INDArray, MaskState> p =
preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
if (p != null) {
maskArray = p.getFirst();
currentMaskState = p.getSecond();
} else {
maskArray = null;
currentMaskState = null;
}
}
Pair<INDArray, MaskState> p =
layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
if (p != null) {
maskArray = p.getFirst();
currentMaskState = p.getSecond();
} else {
maskArray = null;
currentMaskState = null;
}
}
}
return new Pair<>(maskArray, currentMaskState);
}
}

View File

@ -33,72 +33,75 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
*/ */
public interface NeuralNetwork { public interface NeuralNetwork {
/** /**
* This method does initialization of model * This method does initialization of model
* * <p>
* PLEASE NOTE: All implementations should track own state, to avoid double spending * PLEASE NOTE: All implementations should track own state, to avoid double spending
*/ */
void init(); void init();
/** /**
* This method returns model parameters as single INDArray * This method returns model parameters as single INDArray
* *
* @return * @return
*/ */
INDArray params(); INDArray params();
/** /**
* This method returns updater state (if applicable), null otherwise * This method returns updater state (if applicable), null otherwise
* @return *
*/ * @return
INDArray updaterState(); */
INDArray updaterState();
/** /**
* This method returns Optimizer used for training * This method returns Optimizer used for training
* *
* @return * @return
*/ */
ConvexOptimizer getOptimizer(); ConvexOptimizer getOptimizer();
/** /**
* This method fits model with a given DataSet * This method fits model with a given DataSet
* *
* @param dataSet * @param dataSet
*/ */
void fit(DataSet dataSet); void fit(DataSet dataSet);
/** /**
* This method fits model with a given MultiDataSet * This method fits model with a given MultiDataSet
* *
* @param dataSet * @param dataSet
*/ */
void fit(MultiDataSet dataSet); void fit(MultiDataSet dataSet);
/** /**
* This method fits model with a given DataSetIterator * This method fits model with a given DataSetIterator
* *
* @param iterator * @param iterator
*/ */
void fit(DataSetIterator iterator); void fit(DataSetIterator iterator);
/** /**
* This method fits model with a given MultiDataSetIterator * This method fits model with a given MultiDataSetIterator
* *
* @param iterator * @param iterator
*/ */
void fit(MultiDataSetIterator iterator); void fit(MultiDataSetIterator iterator);
/** /**
* This method executes evaluation of the model against given iterator and evaluation implementations * This method executes evaluation of the model against given iterator and evaluation
* * implementations
* @param iterator *
*/ * @param iterator
<T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T... evaluations); */
<T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T... evaluations);
/** /**
* This method executes evaluation of the model against given iterator and evaluation implementations * This method executes evaluation of the model against given iterator and evaluation
* * implementations
* @param iterator *
*/ * @param iterator
<T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations); */
<T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations);
} }

View File

@ -0,0 +1,97 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package org.deeplearning4j.nn.conf.layers.wrapper;
import java.util.Collection;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import net.brutex.ai.dnn.api.LayerConfiguration;
import net.brutex.ai.dnn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@Builder(builderClassName = "Builder", access = AccessLevel.PUBLIC)
public class BuildingBlockLayer extends BaseLayer implements LayerConfiguration {
@NonNull
@Getter
private NeuralNetworkConfiguration conf;
@Override
public Layer instantiate(NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
boolean initializeParams, DataType networkDataType) {
return null;
}
@Override
public ParamInitializer initializer() {
return null;
}
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
return null;
}
@Override
public void setNIn(InputType inputType, boolean override) {
}
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return null;
}
@Override
public boolean isPretrainParam(String paramName) {
return false;
}
@Override
public LayerMemoryReport getMemoryReport(InputType inputType) {
return null;
}
/**
* Create and return an instance of a LayerConfiguration.
*
* @param network the "holding" network for the instance
* @return the new layer instance
*/
@Override
public net.brutex.ai.dnn.api.Layer instantiate(NeuralNetwork network) {
return null;
}
}

View File

@ -101,7 +101,7 @@ import java.util.*;
@Slf4j @Slf4j
public class MultiLayerNetwork implements Serializable, Classifier, Layer, NeuralNetwork { public class MultiLayerNetwork implements Serializable, Classifier, Layer, org.deeplearning4j.nn.api.NeuralNetwork {
//the hidden neural network layers (including output layer) //the hidden neural network layers (including output layer)
protected Layer[] layers; protected Layer[] layers;

View File

@ -100,6 +100,7 @@ include ':cavis-dnn:cavis-dnn-data:cavis-dnn-data-utility-iterators'
include ':cavis-dnn:cavis-dnn-modelimport' include ':cavis-dnn:cavis-dnn-modelimport'
include ':cavis-dnn:cavis-dnn-nlp' include ':cavis-dnn:cavis-dnn-nlp'
include ':cavis-dnn:cavis-dnn-nn' include ':cavis-dnn:cavis-dnn-nn'
include ':cavis-dnn:cavis-dnn-nn-api'
include ':cavis-dnn:cavis-dnn-nn-parent' include ':cavis-dnn:cavis-dnn-nn-parent'
include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-server' include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-server'
include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-client' include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-client'
@ -154,3 +155,4 @@ include ':brutex-extended-tests'
include ':cavis-full' include ':cavis-full'