Compare commits

..

No commits in common. "1c3496ad844392afb1411b78234e9720a023d972" and "1749b9e4afa92f11f03889e73f671237e4dc771b" have entirely different histories.

3742 changed files with 83578 additions and 66081 deletions

View File

@ -1,15 +1,8 @@
FROM nvidia/cuda:11.4.3-cudnn8-devel-ubuntu20.04 FROM nvidia/cuda:11.4.0-cudnn8-devel-ubuntu20.04
RUN apt-get update && \ RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
#Build cmake version from source \ RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \
#RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \ tar -xvf cmake-3.24.2.tar.gz && cd cmake-3.24.2 && \
# tar -xvf cmake-3.24.2.tar.gz && cd cmake-3.24.2 && \ ./bootstrap && make && make install
# ./bootstrap && make && make install
RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2-linux-x86_64.sh && \
mkdir /opt/cmake && sh ./cmake-3.24.2-linux-x86_64.sh --skip-license --prefix=/opt/cmake && ln -s /opt/cmake/bin/cmake /usr/bin/cmake && \
rm cmake-3.24.2-linux-x86_64.sh
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf

30
.gitignore vendored
View File

@ -36,8 +36,6 @@ pom.xml.versionsBackup
pom.xml.next pom.xml.next
release.properties release.properties
*dependency-reduced-pom.xml *dependency-reduced-pom.xml
**/build/*
.gradle/*
# Specific for Nd4j # Specific for Nd4j
*.md5 *.md5
@ -52,12 +50,12 @@ release.properties
*.dylib *.dylib
.vs/ .vs/
.vscode/ .vscode/
.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/resources/bin nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/resources/bin
.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/writeNumpy.csv nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/writeNumpy.csv
.old/nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/data-all* nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/data-all*
.old/nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/checkpoint nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/checkpoint
.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/onnx/ nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/onnx/
.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/tensorflow/ nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/tensorflow/
doc_sources/ doc_sources/
doc_sources_* doc_sources_*
@ -69,8 +67,8 @@ venv/
venv2/ venv2/
# Ignore the nd4j files that are created by javacpp at build to stop merge conflicts # Ignore the nd4j files that are created by javacpp at build to stop merge conflicts
.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
# Ignore meld temp files # Ignore meld temp files
*.orig *.orig
@ -84,15 +82,3 @@ bruai4j-native-common/cmake*
*.dll *.dll
/bruai4j-native/bruai4j-native-common/blasbuild/ /bruai4j-native/bruai4j-native-common/blasbuild/
/bruai4j-native/bruai4j-native-common/build/ /bruai4j-native/bruai4j-native-common/build/
/cavis-native/cavis-native-lib/blasbuild/
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/classes/org.deeplearning4j.gradientcheck.AttentionLayerTest.html
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/base-style.css
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/style.css
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/js/report.js
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/packages/org.deeplearning4j.gradientcheck.html
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/index.html
/cavis-dnn/cavis-dnn-core/build/resources/main/iris.dat
/cavis-dnn/cavis-dnn-core/build/resources/test/junit-platform.properties
/cavis-dnn/cavis-dnn-core/build/resources/test/logback-test.xml
/cavis-dnn/cavis-dnn-core/build/test-results/cudaTest/TEST-org.deeplearning4j.gradientcheck.AttentionLayerTest.xml
/cavis-dnn/cavis-dnn-core/build/tmp/jar/MANIFEST.MF

View File

@ -1,82 +0,0 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
pipeline {
agent {
label 'linux'
}
stages {
stage('prep-build-environment-linux-cpu') {
steps {
checkout scm
sh 'gcc --version'
sh 'cmake --version'
sh 'sh ./gradlew --version'
}
}
stage('build-linux-cpu') {
environment {
MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cpu \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
}
}
/*stage('test-linux-cpu') {
environment {
MAVEN = credentials('Internal Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -PCAVIS_CHIP=cpu \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
}
}*/
stage('publish-linux-cpu') {
environment {
MAVEN = credentials('Internal Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
sh 'sh ./gradlew publish --stacktrace -PCAVIS_CHIP=cpu \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
}
}
}
}

View File

@ -1,62 +0,0 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
pipeline {
agent {
/* dockerfile {
filename 'Dockerfile'
dir '.docker'
label 'linux && cuda'
//additionalBuildArgs '--build-arg version=1.0.2'
//args '--gpus all' --needed for test only, you can build without GPU
}
*/
label 'linux && cuda'
}
stages {
stage('prep-build-environment-linux-cuda') {
steps {
checkout scm
//sh 'nvidia-smi'
sh 'nvcc --version'
sh 'gcc --version'
sh 'cmake --version'
sh 'sh ./gradlew --version'
}
}
stage('build-linux-cuda') {
environment {
MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
sh 'sh ./gradlew build --stacktrace -PCAVIS_CHIP=cuda \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
}
}
}
}

View File

@ -1,66 +0,0 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
pipeline {
agent {
dockerfile {
filename 'Dockerfile'
dir '.docker'
label 'linux && docker && cuda'
//additionalBuildArgs '--build-arg version=1.0.2'
//args '--gpus all' --needed for test only, you can build without GPU
}
}
stages {
stage("Build all chip") {
parallel {
stage('prep-build-environment-linux-cuda') {
steps {
checkout scm
//sh 'nvidia-smi'
sh 'nvcc --version'
sh 'gcc --version'
sh 'cmake --version'
sh 'sh ./gradlew --version'
}
}
stage('build-linux-cuda') {
environment {
MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
}
}
}
}
}
}

View File

@ -1,88 +0,0 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
pipeline {
agent {
dockerfile {
filename 'Dockerfile'
dir '.docker'
label 'linux && docker'
//additionalBuildArgs '--build-arg version=1.0.2'
//args '--gpus all'
}
}
stages {
stage('prep-build-environment-linux-cpu') {
steps {
checkout scm
sh 'gcc --version'
sh 'cmake --version'
sh 'sh ./gradlew --version'
}
}
stage('build-linux-cpu') {
environment {
MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cpu \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
}
}
stage('test-linux-cpu') {
environment {
MAVEN = credentials('Internal Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
//sh 'sh ./gradlew test --stacktrace -PCAVIS_CHIP=cpu \
// -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
// -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
}
}
stage('publish-linux-cpu') {
environment {
MAVEN = credentials('Internal Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
sh 'sh ./gradlew publish --stacktrace -PCAVIS_CHIP=cpu \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
}
}
}
}

View File

@ -1,49 +0,0 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
pipeline {
agent {
dockerfile {
filename 'Dockerfile'
dir '.docker'
label 'linux && docker'
//additionalBuildArgs '--build-arg version=1.0.2'
//args '--gpus all'
}
}
stages {
stage('publish-linux-cpu') {
environment {
MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
sh 'sh ./gradlew publish -x test -PCAVIS_CHIP=cpu \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
}
}
}
}

View File

@ -20,13 +20,14 @@
*/ */
pipeline { pipeline {
agent { agent {
dockerfile { dockerfile {
filename 'Dockerfile' filename 'Dockerfile'
dir '.docker' dir '.docker'
label 'linux && docker && cuda' label 'linuxdocker'
//additionalBuildArgs '--build-arg version=1.0.2' //additionalBuildArgs '--build-arg version=1.0.2'
//args '--gpus all' --needed for test only, you can build without GPU args '--gpus all'
} }
} }
@ -34,7 +35,7 @@ pipeline {
stage('prep-build-environment-linux-cuda') { stage('prep-build-environment-linux-cuda') {
steps { steps {
checkout scm checkout scm
//sh 'nvidia-smi' sh 'nvidia-smi'
sh 'nvcc --version' sh 'nvcc --version'
sh 'gcc --version' sh 'gcc --version'
sh 'cmake --version' sh 'cmake --version'
@ -43,33 +44,19 @@ pipeline {
} }
stage('build-linux-cuda') { stage('build-linux-cuda') {
environment { environment {
MAVEN = credentials('Internal_Archiva') MAVEN = credentials('Internal Archiva')
OSSRH = credentials('OSSRH') OSSRH = credentials('OSSRH')
} }
steps { steps {
withGradle { withGradle {
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \ sh 'sh ./gradlew publish --stacktrace -x test -PCAVIS_CHIP=cuda \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
} }
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
} }
}
stage('test-linux-cuda') {
environment {
MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -Pskip-native=true -PCAVIS_CHIP=cuda \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
}
} }
} }
} }

View File

@ -1,58 +0,0 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
pipeline {
agent {
dockerfile {
filename 'Dockerfile'
dir '.docker'
label 'WSL-docker'
//additionalBuildArgs '--build-arg version=1.0.2'
//args '--gpus all'
}
}
stages {
stage('prep-build-environment-linux-cpu') {
steps {
checkout scm
sh 'gcc --version'
sh 'cmake --version'
sh 'sh ./gradlew --version'
}
}
stage('build-linux-cpu') {
environment {
MAVEN = credentials('Internal_Archiva')
OSSRH = credentials('OSSRH')
}
steps {
withGradle {
sh 'sh ./gradlew publish --stacktrace -x test -PCAVIS_CHIP=cpu \
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
}
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
}
}
}
}

View File

@ -48,12 +48,12 @@ Deeplearning4J offers a very high level API for defining even complex neural net
you how LeNet, a convolutional neural network, is defined in DL4J. you how LeNet, a convolutional neural network, is defined in DL4J.
```java ```java
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed) .seed(seed)
.l2(0.0005) .l2(0.0005)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.updater(new Adam(1e-3)) .updater(new Adam(1e-3))
.list()
.layer(new ConvolutionLayer.Builder(5, 5) .layer(new ConvolutionLayer.Builder(5, 5)
.stride(1,1) .stride(1,1)
.nOut(20) .nOut(20)
@ -78,7 +78,7 @@ NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.nOut(outputNum) .nOut(outputNum)
.activation(Activation.SOFTMAX) .activation(Activation.SOFTMAX)
.build()) .build())
.inputType(InputType.convolutionalFlat(28,28,1)) .setInputType(InputType.convolutionalFlat(28,28,1))
.build(); .build();
``` ```

24
arbiter/.travis.yml Normal file
View File

@ -0,0 +1,24 @@
branches:
only:
- master
notifications:
email: false
dist: trusty
sudo: false
cache:
directories:
- $HOME/.m2
language: java
jdk:
- openjdk8
matrix:
include:
- os: linux
env: OS=linux-x86_64 SCALA=2.10
install: true
script: bash ./ci/build-linux-x86_64.sh
- os: linux
env: OS=linux-x86_64 SCALA=2.11
install: true
script: bash ./ci/build-linux-x86_64.sh

45
arbiter/README.md Normal file
View File

@ -0,0 +1,45 @@
# Arbiter
A tool dedicated to tuning (hyperparameter optimization) of machine learning models. Part of the DL4J Suite of Machine Learning / Deep Learning tools for the enterprise.
## Modules
Arbiter contains the following modules:
- arbiter-core: Defines the API and core functionality, and also contains functionality for the Arbiter UI
- arbiter-deeplearning4j: For hyperparameter optimization of DL4J models (MultiLayerNetwork and ComputationGraph networks)
## Hyperparameter Optimization Functionality
The open-source version of Arbiter currently defines two methods of hyperparameter optimization:
- Grid search
- Random search
For optimization of complex models such as neural networks (those with more than a few hyperparameters), random search is superior to grid search, though Bayesian hyperparameter optimization schemes
For a comparison of random and grid search methods, see [Random Search for Hyper-parameter Optimization (Bergstra and Bengio, 2012)](http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf).
### Core Concepts and Classes in Arbiter for Hyperparameter Optimization
In order to conduct hyperparameter optimization in Arbiter, it is necessary for the user to understand and define the following:
- **Parameter Space**: A ```ParameterSpace<P>``` specifies the type and allowable values of hyperparameters for a model configuration of type ```P```. For example, ```P``` could be a MultiLayerConfiguration for DL4J
- **Candidate Generator**: A ```CandidateGenerator<C>``` is used to generate candidate models configurations of some type ```C```. The following implementations are defined in arbiter-core:
- ```RandomSearchCandidateGenerator```
- ```GridSearchCandidateGenerator```
- **Score Function**: A ```ScoreFunction<M,D>``` is used to score a model of type ```M``` given data of type ```D```. For example, in DL4J a score function might be used to calculate the classification accuracy from a DataSetIterator
- A key concept here is that they score is a single numerical (double precision) value that we either want to minimize or maximize - this is the goal of hyperparameter optimization
- **Termination Conditions**: One or more ```TerminationCondition``` instances must be provided to the ```OptimizationConfiguration```. ```TerminationCondition``` instances are used to control when hyperparameter optimization should be stopped. Some built-in termination conditions:
- ```MaxCandidatesCondition```: Terminate if more than the specified number of candidate hyperparameter configurations have been executed
- ```MaxTimeCondition```: Terminate after a specified amount of time has elapsed since starting the optimization
- **Result Saver**: The ```ResultSaver<C,M,A>``` interface is used to specify how the results of each hyperparameter optimization run should be saved. For example, whether saving should be done to local disk, to a database, to HDFS, or simply stored in memory.
- Note that ```ResultSaver.saveModel``` method returns a ```ResultReference``` object, which provides a mechanism for re-loading both the model and score from wherever it may be saved.
- **Optimization Configuration**: An ```OptimizationConfiguration<C,M,D,A>``` ties together the above configuration options in a fluent (builder) pattern.
- **Candidate Executor**: The ```CandidateExecutor<C,M,D,A>``` interface provides a layer of abstraction between the configuration and execution of each instance of learning. Currently, the only option is the ```LocalCandidateExecutor```, which is used to execute learning on a single machine (in the current JVM). In principle, other execution methods (for example, on Spark or cloud computing machines) could be implemented.
- **Optimization Runner**: The ```OptimizationRunner``` uses an ```OptimizationConfiguration``` and a ```CandidateExecutor``` to actually run the optimization, and save the results.
### Optimization of DeepLearning4J Models
(This section: forthcoming)

View File

@ -0,0 +1,97 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>arbiter</artifactId>
<groupId>net.brutex.ai</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>arbiter-core</artifactId>
<packaging>jar</packaging>
<name>arbiter-core</name>
<dependencies>
<dependency>
<groupId>net.brutex.ai</groupId>
<artifactId>nd4j-api</artifactId>
<version>${project.version}</version>
<exclusions>
<exclusion>
<groupId>com.google.code.findbugs</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>${guava.jre.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>${commons.lang.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>${commons.math.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>
<dependency>
<groupId>joda-time</groupId>
<artifactId>joda-time</artifactId>
<version>${jodatime.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>net.brutex.ai</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-joda</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>net.brutex.ai</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
<classifier>windows-x86_64</classifier>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,91 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<assembly>
<id>bin</id>
<!-- START SNIPPET: formats -->
<formats>
<format>tar.gz</format>
<!--
<format>tar.bz2</format>
<format>zip</format>
-->
</formats>
<!-- END SNIPPET: formats -->
<dependencySets>
<dependencySet>
<outputDirectory>lib</outputDirectory>
<includes>
<include>*:jar:*</include>
</includes>
<excludes>
<exclude>*:sources</exclude>
</excludes>
</dependencySet>
</dependencySets>
<!-- START SNIPPET: fileSets -->
<fileSets>
<fileSet>
<includes>
<include>readme.txt</include>
</includes>
</fileSet>
<fileSet>
<directory>src/main/resources/bin/</directory>
<outputDirectory>bin</outputDirectory>
<includes>
<include>arbiter</include>
</includes>
<lineEnding>unix</lineEnding>
<fileMode>0755</fileMode>
</fileSet>
<fileSet>
<directory>examples</directory>
<outputDirectory>examples</outputDirectory>
<!--
<lineEnding>unix</lineEnding>
https://stackoverflow.com/questions/2958282/stranges-files-in-my-assembly-since-switching-to-lineendingunix-lineending
-->
</fileSet>
<!--
<fileSet>
<directory>src/bin</directory>
<outputDirectory>bin</outputDirectory>
<includes>
<include>hello</include>
</includes>
<lineEnding>unix</lineEnding>
<fileMode>0755</fileMode>
</fileSet>
-->
<fileSet>
<directory>target</directory>
<outputDirectory>./</outputDirectory>
<includes>
<include>*.jar</include>
</includes>
</fileSet>
</fileSets>
<!-- END SNIPPET: fileSets -->
</assembly>

View File

@ -0,0 +1,74 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* Created by Alex on 23/07/2017.
*/
public abstract class AbstractParameterSpace<T> implements ParameterSpace<T> {
@Override
public Map<String, ParameterSpace> getNestedSpaces() {
Map<String, ParameterSpace> m = new LinkedHashMap<>();
//Need to manually build and walk the class heirarchy...
Class<?> currClass = this.getClass();
List<Class<?>> classHeirarchy = new ArrayList<>();
while (currClass != Object.class) {
classHeirarchy.add(currClass);
currClass = currClass.getSuperclass();
}
for (int i = classHeirarchy.size() - 1; i >= 0; i--) {
//Use reflection here to avoid a mass of boilerplate code...
Field[] allFields = classHeirarchy.get(i).getDeclaredFields();
for (Field f : allFields) {
String name = f.getName();
Class<?> fieldClass = f.getType();
boolean isParamSpacefield = ParameterSpace.class.isAssignableFrom(fieldClass);
if (!isParamSpacefield) {
continue;
}
f.setAccessible(true);
ParameterSpace<?> p;
try {
p = (ParameterSpace<?>) f.get(this);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
if (p != null) {
m.put(name, p);
}
}
}
return m;
}
}

View File

@ -0,0 +1,57 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.deeplearning4j.arbiter.optimize.generator.util.SerializedSupplier;
import org.nd4j.common.function.Supplier;
import java.io.Serializable;
import java.util.Map;
/**
* Candidate: a proposed hyperparameter configuration.
* Also includes a map for data parameters, to configure things like data preprocessing, etc.
*/
@Data
@AllArgsConstructor
public class Candidate<C> implements Serializable {
private Supplier<C> supplier;
private int index;
private double[] flatParameters;
private Map<String, Object> dataParameters;
private Exception exception;
public Candidate(C value, int index, double[] flatParameters, Map<String,Object> dataParameters, Exception e) {
this(new SerializedSupplier<C>(value), index, flatParameters, dataParameters, e);
}
public Candidate(C value, int index, double[] flatParameters) {
this(new SerializedSupplier<C>(value), index, flatParameters);
}
public Candidate(Supplier<C> value, int index, double[] flatParameters) {
this(value, index, flatParameters, null, null);
}
public C getValue(){
return supplier.get();
}
}

View File

@ -0,0 +1,68 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api;
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
/**
* A CandidateGenerator proposes candidates (i.e., hyperparameter configurations) for evaluation.
* This abstraction allows for different ways of generating the next configuration to test; for example,
* random search, grid search, Bayesian optimization methods, etc.
*
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface CandidateGenerator {
/**
* Is this candidate generator able to generate more candidates? This will always return true in some
* cases, but some search strategies have a limit (grid search, for example)
*/
boolean hasMoreCandidates();
/**
* Generate a candidate hyperparameter configuration
*/
Candidate getCandidate();
/**
* Report results for the candidate generator.
*
* @param result The results to report
*/
void reportResults(OptimizationResult result);
/**
* @return Get the parameter space for this candidate generator
*/
ParameterSpace<?> getParameterSpace();
/**
* @param rngSeed Set the random number generator seed for the candidate generator
*/
void setRngSeed(long rngSeed);
/**
* @return The type (class) of the generated candidates
*/
Class<?> getCandidateType();
}

View File

@ -0,0 +1,60 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api;
import lombok.Data;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.io.Serializable;
/**
* An optimization result represents the results of an optimization run, including the canditate configuration, the
* trained model, the score for that model, and index of the model
*
* @author Alex Black
*/
@Data
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@JsonIgnoreProperties({"resultReference"})
public class OptimizationResult implements Serializable {
@JsonProperty
private Candidate candidate;
@JsonProperty
private Double score;
@JsonProperty
private int index;
@JsonProperty
private Object modelSpecificResults;
@JsonProperty
private CandidateInfo candidateInfo;
private ResultReference resultReference;
public OptimizationResult(Candidate candidate, Double score, int index, Object modelSpecificResults,
CandidateInfo candidateInfo, ResultReference resultReference) {
this.candidate = candidate;
this.score = score;
this.index = index;
this.modelSpecificResults = modelSpecificResults;
this.candidateInfo = candidateInfo;
this.resultReference = resultReference;
}
}

View File

@ -0,0 +1,81 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.util.List;
import java.util.Map;
/**
* ParameterSpace: defines the acceptable ranges of values a given parameter may take.
* Note that parameter spaces can be simple (like {@code ParameterSpace<Double>}) or complicated, including
* multiple nested ParameterSpaces
*
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ParameterSpace<P> {
/**
* Generate a candidate given a set of values. These values are then mapped to a specific candidate, using some
* mapping function (such as the prior probability distribution)
*
* @param parameterValues A set of values, each in the range [0,1], of length {@link #numParameters()}
*/
P getValue(double[] parameterValues);
/**
* Get the total number of parameters (hyperparameters) to be optimized. This includes optional parameters from
* different parameter subpaces. (Thus, not every parameter may be used in every candidate)
*
* @return Number of hyperparameters to be optimized
*/
int numParameters();
/**
* Collect a list of parameters, recursively. Note that leaf parameters are parameters that do not have any
* nested parameter spaces
*/
List<ParameterSpace> collectLeaves();
/**
* Get a list of nested parameter spaces by name. Note that the returned parameter spaces may in turn have further
* nested parameter spaces. The map should be empty for leaf parameter spaces
*
* @return A map of nested parameter spaces
*/
Map<String, ParameterSpace> getNestedSpaces();
/**
* Is this ParameterSpace a leaf? (i.e., does it contain other ParameterSpaces internally?)
*/
@JsonIgnore
boolean isLeaf();
/**
* For leaf ParameterSpaces: set the indices of the leaf ParameterSpace.
* Expects input of length {@link #numParameters()}. Throws exception if {@link #isLeaf()} is false.
*
* @param indices Indices to set. Length should equal {@link #numParameters()}
*/
void setIndices(int... indices);
}

View File

@ -0,0 +1,62 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.Callable;
/**
* The TaskCreator is used to take a candidate configuration, data provider and score function, and create something
* that can be executed as a Callable
*
* @author Alex Black
*/
public interface TaskCreator {
/**
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
*
* @param candidate Candidate (model) configuration to be trained
* @param dataProvider DataProvider, for the data
* @param scoreFunction Score function to be used to evaluate the model
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
* @return A callable that returns an OptimizationResult, once optimization is complete
*/
@Deprecated
Callable<OptimizationResult> create(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction,
List<StatusListener> statusListeners, IOptimizationRunner runner);
/**
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
*
* @param candidate Candidate (model) configuration to be trained
* @param dataSource Data source
* @param dataSourceProperties Properties (may be null) for the data source
* @param scoreFunction Score function to be used to evaluate the model
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
* @return A callable that returns an OptimizationResult, once optimization is complete
*/
Callable<OptimizationResult> create(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties,
ScoreFunction scoreFunction, List<StatusListener> statusListeners, IOptimizationRunner runner);
}

View File

@ -0,0 +1,43 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api;
import java.util.HashMap;
import java.util.Map;
public class TaskCreatorProvider {
private static Map<Class<? extends ParameterSpace>, Class<? extends TaskCreator>> map = new HashMap<>();
public synchronized static TaskCreator defaultTaskCreatorFor(Class<? extends ParameterSpace> paramSpaceClass){
Class<? extends TaskCreator> c = map.get(paramSpaceClass);
try {
if(c == null){
return null;
}
return c.newInstance();
} catch (Exception e){
throw new RuntimeException("Could not create new instance of task creator class: " + c + " - missing no-arg constructor?", e);
}
}
public synchronized static void registerDefaultTaskCreatorClass(Class<? extends ParameterSpace> spaceClass,
Class<? extends TaskCreator> creatorClass){
map.put(spaceClass, creatorClass);
}
}

View File

@ -0,0 +1,82 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.adapter;
import lombok.AllArgsConstructor;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* An abstract class used for adapting one type into another. Subclasses of this need to merely implement 2 simple methods
*
* @param <F> Type to convert from
* @param <T> Type to convert to
* @author Alex Black
*/
@AllArgsConstructor
public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
protected abstract T convertValue(F from);
protected abstract ParameterSpace<F> underlying();
protected abstract String underlyingName();
@Override
public T getValue(double[] parameterValues) {
return convertValue(underlying().getValue(parameterValues));
}
@Override
public int numParameters() {
return underlying().numParameters();
}
@Override
public List<ParameterSpace> collectLeaves() {
ParameterSpace p = underlying();
if(p.isLeaf()){
return Collections.singletonList(p);
}
return underlying().collectLeaves();
}
@Override
public Map<String, ParameterSpace> getNestedSpaces() {
return Collections.singletonMap(underlyingName(), (ParameterSpace)underlying());
}
@Override
public boolean isLeaf() {
return false; //Underlying may be a leaf, however
}
@Override
public void setIndices(int... indices) {
underlying().setIndices(indices);
}
@Override
public String toString() {
return underlying().toString();
}
}

View File

@ -0,0 +1,54 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.data;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.io.Serializable;
import java.util.Map;
/**
* DataProvider interface abstracts out the providing of data
* @deprecated Use {@link DataSource}
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@Deprecated
public interface DataProvider extends Serializable {
/**
* Get training data given some parameters for the data.
* Data parameters map is used to specify things like batch
* size data preprocessing
*
* @param dataParameters Parameters for data. May be null or empty for default data
* @return training data
*/
Object trainData(Map<String, Object> dataParameters);
/**
* Get training data given some parameters for the data. Data parameters map is used to specify things like batch
* size data preprocessing
*
* @param dataParameters Parameters for data. May be null or empty for default data
* @return training data
*/
Object testData(Map<String, Object> dataParameters);
Class<?> getDataType();
}

View File

@ -0,0 +1,89 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.data;
import lombok.Data;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
import java.util.Map;
/**
* This is a {@link DataProvider} for
* an {@link DataSetIteratorFactory} which
* based on a key of {@link DataSetIteratorFactoryProvider#FACTORY_KEY}
* will create {@link org.nd4j.linalg.dataset.api.iterator.DataSetIterator}
* for use with arbiter.
*
* This {@link DataProvider} is mainly meant for use for command line driven
* applications.
*
* @author Adam Gibson
*/
@Data
public class DataSetIteratorFactoryProvider implements DataProvider {
public final static String FACTORY_KEY = "org.deeplearning4j.arbiter.data.data.factory";
/**
* Get training data given some parameters for the data.
* Data parameters map is used to specify things like batch
* size data preprocessing
*
* @param dataParameters Parameters for data. May be null or empty for default data
* @return training data
*/
@Override
public DataSetIteratorFactory trainData(Map<String, Object> dataParameters) {
return create(dataParameters);
}
/**
* Get training data given some parameters for the data. Data parameters map
* is used to specify things like batch
* size data preprocessing
*
* @param dataParameters Parameters for data. May be null or empty for default data
* @return training data
*/
@Override
public DataSetIteratorFactory testData(Map<String, Object> dataParameters) {
return create(dataParameters);
}
@Override
public Class<?> getDataType() {
return DataSetIteratorFactory.class;
}
private DataSetIteratorFactory create(Map<String, Object> dataParameters) {
if (dataParameters == null)
throw new IllegalArgumentException(
"Data parameters is null. Please specify a class name to create a dataset iterator.");
if (!dataParameters.containsKey(FACTORY_KEY))
throw new IllegalArgumentException(
"No data set iterator factory class found. Please specify a class name with key "
+ FACTORY_KEY);
String value = dataParameters.get(FACTORY_KEY).toString();
try {
Class<? extends DataSetIteratorFactory> clazz =
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
return clazz.newInstance();
} catch (Exception e) {
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
}
}
}

View File

@ -0,0 +1,57 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.data;
import java.io.Serializable;
import java.util.Properties;
/**
* DataSource: defines where the data should come from for training and testing.
* Note that implementations must have a no-argument contsructor
*
* @author Alex Black
*/
public interface DataSource extends Serializable {
/**
* Configure the current data source with the specified properties
* Note: These properties are fixed for the training instance, and are optionally provided by the user
* at the configuration stage.
* The properties could be anything - and are usually specific to each DataSource implementation.
* For example, values such as batch size could be set using these properties
* @param properties Properties to apply to the data source instance
*/
void configure(Properties properties);
/**
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
*/
Object trainData();
/**
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
*/
Object testData();
/**
* The type of data returned by {@link #trainData()} and {@link #testData()}.
* Usually DataSetIterator or MultiDataSetIterator
* @return Class of the objects returned by trainData and testData
*/
Class<?> getDataType();
}

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.evaluation;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import java.io.Serializable;
import java.util.List;
/**
* ModelEvaluator: Used to conduct additional evaluation.
* For example, this may be classification performance on a test set or similar
*/
public interface ModelEvaluator extends Serializable {
Object evaluateModel(Object model, DataProvider dataProvider);
/**
* @return The model types supported by this class
*/
List<Class<?>> getSupportedModelTypes();
/**
* @return The datatypes supported by this class
*/
List<Class<?>> getSupportedDataTypes();
}

View File

@ -0,0 +1,63 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.saving;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
/**
* A simple class to store optimization results in-memory.
* Not recommended for large (or a large number of) models.
*/
@NoArgsConstructor
public class InMemoryResultSaver implements ResultSaver {
@Override
public ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException {
return new InMemoryResult(result, modelResult);
}
@Override
public List<Class<?>> getSupportedCandidateTypes() {
return Collections.<Class<?>>singletonList(Object.class);
}
@Override
public List<Class<?>> getSupportedModelTypes() {
return Collections.<Class<?>>singletonList(Object.class);
}
@AllArgsConstructor
private static class InMemoryResult implements ResultReference {
private OptimizationResult result;
private Object modelResult;
@Override
public OptimizationResult getResult() throws IOException {
return result;
}
@Override
public Object getResultModel() throws IOException {
return modelResult;
}
}
}

View File

@ -0,0 +1,37 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.saving;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.io.IOException;
/**
* Idea: We can't store all results in memory in general (might have thousands of candidates with millions of
* parameters each)
* So instead: return a reference to the saved result. Idea is that the result may be saved to disk or a database,
* and we can easily load it back into memory (if/when required) using the getResult() method
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ResultReference {
OptimizationResult getResult() throws IOException;
Object getResultModel() throws IOException;
}

View File

@ -0,0 +1,57 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.saving;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.io.IOException;
import java.util.List;
/**
* The ResultSaver interface provides a means of saving models in such a way that they can be loaded back into memory later,
* regardless of where/how they are saved.
*
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ResultSaver {
/**
* Save the model (including configuration and any additional evaluation/results)
*
* @param result Optimization result for the model to save
* @param modelResult Model result to save
* @return ResultReference, such that the result can be loaded back into memory
* @throws IOException If IO error occurs during model saving
*/
ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException;
/**
* @return The candidate types supported by this class
*/
List<Class<?>> getSupportedCandidateTypes();
/**
* @return The model types supported by this class
*/
List<Class<?>> getSupportedModelTypes();
}

View File

@ -0,0 +1,75 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.score;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Properties;
/**
* ScoreFunction defines the objective of hyperparameter optimization.
* Specifically, it is used to calculate a score for a given model, relative to the data set provided
* in the configuration.
*
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ScoreFunction extends Serializable {
/**
* Calculate and return the score, for the given model and data provider
*
* @param model Model to score
* @param dataProvider Data provider - data to use
* @param dataParameters Parameters for data
* @return Calculated score
*/
double score(Object model, DataProvider dataProvider, Map<String, Object> dataParameters);
/**
* Calculate and return the score, for the given model and data provider
*
* @param model Model to score
* @param dataSource Data source
* @param dataSourceProperties data source properties
* @return Calculated score
*/
double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties);
/**
* Should this score function be minimized or maximized?
*
* @return true if score should be minimized, false if score should be maximized
*/
boolean minimize();
/**
* @return The model types supported by this class
*/
List<Class<?>> getSupportedModelTypes();
/**
* @return The data types supported by this class
*/
List<Class<?>> getSupportedDataTypes();
}

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.termination;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import com.fasterxml.jackson.annotation.JsonProperty;
/**
* Terminate hyperparameter search when the number of candidates exceeds a specified value.
* Note that this is counted as number of completed candidates, plus number of failed candidates.
*/
@AllArgsConstructor
@NoArgsConstructor
@Data
public class MaxCandidatesCondition implements TerminationCondition {
@JsonProperty
private int maxCandidates;
@Override
public void initialize(IOptimizationRunner optimizationRunner) {
//No op
}
@Override
public boolean terminate(IOptimizationRunner optimizationRunner) {
return optimizationRunner.numCandidatesCompleted() + optimizationRunner.numCandidatesFailed() >= maxCandidates;
}
@Override
public String toString() {
return "MaxCandidatesCondition(" + maxCandidates + ")";
}
}

View File

@ -0,0 +1,81 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.termination;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.concurrent.TimeUnit;
/**
* Terminate hyperparameter optimization after
* a fixed amount of time has passed
* @author Alex Black
*/
@NoArgsConstructor
@Data
public class MaxTimeCondition implements TerminationCondition {
private static final DateTimeFormatter formatter = DateTimeFormat.forPattern("dd-MMM HH:mm ZZ");
private long duration;
private TimeUnit timeUnit;
private long startTime;
private long endTime;
private MaxTimeCondition(@JsonProperty("duration") long duration, @JsonProperty("timeUnit") TimeUnit timeUnit,
@JsonProperty("startTime") long startTime, @JsonProperty("endTime") long endTime) {
this.duration = duration;
this.timeUnit = timeUnit;
this.startTime = startTime;
this.endTime = endTime;
}
/**
* @param duration Duration of time
* @param timeUnit Unit that the duration is specified in
*/
public MaxTimeCondition(long duration, TimeUnit timeUnit) {
this.duration = duration;
this.timeUnit = timeUnit;
}
@Override
public void initialize(IOptimizationRunner optimizationRunner) {
startTime = System.currentTimeMillis();
this.endTime = startTime + timeUnit.toMillis(duration);
}
@Override
public boolean terminate(IOptimizationRunner optimizationRunner) {
return System.currentTimeMillis() >= endTime;
}
@Override
public String toString() {
if (startTime > 0) {
return "MaxTimeCondition(" + duration + "," + timeUnit + ",start=\"" + formatter.print(startTime)
+ "\",end=\"" + formatter.print(endTime) + "\")";
} else {
return "MaxTimeCondition(" + duration + "," + timeUnit + "\")";
}
}
}

View File

@ -0,0 +1,45 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.api.termination;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
/**
* Global termination condition for conducting hyperparameter optimization.
* Termination conditions are used to determine if/when the optimization should stop.
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@JsonInclude(JsonInclude.Include.NON_NULL)
public interface TerminationCondition {
/**
* Initialize the termination condition (such as starting timers, etc).
*/
void initialize(IOptimizationRunner optimizationRunner);
/**
* Determine whether optimization should be terminated
*
* @param optimizationRunner Optimization runner
* @return true if learning should be terminated, false otherwise
*/
boolean terminate(IOptimizationRunner optimizationRunner);
}

View File

@ -0,0 +1,226 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.config;
import lombok.*;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
/**
* OptimizationConfiguration ties together all of the various
* components (such as data, score functions, result saving etc)
* required to execute hyperparameter optimization.
*
* @author Alex Black
*/
@Data
@NoArgsConstructor
@EqualsAndHashCode(exclude = {"dataProvider", "terminationConditions", "candidateGenerator", "resultSaver"})
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public class OptimizationConfiguration {
@JsonSerialize
private DataProvider dataProvider;
@JsonSerialize
private Class<? extends DataSource> dataSource;
@JsonSerialize
private Properties dataSourceProperties;
@JsonSerialize
private CandidateGenerator candidateGenerator;
@JsonSerialize
private ResultSaver resultSaver;
@JsonSerialize
private ScoreFunction scoreFunction;
@JsonSerialize
private List<TerminationCondition> terminationConditions;
@JsonSerialize
private Long rngSeed;
@Getter
@Setter
private long executionStartTime;
private OptimizationConfiguration(Builder builder) {
this.dataProvider = builder.dataProvider;
this.dataSource = builder.dataSource;
this.dataSourceProperties = builder.dataSourceProperties;
this.candidateGenerator = builder.candidateGenerator;
this.resultSaver = builder.resultSaver;
this.scoreFunction = builder.scoreFunction;
this.terminationConditions = builder.terminationConditions;
this.rngSeed = builder.rngSeed;
if (rngSeed != null)
candidateGenerator.setRngSeed(rngSeed);
//Validate the configuration: data types, score types, etc
//TODO
//Validate that the dataSource has a no-arg constructor
if (dataSource != null) {
try {
dataSource.getConstructor();
} catch (NoSuchMethodException e) {
throw new IllegalStateException("Data source class " + dataSource.getName() + " does not have a public no-argument constructor");
}
}
}
public static class Builder {
private DataProvider dataProvider;
private Class<? extends DataSource> dataSource;
private Properties dataSourceProperties;
private CandidateGenerator candidateGenerator;
private ResultSaver resultSaver;
private ScoreFunction scoreFunction;
private List<TerminationCondition> terminationConditions;
private Long rngSeed;
/**
* @deprecated Use {@link #dataSource(Class, Properties)}
*/
@Deprecated
public Builder dataProvider(DataProvider dataProvider) {
this.dataProvider = dataProvider;
return this;
}
/**
* DataSource: defines where the data should come from for training and testing.
* Note that implementations must have a no-argument contsructor
*
* @param dataSource Class for the data source
* @param dataSourceProperties May be null. Properties for configuring the data source
*/
public Builder dataSource(Class<? extends DataSource> dataSource, Properties dataSourceProperties) {
this.dataSource = dataSource;
this.dataSourceProperties = dataSourceProperties;
return this;
}
public Builder candidateGenerator(CandidateGenerator candidateGenerator) {
this.candidateGenerator = candidateGenerator;
return this;
}
public Builder modelSaver(ResultSaver resultSaver) {
this.resultSaver = resultSaver;
return this;
}
public Builder scoreFunction(ScoreFunction scoreFunction) {
this.scoreFunction = scoreFunction;
return this;
}
/**
* Termination conditions to use
*
* @param conditions
* @return
*/
public Builder terminationConditions(TerminationCondition... conditions) {
terminationConditions = Arrays.asList(conditions);
return this;
}
public Builder terminationConditions(List<TerminationCondition> terminationConditions) {
this.terminationConditions = terminationConditions;
return this;
}
public Builder rngSeed(long rngSeed) {
this.rngSeed = rngSeed;
return this;
}
public OptimizationConfiguration build() {
return new OptimizationConfiguration(this);
}
}
/**
* Create an optimization configuration from the json
*
* @param json the json to create the config from
* For type definitions
* @see OptimizationConfiguration
*/
public static OptimizationConfiguration fromYaml(String json) {
try {
return JsonMapper.getYamlMapper().readValue(json, OptimizationConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Create an optimization configuration from the json
*
* @param json the json to create the config from
* @see OptimizationConfiguration
*/
public static OptimizationConfiguration fromJson(String json) {
try {
return JsonMapper.getMapper().readValue(json, OptimizationConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Return a json configuration of this optimization configuration
*
* @return
*/
public String toJson() {
try {
return JsonMapper.getMapper().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Return a yaml configuration of this optimization configuration
*
* @return
*/
public String toYaml() {
try {
return JsonMapper.getYamlMapper().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,96 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.distribution;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException;
/**
* Degenerate distribution: i.e., integer "distribution" that is just a fixed value
*/
public class DegenerateIntegerDistribution implements IntegerDistribution {
private int value;
public DegenerateIntegerDistribution(int value) {
this.value = value;
}
@Override
public double probability(int x) {
return (x == value ? 1.0 : 0.0);
}
@Override
public double cumulativeProbability(int x) {
return (x >= value ? 1.0 : 0.0);
}
@Override
public double cumulativeProbability(int x0, int x1) throws NumberIsTooLargeException {
return (value >= x0 && value <= x1 ? 1.0 : 0.0);
}
@Override
public int inverseCumulativeProbability(double p) throws OutOfRangeException {
throw new UnsupportedOperationException();
}
@Override
public double getNumericalMean() {
return value;
}
@Override
public double getNumericalVariance() {
return 0;
}
@Override
public int getSupportLowerBound() {
return value;
}
@Override
public int getSupportUpperBound() {
return value;
}
@Override
public boolean isSupportConnected() {
return true;
}
@Override
public void reseedRandomGenerator(long seed) {
//no op
}
@Override
public int sample() {
return value;
}
@Override
public int[] sample(int sampleSize) {
int[] out = new int[sampleSize];
for (int i = 0; i < out.length; i++)
out[i] = value;
return out;
}
}

View File

@ -0,0 +1,149 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.distribution;
import org.apache.commons.math3.distribution.*;
/**
* Distribution utils for Apache Commons math distributions - which don't provide equals, hashcode, toString methods,
* don't implement serializable etc.
* Which makes unit testing etc quite difficult.
*
* @author Alex Black
*/
public class DistributionUtils {
private DistributionUtils() {}
public static boolean distributionsEqual(RealDistribution a, RealDistribution b) {
if (a.getClass() != b.getClass())
return false;
Class<?> c = a.getClass();
if (c == BetaDistribution.class) {
BetaDistribution ba = (BetaDistribution) a;
BetaDistribution bb = (BetaDistribution) b;
return ba.getAlpha() == bb.getAlpha() && ba.getBeta() == bb.getBeta();
} else if (c == CauchyDistribution.class) {
CauchyDistribution ca = (CauchyDistribution) a;
CauchyDistribution cb = (CauchyDistribution) b;
return ca.getMedian() == cb.getMedian() && ca.getScale() == cb.getScale();
} else if (c == ChiSquaredDistribution.class) {
ChiSquaredDistribution ca = (ChiSquaredDistribution) a;
ChiSquaredDistribution cb = (ChiSquaredDistribution) b;
return ca.getDegreesOfFreedom() == cb.getDegreesOfFreedom();
} else if (c == ExponentialDistribution.class) {
ExponentialDistribution ea = (ExponentialDistribution) a;
ExponentialDistribution eb = (ExponentialDistribution) b;
return ea.getMean() == eb.getMean();
} else if (c == FDistribution.class) {
FDistribution fa = (FDistribution) a;
FDistribution fb = (FDistribution) b;
return fa.getNumeratorDegreesOfFreedom() == fb.getNumeratorDegreesOfFreedom()
&& fa.getDenominatorDegreesOfFreedom() == fb.getDenominatorDegreesOfFreedom();
} else if (c == GammaDistribution.class) {
GammaDistribution ga = (GammaDistribution) a;
GammaDistribution gb = (GammaDistribution) b;
return ga.getShape() == gb.getShape() && ga.getScale() == gb.getScale();
} else if (c == LevyDistribution.class) {
LevyDistribution la = (LevyDistribution) a;
LevyDistribution lb = (LevyDistribution) b;
return la.getLocation() == lb.getLocation() && la.getScale() == lb.getScale();
} else if (c == LogNormalDistribution.class) {
LogNormalDistribution la = (LogNormalDistribution) a;
LogNormalDistribution lb = (LogNormalDistribution) b;
return la.getScale() == lb.getScale() && la.getShape() == lb.getShape();
} else if (c == NormalDistribution.class) {
NormalDistribution na = (NormalDistribution) a;
NormalDistribution nb = (NormalDistribution) b;
return na.getMean() == nb.getMean() && na.getStandardDeviation() == nb.getStandardDeviation();
} else if (c == ParetoDistribution.class) {
ParetoDistribution pa = (ParetoDistribution) a;
ParetoDistribution pb = (ParetoDistribution) b;
return pa.getScale() == pb.getScale() && pa.getShape() == pb.getShape();
} else if (c == TDistribution.class) {
TDistribution ta = (TDistribution) a;
TDistribution tb = (TDistribution) b;
return ta.getDegreesOfFreedom() == tb.getDegreesOfFreedom();
} else if (c == TriangularDistribution.class) {
TriangularDistribution ta = (TriangularDistribution) a;
TriangularDistribution tb = (TriangularDistribution) b;
return ta.getSupportLowerBound() == tb.getSupportLowerBound()
&& ta.getSupportUpperBound() == tb.getSupportUpperBound() && ta.getMode() == tb.getMode();
} else if (c == UniformRealDistribution.class) {
UniformRealDistribution ua = (UniformRealDistribution) a;
UniformRealDistribution ub = (UniformRealDistribution) b;
return ua.getSupportLowerBound() == ub.getSupportLowerBound()
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
} else if (c == WeibullDistribution.class) {
WeibullDistribution wa = (WeibullDistribution) a;
WeibullDistribution wb = (WeibullDistribution) b;
return wa.getShape() == wb.getShape() && wa.getScale() == wb.getScale();
} else if (c == LogUniformDistribution.class ){
LogUniformDistribution lu_a = (LogUniformDistribution)a;
LogUniformDistribution lu_b = (LogUniformDistribution)b;
return lu_a.getMin() == lu_b.getMin() && lu_a.getMax() == lu_b.getMax();
} else {
throw new UnsupportedOperationException("Unknown or not supported RealDistribution: " + c);
}
}
public static boolean distributionEquals(IntegerDistribution a, IntegerDistribution b) {
if (a.getClass() != b.getClass())
return false;
Class<?> c = a.getClass();
if (c == BinomialDistribution.class) {
BinomialDistribution ba = (BinomialDistribution) a;
BinomialDistribution bb = (BinomialDistribution) b;
return ba.getNumberOfTrials() == bb.getNumberOfTrials()
&& ba.getProbabilityOfSuccess() == bb.getProbabilityOfSuccess();
} else if (c == GeometricDistribution.class) {
GeometricDistribution ga = (GeometricDistribution) a;
GeometricDistribution gb = (GeometricDistribution) b;
return ga.getProbabilityOfSuccess() == gb.getProbabilityOfSuccess();
} else if (c == HypergeometricDistribution.class) {
HypergeometricDistribution ha = (HypergeometricDistribution) a;
HypergeometricDistribution hb = (HypergeometricDistribution) b;
return ha.getPopulationSize() == hb.getPopulationSize()
&& ha.getNumberOfSuccesses() == hb.getNumberOfSuccesses()
&& ha.getSampleSize() == hb.getSampleSize();
} else if (c == PascalDistribution.class) {
PascalDistribution pa = (PascalDistribution) a;
PascalDistribution pb = (PascalDistribution) b;
return pa.getNumberOfSuccesses() == pb.getNumberOfSuccesses()
&& pa.getProbabilityOfSuccess() == pb.getProbabilityOfSuccess();
} else if (c == PoissonDistribution.class) {
PoissonDistribution pa = (PoissonDistribution) a;
PoissonDistribution pb = (PoissonDistribution) b;
return pa.getMean() == pb.getMean();
} else if (c == UniformIntegerDistribution.class) {
UniformIntegerDistribution ua = (UniformIntegerDistribution) a;
UniformIntegerDistribution ub = (UniformIntegerDistribution) b;
return ua.getSupportUpperBound() == ub.getSupportUpperBound()
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
} else if (c == ZipfDistribution.class) {
ZipfDistribution za = (ZipfDistribution) a;
ZipfDistribution zb = (ZipfDistribution) b;
return za.getNumberOfElements() == zb.getNumberOfElements() && za.getExponent() == zb.getNumberOfElements();
} else {
throw new UnsupportedOperationException("Unknown or not supported IntegerDistribution: " + c);
}
}
}

View File

@ -0,0 +1,155 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.distribution;
import com.google.common.base.Preconditions;
import lombok.Getter;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException;
import java.util.Random;
/**
* Log uniform distribution, with support in range [min, max] for min &gt; 0
*
* Reference: <a href="https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php">https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php</a>
*
* @author Alex Black
*/
public class LogUniformDistribution implements RealDistribution {
@Getter private final double min;
@Getter private final double max;
private final double logMin;
private final double logMax;
private transient Random rng = new Random();
/**
*
* @param min Minimum value
* @param max Maximum value
*/
public LogUniformDistribution(double min, double max) {
Preconditions.checkArgument(min > 0, "Minimum must be > 0. Got: " + min);
Preconditions.checkArgument(max > min, "Maximum must be > min. Got: (min, max)=("
+ min + "," + max + ")");
this.min = min;
this.max = max;
this.logMin = Math.log(min);
this.logMax = Math.log(max);
}
@Override
public double probability(double x) {
if(x < min || x > max){
return 0;
}
return 1.0 / (x * (logMax - logMin));
}
@Override
public double density(double x) {
return probability(x);
}
@Override
public double cumulativeProbability(double x) {
if(x <= min){
return 0.0;
} else if(x >= max){
return 1.0;
}
return (Math.log(x)-logMin)/(logMax-logMin);
}
@Override
public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException {
return cumulativeProbability(x1) - cumulativeProbability(x0);
}
@Override
public double inverseCumulativeProbability(double p) throws OutOfRangeException {
Preconditions.checkArgument(p >= 0 && p <= 1, "Invalid input: " + p);
return Math.exp(p * (logMax-logMin) + logMin);
}
@Override
public double getNumericalMean() {
return (max-min)/(logMax-logMin);
}
@Override
public double getNumericalVariance() {
double d1 = (logMax-logMin)*(max*max - min*min) - 2*(max-min)*(max-min);
return d1 / (2*Math.pow(logMax-logMin, 2.0));
}
@Override
public double getSupportLowerBound() {
return min;
}
@Override
public double getSupportUpperBound() {
return max;
}
@Override
public boolean isSupportLowerBoundInclusive() {
return true;
}
@Override
public boolean isSupportUpperBoundInclusive() {
return true;
}
@Override
public boolean isSupportConnected() {
return true;
}
@Override
public void reseedRandomGenerator(long seed) {
rng.setSeed(seed);
}
@Override
public double sample() {
return inverseCumulativeProbability(rng.nextDouble());
}
@Override
public double[] sample(int sampleSize) {
double[] d = new double[sampleSize];
for( int i=0; i<sampleSize; i++ ){
d[i] = sample();
}
return d;
}
@Override
public String toString(){
return "LogUniformDistribution(min=" + min + ",max=" + max + ")";
}
}

View File

@ -0,0 +1,91 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.util.LeafUtils;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
/**
* BaseCandidateGenerator: abstract class upon which {@link RandomSearchGenerator},
* {@link GridSearchCandidateGenerator} and {@link GeneticSearchCandidateGenerator}
* are built.
*
* @param <T> Type of candidates to generate
*/
@Data
@EqualsAndHashCode(exclude = {"rng", "candidateCounter"})
public abstract class BaseCandidateGenerator<T> implements CandidateGenerator {
protected ParameterSpace<T> parameterSpace;
protected AtomicInteger candidateCounter = new AtomicInteger(0);
protected SynchronizedRandomGenerator rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
protected Map<String, Object> dataParameters;
protected boolean initDone = false;
public BaseCandidateGenerator(ParameterSpace<T> parameterSpace, Map<String, Object> dataParameters,
boolean initDone) {
this.parameterSpace = parameterSpace;
this.dataParameters = dataParameters;
this.initDone = initDone;
}
protected void initialize() {
if(!initDone) {
//First: collect leaf parameter spaces objects and remove duplicates
List<ParameterSpace> noDuplicatesList = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
//Second: assign each a number
int i = 0;
for (ParameterSpace ps : noDuplicatesList) {
int np = ps.numParameters();
if (np == 1) {
ps.setIndices(i++);
} else {
int[] values = new int[np];
for (int j = 0; j < np; j++)
values[j] = i++;
ps.setIndices(values);
}
}
initDone = true;
}
}
@Override
public ParameterSpace<T> getParameterSpace() {
return parameterSpace;
}
@Override
public void reportResults(OptimizationResult result) {
//No op
}
@Override
public void setRngSeed(long rngSeed) {
rng.setSeed(rngSeed);
}
}

View File

@ -0,0 +1,187 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.EmptyPopulationInitializer;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.GeneticSelectionOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator;
import java.util.Map;
/**
* Uses a genetic algorithm to generate candidates.
*
* @author Alexandre Boulanger
*/
@Slf4j
public class GeneticSearchCandidateGenerator extends BaseCandidateGenerator {
@Getter
protected final PopulationModel populationModel;
protected final ChromosomeFactory chromosomeFactory;
protected final SelectionOperator selectionOperator;
protected boolean hasMoreCandidates = true;
public static class Builder {
protected final ParameterSpace<?> parameterSpace;
protected Map<String, Object> dataParameters;
protected boolean initDone;
protected boolean minimizeScore;
protected PopulationModel populationModel;
protected ChromosomeFactory chromosomeFactory;
protected SelectionOperator selectionOperator;
/**
* @param parameterSpace ParameterSpace from which to generate candidates
* @param scoreFunction The score function that will be used in the OptimizationConfiguration
*/
public Builder(ParameterSpace<?> parameterSpace, ScoreFunction scoreFunction) {
this.parameterSpace = parameterSpace;
this.minimizeScore = scoreFunction.minimize();
}
/**
* @param populationModel The PopulationModel instance to use.
*/
public Builder populationModel(PopulationModel populationModel) {
this.populationModel = populationModel;
return this;
}
/**
* @param selectionOperator The SelectionOperator to use. Default is GeneticSelectionOperator
*/
public Builder selectionOperator(SelectionOperator selectionOperator) {
this.selectionOperator = selectionOperator;
return this;
}
public Builder dataParameters(Map<String, Object> dataParameters) {
this.dataParameters = dataParameters;
return this;
}
public GeneticSearchCandidateGenerator.Builder initDone(boolean initDone) {
this.initDone = initDone;
return this;
}
/**
* @param chromosomeFactory The ChromosomeFactory to use
*/
public Builder chromosomeFactory(ChromosomeFactory chromosomeFactory) {
this.chromosomeFactory = chromosomeFactory;
return this;
}
public GeneticSearchCandidateGenerator build() {
if (populationModel == null) {
PopulationInitializer defaultPopulationInitializer = new EmptyPopulationInitializer();
populationModel = new PopulationModel.Builder().populationInitializer(defaultPopulationInitializer)
.build();
}
if (chromosomeFactory == null) {
chromosomeFactory = new ChromosomeFactory();
}
if (selectionOperator == null) {
selectionOperator = new GeneticSelectionOperator.Builder().build();
}
return new GeneticSearchCandidateGenerator(this);
}
}
private GeneticSearchCandidateGenerator(Builder builder) {
super(builder.parameterSpace, builder.dataParameters, builder.initDone);
initialize();
chromosomeFactory = builder.chromosomeFactory;
populationModel = builder.populationModel;
selectionOperator = builder.selectionOperator;
chromosomeFactory.initializeInstance(builder.parameterSpace.numParameters());
populationModel.initializeInstance(builder.minimizeScore);
selectionOperator.initializeInstance(populationModel, chromosomeFactory);
}
@Override
public boolean hasMoreCandidates() {
return hasMoreCandidates;
}
@Override
public Candidate getCandidate() {
double[] values = null;
Object value = null;
Exception e = null;
try {
values = selectionOperator.buildNextGenes();
value = parameterSpace.getValue(values);
} catch (GeneticGenerationException e2) {
log.warn("Error generating candidate", e2);
e = e2;
hasMoreCandidates = false;
} catch (Exception e2) {
log.warn("Error getting configuration for candidate", e2);
e = e2;
}
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
}
@Override
public Class<?> getCandidateType() {
return null;
}
@Override
public String toString() {
return "GeneticSearchCandidateGenerator";
}
@Override
public void reportResults(OptimizationResult result) {
if (result.getScore() == null) {
return;
}
Chromosome newChromosome = chromosomeFactory.createChromosome(result.getCandidate().getFlatParameters(),
result.getScore());
populationModel.add(newChromosome);
}
}

View File

@ -0,0 +1,232 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.math3.random.RandomAdaptor;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.deeplearning4j.arbiter.util.LeafUtils;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.*;
import java.util.concurrent.ConcurrentLinkedQueue;
/**
* GridSearchCandidateGenerator: generates candidates in an exhaustive grid search manner.<br>
* Note that:<br>
* - For discrete parameters: the grid size (# values to check per hyperparameter) is equal to the number of values for
* that hyperparameter<br>
* - For integer parameters: the grid size is equal to {@code min(discretizationCount,max-min+1)}. Some integer ranges can
* be large, and we don't necessarily want to exhaustively search them. {@code discretizationCount} is a constructor argument<br>
* - For continuous parameters: the grid size is equal to {@code discretizationCount}.<br>
* In all cases, the minimum, maximum and gridSize-2 values between the min/max will be generated.<br>
* Also note that: if a probability distribution is provided for continuous hyperparameters, this will be taken into account
* when generating candidates. This allows the grid for a hyperparameter to be non-linear: i.e., for example, linear in log space
*
* @author Alex Black
*/
@Slf4j
@EqualsAndHashCode(exclude = {"order"}, callSuper = true)
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
/**
* In what order should candidates be generated?<br>
* <b>Sequential</b>: generate candidates in order. The first hyperparameter will be changed most rapidly, and the last
* will be changed least rapidly.<br>
* <b>RandomOrder</b>: generate candidates in a random order<br>
* In both cases, the same candidates will be generated; only the order of generation is different
*/
public enum Mode {
Sequential, RandomOrder
}
private final int discretizationCount;
private final Mode mode;
private int[] numValuesPerParam;
@Getter
private int totalNumCandidates;
private Queue<Integer> order;
/**
* @param parameterSpace ParameterSpace from which to generate candidates
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
* do [0.0, 0.5, 1.0]. Note that if all values
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
* in which candidates should be generated.
*/
public GridSearchCandidateGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
@JsonProperty("discretizationCount") int discretizationCount, @JsonProperty("mode") Mode mode,
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
@JsonProperty("initDone") boolean initDone) {
super(parameterSpace, dataParameters, initDone);
this.discretizationCount = discretizationCount;
this.mode = mode;
initialize();
}
/**
* @param parameterSpace ParameterSpace from which to generate candidates
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
* do [0.0, 0.5, 1.0]. Note that if all values
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
* in which candidates should be generated.
*/
public GridSearchCandidateGenerator(ParameterSpace<?> parameterSpace, int discretizationCount, Mode mode,
Map<String, Object> dataParameters){
this(parameterSpace, discretizationCount, mode, dataParameters, false);
}
@Override
protected void initialize() {
super.initialize();
List<ParameterSpace> leaves = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
int nParams = leaves.size();
//Work out for each parameter: is it continuous or discrete?
// for grid search: discrete values are grid-searchable as-is
// continuous values: discretize using 'discretizationCount' bins
// integer values: use min(max-min+1, discretizationCount) values. i.e., discretize if necessary
numValuesPerParam = new int[nParams];
long searchSize = 1;
for (int i = 0; i < nParams; i++) {
ParameterSpace ps = leaves.get(i);
if (ps instanceof DiscreteParameterSpace) {
DiscreteParameterSpace dps = (DiscreteParameterSpace) ps;
numValuesPerParam[i] = dps.numValues();
} else if (ps instanceof IntegerParameterSpace) {
IntegerParameterSpace ips = (IntegerParameterSpace) ps;
int min = ips.getMin();
int max = ips.getMax();
//Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000)
numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount);
} else if (ps instanceof FixedValue){
numValuesPerParam[i] = 1;
} else {
numValuesPerParam[i] = discretizationCount;
}
searchSize *= numValuesPerParam[i];
}
if (searchSize >= Integer.MAX_VALUE)
throw new IllegalStateException("Invalid search: cannot process search with " + searchSize
+ " candidates > Integer.MAX_VALUE"); //TODO find a more reasonable upper bound?
order = new ConcurrentLinkedQueue<>();
totalNumCandidates = (int) searchSize;
switch (mode) {
case Sequential:
for (int i = 0; i < totalNumCandidates; i++) {
order.add(i);
}
break;
case RandomOrder:
List<Integer> tempList = new ArrayList<>(totalNumCandidates);
for (int i = 0; i < totalNumCandidates; i++) {
tempList.add(i);
}
Collections.shuffle(tempList, new RandomAdaptor(rng));
order.addAll(tempList);
break;
default:
throw new RuntimeException();
}
}
@Override
public boolean hasMoreCandidates() {
return !order.isEmpty();
}
@Override
public Candidate getCandidate() {
int next = order.remove();
//Next: max integer (candidate number) to values
double[] values = indexToValues(numValuesPerParam, next, totalNumCandidates);
Object value = null;
Exception e = null;
try {
value = parameterSpace.getValue(values);
} catch (Exception e2) {
log.warn("Error getting configuration for candidate", e2);
e = e2;
}
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
}
@Override
public Class<?> getCandidateType() {
return null;
}
public static double[] indexToValues(int[] numValuesPerParam, int candidateIdx, int product) {
//How? first map to index of num possible values. Then: to double values in range 0 to 1
// 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc
//Based on: Nd4j Shape.ind2sub
int countNon1 = 0;
for( int i : numValuesPerParam)
if(i > 1)
countNon1++;
int denom = product;
int num = candidateIdx;
int[] index = new int[numValuesPerParam.length];
for (int i = index.length - 1; i >= 0; i--) {
denom /= numValuesPerParam[i];
index[i] = num / denom;
num %= denom;
}
//Now: convert indexes to values in range [0,1]
//min value -> 0
//max value -> 1
double[] out = new double[countNon1];
int outIdx = 0;
for (int i = 0; i < numValuesPerParam.length; i++) {
if (numValuesPerParam[i] > 1){
out[outIdx++] = index[i] / ((double) (numValuesPerParam[i] - 1));
}
}
return out;
}
@Override
public String toString() {
return "GridSearchCandidateGenerator(mode=" + mode + ")";
}
}

View File

@ -0,0 +1,93 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Map;
/**
* RandomSearchGenerator: generates candidates at random.<br>
* Note: if a probability distribution is provided for continuous hyperparameters,
* this will be taken into account
* when generating candidates. This allows the search to be weighted more towards
* certain values according to a probability
* density. For example: generate samples for learning rate according to log uniform distribution
*
* @author Alex Black
*/
@Slf4j
@EqualsAndHashCode(callSuper = true)
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
public class RandomSearchGenerator extends BaseCandidateGenerator {
@JsonCreator
public RandomSearchGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
@JsonProperty("initDone") boolean initDone) {
super(parameterSpace, dataParameters, initDone);
initialize();
}
public RandomSearchGenerator(ParameterSpace<?> parameterSpace, Map<String,Object> dataParameters){
this(parameterSpace, dataParameters, false);
}
public RandomSearchGenerator(ParameterSpace<?> parameterSpace){
this(parameterSpace, null, false);
}
@Override
public boolean hasMoreCandidates() {
return true;
}
@Override
public Candidate getCandidate() {
double[] randomValues = new double[parameterSpace.numParameters()];
for (int i = 0; i < randomValues.length; i++)
randomValues[i] = rng.nextDouble();
Object value = null;
Exception e = null;
try {
value = parameterSpace.getValue(randomValues);
} catch (Exception e2) {
log.warn("Error getting configuration for candidate", e2);
e = e2;
}
return new Candidate(value, candidateCounter.getAndIncrement(), randomValues, dataParameters, e);
}
@Override
public Class<?> getCandidateType() {
return null;
}
@Override
public String toString() {
return "RandomSearchGenerator";
}
}

View File

@ -0,0 +1,42 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic;
import lombok.Data;
/**
* Candidates are stored as Chromosome in the population model
*
* @author Alexandre Boulanger
*/
@Data
public class Chromosome {
/**
* The fitness score of the genes.
*/
protected final double fitness;
/**
* The genes.
*/
protected final double[] genes;
public Chromosome(double[] genes, double fitness) {
this.genes = genes;
this.fitness = fitness;
}
}

View File

@ -0,0 +1,51 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic;
/**
* A factory that builds new chromosomes. Used by the GeneticSearchCandidateGenerator.
*
* @author Alexandre Boulanger
*/
public class ChromosomeFactory {
private int chromosomeLength;
/**
* Called by the GeneticSearchCandidateGenerator.
*/
public void initializeInstance(int chromosomeLength) {
this.chromosomeLength = chromosomeLength;
}
/**
* Create a new instance of a Chromosome
*
* @param genes The genes
* @param fitness The fitness score
* @return A new instance of Chromosome
*/
public Chromosome createChromosome(double[] genes, double fitness) {
return new Chromosome(genes, fitness);
}
/**
* @return The number of genes in a chromosome
*/
public int getChromosomeLength() {
return chromosomeLength;
}
}

View File

@ -0,0 +1,120 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import org.nd4j.common.base.Preconditions;
/**
* A crossover operator that linearly combines the genes of two parents. <br>
* When a crossover is generated (with a of probability <i>crossover rate</i>), each genes is a linear combination of the corresponding genes of the parents.
* <p>
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i>
*
* @author Alexandre Boulanger
*/
public class ArithmeticCrossover extends TwoParentsCrossoverOperator {
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
private final double crossoverRate;
private final RandomGenerator rng;
public static class Builder {
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
private RandomGenerator rng;
private TwoParentSelection parentSelection;
/**
* The probability that the operator generates a crossover (default 0.85).
*
* @param rate A value between 0.0 and 1.0
*/
public Builder crossoverRate(double rate) {
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
this.crossoverRate = rate;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
/**
* The parent selection behavior. Default is random parent selection.
*
* @param parentSelection An instance of TwoParentSelection
*/
public Builder parentSelection(TwoParentSelection parentSelection) {
this.parentSelection = parentSelection;
return this;
}
public ArithmeticCrossover build() {
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
if (parentSelection == null) {
parentSelection = new RandomTwoParentSelection();
}
return new ArithmeticCrossover(this);
}
}
private ArithmeticCrossover(ArithmeticCrossover.Builder builder) {
super(builder.parentSelection);
this.crossoverRate = builder.crossoverRate;
this.rng = builder.rng;
}
/**
* Has a probability <i>crossoverRate</i> of performing the crossover where each gene is a linear combination of:<br>
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i><br>
* Otherwise, returns the genes of a random parent.
*
* @return The crossover result. See {@link CrossoverResult}.
*/
@Override
public CrossoverResult crossover() {
double[][] parents = parentSelection.selectParents();
double[] offspringValues = new double[parents[0].length];
if (rng.nextDouble() < crossoverRate) {
for (int i = 0; i < offspringValues.length; ++i) {
double t = rng.nextDouble();
offspringValues[i] = t * parents[0][i] + (1.0 - t) * parents[1][i];
}
return new CrossoverResult(true, offspringValues);
}
return new CrossoverResult(false, parents[0]);
}
}

View File

@ -0,0 +1,45 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
/**
* Abstract class for all crossover operators
*
* @author Alexandre Boulanger
*/
public abstract class CrossoverOperator {
protected PopulationModel populationModel;
/**
* Will be called by the selection operator once the population model is instantiated.
*/
public void initializeInstance(PopulationModel populationModel) {
this.populationModel = populationModel;
}
/**
* Performs the crossover
*
* @return The crossover result. See {@link CrossoverResult}.
*/
public abstract CrossoverResult crossover();
}

View File

@ -0,0 +1,43 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import lombok.Data;
/**
* Returned by a crossover operator
*
* @author Alexandre Boulanger
*/
@Data
public class CrossoverResult {
/**
* If false, there was no crossover and the operator simply returned the genes of a random parent.
* If true, the genes are the result of a crossover.
*/
private final boolean isModified;
/**
* The genes returned by the operator.
*/
private final double[] genes;
public CrossoverResult(boolean isModified, double[] genes) {
this.isModified = isModified;
this.genes = genes;
}
}

View File

@ -0,0 +1,178 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
import org.nd4j.common.base.Preconditions;
import java.util.Deque;
/**
* The K-Point crossover will select at random multiple crossover points.<br>
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched.
*/
public class KPointCrossover extends TwoParentsCrossoverOperator {
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
private static final int DEFAULT_MIN_CROSSOVER = 1;
private static final int DEFAULT_MAX_CROSSOVER = 4;
private final double crossoverRate;
private final int minCrossovers;
private final int maxCrossovers;
private final RandomGenerator rng;
public static class Builder {
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
private int minCrossovers = DEFAULT_MIN_CROSSOVER;
private int maxCrossovers = DEFAULT_MAX_CROSSOVER;
private RandomGenerator rng;
private TwoParentSelection parentSelection;
/**
* The probability that the operator generates a crossover (default 0.85).
*
* @param rate A value between 0.0 and 1.0
*/
public Builder crossoverRate(double rate) {
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
this.crossoverRate = rate;
return this;
}
/**
* The number of crossovers points (default is min 1, max 4)
*
* @param min The minimum number
* @param max The maximum number
*/
public Builder numCrossovers(int min, int max) {
Preconditions.checkState(max >= 0 && min >= 0, "Min and max must be positive");
Preconditions.checkState(max >= min, "Max must be greater or equal to min");
this.minCrossovers = min;
this.maxCrossovers = max;
return this;
}
/**
* Use a fixed number of crossover points
*
* @param num The number of crossovers
*/
public Builder numCrossovers(int num) {
Preconditions.checkState(num >= 0, "Num must be positive");
this.minCrossovers = num;
this.maxCrossovers = num;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
/**
* The parent selection behavior. Default is random parent selection.
*
* @param parentSelection An instance of TwoParentSelection
*/
public Builder parentSelection(TwoParentSelection parentSelection) {
this.parentSelection = parentSelection;
return this;
}
public KPointCrossover build() {
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
if (parentSelection == null) {
parentSelection = new RandomTwoParentSelection();
}
return new KPointCrossover(this);
}
}
private CrossoverPointsGenerator crossoverPointsGenerator;
private KPointCrossover(KPointCrossover.Builder builder) {
super(builder.parentSelection);
this.crossoverRate = builder.crossoverRate;
this.maxCrossovers = builder.maxCrossovers;
this.minCrossovers = builder.minCrossovers;
this.rng = builder.rng;
}
private CrossoverPointsGenerator getCrossoverPointsGenerator(int chromosomeLength) {
if (crossoverPointsGenerator == null) {
crossoverPointsGenerator =
new CrossoverPointsGenerator(chromosomeLength, minCrossovers, maxCrossovers, rng);
}
return crossoverPointsGenerator;
}
/**
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select at random multiple crossover points.<br>
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched. <br>
* Otherwise, returns the genes of a random parent.
*
* @return The crossover result. See {@link CrossoverResult}.
*/
@Override
public CrossoverResult crossover() {
double[][] parents = parentSelection.selectParents();
boolean isModified = false;
double[] resultGenes = parents[0];
if (rng.nextDouble() < crossoverRate) {
// Select crossover points
Deque<Integer> crossoverPoints = getCrossoverPointsGenerator(parents[0].length).getCrossoverPoints();
// Crossover
resultGenes = new double[parents[0].length];
int currentParent = 0;
int nextCrossover = crossoverPoints.pop();
for (int i = 0; i < resultGenes.length; ++i) {
if (i == nextCrossover) {
currentParent = currentParent == 0 ? 1 : 0;
nextCrossover = crossoverPoints.pop();
}
resultGenes[i] = parents[currentParent][i];
}
isModified = true;
}
return new CrossoverResult(isModified, resultGenes);
}
}

View File

@ -0,0 +1,123 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import org.nd4j.common.base.Preconditions;
/**
* The single point crossover will select a random point where every genes before that point comes from one parent
* and after which every genes comes from the other parent.
*
* @author Alexandre Boulanger
*/
public class SinglePointCrossover extends TwoParentsCrossoverOperator {
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
private final RandomGenerator rng;
private final double crossoverRate;
public static class Builder {
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
private RandomGenerator rng;
private TwoParentSelection parentSelection;
/**
* The probability that the operator generates a crossover (default 0.85).
*
* @param rate A value between 0.0 and 1.0
*/
public Builder crossoverRate(double rate) {
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
this.crossoverRate = rate;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
/**
* The parent selection behavior. Default is random parent selection.
*
* @param parentSelection An instance of TwoParentSelection
*/
public Builder parentSelection(TwoParentSelection parentSelection) {
this.parentSelection = parentSelection;
return this;
}
public SinglePointCrossover build() {
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
if (parentSelection == null) {
parentSelection = new RandomTwoParentSelection();
}
return new SinglePointCrossover(this);
}
}
private SinglePointCrossover(SinglePointCrossover.Builder builder) {
super(builder.parentSelection);
this.crossoverRate = builder.crossoverRate;
this.rng = builder.rng;
}
/**
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select a random crossover point.<br>
* Each gene before this point comes from one of the two parents and each gene at or after this point comes from the other parent.
* Otherwise, returns the genes of a random parent.
*
* @return The crossover result. See {@link CrossoverResult}.
*/
public CrossoverResult crossover() {
double[][] parents = parentSelection.selectParents();
boolean isModified = false;
double[] resultGenes = parents[0];
if (rng.nextDouble() < crossoverRate) {
int chromosomeLength = parents[0].length;
// Crossover
resultGenes = new double[chromosomeLength];
int crossoverPoint = rng.nextInt(chromosomeLength);
for (int i = 0; i < resultGenes.length; ++i) {
resultGenes[i] = ((i < crossoverPoint) ? parents[0] : parents[1])[i];
}
isModified = true;
}
return new CrossoverResult(isModified, resultGenes);
}
}

View File

@ -0,0 +1,46 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
/**
* Abstract class for all crossover operators that applies to two parents.
*
* @author Alexandre Boulanger
*/
public abstract class TwoParentsCrossoverOperator extends CrossoverOperator {
protected final TwoParentSelection parentSelection;
/**
* @param parentSelection A parent selection that selects two parents.
*/
protected TwoParentsCrossoverOperator(TwoParentSelection parentSelection) {
this.parentSelection = parentSelection;
}
/**
* Will be called by the selection operator once the population model is instantiated.
*/
@Override
public void initializeInstance(PopulationModel populationModel) {
super.initializeInstance(populationModel);
parentSelection.initializeInstance(populationModel.getPopulation());
}
}

View File

@ -0,0 +1,136 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import org.nd4j.common.base.Preconditions;
/**
* The uniform crossover will, for each gene, randomly select the parent that donates the gene.
*
* @author Alexandre Boulanger
*/
public class UniformCrossover extends TwoParentsCrossoverOperator {
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
private static final double DEFAULT_PARENT_BIAS_FACTOR = 0.5;
private final double crossoverRate;
private final double parentBiasFactor;
private final RandomGenerator rng;
public static class Builder {
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
private double parentBiasFactor = DEFAULT_PARENT_BIAS_FACTOR;
private RandomGenerator rng;
private TwoParentSelection parentSelection;
/**
* The probability that the operator generates a crossover (default 0.85).
*
* @param rate A value between 0.0 and 1.0
*/
public Builder crossoverRate(double rate) {
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
this.crossoverRate = rate;
return this;
}
/**
* A factor that will introduce a bias in the parent selection.<br>
*
* @param factor In the range [0, 1]. 0 will only select the first parent while 1 only select the second one. The default is 0.5; no bias.
*/
public Builder parentBiasFactor(double factor) {
Preconditions.checkState(factor >= 0.0 && factor <= 1.0, "Factor must be between 0.0 and 1.0, got %s",
factor);
this.parentBiasFactor = factor;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
/**
* The parent selection behavior. Default is random parent selection.
*
* @param parentSelection An instance of TwoParentSelection
*/
public Builder parentSelection(TwoParentSelection parentSelection) {
this.parentSelection = parentSelection;
return this;
}
public UniformCrossover build() {
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
if (parentSelection == null) {
parentSelection = new RandomTwoParentSelection();
}
return new UniformCrossover(this);
}
}
private UniformCrossover(UniformCrossover.Builder builder) {
super(builder.parentSelection);
this.crossoverRate = builder.crossoverRate;
this.parentBiasFactor = builder.parentBiasFactor;
this.rng = builder.rng;
}
/**
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select randomly which parent donates the gene.<br>
* One of the parent may be favored if the bias is different than 0.5
* Otherwise, returns the genes of a random parent.
*
* @return The crossover result. See {@link CrossoverResult}.
*/
@Override
public CrossoverResult crossover() {
// select the parents
double[][] parents = parentSelection.selectParents();
double[] resultGenes = parents[0];
boolean isModified = false;
if (rng.nextDouble() < crossoverRate) {
// Crossover
resultGenes = new double[parents[0].length];
for (int i = 0; i < resultGenes.length; ++i) {
resultGenes[i] = ((rng.nextDouble() < parentBiasFactor) ? parents[0] : parents[1])[i];
}
isModified = true;
}
return new CrossoverResult(isModified, resultGenes);
}
}

View File

@ -0,0 +1,44 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import java.util.List;
/**
* Abstract class for all parent selection behaviors
*
* @author Alexandre Boulanger
*/
public abstract class ParentSelection {
protected List<Chromosome> population;
/**
* Will be called by the crossover operator once the population model is instantiated.
*/
public void initializeInstance(List<Chromosome> population) {
this.population = population;
}
/**
* Performs the parent selection
*
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
*/
public abstract double[][] selectParents();
}

View File

@ -0,0 +1,65 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
/**
* A parent selection behavior that returns two random parents.
*
* @author Alexandre Boulanger
*/
public class RandomTwoParentSelection extends TwoParentSelection {
private final RandomGenerator rng;
public RandomTwoParentSelection() {
this(new SynchronizedRandomGenerator(new JDKRandomGenerator()));
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public RandomTwoParentSelection(RandomGenerator rng) {
this.rng = rng;
}
/**
* Selects two random parents
*
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
*/
@Override
public double[][] selectParents() {
double[][] parents = new double[2][];
int parent1Idx = rng.nextInt(population.size());
int parent2Idx;
do {
parent2Idx = rng.nextInt(population.size());
} while (parent1Idx == parent2Idx);
parents[0] = population.get(parent1Idx).getGenes();
parents[1] = population.get(parent2Idx).getGenes();
return parents;
}
}

View File

@ -0,0 +1,25 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
/**
* Abstract class for all parent selection behaviors that selects two parents.
*
* @author Alexandre Boulanger
*/
public abstract class TwoParentSelection extends ParentSelection {
}

View File

@ -0,0 +1,68 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
import java.util.*;
/**
* A helper class used by {@link KPointCrossover} to generate the crossover points
*
* @author Alexandre Boulanger
*/
public class CrossoverPointsGenerator {
private final int minCrossovers;
private final int maxCrossovers;
private final RandomGenerator rng;
private List<Integer> parameterIndexes;
/**
* Constructor
*
* @param chromosomeLength The number of genes
* @param minCrossovers The minimum number of crossover points to generate
* @param maxCrossovers The maximum number of crossover points to generate
* @param rng A RandomGenerator instance
*/
public CrossoverPointsGenerator(int chromosomeLength, int minCrossovers, int maxCrossovers, RandomGenerator rng) {
this.minCrossovers = minCrossovers;
this.maxCrossovers = maxCrossovers;
this.rng = rng;
parameterIndexes = new ArrayList<Integer>();
for (int i = 0; i < chromosomeLength; ++i) {
parameterIndexes.add(i);
}
}
/**
* Generate a list of crossover points.
*
* @return An ordered list of crossover point indexes and with Integer.MAX_VALUE as the last element
*/
public Deque<Integer> getCrossoverPoints() {
Collections.shuffle(parameterIndexes);
List<Integer> crossoverPointLists =
parameterIndexes.subList(0, rng.nextInt(maxCrossovers - minCrossovers) + minCrossovers);
Collections.sort(crossoverPointLists);
Deque<Integer> crossoverPoints = new ArrayDeque<Integer>(crossoverPointLists);
crossoverPoints.add(Integer.MAX_VALUE);
return crossoverPoints;
}
}

View File

@ -0,0 +1,41 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
/**
* The cull operator will remove from the population the least desirables chromosomes.
*
* @author Alexandre Boulanger
*/
public interface CullOperator {
/**
* Will be called by the population model once created.
*/
void initializeInstance(PopulationModel populationModel);
/**
* Cull the population to the culled size.
*/
void cullPopulation();
/**
* @return The target population size after culling.
*/
int getCulledSize();
}

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
/**
* An elitist cull operator that discards the chromosomes with the worst fitness while keeping the best ones.
*
* @author Alexandre Boulanger
*/
public class LeastFitCullOperator extends RatioCullOperator {
/**
* The default cull ratio is 1/3.
*/
public LeastFitCullOperator() {
super();
}
/**
* @param cullRatio The ratio of the maximum population size to be culled.<br>
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
*/
public LeastFitCullOperator(double cullRatio) {
super(cullRatio);
}
/**
* Will discard the chromosomes with the worst fitness until the population size fall back at the culled size.
*/
@Override
public void cullPopulation() {
while (population.size() > culledSize) {
population.remove(population.size() - 1);
}
}
}

View File

@ -0,0 +1,70 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import org.nd4j.common.base.Preconditions;
import java.util.List;
/**
* An abstract base for cull operators that culls back the population to a ratio of its maximum size.
*
* @author Alexandre Boulanger
*/
public abstract class RatioCullOperator implements CullOperator {
private static final double DEFAULT_CULL_RATIO = 1.0 / 3.0;
protected int culledSize;
protected List<Chromosome> population;
protected final double cullRatio;
/**
* @param cullRatio The ratio of the maximum population size to be culled.<br>
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
*/
public RatioCullOperator(double cullRatio) {
Preconditions.checkState(cullRatio >= 0.0 && cullRatio <= 1.0, "Cull ratio must be between 0.0 and 1.0, got %s",
cullRatio);
this.cullRatio = cullRatio;
}
/**
* The default cull ratio is 1/3
*/
public RatioCullOperator() {
this(DEFAULT_CULL_RATIO);
}
/**
* Will be called by the population model once created.
*/
public void initializeInstance(PopulationModel populationModel) {
this.population = populationModel.getPopulation();
culledSize = (int) (populationModel.getPopulationSize() * (1.0 - cullRatio) + 0.5);
}
/**
* @return The target population size after culling.
*/
@Override
public int getCulledSize() {
return culledSize;
}
}

View File

@ -0,0 +1,23 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions;
public class GeneticGenerationException extends RuntimeException {
public GeneticGenerationException(String message) {
super(message);
}
}

View File

@ -0,0 +1,33 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
/**
* The mutation operator will apply a mutation to the given genes.
*
* @author Alexandre Boulanger
*/
public interface MutationOperator {
/**
* Performs a mutation.
*
* @param genes The genes to be mutated
* @return True if the genes were mutated, otherwise false.
*/
boolean mutate(double[] genes);
}

View File

@ -0,0 +1,93 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.nd4j.common.base.Preconditions;
/**
* A mutation operator where each gene has a chance of being mutated with a <i>mutation rate</i> probability.
*
* @author Alexandre Boulanger
*/
public class RandomMutationOperator implements MutationOperator {
private static final double DEFAULT_MUTATION_RATE = 0.005;
private final double mutationRate;
private final RandomGenerator rng;
public static class Builder {
private double mutationRate = DEFAULT_MUTATION_RATE;
private RandomGenerator rng;
/**
* Each gene will have this probability of being mutated.
*
* @param rate The mutation rate. (default 0.005)
*/
public Builder mutationRate(double rate) {
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
this.mutationRate = rate;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
public RandomMutationOperator build() {
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
return new RandomMutationOperator(this);
}
}
private RandomMutationOperator(RandomMutationOperator.Builder builder) {
this.mutationRate = builder.mutationRate;
this.rng = builder.rng;
}
/**
* Performs the mutation. Each gene has a <i>mutation rate</i> probability of being mutated.
*
* @param genes The genes to be mutated
* @return True if the genes were mutated, otherwise false.
*/
@Override
public boolean mutate(double[] genes) {
boolean hasMutation = false;
for (int i = 0; i < genes.length; ++i) {
if (rng.nextDouble() < mutationRate) {
genes[i] = rng.nextDouble();
hasMutation = true;
}
}
return hasMutation;
}
}

View File

@ -0,0 +1,41 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import java.util.ArrayList;
import java.util.List;
/**
* A population initializer that build an empty population.
*
* @author Alexandre Boulanger
*/
public class EmptyPopulationInitializer implements PopulationInitializer {
/**
* Initialize an empty population
*
* @param size The maximum size of the population.
* @return The initialized population.
*/
@Override
public List<Chromosome> getInitializedPopulation(int size) {
return new ArrayList<>(size);
}
}

View File

@ -0,0 +1,36 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import java.util.List;
/**
* An initializer that construct the population used by the population model.
*
* @author Alexandre Boulanger
*/
public interface PopulationInitializer {
/**
* Called by the population model to construct the population
*
* @param size The maximum size of the population
* @return An initialized population
*/
List<Chromosome> getInitializedPopulation(int size);
}

View File

@ -0,0 +1,35 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import java.util.List;
/**
* A listener that is called when the population changes.
*
* @author Alexandre Boulanger
*/
public interface PopulationListener {
/**
* Called after the population has changed.
*
* @param population The population after it has changed.
*/
void onChanged(List<Chromosome> population);
}

View File

@ -0,0 +1,182 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
import lombok.Getter;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* The population model handles all aspects of the population (initialization, additions and culling)
*
* @author Alexandre Boulanger
*/
public class PopulationModel {
private static final int DEFAULT_POPULATION_SIZE = 30;
private final CullOperator cullOperator;
private final List<PopulationListener> populationListeners = new ArrayList<>();
private Comparator<Chromosome> chromosomeComparator;
/**
* The maximum population size
*/
@Getter
private final int populationSize;
/**
* The population
*/
@Getter
public final List<Chromosome> population;
/**
* A comparator used when higher fitness value is better
*/
public static class MaximizeScoreComparator implements Comparator<Chromosome> {
@Override
public int compare(Chromosome lhs, Chromosome rhs) {
return -Double.compare(lhs.getFitness(), rhs.getFitness());
}
}
/**
* A comparator used when lower fitness value is better
*/
public static class MinimizeScoreComparator implements Comparator<Chromosome> {
@Override
public int compare(Chromosome lhs, Chromosome rhs) {
return Double.compare(lhs.getFitness(), rhs.getFitness());
}
}
public static class Builder {
private int populationSize = DEFAULT_POPULATION_SIZE;
private PopulationInitializer populationInitializer;
private CullOperator cullOperator;
/**
* Use an alternate population initialization behavior. Default is empty population.
*
* @param populationInitializer An instance of PopulationInitializer
*/
public Builder populationInitializer(PopulationInitializer populationInitializer) {
this.populationInitializer = populationInitializer;
return this;
}
/**
* The maximum population size. <br>
* If using a ratio based culling, using a population with culled size of around 1.5 to 2 times the number of genes generally gives good results.
* (e.g. For a chromosome having 10 genes, the culled size should be between 15 and 20. And with a cull ratio of 1/3 we should set the population size to 23 to 30. (15 / (1 - 1/3)), rounded up)
*
* @param size The maximum size of the population
*/
public Builder populationSize(int size) {
populationSize = size;
return this;
}
/**
* Use an alternate cull operator behavior. Default is least fit culling.
*
* @param cullOperator An instance of a CullOperator
*/
public Builder cullOperator(CullOperator cullOperator) {
this.cullOperator = cullOperator;
return this;
}
public PopulationModel build() {
if (cullOperator == null) {
cullOperator = new LeastFitCullOperator();
}
if (populationInitializer == null) {
populationInitializer = new EmptyPopulationInitializer();
}
return new PopulationModel(this);
}
}
public PopulationModel(PopulationModel.Builder builder) {
populationSize = builder.populationSize;
population = new ArrayList<>(builder.populationSize);
PopulationInitializer populationInitializer = builder.populationInitializer;
List<Chromosome> initializedPopulation = populationInitializer.getInitializedPopulation(populationSize);
population.clear();
population.addAll(initializedPopulation);
cullOperator = builder.cullOperator;
cullOperator.initializeInstance(this);
}
/**
* Called by the GeneticSearchCandidateGenerator
*/
public void initializeInstance(boolean minimizeScore) {
chromosomeComparator = minimizeScore ? new MinimizeScoreComparator() : new MaximizeScoreComparator();
}
/**
* Add a PopulationListener to the list of change listeners
* @param listener A PopulationListener instance
*/
public void addListener(PopulationListener listener) {
populationListeners.add(listener);
}
/**
* Add a Chromosome to the population and call the PopulationListeners. Culling may be triggered.
*
* @param element The chromosome to be added
*/
public void add(Chromosome element) {
if (population.size() == populationSize) {
cullOperator.cullPopulation();
}
population.add(element);
Collections.sort(population, chromosomeComparator);
triggerPopulationChangedListeners(population);
}
/**
* @return Return false when the population is below the culled size, otherwise true. <br>
* Used by the selection operator to know if the population is still too small and should generate random genes.
*/
public boolean isReadyToBreed() {
return population.size() >= cullOperator.getCulledSize();
}
private void triggerPopulationChangedListeners(List<Chromosome> population) {
for (PopulationListener listener : populationListeners) {
listener.onChanged(population);
}
}
}

View File

@ -0,0 +1,197 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.MutationOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import java.util.Arrays;
/**
* A selection operator that will generate random genes initially. Once the population has reached the culled size,
* will start to generate offsprings of parents selected in the population.
*
* @author Alexandre Boulanger
*/
public class GeneticSelectionOperator extends SelectionOperator {
private final static int PREVIOUS_GENES_TO_KEEP = 100;
private final static int MAX_NUM_GENERATION_ATTEMPTS = 1024;
private final CrossoverOperator crossoverOperator;
private final MutationOperator mutationOperator;
private final RandomGenerator rng;
private double[][] previousGenes = new double[PREVIOUS_GENES_TO_KEEP][];
private int previousGenesIdx = 0;
public static class Builder {
private ChromosomeFactory chromosomeFactory;
private PopulationModel populationModel;
private CrossoverOperator crossoverOperator;
private MutationOperator mutationOperator;
private RandomGenerator rng;
/**
* Use an alternate crossover behavior. Default is SinglePointCrossover.
*
* @param crossoverOperator An instance of CrossoverOperator
*/
public Builder crossoverOperator(CrossoverOperator crossoverOperator) {
this.crossoverOperator = crossoverOperator;
return this;
}
/**
* Use an alternate mutation behavior. Default is RandomMutationOperator.
*
* @param mutationOperator An instance of MutationOperator
*/
public Builder mutationOperator(MutationOperator mutationOperator) {
this.mutationOperator = mutationOperator;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
public GeneticSelectionOperator build() {
if (crossoverOperator == null) {
crossoverOperator = new SinglePointCrossover.Builder().build();
}
if (mutationOperator == null) {
mutationOperator = new RandomMutationOperator.Builder().build();
}
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
return new GeneticSelectionOperator(crossoverOperator, mutationOperator, rng);
}
}
private GeneticSelectionOperator(CrossoverOperator crossoverOperator, MutationOperator mutationOperator,
RandomGenerator rng) {
this.crossoverOperator = crossoverOperator;
this.mutationOperator = mutationOperator;
this.rng = rng;
}
/**
* Called by GeneticSearchCandidateGenerator
*/
@Override
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
super.initializeInstance(populationModel, chromosomeFactory);
crossoverOperator.initializeInstance(populationModel);
}
/**
* Build a new set of genes. Has two distinct modes of operation
* <ul>
* <li>Before the population has reached the culled size: will return a random set of genes.</li>
* <li>After: Parents will be selected among the population, a crossover will be applied followed by a mutation.</li>
* </ul>
* @return Returns the generated set of genes
* @throws GeneticGenerationException If buildNextGenes() can't generate a set that has not already been tried,
* or if the crossover and the mutation operators can't generate a set,
* this exception is thrown.
*/
@Override
public double[] buildNextGenes() {
double[] result;
boolean hasAlreadyBeenTried;
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
do {
if (populationModel.isReadyToBreed()) {
result = buildOffspring();
} else {
result = buildRandomGenes();
}
hasAlreadyBeenTried = hasAlreadyBeenTried(result);
if (hasAlreadyBeenTried && --attemptsRemaining == 0) {
throw new GeneticGenerationException("Failed to generate a set of genes not already tried.");
}
} while (hasAlreadyBeenTried);
previousGenes[previousGenesIdx] = result;
previousGenesIdx = ++previousGenesIdx % previousGenes.length;
return result;
}
private boolean hasAlreadyBeenTried(double[] genes) {
for (int i = 0; i < previousGenes.length; ++i) {
double[] current = previousGenes[i];
if (current != null && Arrays.equals(current, genes)) {
return true;
}
}
return false;
}
private double[] buildOffspring() {
double[] offspringValues;
boolean isModified;
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
do {
CrossoverResult crossoverResult = crossoverOperator.crossover();
offspringValues = crossoverResult.getGenes();
isModified = crossoverResult.isModified();
isModified |= mutationOperator.mutate(offspringValues);
if (!isModified && --attemptsRemaining == 0) {
throw new GeneticGenerationException(
String.format("Crossover and mutation operators failed to generate a new set of genes after %s attempts.",
MAX_NUM_GENERATION_ATTEMPTS));
}
} while (!isModified);
return offspringValues;
}
private double[] buildRandomGenes() {
double[] randomValues = new double[chromosomeFactory.getChromosomeLength()];
for (int i = 0; i < randomValues.length; ++i) {
randomValues[i] = rng.nextDouble();
}
return randomValues;
}
}

View File

@ -0,0 +1,44 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
/**
* An abstract class for all selection operators. Used by the GeneticSearchCandidateGenerator to generate new candidates.
*
* @author Alexandre Boulanger
*/
public abstract class SelectionOperator {
protected PopulationModel populationModel;
protected ChromosomeFactory chromosomeFactory;
/**
* Called by GeneticSearchCandidateGenerator
*/
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
this.populationModel = populationModel;
this.chromosomeFactory = chromosomeFactory;
}
/**
* Generate a new set of genes.
*/
public abstract double[] buildNextGenes();
}

View File

@ -0,0 +1,46 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.generator.util;
import org.nd4j.common.function.Supplier;
import java.io.*;
public class SerializedSupplier<T> implements Serializable, Supplier<T> {
private byte[] asBytes;
public SerializedSupplier(T obj){
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
oos.writeObject(obj);
oos.flush();
oos.close();
asBytes = baos.toByteArray();
} catch (Exception e){
throw new RuntimeException("Error serializing object - must be serializable",e);
}
}
@Override
public T get() {
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(asBytes))){
return (T)ois.readObject();
} catch (Exception e){
throw new RuntimeException("Error deserializing object",e);
}
}
}

Some files were not shown because too many files have changed in this diff Show More