parent
4665c5a10a
commit
fec570ff98
|
@ -21,15 +21,10 @@
|
|||
|
||||
package net.brutex.gan;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import javax.ws.rs.client.ClientBuilder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.Response;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.datavec.api.Writable;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
import org.datavec.image.recordreader.ImageRecordReader;
|
||||
|
@ -37,34 +32,29 @@ import org.datavec.image.transform.ColorConversionTransform;
|
|||
import org.datavec.image.transform.ImageTransform;
|
||||
import org.datavec.image.transform.PipelineImageTransform;
|
||||
import org.datavec.image.transform.ResizeImageTransform;
|
||||
import org.datavec.image.transform.ScaleImageTransform;
|
||||
import org.datavec.image.transform.ShowImageTransform;
|
||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.nn.conf.CacheMode;
|
||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
||||
import org.glassfish.jersey.client.JerseyClient;
|
||||
import org.glassfish.jersey.client.JerseyClientBuilder;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
|
||||
import javax.swing.*;
|
||||
|
@ -106,6 +96,8 @@ public class App {
|
|||
new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH)
|
||||
.build()
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -114,7 +106,7 @@ public class App {
|
|||
* @return config
|
||||
*/
|
||||
private static MultiLayerConfiguration generator() {
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
MultiLayerConfiguration confxx = new NeuralNetConfiguration.Builder()
|
||||
.seed(42)
|
||||
.updater(UPDATER)
|
||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||||
|
@ -123,9 +115,25 @@ public class App {
|
|||
.activation(Activation.IDENTITY)
|
||||
.list(genLayers())
|
||||
.setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||
// .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS))
|
||||
.build();
|
||||
log.debug("Generator network: \n{}", confxx.toJson());
|
||||
|
||||
NeuralNetworkConfiguration conf2 = NeuralNetworkConfiguration.builder().build();
|
||||
|
||||
NeuralNetworkConfiguration confx = NeuralNetworkConfiguration.builder()
|
||||
.cacheMode(CacheMode.HOST)
|
||||
.layer( new DenseLayer.Builder().build())
|
||||
.layer( new DenseLayer.Builder().build())
|
||||
.layer( BuildingBlockLayer.builder().build())
|
||||
.layers( List.of(genLayers()))
|
||||
.inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS))
|
||||
.build();
|
||||
|
||||
return conf;
|
||||
|
||||
|
||||
|
||||
return confx;
|
||||
}
|
||||
|
||||
private static Layer[] disLayers() {
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
//apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
||||
|
||||
dependencies {
|
||||
implementation platform(projects.cavisCommonPlatform)
|
||||
implementation projects.cavisDnn.cavisDnnApi
|
||||
implementation projects.cavisDnn.cavisDnnNn
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.api;
|
||||
|
||||
/**
|
||||
* This is an "executable" Layer, that is based on a {@link LayerConfiguration}
|
||||
*/
|
||||
public interface Layer {
|
||||
|
||||
/**
|
||||
* Get the underlying configuration for this Layer
|
||||
* @return configuration
|
||||
*/
|
||||
LayerConfiguration getLayerConfiguration();
|
||||
|
||||
/**
|
||||
* Set the underlying layer configuration
|
||||
* @param conf The new configuration
|
||||
*/
|
||||
void setLayerConfiguration(LayerConfiguration conf);
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.api;
|
||||
|
||||
public interface LayerConfiguration {
|
||||
|
||||
/**
|
||||
* Create and return an instance of a LayerConfiguration.
|
||||
*
|
||||
* @param network the "holding" network for the instance
|
||||
* @return the new layer instance
|
||||
*/
|
||||
Layer instantiate(NeuralNetwork network);
|
||||
|
||||
|
||||
/**
|
||||
* Defines the valid input type for this Layer
|
||||
*
|
||||
* @return InputType
|
||||
*/
|
||||
org.deeplearning4j.nn.conf.inputs.InputType.Type getInputType();
|
||||
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.api;
|
||||
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
/**
|
||||
* A Neural Network is an instance of a {@link NeuralNetworkConfiguration}, that can be trained,
|
||||
* evaluated, saved, exported, etc. Its configuration state is defined with the
|
||||
* {@link #setConfiguration(NeuralNetworkConfiguration)} and {@link #getConfiguration()} methods.
|
||||
*
|
||||
*/
|
||||
public interface NeuralNetwork {
|
||||
|
||||
/**
|
||||
* The configuration that defines this Neural Network
|
||||
*
|
||||
* @param conf the configuration to use for this network
|
||||
*/
|
||||
void setConfiguration(NeuralNetworkConfiguration conf);
|
||||
NeuralNetworkConfiguration getConfiguration();
|
||||
|
||||
/**
|
||||
* This method fits model with a given DataSet
|
||||
*
|
||||
* @param dataSet the dataset to use for training
|
||||
*/
|
||||
void fit(DataSet dataSet);
|
||||
|
||||
/**
|
||||
* This method fits model with a given MultiDataSet
|
||||
*
|
||||
* @param dataSet the multi dataset to use for training
|
||||
*/
|
||||
void fit(MultiDataSet dataSet);
|
||||
|
||||
/**
|
||||
* The name of the Neural Network
|
||||
* @return the name
|
||||
*/
|
||||
String getName();
|
||||
|
||||
/**
|
||||
* Set the name for this Neural Network
|
||||
* @param name the name
|
||||
*/
|
||||
void setName(String name);
|
||||
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.api;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface NeuralNetworkConfiguration {
|
||||
|
||||
/**
|
||||
* Provides a flat list of all embedded layer configurations, this
|
||||
* can only be called after the layer is initialized or {@link #getLayerConfigurations()} is
|
||||
* called.
|
||||
*
|
||||
* @return unstacked layer configurations
|
||||
*/
|
||||
List<LayerConfiguration> getLayerConfigurations();
|
||||
|
||||
|
||||
/**
|
||||
* This uncollables any stacked layer configurations within building blocks like
|
||||
* @link BuildingBlockLayer}
|
||||
*/
|
||||
void calculateInnerLayerConfigurations();
|
||||
}
|
|
@ -22,7 +22,7 @@ apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
|||
|
||||
dependencies {
|
||||
implementation platform(projects.cavisCommonPlatform)
|
||||
|
||||
implementation projects.cavisDnn.cavisDnnNnApi
|
||||
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
|
||||
implementation 'org.lucee:oswego-concurrent:1.3.4'
|
||||
implementation projects.cavisDnn.cavisDnnCommon
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.conf;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import lombok.Singular;
|
||||
import lombok.extern.jackson.Jacksonized;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.brutex.ai.dnn.api.LayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.CacheMode;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer;
|
||||
|
||||
/**
|
||||
* The NeuralNetworkConfiguration is a sequential container for the different layers in your
|
||||
* network (or other NeuralNetworkConfigurations). That said, NeuralNetworkConfigurations can be
|
||||
* stacked.<br/><br/>
|
||||
* It then “chains” outputs to inputs sequentially for each NeuralNetworkConfiguration,
|
||||
* finally returning the output of the "top" configuration. Any settings made, are inherited and can
|
||||
* be overridden on a "deeper" level. For this use case, you need to wrap the NeuralNetworkConfiguration
|
||||
* into a BuildingBlockLayer
|
||||
*
|
||||
*/
|
||||
@Jacksonized
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
@lombok.Builder
|
||||
@Slf4j
|
||||
public class NeuralNetworkConfiguration implements net.brutex.ai.dnn.api.NeuralNetworkConfiguration, Serializable, Cloneable {
|
||||
|
||||
/**
|
||||
* The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified otherwise.
|
||||
* Valid values are<br/>
|
||||
* CacheMode.NONE,<br/>
|
||||
* CacheMode.HOST or<br/>
|
||||
* CacheMode.DEVICE<br/>
|
||||
*/
|
||||
@NonNull
|
||||
@lombok.Builder.Default private CacheMode cacheMode = CacheMode.NONE;
|
||||
|
||||
@Getter @Setter @NonNull
|
||||
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
|
||||
|
||||
@Getter @Setter @NonNull
|
||||
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
|
||||
|
||||
@Getter @Setter @NonNull
|
||||
protected BackpropType backpropType = BackpropType.Standard;
|
||||
|
||||
@Getter
|
||||
protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<>();
|
||||
|
||||
|
||||
@Getter @Setter protected int tbpttFwdLength = 20;
|
||||
@Getter @Setter protected int tbpttBackLength = 20;
|
||||
|
||||
|
||||
/**
|
||||
* The list of layer configurations in this configuration. They will be indexed automatically
|
||||
* as the layers get added starting with index 0.
|
||||
*/
|
||||
@Singular @Getter
|
||||
private List<LayerConfiguration> layerConfigurations;
|
||||
|
||||
/**
|
||||
* The name for this configuration. Defaults to "Anonymous NeuralNetworkConfiguration" if
|
||||
* it is not specified.
|
||||
*/
|
||||
@lombok.Builder.Default @Getter
|
||||
private String name = "Anonymous NeuralNetworkConfiguration";
|
||||
|
||||
|
||||
/**
|
||||
* The {@link InputType} of the data for this network configuration
|
||||
*/
|
||||
private InputType inputType;
|
||||
|
||||
/**
|
||||
* hidden list of layers, that "flattens" all the layers of this network and applies
|
||||
* inheritance.
|
||||
*/
|
||||
@lombok.Builder.ObtainVia(method = "calculateInnerLayers")
|
||||
private final List<LayerConfiguration> innerLayerConfigurations;
|
||||
|
||||
@Override
|
||||
public void calculateInnerLayerConfigurations() {
|
||||
List<LayerConfiguration> list = new ArrayList<>();
|
||||
for( LayerConfiguration layer : this.layerConfigurations) {
|
||||
if(layer instanceof BuildingBlockLayer) {
|
||||
BuildingBlockLayer blayer = (BuildingBlockLayer) layer;
|
||||
blayer.getConf().calculateInnerLayerConfigurations();
|
||||
list.addAll(blayer.getConf().getLayerConfigurations());
|
||||
} else {
|
||||
list.add(layer);
|
||||
}
|
||||
}
|
||||
this.layerConfigurations = list;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates and returns a copy of this object.
|
||||
*
|
||||
* @return a clone of this instance.
|
||||
* @throws CloneNotSupportedException if the object's class does not support the {@code Cloneable}
|
||||
* interface. Subclasses that override the {@code clone} method
|
||||
* can also throw this exception to indicate that an instance
|
||||
* cannot be cloned.
|
||||
* @see Cloneable
|
||||
*/
|
||||
@Override
|
||||
protected Object clone() throws CloneNotSupportedException {
|
||||
return super.clone();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.conf.layer;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import net.brutex.ai.dnn.api.LayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
|
||||
public abstract class AbstractLayerConfiguration implements LayerConfiguration {
|
||||
|
||||
@Getter @Setter @NonNull
|
||||
private InputType.Type inputType;
|
||||
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.conf.layer;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.brutex.ai.dnn.api.Layer;
|
||||
import net.brutex.ai.dnn.api.NeuralNetwork;
|
||||
import net.brutex.ai.dnn.conf.layer.AbstractLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType.Type;
|
||||
|
||||
@Slf4j
|
||||
public class FFLayer extends AbstractLayerConfiguration {
|
||||
|
||||
|
||||
/**
|
||||
* Create and return an instance of a LayerConfiguration.
|
||||
*
|
||||
* @param network the "holding" network for the instance
|
||||
* @return the new layer instance
|
||||
*/
|
||||
@Override
|
||||
public Layer instantiate(NeuralNetwork network) {
|
||||
//Let's do some verifications first
|
||||
if(getInputType() != Type.FF) {
|
||||
log.error("The {} layer configuration must use an InputType of {}, but found {}",
|
||||
this.getClass().getSimpleName(),
|
||||
Type.FF.name(),
|
||||
getInputType().name());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.conf.layer;
|
||||
|
||||
public abstract class LayerConfiguration {
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.impl.network;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import net.brutex.ai.dnn.api.Layer;
|
||||
import net.brutex.ai.dnn.api.NeuralNetwork;
|
||||
import net.brutex.ai.dnn.api.LayerConfiguration;
|
||||
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
public abstract class AbstractNeuralNetwork implements NeuralNetwork {
|
||||
|
||||
@Getter @Setter @NonNull
|
||||
private String name;
|
||||
|
||||
@Getter @NonNull
|
||||
private NeuralNetworkConfiguration configuration;
|
||||
|
||||
@Getter
|
||||
private final Collection<TrainingListener> trainingListeners = new HashSet<>();
|
||||
|
||||
/**
|
||||
* The neural network holds an instantiation of its configured
|
||||
* layers.
|
||||
* @return the actual runtime layers
|
||||
*/
|
||||
@Getter
|
||||
private final List<Layer> runtimeLayers = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* Sets the configuration to be used. Each time a configuration is set, the runtime layers
|
||||
* of this NeuralNetwork are updated from the configuration.
|
||||
*
|
||||
* @param conf the configuration to use for this network
|
||||
*/
|
||||
public void setConfiguration(net.brutex.ai.dnn.api.NeuralNetworkConfiguration conf) {
|
||||
List<LayerConfiguration> layers = conf.getLayerConfigurations();
|
||||
for(LayerConfiguration layer : layers) {
|
||||
Layer initializedLayer = layer.instantiate(this);
|
||||
this.getRuntimeLayers().add(initializedLayer);
|
||||
}
|
||||
this.configuration = configuration;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,692 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.impl.network;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.nn.api.Classifier;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.api.Updater;
|
||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
|
||||
import org.deeplearning4j.nn.conf.BackpropType;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
|
||||
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.updater.UpdaterCreator;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.optimize.Solver;
|
||||
import org.deeplearning4j.optimize.api.ConvexOptimizer;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.deeplearning4j.util.CrashReportingUtil;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
||||
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
||||
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.heartbeat.Heartbeat;
|
||||
import org.nd4j.linalg.heartbeat.reports.Environment;
|
||||
import org.nd4j.linalg.heartbeat.reports.Event;
|
||||
import org.nd4j.linalg.heartbeat.reports.Task;
|
||||
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
|
||||
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
||||
@Slf4j
|
||||
public class NeuralNetwork extends AbstractNeuralNetwork {
|
||||
|
||||
|
||||
//the hidden neural network layers (including output layer)
|
||||
protected Layer[] layers;
|
||||
|
||||
protected transient ThreadLocal<Long> lastEtlTime = new ThreadLocal<>();
|
||||
|
||||
//Current training data: input features and labels
|
||||
@Getter @Setter @NonNull
|
||||
protected INDArray input;
|
||||
@Getter @Setter
|
||||
protected INDArray labels;
|
||||
|
||||
//Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers
|
||||
@Getter
|
||||
protected transient Map<String, Pointer> helperWorkspaces = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Used to call optimizers during backprop
|
||||
*/
|
||||
@NonNull
|
||||
protected transient Solver solver = new Solver.Builder().configure(getConfiguration()).
|
||||
listeners(getTrainingListeners()).model(this).build();
|
||||
|
||||
|
||||
/**
|
||||
* Create a new NeuralNetwork from the given configuration
|
||||
* @param conf
|
||||
*/
|
||||
public NeuralNetwork(NeuralNetworkConfiguration conf) {
|
||||
if(! validateConfiguration() ) {
|
||||
log.error("Configuration '{}' has failed validation.", conf.getName());
|
||||
throw new RuntimeException();
|
||||
}
|
||||
log.info("Configuration '{}' has been validated successfully.", conf.getName());
|
||||
this.conf = conf;
|
||||
}
|
||||
|
||||
private boolean validateConfiguration() {
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private void logNotImplemented( ) {
|
||||
// getStackTrace() method return
|
||||
// current method name at 0th index
|
||||
String method = new Throwable()
|
||||
.getStackTrace()[1]
|
||||
.getMethodName();
|
||||
log.trace("Method '{}}' is not implemented for {}", method, this.getClass().getSimpleName());
|
||||
}
|
||||
|
||||
/**
|
||||
* This method does initialization of model
|
||||
* <p>
|
||||
* PLEASE NOTE: All implementations should track own state, to avoid double spending
|
||||
*/
|
||||
@Override
|
||||
public void init() {
|
||||
logNotImplemented();
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns model parameters as single INDArray
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public INDArray params() {
|
||||
logNotImplemented();
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns updater state (if applicable), null otherwise
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public INDArray updaterState() {
|
||||
return getUpdater(true) != null ? getUpdater(true).getStateViewArray() : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns Optimizer used for training
|
||||
*
|
||||
* @return the optimizer
|
||||
*/
|
||||
@Override
|
||||
public ConvexOptimizer getOptimizer() {
|
||||
return solver.getOptimizer();
|
||||
}
|
||||
|
||||
|
||||
|
||||
/** Get the updater for this NeuralNetwork from the Solver
|
||||
* @return Updater for NeuralNetwork
|
||||
*/
|
||||
private Updater getUpdater(boolean initializeIfReq) {
|
||||
if (solver == null && initializeIfReq) {
|
||||
synchronized(this){
|
||||
if(solver == null) { //May have been created while waiting for lock
|
||||
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this).build();
|
||||
solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
|
||||
}
|
||||
}
|
||||
}
|
||||
if(solver != null) {
|
||||
return solver.getOptimizer().getUpdater(initializeIfReq);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the updater for the NeuralNetwork in the Solver
|
||||
* */
|
||||
public void setUpdater(@NonNull Updater updater) {
|
||||
solver.getOptimizer().setUpdater(updater);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void fit(MultiDataSet dataSet) {
|
||||
if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) {
|
||||
INDArray features = dataSet.getFeatures(0);
|
||||
INDArray labels = dataSet.getLabels(0);
|
||||
INDArray fMask = null;
|
||||
INDArray lMask = null;
|
||||
|
||||
if (dataSet.getFeaturesMaskArrays() != null)
|
||||
fMask = dataSet.getFeaturesMaskArrays()[0];
|
||||
|
||||
if (dataSet.getFeaturesMaskArrays() != null)
|
||||
lMask = dataSet.getLabelsMaskArrays()[0];
|
||||
|
||||
DataSet ds = new DataSet(features, labels, fMask, lMask);
|
||||
fit(ds);
|
||||
} else {
|
||||
throw new DL4JInvalidInputException(
|
||||
"MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." +
|
||||
"Please consider use of ComputationGraph");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform minibatch training on all minibatches in the MultiDataSetIterator, for the specified number of epochs.
|
||||
* Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a loop
|
||||
*
|
||||
* @param iterator Training data (DataSetIterator). Iterator must support resetting
|
||||
* @param numEpochs Number of training epochs, >= 1
|
||||
*/
|
||||
public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs){
|
||||
Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs);
|
||||
Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" +
|
||||
"iterator has does not support resetting (iterator.resetSupported() returned false)");
|
||||
|
||||
for(int i = 0; i < numEpochs; i++) {
|
||||
fit(iterator);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform minibatch training on all minibatches in the MultiDataSetIterator.<br>
|
||||
* Note: The MultiDataSets in the MultiDataSetIterator must have exactly 1 input and output array (as
|
||||
* MultiLayerNetwork only supports 1 input and 1 output)
|
||||
*
|
||||
* @param iterator Training data (DataSetIterator). Iterator must support resetting
|
||||
*/
|
||||
@Override
|
||||
public void fit(MultiDataSetIterator iterator) {
|
||||
fit(new MultiDataSetWrapperIterator(iterator));
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform minibatch training on all minibatches in the DataSetIterator for 1 epoch.<br>
|
||||
* Note that this method does not do layerwise pretraining.<br>
|
||||
* For pretraining use method pretrain.. #pretrain(DataSetIterator)<br>
|
||||
* @param iterator Training data (DataSetIterator)
|
||||
*/
|
||||
@Override
|
||||
public void fit(DataSetIterator iterator) {
|
||||
try{
|
||||
fitHelper(iterator);
|
||||
} catch (OutOfMemoryError e){
|
||||
CrashReportingUtil.writeMemoryCrashDump(this, e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private synchronized void fitHelper(DataSetIterator iterator){
|
||||
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
|
||||
DataSetIterator iter;
|
||||
boolean destructable = false;
|
||||
if (iterator.asyncSupported()) {
|
||||
iter = new AsyncDataSetIterator(iterator, Math.min(
|
||||
Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true);
|
||||
destructable = true;
|
||||
} else {
|
||||
iter = iterator;
|
||||
}
|
||||
|
||||
for (TrainingListener tl : trainingListeners) {
|
||||
tl.onEpochStart(this);
|
||||
}
|
||||
|
||||
LayerWorkspaceMgr workspaceMgr;
|
||||
if(conf.getTrainingWorkspaceMode() == WorkspaceMode.NONE){
|
||||
workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
|
||||
} else {
|
||||
workspaceMgr = LayerWorkspaceMgr.builder()
|
||||
.with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
|
||||
.with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
|
||||
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
|
||||
.with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
|
||||
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
|
||||
.with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
|
||||
//Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM
|
||||
// as these should be closed by the time updaters are executed
|
||||
//Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this
|
||||
.with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
|
||||
.build();
|
||||
}
|
||||
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
|
||||
|
||||
update(TaskUtils.buildTask(iter));
|
||||
if (!iter.hasNext() && iter.resetSupported()) {
|
||||
iter.reset();
|
||||
}
|
||||
long time1 = System.currentTimeMillis();
|
||||
while (iter.hasNext()) {
|
||||
|
||||
DataSet next = iter.next();
|
||||
long time2 = System.currentTimeMillis();
|
||||
|
||||
lastEtlTime.set((time2 - time1));
|
||||
|
||||
if (next.getFeatures() == null || next.getLabels() == null)
|
||||
break;
|
||||
|
||||
// TODO: basically we want to wrap internals of this loop into workspace
|
||||
|
||||
|
||||
boolean hasMaskArrays = next.hasMaskArrays();
|
||||
|
||||
if (conf.getBackpropType() == BackpropType.TruncatedBPTT) {
|
||||
doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArray(),
|
||||
next.getLabelsMaskArray(), workspaceMgr);
|
||||
} else {
|
||||
if (hasMaskArrays)
|
||||
setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
|
||||
|
||||
setInput(next.getFeatures());
|
||||
setLabels(next.getLabels());
|
||||
|
||||
if (solver == null) {
|
||||
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
//TODO CACHE
|
||||
solver.optimize(workspaceMgr);
|
||||
}
|
||||
|
||||
if (hasMaskArrays)
|
||||
clearLayerMaskArrays();
|
||||
|
||||
time1 = System.currentTimeMillis();
|
||||
synchronizeIterEpochCounts();
|
||||
}
|
||||
|
||||
if (!trainingListeners.isEmpty()) {
|
||||
for (TrainingListener tl : trainingListeners) {
|
||||
tl.onEpochEnd(this);
|
||||
}
|
||||
}
|
||||
|
||||
clearLayersStates();
|
||||
|
||||
if (destructable)
|
||||
((AsyncDataSetIterator) iter).shutdown();
|
||||
|
||||
incrementEpochCount();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Workspace for working memory for a single layer: forward pass and backward pass
|
||||
* Note that this is opened/closed once per op (activate/backpropGradient call)
|
||||
*/
|
||||
protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM";
|
||||
/**
|
||||
* Workspace for storing all layers' activations - used only to store activations (layer inputs) as part of backprop
|
||||
* Not used for inference
|
||||
*/
|
||||
protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT";
|
||||
/**
|
||||
* Next 2 workspaces: used for:
|
||||
* (a) Inference: holds activations for one layer only
|
||||
* (b) Backprop: holds activation gradients for one layer only
|
||||
* In both cases, they are opened and closed on every second layer
|
||||
*/
|
||||
protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1";
|
||||
protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2";
|
||||
|
||||
/**
|
||||
* Workspace for output methods that use OutputAdapter
|
||||
*/
|
||||
protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM";
|
||||
|
||||
/**
|
||||
* Workspace for working memory in RNNs - opened and closed once per RNN time step
|
||||
*/
|
||||
protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM";
|
||||
|
||||
|
||||
protected WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG;
|
||||
|
||||
protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder()
|
||||
.initialSize(0)
|
||||
.overallocationLimit(0.05)
|
||||
.policyLearning(LearningPolicy.FIRST_LOOP)
|
||||
.policyReset(ResetPolicy.BLOCK_LEFT)
|
||||
.policySpill(SpillPolicy.REALLOCATE)
|
||||
.policyAllocation(AllocationPolicy.OVERALLOCATE)
|
||||
.build();
|
||||
|
||||
protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG;
|
||||
|
||||
protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder()
|
||||
.initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT)
|
||||
.policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE)
|
||||
.policyLearning(LearningPolicy.FIRST_LOOP).build();
|
||||
|
||||
|
||||
boolean initDone;
|
||||
protected void update(Task task) {
|
||||
if (!initDone) {
|
||||
initDone = true;
|
||||
Heartbeat heartbeat = Heartbeat.getInstance();
|
||||
task = ModelSerializer.taskByModel(this);
|
||||
Environment env = EnvironmentUtils.buildEnvironment();
|
||||
heartbeat.reportEvent(Event.STANDALONE, env, task);
|
||||
}
|
||||
}
|
||||
|
||||
protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray,
|
||||
INDArray labelsMaskArray, LayerWorkspaceMgr workspaceMgr) {
|
||||
if (input.rank() != 3 || labels.rank() != 3) {
|
||||
log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got "
|
||||
+ Arrays.toString(input.shape()) + "\tand labels with shape "
|
||||
+ Arrays.toString(labels.shape()));
|
||||
return;
|
||||
}
|
||||
if (input.size(2) != labels.size(2)) {
|
||||
log.warn("Input and label time series have different lengths: {} input length, {} label length",
|
||||
input.size(2), labels.size(2));
|
||||
return;
|
||||
}
|
||||
|
||||
int fwdLen = conf.getTbpttFwdLength();
|
||||
update(TaskUtils.buildTask(input, labels));
|
||||
val timeSeriesLength = input.size(2);
|
||||
long nSubsets = timeSeriesLength / fwdLen;
|
||||
if (timeSeriesLength % fwdLen != 0)
|
||||
nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20)
|
||||
|
||||
rnnClearPreviousState();
|
||||
|
||||
for (int i = 0; i < nSubsets; i++) {
|
||||
long startTimeIdx = (long) i * fwdLen;
|
||||
long endTimeIdx = startTimeIdx + fwdLen;
|
||||
if (endTimeIdx > timeSeriesLength)
|
||||
endTimeIdx = timeSeriesLength;
|
||||
|
||||
if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE)
|
||||
throw new ND4JArraySizeException();
|
||||
INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels,
|
||||
featuresMaskArray, labelsMaskArray);
|
||||
|
||||
setInput(subsets[0]);
|
||||
setLabels(subsets[1]);
|
||||
setLayerMaskArrays(subsets[2], subsets[3]);
|
||||
|
||||
if (solver == null) {
|
||||
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
solver.optimize(workspaceMgr);
|
||||
|
||||
//Finally, update the state of the RNN layers:
|
||||
updateRnnStateWithTBPTTState();
|
||||
}
|
||||
|
||||
rnnClearPreviousState();
|
||||
clearLayerMaskArrays();
|
||||
}
|
||||
|
||||
private INDArray[] getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray input, INDArray labels,
|
||||
INDArray fMask, INDArray lMask ){
|
||||
INDArray[] out = new INDArray[4];
|
||||
out[0] = input.get(NDArrayIndex.all(), NDArrayIndex.all(),
|
||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
||||
out[1] = labels.get(NDArrayIndex.all(), NDArrayIndex.all(),
|
||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
||||
|
||||
if (fMask != null) {
|
||||
out[2] = fMask.get(NDArrayIndex.all(),
|
||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
||||
}
|
||||
if (lMask != null) {
|
||||
out[3] = lMask.get(NDArrayIndex.all(),
|
||||
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Intended for internal/developer use
|
||||
*/
|
||||
public void updateRnnStateWithTBPTTState() {
|
||||
Layer[] layers = conf.calculateInnerLayers().toArray(new Layer[]{});
|
||||
for (int i = 0; i < layers.length; i++) {
|
||||
if (layers[i] instanceof RecurrentLayer) {
|
||||
RecurrentLayer l = ((RecurrentLayer) layers[i]);
|
||||
l.rnnSetPreviousState(l.rnnGetTBPTTState());
|
||||
} else if (layers[i] instanceof MultiLayerNetwork) {
|
||||
((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Clear the previous state of the RNN layers (if any).
|
||||
*/
|
||||
public void rnnClearPreviousState() {
|
||||
Layer[] layers = conf.getLayers().toArray(new Layer[]{});
|
||||
if (layers == null)
|
||||
return;
|
||||
for (int i = 0; i < layers.length; i++) {
|
||||
if (layers[i] instanceof RecurrentLayer)
|
||||
((RecurrentLayer) layers[i]).rnnClearPreviousState();
|
||||
else if (layers[i] instanceof MultiLayerNetwork) {
|
||||
((MultiLayerNetwork) layers[i]).rnnClearPreviousState();
|
||||
} else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){
|
||||
((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying()).rnnClearPreviousState();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/** Remove the mask arrays from all layers.<br>
|
||||
* See {@link #setLayerMaskArrays(INDArray, INDArray)} for details on mask arrays.
|
||||
*/
|
||||
public void clearLayerMaskArrays() {
|
||||
Layer[] layers = conf.getLayers().toArray(new Layer[]{});
|
||||
for (Layer layer : layers) {
|
||||
layer.setMaskArray(null);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Increment the epoch count (in the underlying {@link MultiLayerConfiguration} by 1).
|
||||
* Note that this is done <i>automatically</i> when using iterator-based fitting methods, such as
|
||||
* {@link #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, INDArray/INDArray etc),
|
||||
* the network has no way to know when one epoch ends and another starts. In such situations, this method
|
||||
* can be used to increment the epoch counter.<br>
|
||||
* Note that the epoch counter is used for situations such as some learning rate schedules, and the like.
|
||||
*
|
||||
* The current epoch count can be obtained using {@code MultiLayerConfiguration.getLayerwiseConfiguration().getEpochCount()}
|
||||
*/
|
||||
public void incrementEpochCount(){
|
||||
conf.setEpochCount(conf.getEpochCount() + 1);
|
||||
synchronizeIterEpochCounts();
|
||||
}
|
||||
|
||||
protected void synchronizeIterEpochCounts() {
|
||||
//TODO: this is necessary for some schedules - but the redundant values are a little ugly...
|
||||
int currIter = conf.getIterationCount();
|
||||
int currEpoch = conf.getEpochCount();
|
||||
log.error("Something went wrong here. Code incomplete");
|
||||
/*for(Layer l : conf.getLayers()) {
|
||||
l.setIterationCount(currIter);
|
||||
l.setEpochCount(currEpoch);
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
/**
|
||||
* This method just makes sure there's no state preserved within layers
|
||||
*/
|
||||
public void clearLayersStates() {
|
||||
for (Layer layer : layers) {
|
||||
layer.clear();
|
||||
layer.clearNoiseWeightParams();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many
|
||||
* and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths
|
||||
* within the same minibatch.<br>
|
||||
* For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape
|
||||
* [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength]
|
||||
* and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding -
|
||||
* at a given time step).<br>
|
||||
* <b>NOTE</b>: This method is not usually used directly. Instead, methods such as @link #feedForward(INDArray, INDArray, INDArray)}
|
||||
* and @link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking internally.
|
||||
* @param featuresMaskArray Mask array for features (input)
|
||||
* @param labelsMaskArray Mask array for labels (output)
|
||||
* @see #clearLayerMaskArrays()
|
||||
*/
|
||||
public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) {
|
||||
if (featuresMaskArray != null) {
|
||||
|
||||
if (featuresMaskArray.size(0) > Integer.MAX_VALUE)
|
||||
throw new ND4JArraySizeException();
|
||||
//New approach: use feedForwardMaskArray method
|
||||
feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0));
|
||||
|
||||
|
||||
/*
|
||||
//feedforward layers below a RNN layer: need the input (features) mask array
|
||||
//Reason: even if the time series input is zero padded, the output from the dense layers are
|
||||
// non-zero (i.e., activationFunction(0*weights + bias) != 0 in general)
|
||||
//This assumes that the time series input is masked - i.e., values are 0 at the padded time steps,
|
||||
// so we don't need to do anything for the recurrent layer
|
||||
|
||||
//Now, if mask array is 2d -> need to reshape to 1d (column vector) in the exact same order
|
||||
// as is done for 3d -> 2d time series reshaping
|
||||
INDArray reshapedFeaturesMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(featuresMaskArray);
|
||||
|
||||
for( int i=0; i<layers.length-1; i++ ){
|
||||
Type t = layers[i].type();
|
||||
if( t == Type.CONVOLUTIONAL || t == Type.FEED_FORWARD ){
|
||||
layers[i].setMaskArray(reshapedFeaturesMask);
|
||||
} else if( t == Type.RECURRENT ) break;
|
||||
|
||||
}
|
||||
*/
|
||||
}
|
||||
if (labelsMaskArray != null) {
|
||||
if (!(getOutputLayer() instanceof IOutputLayer))
|
||||
return;
|
||||
layers[layers.length - 1].setMaskArray(labelsMaskArray);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get the output layer - i.e., the last layer in the netwok
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public Layer getOutputLayer() {
|
||||
Layer ret = layers[layers.length - 1];
|
||||
if (ret instanceof FrozenLayerWithBackprop) {
|
||||
ret = ((FrozenLayerWithBackprop) ret).getInsideLayer();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState,
|
||||
int minibatchSize) {
|
||||
if (maskArray == null) {
|
||||
for (int i = 0; i < layers.length; i++) {
|
||||
layers[i].feedForwardMaskArray(null, null, minibatchSize);
|
||||
}
|
||||
} else {
|
||||
//Do a forward pass through each preprocessor and layer
|
||||
for (int i = 0; i < layers.length; i++) {
|
||||
InputPreProcessor preProcessor = conf.getInputPreProcessors().get(i);
|
||||
|
||||
if (preProcessor != null) {
|
||||
Pair<INDArray, MaskState> p =
|
||||
preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
|
||||
if (p != null) {
|
||||
maskArray = p.getFirst();
|
||||
currentMaskState = p.getSecond();
|
||||
} else {
|
||||
maskArray = null;
|
||||
currentMaskState = null;
|
||||
}
|
||||
}
|
||||
|
||||
Pair<INDArray, MaskState> p =
|
||||
layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
|
||||
if (p != null) {
|
||||
maskArray = p.getFirst();
|
||||
currentMaskState = p.getSecond();
|
||||
} else {
|
||||
maskArray = null;
|
||||
currentMaskState = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new Pair<>(maskArray, currentMaskState);
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -35,7 +35,7 @@ public interface NeuralNetwork {
|
|||
|
||||
/**
|
||||
* This method does initialization of model
|
||||
*
|
||||
* <p>
|
||||
* PLEASE NOTE: All implementations should track own state, to avoid double spending
|
||||
*/
|
||||
void init();
|
||||
|
@ -49,6 +49,7 @@ public interface NeuralNetwork {
|
|||
|
||||
/**
|
||||
* This method returns updater state (if applicable), null otherwise
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
INDArray updaterState();
|
||||
|
@ -89,14 +90,16 @@ public interface NeuralNetwork {
|
|||
void fit(MultiDataSetIterator iterator);
|
||||
|
||||
/**
|
||||
* This method executes evaluation of the model against given iterator and evaluation implementations
|
||||
* This method executes evaluation of the model against given iterator and evaluation
|
||||
* implementations
|
||||
*
|
||||
* @param iterator
|
||||
*/
|
||||
<T extends IEvaluation> T[] doEvaluation(DataSetIterator iterator, T... evaluations);
|
||||
|
||||
/**
|
||||
* This method executes evaluation of the model against given iterator and evaluation implementations
|
||||
* This method executes evaluation of the model against given iterator and evaluation
|
||||
* implementations
|
||||
*
|
||||
* @param iterator
|
||||
*/
|
||||
|
|
|
@ -52,6 +52,33 @@ import java.io.IOException;
|
|||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
|
||||
* multiple layers. Everything starts with a MultiLayerConfiguration, which organizes those layers
|
||||
* and their hyperparameters. Hyperparameters are variables that determine how a neural network
|
||||
* learns. They include how many times to update the weights of the model, how to initialize those
|
||||
* weights, which activation function to attach to the nodes, which optimization algorithm to use,
|
||||
* and how fast the model should learn. This is what one configuration would look like:
|
||||
* <br/><br/>
|
||||
*
|
||||
* MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()<br/>
|
||||
* .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)<br/>
|
||||
* .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)<br/>
|
||||
* .updater(new Sgd(0.05)) //... other hyperparameters <br/>
|
||||
* .list() .backprop(true)<br/>
|
||||
* .build();<br/><br/>
|
||||
*
|
||||
* With Deeplearning4j, you add a layer
|
||||
* by calling layer on the NeuralNetConfiguration.Builder(), specifying its place in the order of
|
||||
* layers (the zero-indexed layer below is the input layer), the number of input and output nodes,
|
||||
* nIn and nOut, as well as the type: DenseLayer.<br/><br/>
|
||||
*
|
||||
* .layer(0, new DenseLayer.Builder().nIn(784).nOut(250)<br/>
|
||||
* .build())<br/><br/>
|
||||
*
|
||||
* Once you've configured your net, you train the
|
||||
* model with model.fit.
|
||||
*/
|
||||
@Data
|
||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@NoArgsConstructor
|
||||
|
@ -89,6 +116,252 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
//Counter for the number of epochs completed so far. Used for per-epoch schedules
|
||||
protected int epochCount = 0;
|
||||
|
||||
/**
|
||||
* Create a neural net configuration from json
|
||||
*
|
||||
* @param json the neural net configuration from json
|
||||
* @return {@link MultiLayerConfiguration}
|
||||
*/
|
||||
public static MultiLayerConfiguration fromYaml(String json) {
|
||||
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
|
||||
try {
|
||||
return mapper.readValue(json, MultiLayerConfiguration.class);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a neural net configuration from json
|
||||
*
|
||||
* @param json the neural net configuration from json
|
||||
* @return {@link MultiLayerConfiguration}
|
||||
*/
|
||||
public static MultiLayerConfiguration fromJson(String json) {
|
||||
MultiLayerConfiguration conf;
|
||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
||||
try {
|
||||
conf = mapper.readValue(json, MultiLayerConfiguration.class);
|
||||
} catch (InvalidTypeIdException e) {
|
||||
if (e.getMessage().contains("@class")) {
|
||||
try {
|
||||
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
||||
return JsonMappers.getLegacyMapper().readValue(json, MultiLayerConfiguration.class);
|
||||
} catch (InvalidTypeIdException e2) {
|
||||
//Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
|
||||
//1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
|
||||
String msg = e2.getMessage();
|
||||
if (msg != null && msg.contains("Could not resolve type id")) {
|
||||
throw new RuntimeException(
|
||||
"Error deserializing MultiLayerConfiguration - configuration may have a custom " +
|
||||
"layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom"
|
||||
+
|
||||
" layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J",
|
||||
e);
|
||||
}
|
||||
throw new RuntimeException(e2);
|
||||
} catch (IOException e2) {
|
||||
throw new RuntimeException(e2);
|
||||
}
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
} catch (IOException e) {
|
||||
//Check if this exception came from legacy deserializer...
|
||||
String msg = e.getMessage();
|
||||
if (msg != null && msg.contains("legacy")) {
|
||||
throw new RuntimeException(
|
||||
"Error deserializing MultiLayerConfiguration - configuration may have a custom " +
|
||||
"layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be "
|
||||
+
|
||||
"deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)",
|
||||
e);
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
//To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
|
||||
// Previously: enumeration used for loss functions. Now: use classes
|
||||
// IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
|
||||
int layerCount = 0;
|
||||
JsonNode confs = null;
|
||||
for (NeuralNetConfiguration nnc : conf.getConfs()) {
|
||||
Layer l = nnc.getLayer();
|
||||
if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) {
|
||||
//lossFn field null -> may be an old config format, with lossFunction field being for the enum
|
||||
//if so, try walking the JSON graph to extract out the appropriate enum value
|
||||
|
||||
BaseOutputLayer ol = (BaseOutputLayer) l;
|
||||
try {
|
||||
JsonNode jsonNode = mapper.readTree(json);
|
||||
if (confs == null) {
|
||||
confs = jsonNode.get("confs");
|
||||
}
|
||||
if (confs instanceof ArrayNode) {
|
||||
ArrayNode layerConfs = (ArrayNode) confs;
|
||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
||||
if (outputLayerNNCNode == null) {
|
||||
return conf; //Should never happen...
|
||||
}
|
||||
JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
|
||||
|
||||
JsonNode lossFunctionNode = null;
|
||||
if (outputLayerNode.has("output")) {
|
||||
lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
|
||||
} else if (outputLayerNode.has("rnnoutput")) {
|
||||
lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
|
||||
}
|
||||
|
||||
if (lossFunctionNode != null) {
|
||||
String lossFunctionEnumStr = lossFunctionNode.asText();
|
||||
LossFunctions.LossFunction lossFunction = null;
|
||||
try {
|
||||
lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
|
||||
} catch (Exception e) {
|
||||
log.warn(
|
||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
|
||||
e);
|
||||
}
|
||||
|
||||
if (lossFunction != null) {
|
||||
switch (lossFunction) {
|
||||
case MSE:
|
||||
ol.setLossFn(new LossMSE());
|
||||
break;
|
||||
case XENT:
|
||||
ol.setLossFn(new LossBinaryXENT());
|
||||
break;
|
||||
case NEGATIVELOGLIKELIHOOD:
|
||||
ol.setLossFn(new LossNegativeLogLikelihood());
|
||||
break;
|
||||
case MCXENT:
|
||||
ol.setLossFn(new LossMCXENT());
|
||||
break;
|
||||
|
||||
//Remaining: TODO
|
||||
case SQUARED_LOSS:
|
||||
case RECONSTRUCTION_CROSSENTROPY:
|
||||
default:
|
||||
log.warn(
|
||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}",
|
||||
lossFunction);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
log.warn(
|
||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})",
|
||||
(confs != null ? confs.getClass() : null));
|
||||
}
|
||||
} catch (IOException e) {
|
||||
log.warn(
|
||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
|
||||
e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
//Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn")
|
||||
//Try to load the old format if necessary, and create the appropriate IActivation instance
|
||||
if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) {
|
||||
try {
|
||||
JsonNode jsonNode = mapper.readTree(json);
|
||||
if (confs == null) {
|
||||
confs = jsonNode.get("confs");
|
||||
}
|
||||
if (confs instanceof ArrayNode) {
|
||||
ArrayNode layerConfs = (ArrayNode) confs;
|
||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
||||
if (outputLayerNNCNode == null) {
|
||||
return conf; //Should never happen...
|
||||
}
|
||||
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
||||
|
||||
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
JsonNode layerNode = layerWrapperNode.elements().next();
|
||||
JsonNode activationFunction = layerNode.get(
|
||||
"activationFunction"); //Should only have 1 element: "dense", "output", etc
|
||||
|
||||
if (activationFunction != null) {
|
||||
IActivation ia = Activation.fromString(activationFunction.asText())
|
||||
.getActivationFunction();
|
||||
((BaseLayer) l).setActivationFn(ia);
|
||||
}
|
||||
}
|
||||
|
||||
} catch (IOException e) {
|
||||
log.warn(
|
||||
"Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
|
||||
e);
|
||||
}
|
||||
}
|
||||
|
||||
if (!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) {
|
||||
return conf;
|
||||
}
|
||||
|
||||
layerCount++;
|
||||
}
|
||||
return conf;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied
|
||||
* from handling of {@link Activation} above.
|
||||
*
|
||||
* @return True if all is well and layer iteration shall continue. False else-wise.
|
||||
*/
|
||||
private static boolean handleLegacyWeightInitFromJson(String json, Layer l, ObjectMapper mapper,
|
||||
JsonNode confs, int layerCount) {
|
||||
if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) {
|
||||
try {
|
||||
JsonNode jsonNode = mapper.readTree(json);
|
||||
if (confs == null) {
|
||||
confs = jsonNode.get("confs");
|
||||
}
|
||||
if (confs instanceof ArrayNode) {
|
||||
ArrayNode layerConfs = (ArrayNode) confs;
|
||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
||||
if (outputLayerNNCNode == null) {
|
||||
return false; //Should never happen...
|
||||
}
|
||||
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
||||
|
||||
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
JsonNode layerNode = layerWrapperNode.elements().next();
|
||||
JsonNode weightInit = layerNode.get(
|
||||
"weightInit"); //Should only have 1 element: "dense", "output", etc
|
||||
JsonNode distribution = layerNode.get("dist");
|
||||
|
||||
Distribution dist = null;
|
||||
if (distribution != null) {
|
||||
dist = mapper.treeToValue(distribution, Distribution.class);
|
||||
}
|
||||
|
||||
if (weightInit != null) {
|
||||
final IWeightInit wi = WeightInit.valueOf(weightInit.asText())
|
||||
.getWeightInitFunction(dist);
|
||||
((BaseLayer) l).setWeightInitFn(wi);
|
||||
}
|
||||
}
|
||||
|
||||
} catch (IOException e) {
|
||||
log.warn(
|
||||
"Layer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON",
|
||||
e);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
public int getEpochCount() {
|
||||
return epochCount;
|
||||
}
|
||||
|
@ -114,22 +387,6 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a neural net configuration from json
|
||||
*
|
||||
* @param json the neural net configuration from json
|
||||
* @return {@link MultiLayerConfiguration}
|
||||
*/
|
||||
public static MultiLayerConfiguration fromYaml(String json) {
|
||||
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
|
||||
try {
|
||||
return mapper.readValue(json, MultiLayerConfiguration.class);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @return JSON representation of NN configuration
|
||||
*/
|
||||
|
@ -146,217 +403,6 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a neural net configuration from json
|
||||
*
|
||||
* @param json the neural net configuration from json
|
||||
* @return {@link MultiLayerConfiguration}
|
||||
*/
|
||||
public static MultiLayerConfiguration fromJson(String json) {
|
||||
MultiLayerConfiguration conf;
|
||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
||||
try {
|
||||
conf = mapper.readValue(json, MultiLayerConfiguration.class);
|
||||
} catch (InvalidTypeIdException e){
|
||||
if(e.getMessage().contains("@class")){
|
||||
try {
|
||||
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
||||
return JsonMappers.getLegacyMapper().readValue(json, MultiLayerConfiguration.class);
|
||||
} catch (InvalidTypeIdException e2){
|
||||
//Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
|
||||
//1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
|
||||
String msg = e2.getMessage();
|
||||
if(msg != null && msg.contains("Could not resolve type id")){
|
||||
throw new RuntimeException("Error deserializing MultiLayerConfiguration - configuration may have a custom " +
|
||||
"layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" +
|
||||
" layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", e);
|
||||
}
|
||||
throw new RuntimeException(e2);
|
||||
} catch (IOException e2){
|
||||
throw new RuntimeException(e2);
|
||||
}
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
} catch (IOException e) {
|
||||
//Check if this exception came from legacy deserializer...
|
||||
String msg = e.getMessage();
|
||||
if (msg != null && msg.contains("legacy")) {
|
||||
throw new RuntimeException("Error deserializing MultiLayerConfiguration - configuration may have a custom " +
|
||||
"layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be " +
|
||||
"deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", e);
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
|
||||
//To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
|
||||
// Previously: enumeration used for loss functions. Now: use classes
|
||||
// IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
|
||||
int layerCount = 0;
|
||||
JsonNode confs = null;
|
||||
for (NeuralNetConfiguration nnc : conf.getConfs()) {
|
||||
Layer l = nnc.getLayer();
|
||||
if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) {
|
||||
//lossFn field null -> may be an old config format, with lossFunction field being for the enum
|
||||
//if so, try walking the JSON graph to extract out the appropriate enum value
|
||||
|
||||
BaseOutputLayer ol = (BaseOutputLayer) l;
|
||||
try {
|
||||
JsonNode jsonNode = mapper.readTree(json);
|
||||
if (confs == null) {
|
||||
confs = jsonNode.get("confs");
|
||||
}
|
||||
if (confs instanceof ArrayNode) {
|
||||
ArrayNode layerConfs = (ArrayNode) confs;
|
||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
||||
if (outputLayerNNCNode == null)
|
||||
return conf; //Should never happen...
|
||||
JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
|
||||
|
||||
JsonNode lossFunctionNode = null;
|
||||
if (outputLayerNode.has("output")) {
|
||||
lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
|
||||
} else if (outputLayerNode.has("rnnoutput")) {
|
||||
lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
|
||||
}
|
||||
|
||||
if (lossFunctionNode != null) {
|
||||
String lossFunctionEnumStr = lossFunctionNode.asText();
|
||||
LossFunctions.LossFunction lossFunction = null;
|
||||
try {
|
||||
lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
|
||||
} catch (Exception e) {
|
||||
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
|
||||
e);
|
||||
}
|
||||
|
||||
if (lossFunction != null) {
|
||||
switch (lossFunction) {
|
||||
case MSE:
|
||||
ol.setLossFn(new LossMSE());
|
||||
break;
|
||||
case XENT:
|
||||
ol.setLossFn(new LossBinaryXENT());
|
||||
break;
|
||||
case NEGATIVELOGLIKELIHOOD:
|
||||
ol.setLossFn(new LossNegativeLogLikelihood());
|
||||
break;
|
||||
case MCXENT:
|
||||
ol.setLossFn(new LossMCXENT());
|
||||
break;
|
||||
|
||||
//Remaining: TODO
|
||||
case SQUARED_LOSS:
|
||||
case RECONSTRUCTION_CROSSENTROPY:
|
||||
default:
|
||||
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}",
|
||||
lossFunction);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})",
|
||||
(confs != null ? confs.getClass() : null));
|
||||
}
|
||||
} catch (IOException e) {
|
||||
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
|
||||
e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
//Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn")
|
||||
//Try to load the old format if necessary, and create the appropriate IActivation instance
|
||||
if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) {
|
||||
try {
|
||||
JsonNode jsonNode = mapper.readTree(json);
|
||||
if (confs == null) {
|
||||
confs = jsonNode.get("confs");
|
||||
}
|
||||
if (confs instanceof ArrayNode) {
|
||||
ArrayNode layerConfs = (ArrayNode) confs;
|
||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
||||
if (outputLayerNNCNode == null)
|
||||
return conf; //Should never happen...
|
||||
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
||||
|
||||
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
JsonNode layerNode = layerWrapperNode.elements().next();
|
||||
JsonNode activationFunction = layerNode.get("activationFunction"); //Should only have 1 element: "dense", "output", etc
|
||||
|
||||
if (activationFunction != null) {
|
||||
IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
|
||||
((BaseLayer) l).setActivationFn(ia);
|
||||
}
|
||||
}
|
||||
|
||||
} catch (IOException e) {
|
||||
log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
|
||||
e);
|
||||
}
|
||||
}
|
||||
|
||||
if(!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) {
|
||||
return conf;
|
||||
}
|
||||
|
||||
layerCount++;
|
||||
}
|
||||
return conf;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied from handling of {@link Activation}
|
||||
* above.
|
||||
* @return True if all is well and layer iteration shall continue. False else-wise.
|
||||
*/
|
||||
private static boolean handleLegacyWeightInitFromJson(String json, Layer l, ObjectMapper mapper, JsonNode confs, int layerCount) {
|
||||
if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) {
|
||||
try {
|
||||
JsonNode jsonNode = mapper.readTree(json);
|
||||
if (confs == null) {
|
||||
confs = jsonNode.get("confs");
|
||||
}
|
||||
if (confs instanceof ArrayNode) {
|
||||
ArrayNode layerConfs = (ArrayNode) confs;
|
||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
||||
if (outputLayerNNCNode == null)
|
||||
return false; //Should never happen...
|
||||
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
||||
|
||||
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
JsonNode layerNode = layerWrapperNode.elements().next();
|
||||
JsonNode weightInit = layerNode.get("weightInit"); //Should only have 1 element: "dense", "output", etc
|
||||
JsonNode distribution = layerNode.get("dist");
|
||||
|
||||
Distribution dist = null;
|
||||
if(distribution != null) {
|
||||
dist = mapper.treeToValue(distribution, Distribution.class);
|
||||
}
|
||||
|
||||
if (weightInit != null) {
|
||||
final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist);
|
||||
((BaseLayer) l).setWeightInitFn(wi);
|
||||
}
|
||||
}
|
||||
|
||||
} catch (IOException e) {
|
||||
log.warn("Layer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON",
|
||||
e);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return toJson();
|
||||
|
@ -434,12 +480,13 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
inputType = confs.get(i).getLayer().getOutputType(i, inputType);
|
||||
}
|
||||
|
||||
return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class, "MultiLayerNetwork", inputType);
|
||||
return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class,
|
||||
"MultiLayerNetwork", inputType);
|
||||
}
|
||||
|
||||
/**
|
||||
* For the given input shape/type for the network, return a list of activation sizes for each layer in the network.<br>
|
||||
* i.e., list.get(i) is the output activation sizes for layer i
|
||||
* For the given input shape/type for the network, return a list of activation sizes for each
|
||||
* layer in the network.<br> i.e., list.get(i) is the output activation sizes for layer i
|
||||
*
|
||||
* @param inputType Input type for the network
|
||||
* @return A lits of activation types for the network, indexed by layer number
|
||||
|
@ -482,11 +529,10 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
|
||||
|
||||
/**
|
||||
* Whether to over ride the nIn
|
||||
* configuration forcibly upon construction.
|
||||
* Default value is true
|
||||
* @param overrideNinUponBuild Whether to over ride the nIn
|
||||
* configuration forcibly upon construction.
|
||||
* Whether to over ride the nIn configuration forcibly upon construction. Default value is true
|
||||
*
|
||||
* @param overrideNinUponBuild Whether to over ride the nIn configuration forcibly upon
|
||||
* construction.
|
||||
* @return builder pattern
|
||||
*/
|
||||
public Builder overrideNinUponBuild(boolean overrideNinUponBuild) {
|
||||
|
@ -495,8 +541,7 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
|
||||
/**
|
||||
* Specify the processors.
|
||||
* These are used at each layer for doing things like normalization and
|
||||
* Specify the processors. These are used at each layer for doing things like normalization and
|
||||
* shaping of input.
|
||||
*
|
||||
* @param processor what to use to preProcess the data.
|
||||
|
@ -507,6 +552,23 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder inputPreProcessor(String layer, InputPreProcessor processor) {
|
||||
int i = 0;
|
||||
for (NeuralNetConfiguration conf : this.confs) {
|
||||
if (conf.getLayer().getLayerName().equals(layer)) {
|
||||
inputPreProcessors.put(i, processor);
|
||||
log.trace("Assigned preProcessor to layer with name {} at index {}", layer, i);
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
if (i >= this.confs.size()) {
|
||||
log.warn("Could not assign preprocessor to layer with name {} as layer was not found.",
|
||||
layer);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder inputPreProcessors(Map<Integer, InputPreProcessor> processors) {
|
||||
this.inputPreProcessors = processors;
|
||||
return this;
|
||||
|
@ -531,10 +593,9 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
|
||||
/**
|
||||
* This method defines how/if preOutput cache is handled:
|
||||
* NONE: cache disabled (default value)
|
||||
* HOST: Host memory will be used
|
||||
* DEVICE: GPU memory will be used (on CPU backends effect will be the same as for HOST)
|
||||
* This method defines how/if preOutput cache is handled: NONE: cache disabled (default value)
|
||||
* HOST: Host memory will be used DEVICE: GPU memory will be used (on CPU backends effect will
|
||||
* be the same as for HOST)
|
||||
*
|
||||
* @param cacheMode
|
||||
* @return
|
||||
|
@ -545,9 +606,9 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
|
||||
/**
|
||||
* The type of backprop. Default setting is used for most networks (MLP, CNN etc),
|
||||
* but optionally truncated BPTT can be used for training recurrent neural networks.
|
||||
* If using TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength()
|
||||
* The type of backprop. Default setting is used for most networks (MLP, CNN etc), but
|
||||
* optionally truncated BPTT can be used for training recurrent neural networks. If using
|
||||
* TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength()
|
||||
*/
|
||||
public Builder backpropType(@NonNull BackpropType type) {
|
||||
this.backpropType = type;
|
||||
|
@ -555,9 +616,9 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
|
||||
/**
|
||||
* When doing truncated BPTT: how many steps should we do?<br>
|
||||
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)<br>
|
||||
* See: <a href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
|
||||
* When doing truncated BPTT: how many steps should we do?<br> Only applicable when doing
|
||||
* backpropType(BackpropType.TruncatedBPTT)<br> See: <a
|
||||
* href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
|
||||
*
|
||||
* @param bpttLength length > 0
|
||||
*/
|
||||
|
@ -567,14 +628,14 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
|
||||
/**
|
||||
* When doing truncated BPTT: how many steps of forward pass should we do
|
||||
* before doing (truncated) backprop?<br>
|
||||
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)<br>
|
||||
* Typically tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter,
|
||||
* but may be larger than it in some circumstances (but never smaller)<br>
|
||||
* Ideally your training data time series length should be divisible by this
|
||||
* When doing truncated BPTT: how many steps of forward pass should we do before doing
|
||||
* (truncated) backprop?<br> Only applicable when doing
|
||||
* backpropType(BackpropType.TruncatedBPTT)<br> Typically tBPTTForwardLength parameter is same
|
||||
* as the tBPTTBackwardLength parameter, but may be larger than it in some circumstances (but
|
||||
* never smaller)<br> Ideally your training data time series length should be divisible by this
|
||||
* This is the k1 parameter on pg23 of
|
||||
* <a href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
|
||||
* <a
|
||||
* href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
|
||||
*
|
||||
* @param forwardLength Forward length > 0, >= backwardLength
|
||||
*/
|
||||
|
@ -584,10 +645,10 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
|
||||
/**
|
||||
* When doing truncated BPTT: how many steps of backward should we do?<br>
|
||||
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)<br>
|
||||
* This is the k2 parameter on pg23 of
|
||||
* <a href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
|
||||
* When doing truncated BPTT: how many steps of backward should we do?<br> Only applicable when
|
||||
* doing backpropType(BackpropType.TruncatedBPTT)<br> This is the k2 parameter on pg23 of
|
||||
* <a
|
||||
* href="http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf">http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf</a>
|
||||
*
|
||||
* @param backwardLength <= forwardLength
|
||||
*/
|
||||
|
@ -607,12 +668,12 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
|
||||
/**
|
||||
* Enabled by default. If enabled, the output layer configuration will be validated, to throw an exception on
|
||||
* likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.<br>
|
||||
* If disabled (false) no output layer validation will be performed.<br>
|
||||
* Disabling this validation is not recommended, as the configurations that fail validation usually will
|
||||
* not be able to learn correctly. However, the option to disable this validation is provided for advanced users
|
||||
* when creating non-standard architectures.
|
||||
* Enabled by default. If enabled, the output layer configuration will be validated, to throw an
|
||||
* exception on likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.<br> If
|
||||
* disabled (false) no output layer validation will be performed.<br> Disabling this validation
|
||||
* is not recommended, as the configurations that fail validation usually will not be able to
|
||||
* learn correctly. However, the option to disable this validation is provided for advanced
|
||||
* users when creating non-standard architectures.
|
||||
*
|
||||
* @param validate If true: validate output layer configuration. False: don't validate
|
||||
*/
|
||||
|
@ -622,10 +683,11 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
|
||||
/**
|
||||
* Enabled by default. If enabled, an exception will be throw when using the (invalid) combination of truncated
|
||||
* backpropagation through time (TBPTT) with either a GlobalPoolingLayer or LastTimeStepLayer.<br>
|
||||
* It is possible to disable this validation to allow what is almost certainly an invalid configuration to be used,
|
||||
* however this is not recommended.
|
||||
* Enabled by default. If enabled, an exception will be throw when using the (invalid)
|
||||
* combination of truncated backpropagation through time (TBPTT) with either a
|
||||
* GlobalPoolingLayer or LastTimeStepLayer.<br> It is possible to disable this validation to
|
||||
* allow what is almost certainly an invalid configuration to be used, however this is not
|
||||
* recommended.
|
||||
*
|
||||
* @param validate Whether TBPTT validation should be performed
|
||||
*/
|
||||
|
@ -635,7 +697,9 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
|
||||
/**
|
||||
* Set the DataType for the network parameters and activations for all layers in the network. Default: Float
|
||||
* Set the DataType for the network parameters and activations for all layers in the network.
|
||||
* Default: Float
|
||||
*
|
||||
* @param dataType Datatype to use for parameters and activations
|
||||
*/
|
||||
public Builder dataType(@NonNull DataType dataType) {
|
||||
|
@ -646,9 +710,12 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
|
||||
public MultiLayerConfiguration build() {
|
||||
//Validate BackpropType setting
|
||||
if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH) && backpropType != BackpropType.TruncatedBPTT) {
|
||||
log.warn("Truncated backpropagation through time lengths have been configured with values " + tbpttFwdLength
|
||||
+ " and " + tbpttBackLength + " but backprop type is set to " + backpropType + ". TBPTT configuration" +
|
||||
if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH)
|
||||
&& backpropType != BackpropType.TruncatedBPTT) {
|
||||
log.warn("Truncated backpropagation through time lengths have been configured with values "
|
||||
+ tbpttFwdLength
|
||||
+ " and " + tbpttBackLength + " but backprop type is set to " + backpropType
|
||||
+ ". TBPTT configuration" +
|
||||
" settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
|
||||
}
|
||||
|
||||
|
@ -657,15 +724,18 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
for (int i = 0; i < confs.size(); i++) {
|
||||
Layer l = confs.get(i).getLayer();
|
||||
if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) {
|
||||
throw new IllegalStateException("Invalid network configuration detected: Truncated backpropagation through time (TBPTT)" +
|
||||
" cannot be used with layer " + i + " of type " + l.getClass().getName() + ": TBPTT is incompatible with this layer type (which is designed " +
|
||||
"to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n" +
|
||||
throw new IllegalStateException(
|
||||
"Invalid network configuration detected: Truncated backpropagation through time (TBPTT)"
|
||||
+
|
||||
" cannot be used with layer " + i + " of type " + l.getClass().getName()
|
||||
+ ": TBPTT is incompatible with this layer type (which is designed " +
|
||||
"to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n"
|
||||
+
|
||||
"This check can be disabled using validateTbpttConfig(false) but this is not recommended.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (inputType == null && inputPreProcessors.get(0) == null) {
|
||||
//User hasn't set the InputType. Sometimes we can infer it...
|
||||
// For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to feed in
|
||||
|
@ -690,7 +760,6 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
//Add preprocessors and set nIns, if InputType has been set
|
||||
// Builder.inputType field can be set in 1 of 4 ways:
|
||||
// 1. User calls setInputType directly
|
||||
|
@ -723,17 +792,19 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType;
|
||||
feedForwardLayer.setNIn(recurrent.getTimeSeriesLength());
|
||||
}
|
||||
} else {
|
||||
l.setNIn(currentInputType,
|
||||
overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user
|
||||
}
|
||||
else
|
||||
l.setNIn(currentInputType, overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user
|
||||
} else {
|
||||
l.setNIn(currentInputType,
|
||||
overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user
|
||||
}
|
||||
else
|
||||
l.setNIn(currentInputType, overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user
|
||||
|
||||
} else {
|
||||
l.setNIn(currentInputType,
|
||||
overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user
|
||||
}
|
||||
else
|
||||
l.setNIn(currentInputType, overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user
|
||||
|
||||
|
||||
currentInputType = l.getOutputType(i, currentInputType);
|
||||
}
|
||||
|
@ -758,7 +829,8 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
//Validate output layer configurations...
|
||||
for (NeuralNetConfiguration n : conf.getConfs()) {
|
||||
Layer l = n.getLayer();
|
||||
OutputLayerUtil.validateOutputLayer(l.getLayerName(), l); //No-op for non output/loss layers
|
||||
OutputLayerUtil.validateOutputLayer(l.getLayerName(),
|
||||
l); //No-op for non output/loss layers
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nn.conf.layers.wrapper;
|
||||
|
||||
import java.util.Collection;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import net.brutex.ai.dnn.api.LayerConfiguration;
|
||||
import net.brutex.ai.dnn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
@Builder(builderClassName = "Builder", access = AccessLevel.PUBLIC)
|
||||
public class BuildingBlockLayer extends BaseLayer implements LayerConfiguration {
|
||||
|
||||
@NonNull
|
||||
@Getter
|
||||
private NeuralNetworkConfiguration conf;
|
||||
|
||||
@Override
|
||||
public Layer instantiate(NeuralNetConfiguration conf,
|
||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
||||
boolean initializeParams, DataType networkDataType) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParamInitializer initializer() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPretrainParam(String paramName) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create and return an instance of a LayerConfiguration.
|
||||
*
|
||||
* @param network the "holding" network for the instance
|
||||
* @return the new layer instance
|
||||
*/
|
||||
@Override
|
||||
public net.brutex.ai.dnn.api.Layer instantiate(NeuralNetwork network) {
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -101,7 +101,7 @@ import java.util.*;
|
|||
|
||||
|
||||
@Slf4j
|
||||
public class MultiLayerNetwork implements Serializable, Classifier, Layer, NeuralNetwork {
|
||||
public class MultiLayerNetwork implements Serializable, Classifier, Layer, org.deeplearning4j.nn.api.NeuralNetwork {
|
||||
|
||||
//the hidden neural network layers (including output layer)
|
||||
protected Layer[] layers;
|
||||
|
|
|
@ -100,6 +100,7 @@ include ':cavis-dnn:cavis-dnn-data:cavis-dnn-data-utility-iterators'
|
|||
include ':cavis-dnn:cavis-dnn-modelimport'
|
||||
include ':cavis-dnn:cavis-dnn-nlp'
|
||||
include ':cavis-dnn:cavis-dnn-nn'
|
||||
include ':cavis-dnn:cavis-dnn-nn-api'
|
||||
include ':cavis-dnn:cavis-dnn-nn-parent'
|
||||
include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-server'
|
||||
include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-client'
|
||||
|
@ -154,3 +155,4 @@ include ':brutex-extended-tests'
|
|||
include ':cavis-full'
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue