parent
4665c5a10a
commit
fec570ff98
|
@ -21,15 +21,10 @@
|
|||
|
||||
package net.brutex.gan;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import javax.ws.rs.client.ClientBuilder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.Response;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.datavec.api.Writable;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
import org.datavec.image.recordreader.ImageRecordReader;
|
||||
|
@ -37,34 +32,29 @@ import org.datavec.image.transform.ColorConversionTransform;
|
|||
import org.datavec.image.transform.ImageTransform;
|
||||
import org.datavec.image.transform.PipelineImageTransform;
|
||||
import org.datavec.image.transform.ResizeImageTransform;
|
||||
import org.datavec.image.transform.ScaleImageTransform;
|
||||
import org.datavec.image.transform.ShowImageTransform;
|
||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.nn.conf.CacheMode;
|
||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
||||
import org.glassfish.jersey.client.JerseyClient;
|
||||
import org.glassfish.jersey.client.JerseyClientBuilder;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
|
||||
import javax.swing.*;
|
||||
|
@ -106,6 +96,8 @@ public class App {
|
|||
new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH)
|
||||
.build()
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -114,7 +106,7 @@ public class App {
|
|||
* @return config
|
||||
*/
|
||||
private static MultiLayerConfiguration generator() {
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
MultiLayerConfiguration confxx = new NeuralNetConfiguration.Builder()
|
||||
.seed(42)
|
||||
.updater(UPDATER)
|
||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||
|
@ -123,9 +115,25 @@ public class App {
|
|||
.activation(Activation.IDENTITY)
|
||||
.list(genLayers())
|
||||
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
|
||||
.build();
|
||||
log.debug("Generator network: \n{}", confxx.toJson());
|
||||
|
||||
NeuralNetworkConfiguration conf2 = NeuralNetworkConfiguration.builder().build();
|
||||
|
||||
NeuralNetworkConfiguration confx = NeuralNetworkConfiguration.builder()
|
||||
.cacheMode(CacheMode.HOST)
|
||||
.layer( new DenseLayer.Builder().build())
|
||||
.layer( new DenseLayer.Builder().build())
|
||||
.layer( BuildingBlockLayer.builder().build())
|
||||
.layers( List.of(genLayers()))
|
||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||
.build();
|
||||
|
||||
return conf;
|
||||
|
||||
|
||||
|
||||
return confx;
|
||||
}
|
||||
|
||||
private static Layer[] disLayers() {
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
//apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
||||
|
||||
dependencies {
|
||||
implementation platform(projects.cavisCommonPlatform)
|
||||
implementation projects.cavisDnn.cavisDnnApi
|
||||
implementation projects.cavisDnn.cavisDnnNn
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.api;
|
||||
|
||||
/**
|
||||
* This is an "executable" Layer, that is based on a {@link LayerConfiguration}
|
||||
*/
|
||||
public interface Layer {
|
||||
|
||||
/**
|
||||
* Get the underlying configuration for this Layer
|
||||
* @return configuration
|
||||
*/
|
||||
LayerConfiguration getLayerConfiguration();
|
||||
|
||||
/**
|
||||
* Set the underlying layer configuration
|
||||
* @param conf The new configuration
|
||||
*/
|
||||
void setLayerConfiguration(LayerConfiguration conf);
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.api;
|
||||
|
||||
public interface LayerConfiguration {
|
||||
|
||||
/**
|
||||
* Create and return an instance of a LayerConfiguration.
|
||||
*
|
||||
* @param network the "holding" network for the instance
|
||||
* @return the new layer instance
|
||||
*/
|
||||
Layer instantiate(NeuralNetwork network);
|
||||
|
||||
|
||||
/**
|
||||
* Defines the valid input type for this Layer
|
||||
*
|
||||
* @return InputType
|
||||
*/
|
||||
org.deeplearning4j.nn.conf.inputs.InputType.Type getInputType();
|
||||
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.api;
|
||||
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
/**
|
||||
* A Neural Network is an instance of a {@link NeuralNetworkConfiguration}, that can be trained,
|
||||
* evaluated, saved, exported, etc. Its configuration state is defined with the
|
||||
* {@link #setConfiguration(NeuralNetworkConfiguration)} and {@link #getConfiguration()} methods.
|
||||
*
|
||||
*/
|
||||
public interface NeuralNetwork {
|
||||
|
||||
/**
|
||||
* The configuration that defines this Neural Network
|
||||
*
|
||||
* @param conf the configuration to use for this network
|
||||
*/
|
||||
void setConfiguration(NeuralNetworkConfiguration conf);
|
||||
NeuralNetworkConfiguration getConfiguration();
|
||||
|
||||
/**
|
||||
* This method fits model with a given DataSet
|
||||
*
|
||||
* @param dataSet the dataset to use for training
|
||||
*/
|
||||
void fit(DataSet dataSet);
|
||||
|
||||
/**
|
||||
* This method fits model with a given MultiDataSet
|
||||
*
|
||||
* @param dataSet the multi dataset to use for training
|
||||
*/
|
||||
void fit(MultiDataSet dataSet);
|
||||
|
||||
/**
|
||||
* The name of the Neural Network
|
||||
* @return the name
|
||||
*/
|
||||
String getName();
|
||||
|
||||
/**
|
||||
* Set the name for this Neural Network
|
||||
* @param name the name
|
||||
*/
|
||||
void setName(String name);
|
||||
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.api;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface NeuralNetworkConfiguration {
|
||||
|
||||
/**
|
||||
* Provides a flat list of all embedded layer configurations, this
|
||||
* can only be called after the layer is initialized or {@link #getLayerConfigurations()} is
|
||||
* called.
|
||||
*
|
||||
* @return unstacked layer configurations
|
||||
*/
|
||||
List<LayerConfiguration> getLayerConfigurations();
|
||||
|
||||
|
||||
/**
|
||||
* This uncollables any stacked layer configurations within building blocks like
|
||||
* @link BuildingBlockLayer}
|
||||
*/
|
||||
void calculateInnerLayerConfigurations();
|
||||
}
|
|
@ -22,7 +22,7 @@ apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
|||
|
||||
dependencies {
|
||||
implementation platform(projects.cavisCommonPlatform)
|
||||
|
||||
implementation projects.cavisDnn.cavisDnnNnApi
|
||||
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
|
||||
implementation 'org.lucee:oswego-concurrent:1.3.4'
|
||||
implementation projects.cavisDnn.cavisDnnCommon
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.conf;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import lombok.Singular;
|
||||
import lombok.extern.jackson.Jacksonized;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.brutex.ai.dnn.api.LayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.CacheMode;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer;
|
||||
|
||||
/**
|
||||
* The NeuralNetworkConfiguration is a sequential container for the different layers in your
|
||||
* network (or other NeuralNetworkConfigurations). That said, NeuralNetworkConfigurations can be
|
||||
* stacked.<br/><br/>
|
||||
* It then “chains” outputs to inputs sequentially for each NeuralNetworkConfiguration,
|
||||
* finally returning the output of the "top" configuration. Any settings made, are inherited and can
|
||||
* be overridden on a "deeper" level. For this use case, you need to wrap the NeuralNetworkConfiguration
|
||||
* into a BuildingBlockLayer
|
||||
*
|
||||
*/
|
||||
@Jacksonized
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
@lombok.Builder
|
||||
@Slf4j
|
||||
public class NeuralNetworkConfiguration implements net.brutex.ai.dnn.api.NeuralNetworkConfiguration, Serializable, Cloneable {
|
||||
|
||||
/**
|
||||
* The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified otherwise.
|
||||
* Valid values are<br/>
|
||||
* CacheMode.NONE,<br/>
|
||||
* CacheMode.HOST or<br/>
|
||||
* CacheMode.DEVICE<br/>
|
||||
*/
|
||||
@NonNull
|
||||
@lombok.Builder.Default private CacheMode cacheMode = CacheMode.NONE;
|
||||
|
||||
@Getter @Setter @NonNull
|
||||
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
|
||||
|
||||
@Getter @Setter @NonNull
|
||||
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
|
||||
|
||||
@Getter @Setter @NonNull
|
||||
protected BackpropType backpropType = BackpropType.Standard;
|
||||
|
||||
@Getter
|
||||
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
|
||||
|
||||
|
||||
@Getter @Setter protected int tbpttFwdLength = 20;
|
||||
@Getter @Setter protected int tbpttBackLength = 20;
|
||||
|
||||
|
||||
/**
|
||||
* The list of layer configurations in this configuration. They will be indexed automatically
|
||||
* as the layers get added starting with index 0.
|
||||
*/
|
||||
@Singular @Getter
|
||||
private List<LayerConfiguration> layerConfigurations;
|
||||
|
||||
/**
|
||||
* The name for this configuration. Defaults to "Anonymous NeuralNetworkConfiguration" if
|
||||
* it is not specified.
|
||||
*/
|
||||
@lombok.Builder.Default @Getter
|
||||
private String name = "Anonymous NeuralNetworkConfiguration";
|
||||
|
||||
|
||||
/**
|
||||
* The {@link InputType} of the data for this network configuration
|
||||
*/
|
||||
private InputType inputType;
|
||||
|
||||
/**
|
||||
* hidden list of layers, that "flattens" all the layers of this network and applies
|
||||
* inheritance.
|
||||
*/
|
||||
@lombok.Builder.ObtainVia(method = "calculateInnerLayers")
|
||||
private final List<LayerConfiguration> innerLayerConfigurations;
|
||||
|
||||
@Override
|
||||
public void calculateInnerLayerConfigurations() {
|
||||
List<LayerConfiguration> list = new ArrayList<>();
|
||||
for( LayerConfiguration layer : this.layerConfigurations) {
|
||||
if(layer instanceof BuildingBlockLayer) {
|
||||
BuildingBlockLayer blayer = (BuildingBlockLayer) layer;
|
||||
blayer.getConf().calculateInnerLayerConfigurations();
|
||||
list.addAll(blayer.getConf().getLayerConfigurations());
|
||||
} else {
|
||||
list.add(layer);
|
||||
}
|
||||
}
|
||||
this.layerConfigurations = list;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates and returns a copy of this object.
|
||||
*
|
||||
* @return a clone of this instance.
|
||||
* @throws CloneNotSupportedException if the object's class does not support the {@code Cloneable}
|
||||
* interface. Subclasses that override the {@code clone} method
|
||||
* can also throw this exception to indicate that an instance
|
||||
* cannot be cloned.
|
||||
* @see Cloneable
|
||||
*/
|
||||
@Override
|
||||
protected Object clone() throws CloneNotSupportedException {
|
||||
return super.clone();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.conf.layer;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import net.brutex.ai.dnn.api.LayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
|
||||
public abstract class AbstractLayerConfiguration implements LayerConfiguration {
|
||||
|
||||
@Getter @Setter @NonNull
|
||||
private InputType.Type inputType;
|
||||
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.conf.layer;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.brutex.ai.dnn.api.Layer;
|
||||
import net.brutex.ai.dnn.api.NeuralNetwork;
|
||||
import net.brutex.ai.dnn.conf.layer.AbstractLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType.Type;
|
||||
|
||||
@Slf4j
|
||||
public class FFLayer extends AbstractLayerConfiguration {
|
||||
|
||||
|
||||
/**
|
||||
* Create and return an instance of a LayerConfiguration.
|
||||
*
|
||||
* @param network the "holding" network for the instance
|
||||
* @return the new layer instance
|
||||
*/
|
||||
@Override
|
||||
public Layer instantiate(NeuralNetwork network) {
|
||||
//Let's do some verifications first
|
||||
if(getInputType() != Type.FF) {
|
||||
log.error("The {} layer configuration must use an InputType of {}, but found {}",
|
||||
this.getClass().getSimpleName(),
|
||||
Type.FF.name(),
|
||||
getInputType().name());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.conf.layer;
|
||||
|
||||
public abstract class LayerConfiguration {
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.impl.network;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import net.brutex.ai.dnn.api.Layer;
|
||||
import net.brutex.ai.dnn.api.NeuralNetwork;
|
||||
import net.brutex.ai.dnn.api.LayerConfiguration;
|
||||
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
public abstract class AbstractNeuralNetwork implements NeuralNetwork {
|
||||
|
||||
@Getter @Setter @NonNull
|
||||
private String name;
|
||||
|
||||
@Getter @NonNull
|
||||
private NeuralNetworkConfiguration configuration;
|
||||
|
||||
@Getter
|
||||
private final Collection<TrainingListener> trainingListeners = new HashSet<>();
|
||||
|
||||
/**
|
||||
* The neural network holds an instantiation of its configured
|
||||
* layers.
|
||||
* @return the actual runtime layers
|
||||
*/
|
||||
@Getter
|
||||
private final List<Layer> runtimeLayers = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* Sets the configuration to be used. Each time a configuration is set, the runtime layers
|
||||
* of this NeuralNetwork are updated from the configuration.
|
||||
*
|
||||
* @param conf the configuration to use for this network
|
||||
*/
|
||||
public void setConfiguration(net.brutex.ai.dnn.api.NeuralNetworkConfiguration conf) {
|
||||
List<LayerConfiguration> layers = conf.getLayerConfigurations();
|
||||
for(LayerConfiguration layer : layers) {
|
||||
Layer initializedLayer = layer.instantiate(this);
|
||||
this.getRuntimeLayers().add(initializedLayer);
|
||||
}
|
||||
this.configuration = configuration;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,692 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.impl.network;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.nn.api.Classifier;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.api.Updater;
|
||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
|
||||
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.updater.UpdaterCreator;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.optimize.Solver;
|
||||
import org.deeplearning4j.optimize.api.ConvexOptimizer;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.deeplearning4j.util.CrashReportingUtil;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
||||
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
||||
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.heartbeat.Heartbeat;
|
||||
import org.nd4j.linalg.heartbeat.reports.Environment;
|
||||
import org.nd4j.linalg.heartbeat.reports.Event;
|
||||
import org.nd4j.linalg.heartbeat.reports.Task;
|
||||
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
|
||||
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
||||
@Slf4j
|
||||
public class NeuralNetwork extends AbstractNeuralNetwork {
|
||||
|
||||
|
||||
//the hidden neural network layers (including output layer)
|
||||
protected Layer[] layers;
|
||||
|
||||
protected transient ThreadLocal<Long> lastEtlTime = new ThreadLocal<>();
|
||||
|
||||
//Current training data: input features and labels
|
||||
@Getter @Setter @NonNull
|
||||
protected INDArray input;
|
||||
@Getter @Setter
|
||||
protected INDArray labels;
|
||||
|
||||
//Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers
|
||||
@Getter
|
||||
protected transient Map<String, Pointer> helperWorkspaces = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Used to call optimizers during backprop
|
||||
*/
|
||||
@NonNull
|
||||
protected transient Solver solver = new Solver.Builder().configure(getConfiguration()).
|
||||
listeners(getTrainingListeners()).model(this).build();
|
||||
|
||||
|
||||
/**
|
||||
* Create a new NeuralNetwork from the given configuration
|
||||
* @param conf
|
||||
*/
|
||||
public NeuralNetwork(NeuralNetworkConfiguration conf) {
|
||||
if(! validateConfiguration() ) {
|
||||
log.error("Configuration '{}' has failed validation.", conf.getName());
|
||||
throw new RuntimeException();
|
||||
}
|
||||
log.info("Configuration '{}' has been validated successfully.", conf.getName());
|
||||
this.conf = conf;
|
||||
}
|
||||
|
||||
private boolean validateConfiguration() {
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private void logNotImplemented( ) {
|
||||
// getStackTrace() method return
|
||||
// current method name at 0th index
|
||||
String method = new Throwable()
|
||||
.getStackTrace()[1]
|
||||
.getMethodName();
|
||||
log.trace("Method '{}}' is not implemented for {}", method, this.getClass().getSimpleName());
|
||||
}
|
||||
|
||||
/**
|
||||
* This method does initialization of model
|
||||
* <p>
|
||||
* PLEASE NOTE: All implementations should track own state, to avoid double spending
|
||||
*/
|
||||
@Override
|
||||
public void init() {
|
||||
logNotImplemented();
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns model parameters as single INDArray
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public INDArray params() {
|
||||
logNotImplemented();
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns updater state (if applicable), null otherwise
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public INDArray updaterState() {
|
||||
return getUpdater(true) != null ? getUpdater(true).getStateViewArray() : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns Optimizer used for training
|
||||
*
|
||||
* @return the optimizer
|
||||
*/
|
||||
@Override
|
||||
public ConvexOptimizer getOptimizer() {
|
||||
return solver.getOptimizer();
|
||||
}
|
||||
|
||||
|
||||
|
||||
/** Get the updater for this NeuralNetwork from the Solver
|
||||
* @return Updater for NeuralNetwork
|
||||
*/
|
||||
private Updater getUpdater(boolean initializeIfReq) {
|
||||
if (solver == null && initializeIfReq) {
|
||||
synchronized(this){
|
||||
if(solver == null) { //May have been created while waiting for lock
|
||||
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this).build();
|
||||
solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
|
||||
}
|
||||
}
|
||||
}
|
||||
if(solver != null) {
|
||||
return solver.getOptimizer().getUpdater(initializeIfReq);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the updater for the NeuralNetwork in the Solver
|
||||
* */
|
||||
public void setUpdater(@NonNull Updater updater) {
|
||||
solver.getOptimizer().setUpdater(updater);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void fit(MultiDataSet dataSet) {
|
||||
if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) {
|
||||
INDArray features = dataSet.getFeatures(0);
|
||||
INDArray labels = dataSet.getLabels(0);
|
||||
INDArray fMask = null;
|
||||
INDArray lMask = null;
|
||||
|
||||
if (dataSet.getFeaturesMaskArrays() != null)
|
||||
fMask = dataSet.getFeaturesMaskArrays()[0];
|
||||
|
||||
if (dataSet.getFeaturesMaskArrays() != null)
|
||||
lMask = dataSet.getLabelsMaskArrays()[0];
|
||||
|
||||
DataSet ds = new DataSet(features, labels, fMask, lMask);
|
||||
fit(ds);
|
||||
} else {
|
||||
throw new DL4JInvalidInputException(
|
||||
"MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." +
|
||||
"Please consider use of ComputationGraph");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform minibatch training on all minibatches in the MultiDataSetIterator, for the specified number of epochs.
|
||||
* Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a loop
|
||||
*
|
||||
* @param iterator Training data (DataSetIterator). Iterator must support resetting
|
||||
* @param numEpochs Number of training epochs, >= 1
|
||||
*/
|
||||
public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs){
|
||||
Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs);
|
||||
Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" +
|
||||
"iterator has does not support resetting (iterator.resetSupported() returned false)");
|
||||
|
||||
for(int i = 0; i < numEpochs; i++) {
|
||||
fit(iterator);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform minibatch training on all minibatches in the MultiDataSetIterator.<br>
|
||||
* Note: The MultiDataSets in the MultiDataSetIterator must have exactly 1 input and output array (as
|
||||
* MultiLayerNetwork only supports 1 input and 1 output)
|
||||
*
|
||||
* @param iterator Training data (DataSetIterator). Iterator must support resetting
|
||||
*/
|
||||
@Override
|
||||
public void fit(MultiDataSetIterator iterator) {
|
||||
fit(new MultiDataSetWrapperIterator(iterator));
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform minibatch training on all minibatches in the DataSetIterator for 1 epoch.<br>
|
||||
* Note that this method does not do layerwise pretraining.<br>
|
||||
* For pretraining use method pretrain.. #pretrain(DataSetIterator)<br>
|
||||
* @param iterator Training data (DataSetIterator)
|
||||
*/
|
||||
@Override
|
||||
public void fit(DataSetIterator iterator) {
|
||||
try{
|
||||
fitHelper(iterator);
|
||||
} catch (OutOfMemoryError e){
|
||||
CrashReportingUtil.writeMemoryCrashDump(this, e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private synchronized void fitHelper(DataSetIterator iterator){
|
||||
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
|
||||
DataSetIterator iter;
|
||||
boolean destructable = false;
|
||||
if (iterator.asyncSupported()) {
|
||||
iter = new AsyncDataSetIterator(iterator, Math.min(
|
||||
Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true);
|
||||
destructable = true;
|
||||
} else {
|
||||
iter = iterator;
|
||||
}
|
||||
|
||||
for (TrainingListener tl : trainingListeners) {
|
||||
tl.onEpochStart(this);
|
||||
}
|
||||
|
||||
LayerWorkspaceMgr workspaceMgr;
|
||||
if(conf.getTrainingWorkspaceMode() == WorkspaceMode.NONE){
|
||||
workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
|
||||
} else {
|
||||
workspaceMgr = LayerWorkspaceMgr.builder()
|
||||
.with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
|
||||
.with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
|
||||
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
|
||||
.with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
|
||||
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
|
||||
.with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
|
||||
//Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM
|
||||
// as these should be closed by the time updaters are executed
|
||||
//Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this
|
||||
.with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
|
||||
.build();
|
||||
}
|
||||
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
|
||||
|
||||
update(TaskUtils.buildTask(iter));
|
||||
if (!iter.hasNext() && iter.resetSupported()) {
|
||||
iter.reset();
|
||||
}
|
||||
long time1 = System.currentTimeMillis();
|
||||
while (iter.hasNext()) {
|
||||
|
||||
DataSet next = iter.next();
|
||||
long time2 = System.currentTimeMillis();
|
||||
|
||||
lastEtlTime.set((time2 - time1));
|
||||
|
||||
if (next.getFeatures() == null || next.getLabels() == null)
|
||||
break;
|
||||
|
||||
// TODO: basically we want to wrap internals of this loop into workspace
|
||||
|
||||
|
||||
boolean hasMaskArrays = next.hasMaskArrays();
|
||||
|
||||
if (conf.getBackpropType() == BackpropType.TruncatedBPTT) {
|
||||
doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArray(),
|
||||
next.getLabelsMaskArray(), workspaceMgr);
|
||||
} else {
|
||||
if (hasMaskArrays)
|
||||
setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
|
||||
|
||||
setInput(next.getFeatures());
|
||||
setLabels(next.getLabels());
|
||||
|
||||
if (solver == null) {
|
||||
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
//TODO CACHE
|
||||
solver.optimize(workspaceMgr);
|
||||
}
|
||||
|
||||
if (hasMaskArrays)
|
||||
clearLayerMaskArrays();
|
||||
|
||||
time1 = System.currentTimeMillis();
|
||||
synchronizeIterEpochCounts();
|
||||
}
|
||||
|
||||
if (!trainingListeners.isEmpty()) {
|
||||
for (TrainingListener tl : trainingListeners) {
|
||||
tl.onEpochEnd(this);
|
||||
}
|
||||
}
|
||||
|
||||
clearLayersStates();
|
||||
|
||||
if (destructable)
|
||||
((AsyncDataSetIterator) iter).shutdown();
|
||||
|
||||
incrementEpochCount();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Workspace for working memory for a single layer: forward pass and backward pass
|
||||
* Note that this is opened/closed once per op (activate/backpropGradient call)
|
||||
*/
|
||||
protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM";
|
||||
/**
|
||||
* Workspace for storing all layers' activations - used only to store activations (layer inputs) as part of backprop
|
||||
* Not used for inference
|
||||
*/
|
||||
protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT";
|
||||
/**
|
||||
* Next 2 workspaces: used for:
|
||||
* (a) Inference: holds activations for one layer only
|
||||
* (b) Backprop: holds activation gradients for one layer only
|
||||
* In both cases, they are opened and closed on every second layer
|
||||
*/
|
||||
protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1";
|
||||
protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2";
|
||||
|
||||
/**
|
||||
* Workspace for output methods that use OutputAdapter
|
||||
*/
|
||||
protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM";
|
||||
|
||||
/**
|
||||
* Workspace for working memory in RNNs - opened and closed once per RNN time step
|
||||
*/
|
||||
protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM";
|
||||
|
||||
|
||||
protected WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG;
|
||||
|
||||
protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder()
|
||||
.initialSize(0)
|
||||
.overallocationLimit(0.05)
|
||||
.policyLearning(LearningPolicy.FIRST_LOOP)
|
||||
.policyReset(ResetPolicy.BLOCK_LEFT)
|
||||
.policySpill(SpillPolicy.REALLOCATE)
|
||||
.policyAllocation(AllocationPolicy.OVERALLOCATE)
|
||||
.build();
|
||||
|
||||
protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG;
|
||||
|
||||
protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder()
|
||||
.initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT)
|
||||
.policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE)
|
||||
.policyLearning(LearningPolicy.FIRST_LOOP).build();
|
||||
|
||||
|
||||
boolean initDone;
|
||||
protected void update(Task task) {
|
||||
if (!initDone) {
|
||||
initDone = true;
|
||||
Heartbeat heartbeat = Heartbeat.getInstance();
|
||||
task = ModelSerializer.taskByModel(this);
|
||||
Environment env = EnvironmentUtils.buildEnvironment();
|
||||
heartbeat.reportEvent(Event.STANDALONE, env, task);
|
||||
}
|
||||
}
|
||||
|
||||
protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray,
|
||||
INDArray labelsMaskArray, LayerWorkspaceMgr workspaceMgr) {
|
||||
if (input.rank() != 3 || labels.rank() != 3) {
|
||||
log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got "
|
||||
+ Arrays.toString(input.shape()) + "\tand labels with shape "
|
||||
+ Arrays.toString(labels.shape()));
|
||||
return;
|
||||
}
|
||||
if (input.size(2) != labels.size(2)) {
|
||||
log.warn("Input and label time series have different lengths: {} input length, {} label length",
|
||||
input.size(2), labels.size(2));
|
||||
return;
|
||||
}
|
||||
|
||||
int fwdLen = conf.getTbpttFwdLength();
|
||||
update(TaskUtils.buildTask(input, labels));
|
||||
val timeSeriesLength = input.size(2);
|
||||
long nSubsets = timeSeriesLength / fwdLen;
|
||||
if (timeSeriesLength % fwdLen != 0)
|
||||
nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20)
|
||||
|
||||
rnnClearPreviousState();
|
||||
|
||||
for (int i = 0; i < nSubsets; i++) {
|
||||
long startTimeIdx = (long) i * fwdLen;
|
||||
long endTimeIdx = startTimeIdx + fwdLen;
|
||||
if (endTimeIdx > timeSeriesLength)
|
||||
endTimeIdx = timeSeriesLength;
|
||||
|
||||
if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE)
|
||||
throw new ND4JArraySizeException();
|
||||
INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels,
|
||||
featuresMaskArray, labelsMaskArray);
|
||||
|
||||
setInput(subsets[0]);
|
||||
setLabels(subsets[1]);
|
||||
setLayerMaskArrays(subsets[2], subsets[3]);
|
||||
|
||||
if (solver == null) {
|
||||
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
solver.optimize(workspaceMgr);
|
||||
|
||||
//Finally, update the state of the RNN layers:
|
||||
updateRnnStateWithTBPTTState();
|
||||
}
|
||||
|
||||
rnnClearPreviousState();
|
||||
clearLayerMaskArrays();
|
||||
}
|
||||
|
||||
private INDArray[] getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray input, INDArray labels,
|
||||
INDArray fMask, INDArray lMask ){
|
||||
INDArray[] out = new INDArray[4];
|
||||
out[0] = input.get(NDArrayIndex.all(), NDArrayIndex.all(),
|
||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
||||
out[1] = labels.get(NDArrayIndex.all(), NDArrayIndex.all(),
|
||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
||||
|
||||
if (fMask != null) {
|
||||
out[2] = fMask.get(NDArrayIndex.all(),
|
||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
||||
}
|
||||
if (lMask != null) {
|
||||
out[3] = lMask.get(NDArrayIndex.all(),
|
||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Intended for internal/developer use
|
||||
*/
|
||||
public void updateRnnStateWithTBPTTState() {
|
||||
Layer[] layers = conf.calculateInnerLayers().toArray(new Layer[]{});
|
||||
for (int i = 0; i < layers.length; i++) {
|
||||
if (layers[i] instanceof RecurrentLayer) {
|
||||
RecurrentLayer l = ((RecurrentLayer) layers[i]);
|
||||
l.rnnSetPreviousState(l.rnnGetTBPTTState());
|
||||
} else if (layers[i] instanceof MultiLayerNetwork) {
|
||||
((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Clear the previous state of the RNN layers (if any).
|
||||
*/
|
||||
public void rnnClearPreviousState() {
|
||||
Layer[] layers = conf.getLayers().toArray(new Layer[]{});
|
||||
if (layers == null)
|
||||
return;
|
||||
for (int i = 0; i < layers.length; i++) {
|
||||
if (layers[i] instanceof RecurrentLayer)
|
||||
((RecurrentLayer) layers[i]).rnnClearPreviousState();
|
||||
else if (layers[i] instanceof MultiLayerNetwork) {
|
||||
((MultiLayerNetwork) layers[i]).rnnClearPreviousState();
|
||||
} else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){
|
||||
((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying()).rnnClearPreviousState();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/** Remove the mask arrays from all layers.<br>
|
||||
* See {@link #setLayerMaskArrays(INDArray, INDArray)} for details on mask arrays.
|
||||
*/
|
||||
public void clearLayerMaskArrays() {
|
||||
Layer[] layers = conf.getLayers().toArray(new Layer[]{});
|
||||
for (Layer layer : layers) {
|
||||
layer.setMaskArray(null);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Increment the epoch count (in the underlying {@link MultiLayerConfiguration} by 1).
|
||||
* Note that this is done <i>automatically</i> when using iterator-based fitting methods, such as
|
||||
* {@link #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, INDArray/INDArray etc),
|
||||
* the network has no way to know when one epoch ends and another starts. In such situations, this method
|
||||
* can be used to increment the epoch counter.<br>
|
||||
* Note that the epoch counter is used for situations such as some learning rate schedules, and the like.
|
||||
*
|
||||
* The current epoch count can be obtained using {@code MultiLayerConfiguration.getLayerwiseConfiguration().getEpochCount()}
|
||||
*/
|
||||
public void incrementEpochCount(){
|
||||
conf.setEpochCount(conf.getEpochCount() + 1);
|
||||
synchronizeIterEpochCounts();
|
||||
}
|
||||
|
||||
protected void synchronizeIterEpochCounts() {
|
||||
//TODO: this is necessary for some schedules - but the redundant values are a little ugly...
|
||||
int currIter = conf.getIterationCount();
|
||||
int currEpoch = conf.getEpochCount();
|
||||
log.error("Something went wrong here. Code incomplete");
|
||||
/*for(Layer l : conf.getLayers()) {
|
||||
l.setIterationCount(currIter);
|
||||
l.setEpochCount(currEpoch);
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
/**
|
||||
* This method just makes sure there's no state preserved within layers
|
||||
*/
|
||||
public void clearLayersStates() {
|
||||
for (Layer layer : layers) {
|
||||
layer.clear();
|
||||
layer.clearNoiseWeightParams();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many
|
||||
* and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths
|
||||
* within the same minibatch.<br>
|
||||
* For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape
|
||||
* [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength]
|
||||
* and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding -
|
||||
* at a given time step).<br>
|
||||
* <b>NOTE</b>: This method is not usually used directly. Instead, methods such as @link #feedForward(INDArray, INDArray, INDArray)}
|
||||
* and @link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking internally.
|
||||
* @param featuresMaskArray Mask array for features (input)
|
||||
* @param labelsMaskArray Mask array for labels (output)
|
||||
* @see #clearLayerMaskArrays()
|
||||
*/
|
||||
public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) {
|
||||
if (featuresMaskArray != null) {
|
||||
|
||||
if (featuresMaskArray.size(0) > Integer.MAX_VALUE)
|
||||
throw new ND4JArraySizeException();
|
||||
//New approach: use feedForwardMaskArray method
|
||||
feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0));
|
||||
|
||||
|
||||
/*
|
||||
//feedforward layers below a RNN layer: need the input (features) mask array
|
||||
//Reason: even if the time series input is zero padded, the output from the dense layers are
|
||||
// non-zero (i.e., activationFunction(0*weights + bias) != 0 in general)
|
||||
//This assumes that the time series input is masked - i.e., values are 0 at the padded time steps,
|
||||
// so we don't need to do anything for the recurrent layer
|
||||
|
||||
//Now, if mask array is 2d -> need to reshape to 1d (column vector) in the exact same order
|
||||
// as is done for 3d -> 2d time series reshaping
|
||||
INDArray reshapedFeaturesMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(featuresMaskArray);
|
||||
|
||||
for( int i=0; i<layers.length-1; i++ ){
|
||||
Type t = layers[i].type();
|
||||
if( t == Type.CONVOLUTIONAL || t == Type.FEED_FORWARD ){
|
||||
layers[i].setMaskArray(reshapedFeaturesMask);
|
||||
} else if( t == Type.RECURRENT ) break;
|
||||
|
||||
}
|
||||
*/
|
||||
}
|
||||
if (labelsMaskArray != null) {
|
||||
if (!(getOutputLayer() instanceof IOutputLayer))
|
||||
return;
|
||||
layers[layers.length - 1].setMaskArray(labelsMaskArray);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get the output layer - i.e., the last layer in the netwok
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public Layer getOutputLayer() {
|
||||
Layer ret = layers[layers.length - 1];
|
||||
if (ret instanceof FrozenLayerWithBackprop) {
|
||||
ret = ((FrozenLayerWithBackprop) ret).getInsideLayer();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
|
||||
int minibatchSize) {
|
||||
if (maskArray == null) {
|
||||
for (int i = 0; i < layers.length; i++) {
|
||||
layers[i].feedForwardMaskArray(null, null, minibatchSize);
|
||||
}
|
||||
} else {
|
||||
//Do a forward pass through each preprocessor and layer
|
||||
for (int i = 0; i < layers.length; i++) {
|
||||
InputPreProcessor preProcessor = conf.getInputPreProcessors().get(i);
|
||||
|
||||
if (preProcessor != null) {
|
||||
Pair<INDArray, MaskState> p =
|
||||
preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
|
||||
if (p != null) {
|
||||
maskArray = p.getFirst();
|
||||
currentMaskState = p.getSecond();
|
||||
} else {
|
||||
maskArray = null;
|
||||
currentMaskState = null;
|
||||
}
|
||||
}
|
||||
|
||||
Pair<INDArray, MaskState> p =
|
||||
layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
|
||||
if (p != null) {
|
||||
maskArray = p.getFirst();
|
||||
currentMaskState = p.getSecond();
|
||||
} else {
|
||||
maskArray = null;
|
||||
currentMaskState = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new Pair<>(maskArray, currentMaskState);
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -33,72 +33,75 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
|||
*/
|
||||
public interface NeuralNetwork {
|
||||
|
||||
/**
|
||||
* This method does initialization of model
|
||||
*
|
||||
* PLEASE NOTE: All implementations should track own state, to avoid double spending
|
||||
*/
|
||||
void init();
|
||||
/**
|
||||
* This method does initialization of model
|
||||
* <p>
|
||||
* PLEASE NOTE: All implementations should track own state, to avoid double spending
|
||||
*/
|
||||
void init();
|
||||
|
||||
/**
|
||||
* This method returns model parameters as single INDArray
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
INDArray params();
|
||||
/**
|
||||
* This method returns model parameters as single INDArray
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
INDArray params();
|
||||
|
||||
/**
|
||||
* This method returns updater state (if applicable), null otherwise
|
||||
* @return
|
||||
*/
|
||||
INDArray updaterState();
|
||||
/**
|
||||
* This method returns updater state (if applicable), null otherwise
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
INDArray updaterState();
|
||||
|
||||
/**
|
||||
* This method returns Optimizer used for training
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
ConvexOptimizer getOptimizer();
|
||||
/**
|
||||
* This method returns Optimizer used for training
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
ConvexOptimizer getOptimizer();
|
||||
|
||||
/**
|
||||
* This method fits model with a given DataSet
|
||||
*
|
||||
* @param dataSet
|
||||
*/
|
||||
void fit(DataSet dataSet);
|
||||
/**
|
||||
* This method fits model with a given DataSet
|
||||
*
|
||||
* @param dataSet
|
||||
*/
|
||||
void fit(DataSet dataSet);
|
||||
|
||||
/**
|
||||
* This method fits model with a given MultiDataSet
|
||||
*
|
||||
* @param dataSet
|
||||
*/
|
||||
void fit(MultiDataSet dataSet);
|
||||
/**
|
||||
* This method fits model with a given MultiDataSet
|
||||
*
|
||||
* @param dataSet
|
||||
*/
|
||||
void fit(MultiDataSet dataSet);
|
||||
|
||||
/**
|
||||
* This method fits model with a given DataSetIterator
|
||||
*
|
||||
* @param iterator
|
||||
*/
|
||||
void fit(DataSetIterator iterator);
|
||||
/**
|
||||
* This method fits model with a given DataSetIterator
|
||||
*
|
||||
* @param iterator
|
||||
*/
|
||||
void fit(DataSetIterator iterator);
|
||||
|
||||
/**
|
||||
* This method fits model with a given MultiDataSetIterator
|
||||
*
|
||||
* @param iterator
|
||||
*/
|
||||
void fit(MultiDataSetIterator iterator);
|
||||
/**
|
||||
* This method fits model with a given MultiDataSetIterator
|
||||
*
|
||||
* @param iterator
|
||||
*/
|
||||
void fit(MultiDataSetIterator iterator);
|
||||
|
||||
/**
|
||||
* This method executes evaluation of the model against given iterator and evaluation implementations
|
||||
*
|
||||
* @param iterator
|
||||
*/
|
||||
<T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T... evaluations);
|
||||
/**
|
||||
* This method executes evaluation of the model against given iterator and evaluation
|
||||
* implementations
|
||||
*
|
||||
* @param iterator
|
||||
*/
|
||||
<T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T... evaluations);
|
||||
|
||||
/**
|
||||
* This method executes evaluation of the model against given iterator and evaluation implementations
|
||||
*
|
||||
* @param iterator
|
||||
*/
|
||||
<T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations);
|
||||
/**
|
||||
* This method executes evaluation of the model against given iterator and evaluation
|
||||
* implementations
|
||||
*
|
||||
* @param iterator
|
||||
*/
|
||||
<T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations);
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,97 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nn.conf.layers.wrapper;
|
||||
|
||||
import java.util.Collection;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import net.brutex.ai.dnn.api.LayerConfiguration;
|
||||
import net.brutex.ai.dnn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
@Builder(builderClassName = "Builder", access = AccessLevel.PUBLIC)
|
||||
public class BuildingBlockLayer extends BaseLayer implements LayerConfiguration {
|
||||
|
||||
@NonNull
|
||||
@Getter
|
||||
private NeuralNetworkConfiguration conf;
|
||||
|
||||
@Override
|
||||
public Layer instantiate(NeuralNetConfiguration conf,
|
||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
||||
boolean initializeParams, DataType networkDataType) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParamInitializer initializer() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPretrainParam(String paramName) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create and return an instance of a LayerConfiguration.
|
||||
*
|
||||
* @param network the "holding" network for the instance
|
||||
* @return the new layer instance
|
||||
*/
|
||||
@Override
|
||||
public net.brutex.ai.dnn.api.Layer instantiate(NeuralNetwork network) {
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -101,7 +101,7 @@ import java.util.*;
|
|||
|
||||
|
||||
@Slf4j
|
||||
public class MultiLayerNetwork implements Serializable, Classifier, Layer, NeuralNetwork {
|
||||
public class MultiLayerNetwork implements Serializable, Classifier, Layer, org.deeplearning4j.nn.api.NeuralNetwork {
|
||||
|
||||
//the hidden neural network layers (including output layer)
|
||||
protected Layer[] layers;
|
||||
|
|
|
@ -100,6 +100,7 @@ include ':cavis-dnn:cavis-dnn-data:cavis-dnn-data-utility-iterators'
|
|||
include ':cavis-dnn:cavis-dnn-modelimport'
|
||||
include ':cavis-dnn:cavis-dnn-nlp'
|
||||
include ':cavis-dnn:cavis-dnn-nn'
|
||||
include ':cavis-dnn:cavis-dnn-nn-api'
|
||||
include ':cavis-dnn:cavis-dnn-nn-parent'
|
||||
include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-server'
|
||||
include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-client'
|
||||
|
@ -154,3 +155,4 @@ include ':brutex-extended-tests'
|
|||
include ':cavis-full'
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue