Compare commits
No commits in common. "1c3496ad844392afb1411b78234e9720a023d972" and "1749b9e4afa92f11f03889e73f671237e4dc771b" have entirely different histories.
1c3496ad84
...
1749b9e4af
|
@ -1,15 +1,8 @@
|
||||||
FROM nvidia/cuda:11.4.3-cudnn8-devel-ubuntu20.04
|
FROM nvidia/cuda:11.4.0-cudnn8-devel-ubuntu20.04
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
|
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
|
||||||
#Build cmake version from source \
|
RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \
|
||||||
#RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \
|
tar -xvf cmake-3.24.2.tar.gz && cd cmake-3.24.2 && \
|
||||||
# tar -xvf cmake-3.24.2.tar.gz && cd cmake-3.24.2 && \
|
./bootstrap && make && make install
|
||||||
# ./bootstrap && make && make install
|
|
||||||
RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2-linux-x86_64.sh && \
|
|
||||||
mkdir /opt/cmake && sh ./cmake-3.24.2-linux-x86_64.sh --skip-license --prefix=/opt/cmake && ln -s /opt/cmake/bin/cmake /usr/bin/cmake && \
|
|
||||||
rm cmake-3.24.2-linux-x86_64.sh
|
|
||||||
|
|
||||||
|
|
||||||
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
|
|
||||||
|
|
||||||
|
|
|
@ -36,8 +36,6 @@ pom.xml.versionsBackup
|
||||||
pom.xml.next
|
pom.xml.next
|
||||||
release.properties
|
release.properties
|
||||||
*dependency-reduced-pom.xml
|
*dependency-reduced-pom.xml
|
||||||
**/build/*
|
|
||||||
.gradle/*
|
|
||||||
|
|
||||||
# Specific for Nd4j
|
# Specific for Nd4j
|
||||||
*.md5
|
*.md5
|
||||||
|
@ -52,12 +50,12 @@ release.properties
|
||||||
*.dylib
|
*.dylib
|
||||||
.vs/
|
.vs/
|
||||||
.vscode/
|
.vscode/
|
||||||
.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/resources/bin
|
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/resources/bin
|
||||||
.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/writeNumpy.csv
|
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/writeNumpy.csv
|
||||||
.old/nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/data-all*
|
nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/data-all*
|
||||||
.old/nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/checkpoint
|
nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/checkpoint
|
||||||
.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/onnx/
|
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/onnx/
|
||||||
.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/tensorflow/
|
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/tensorflow/
|
||||||
|
|
||||||
doc_sources/
|
doc_sources/
|
||||||
doc_sources_*
|
doc_sources_*
|
||||||
|
@ -69,8 +67,8 @@ venv/
|
||||||
venv2/
|
venv2/
|
||||||
|
|
||||||
# Ignore the nd4j files that are created by javacpp at build to stop merge conflicts
|
# Ignore the nd4j files that are created by javacpp at build to stop merge conflicts
|
||||||
.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
|
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
|
||||||
.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
|
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
|
||||||
|
|
||||||
# Ignore meld temp files
|
# Ignore meld temp files
|
||||||
*.orig
|
*.orig
|
||||||
|
@ -84,15 +82,3 @@ bruai4j-native-common/cmake*
|
||||||
*.dll
|
*.dll
|
||||||
/bruai4j-native/bruai4j-native-common/blasbuild/
|
/bruai4j-native/bruai4j-native-common/blasbuild/
|
||||||
/bruai4j-native/bruai4j-native-common/build/
|
/bruai4j-native/bruai4j-native-common/build/
|
||||||
/cavis-native/cavis-native-lib/blasbuild/
|
|
||||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/classes/org.deeplearning4j.gradientcheck.AttentionLayerTest.html
|
|
||||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/base-style.css
|
|
||||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/style.css
|
|
||||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/js/report.js
|
|
||||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/packages/org.deeplearning4j.gradientcheck.html
|
|
||||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/index.html
|
|
||||||
/cavis-dnn/cavis-dnn-core/build/resources/main/iris.dat
|
|
||||||
/cavis-dnn/cavis-dnn-core/build/resources/test/junit-platform.properties
|
|
||||||
/cavis-dnn/cavis-dnn-core/build/resources/test/logback-test.xml
|
|
||||||
/cavis-dnn/cavis-dnn-core/build/test-results/cudaTest/TEST-org.deeplearning4j.gradientcheck.AttentionLayerTest.xml
|
|
||||||
/cavis-dnn/cavis-dnn-core/build/tmp/jar/MANIFEST.MF
|
|
||||||
|
|
|
@ -1,82 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
pipeline {
|
|
||||||
agent {
|
|
||||||
label 'linux'
|
|
||||||
}
|
|
||||||
|
|
||||||
stages {
|
|
||||||
stage('prep-build-environment-linux-cpu') {
|
|
||||||
steps {
|
|
||||||
checkout scm
|
|
||||||
sh 'gcc --version'
|
|
||||||
sh 'cmake --version'
|
|
||||||
sh 'sh ./gradlew --version'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stage('build-linux-cpu') {
|
|
||||||
environment {
|
|
||||||
MAVEN = credentials('Internal_Archiva')
|
|
||||||
OSSRH = credentials('OSSRH')
|
|
||||||
}
|
|
||||||
|
|
||||||
steps {
|
|
||||||
withGradle {
|
|
||||||
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cpu \
|
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
|
||||||
}
|
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*stage('test-linux-cpu') {
|
|
||||||
environment {
|
|
||||||
MAVEN = credentials('Internal Archiva')
|
|
||||||
OSSRH = credentials('OSSRH')
|
|
||||||
}
|
|
||||||
|
|
||||||
steps {
|
|
||||||
withGradle {
|
|
||||||
sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -PCAVIS_CHIP=cpu \
|
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
|
||||||
}
|
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
|
||||||
}
|
|
||||||
}*/
|
|
||||||
stage('publish-linux-cpu') {
|
|
||||||
environment {
|
|
||||||
MAVEN = credentials('Internal Archiva')
|
|
||||||
OSSRH = credentials('OSSRH')
|
|
||||||
}
|
|
||||||
|
|
||||||
steps {
|
|
||||||
withGradle {
|
|
||||||
sh 'sh ./gradlew publish --stacktrace -PCAVIS_CHIP=cpu \
|
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
|
||||||
}
|
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,62 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
pipeline {
|
|
||||||
agent {
|
|
||||||
/* dockerfile {
|
|
||||||
filename 'Dockerfile'
|
|
||||||
dir '.docker'
|
|
||||||
label 'linux && cuda'
|
|
||||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
|
||||||
//args '--gpus all' --needed for test only, you can build without GPU
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
label 'linux && cuda'
|
|
||||||
}
|
|
||||||
|
|
||||||
stages {
|
|
||||||
stage('prep-build-environment-linux-cuda') {
|
|
||||||
steps {
|
|
||||||
checkout scm
|
|
||||||
//sh 'nvidia-smi'
|
|
||||||
sh 'nvcc --version'
|
|
||||||
sh 'gcc --version'
|
|
||||||
sh 'cmake --version'
|
|
||||||
sh 'sh ./gradlew --version'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stage('build-linux-cuda') {
|
|
||||||
environment {
|
|
||||||
MAVEN = credentials('Internal_Archiva')
|
|
||||||
OSSRH = credentials('OSSRH')
|
|
||||||
}
|
|
||||||
|
|
||||||
steps {
|
|
||||||
withGradle {
|
|
||||||
sh 'sh ./gradlew build --stacktrace -PCAVIS_CHIP=cuda \
|
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
|
||||||
}
|
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,66 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
pipeline {
|
|
||||||
agent {
|
|
||||||
dockerfile {
|
|
||||||
filename 'Dockerfile'
|
|
||||||
dir '.docker'
|
|
||||||
label 'linux && docker && cuda'
|
|
||||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
|
||||||
//args '--gpus all' --needed for test only, you can build without GPU
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stages {
|
|
||||||
stage("Build all chip") {
|
|
||||||
parallel {
|
|
||||||
|
|
||||||
stage('prep-build-environment-linux-cuda') {
|
|
||||||
|
|
||||||
steps {
|
|
||||||
checkout scm
|
|
||||||
//sh 'nvidia-smi'
|
|
||||||
sh 'nvcc --version'
|
|
||||||
sh 'gcc --version'
|
|
||||||
sh 'cmake --version'
|
|
||||||
sh 'sh ./gradlew --version'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stage('build-linux-cuda') {
|
|
||||||
environment {
|
|
||||||
MAVEN = credentials('Internal_Archiva')
|
|
||||||
OSSRH = credentials('OSSRH')
|
|
||||||
}
|
|
||||||
|
|
||||||
steps {
|
|
||||||
withGradle {
|
|
||||||
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \
|
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
|
||||||
}
|
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,88 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
pipeline {
|
|
||||||
agent {
|
|
||||||
dockerfile {
|
|
||||||
filename 'Dockerfile'
|
|
||||||
dir '.docker'
|
|
||||||
label 'linux && docker'
|
|
||||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
|
||||||
//args '--gpus all'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stages {
|
|
||||||
stage('prep-build-environment-linux-cpu') {
|
|
||||||
steps {
|
|
||||||
checkout scm
|
|
||||||
sh 'gcc --version'
|
|
||||||
sh 'cmake --version'
|
|
||||||
sh 'sh ./gradlew --version'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stage('build-linux-cpu') {
|
|
||||||
environment {
|
|
||||||
MAVEN = credentials('Internal_Archiva')
|
|
||||||
OSSRH = credentials('OSSRH')
|
|
||||||
}
|
|
||||||
|
|
||||||
steps {
|
|
||||||
withGradle {
|
|
||||||
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cpu \
|
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
|
||||||
}
|
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stage('test-linux-cpu') {
|
|
||||||
environment {
|
|
||||||
MAVEN = credentials('Internal Archiva')
|
|
||||||
OSSRH = credentials('OSSRH')
|
|
||||||
}
|
|
||||||
|
|
||||||
steps {
|
|
||||||
withGradle {
|
|
||||||
//sh 'sh ./gradlew test --stacktrace -PCAVIS_CHIP=cpu \
|
|
||||||
// -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
|
||||||
// -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
|
||||||
}
|
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stage('publish-linux-cpu') {
|
|
||||||
environment {
|
|
||||||
MAVEN = credentials('Internal Archiva')
|
|
||||||
OSSRH = credentials('OSSRH')
|
|
||||||
}
|
|
||||||
|
|
||||||
steps {
|
|
||||||
withGradle {
|
|
||||||
sh 'sh ./gradlew publish --stacktrace -PCAVIS_CHIP=cpu \
|
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
|
||||||
}
|
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,49 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
pipeline {
|
|
||||||
agent {
|
|
||||||
dockerfile {
|
|
||||||
filename 'Dockerfile'
|
|
||||||
dir '.docker'
|
|
||||||
label 'linux && docker'
|
|
||||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
|
||||||
//args '--gpus all'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stages {
|
|
||||||
stage('publish-linux-cpu') {
|
|
||||||
environment {
|
|
||||||
MAVEN = credentials('Internal_Archiva')
|
|
||||||
OSSRH = credentials('OSSRH')
|
|
||||||
}
|
|
||||||
|
|
||||||
steps {
|
|
||||||
withGradle {
|
|
||||||
sh 'sh ./gradlew publish -x test -PCAVIS_CHIP=cpu \
|
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -20,13 +20,14 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
pipeline {
|
pipeline {
|
||||||
|
|
||||||
agent {
|
agent {
|
||||||
dockerfile {
|
dockerfile {
|
||||||
filename 'Dockerfile'
|
filename 'Dockerfile'
|
||||||
dir '.docker'
|
dir '.docker'
|
||||||
label 'linux && docker && cuda'
|
label 'linuxdocker'
|
||||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||||
//args '--gpus all' --needed for test only, you can build without GPU
|
args '--gpus all'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +35,7 @@ pipeline {
|
||||||
stage('prep-build-environment-linux-cuda') {
|
stage('prep-build-environment-linux-cuda') {
|
||||||
steps {
|
steps {
|
||||||
checkout scm
|
checkout scm
|
||||||
//sh 'nvidia-smi'
|
sh 'nvidia-smi'
|
||||||
sh 'nvcc --version'
|
sh 'nvcc --version'
|
||||||
sh 'gcc --version'
|
sh 'gcc --version'
|
||||||
sh 'cmake --version'
|
sh 'cmake --version'
|
||||||
|
@ -43,33 +44,19 @@ pipeline {
|
||||||
}
|
}
|
||||||
stage('build-linux-cuda') {
|
stage('build-linux-cuda') {
|
||||||
environment {
|
environment {
|
||||||
MAVEN = credentials('Internal_Archiva')
|
MAVEN = credentials('Internal Archiva')
|
||||||
OSSRH = credentials('OSSRH')
|
OSSRH = credentials('OSSRH')
|
||||||
}
|
}
|
||||||
|
|
||||||
steps {
|
steps {
|
||||||
withGradle {
|
withGradle {
|
||||||
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \
|
sh 'sh ./gradlew publish --stacktrace -x test -PCAVIS_CHIP=cuda \
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||||
}
|
}
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||||
}
|
}
|
||||||
}
|
|
||||||
stage('test-linux-cuda') {
|
|
||||||
environment {
|
|
||||||
MAVEN = credentials('Internal_Archiva')
|
|
||||||
OSSRH = credentials('OSSRH')
|
|
||||||
}
|
|
||||||
|
|
||||||
steps {
|
|
||||||
withGradle {
|
|
||||||
sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -Pskip-native=true -PCAVIS_CHIP=cuda \
|
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
|
||||||
}
|
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,58 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
pipeline {
|
|
||||||
agent {
|
|
||||||
dockerfile {
|
|
||||||
filename 'Dockerfile'
|
|
||||||
dir '.docker'
|
|
||||||
label 'WSL-docker'
|
|
||||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
|
||||||
//args '--gpus all'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stages {
|
|
||||||
stage('prep-build-environment-linux-cpu') {
|
|
||||||
steps {
|
|
||||||
checkout scm
|
|
||||||
sh 'gcc --version'
|
|
||||||
sh 'cmake --version'
|
|
||||||
sh 'sh ./gradlew --version'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stage('build-linux-cpu') {
|
|
||||||
environment {
|
|
||||||
MAVEN = credentials('Internal_Archiva')
|
|
||||||
OSSRH = credentials('OSSRH')
|
|
||||||
}
|
|
||||||
|
|
||||||
steps {
|
|
||||||
withGradle {
|
|
||||||
sh 'sh ./gradlew publish --stacktrace -x test -PCAVIS_CHIP=cpu \
|
|
||||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
|
||||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
|
||||||
}
|
|
||||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -48,12 +48,12 @@ Deeplearning4J offers a very high level API for defining even complex neural net
|
||||||
you how LeNet, a convolutional neural network, is defined in DL4J.
|
you how LeNet, a convolutional neural network, is defined in DL4J.
|
||||||
|
|
||||||
```java
|
```java
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.seed(seed)
|
.seed(seed)
|
||||||
.l2(0.0005)
|
.l2(0.0005)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(new Adam(1e-3))
|
.updater(new Adam(1e-3))
|
||||||
|
.list()
|
||||||
.layer(new ConvolutionLayer.Builder(5, 5)
|
.layer(new ConvolutionLayer.Builder(5, 5)
|
||||||
.stride(1,1)
|
.stride(1,1)
|
||||||
.nOut(20)
|
.nOut(20)
|
||||||
|
@ -78,7 +78,7 @@ NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.nOut(outputNum)
|
.nOut(outputNum)
|
||||||
.activation(Activation.SOFTMAX)
|
.activation(Activation.SOFTMAX)
|
||||||
.build())
|
.build())
|
||||||
.inputType(InputType.convolutionalFlat(28,28,1))
|
.setInputType(InputType.convolutionalFlat(28,28,1))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
branches:
|
||||||
|
only:
|
||||||
|
- master
|
||||||
|
notifications:
|
||||||
|
email: false
|
||||||
|
dist: trusty
|
||||||
|
sudo: false
|
||||||
|
cache:
|
||||||
|
directories:
|
||||||
|
- $HOME/.m2
|
||||||
|
language: java
|
||||||
|
jdk:
|
||||||
|
- openjdk8
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- os: linux
|
||||||
|
env: OS=linux-x86_64 SCALA=2.10
|
||||||
|
install: true
|
||||||
|
script: bash ./ci/build-linux-x86_64.sh
|
||||||
|
- os: linux
|
||||||
|
env: OS=linux-x86_64 SCALA=2.11
|
||||||
|
install: true
|
||||||
|
script: bash ./ci/build-linux-x86_64.sh
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
# Arbiter
|
||||||
|
|
||||||
|
A tool dedicated to tuning (hyperparameter optimization) of machine learning models. Part of the DL4J Suite of Machine Learning / Deep Learning tools for the enterprise.
|
||||||
|
|
||||||
|
|
||||||
|
## Modules
|
||||||
|
Arbiter contains the following modules:
|
||||||
|
|
||||||
|
- arbiter-core: Defines the API and core functionality, and also contains functionality for the Arbiter UI
|
||||||
|
- arbiter-deeplearning4j: For hyperparameter optimization of DL4J models (MultiLayerNetwork and ComputationGraph networks)
|
||||||
|
|
||||||
|
|
||||||
|
## Hyperparameter Optimization Functionality
|
||||||
|
|
||||||
|
The open-source version of Arbiter currently defines two methods of hyperparameter optimization:
|
||||||
|
|
||||||
|
- Grid search
|
||||||
|
- Random search
|
||||||
|
|
||||||
|
For optimization of complex models such as neural networks (those with more than a few hyperparameters), random search is superior to grid search, though Bayesian hyperparameter optimization schemes
|
||||||
|
For a comparison of random and grid search methods, see [Random Search for Hyper-parameter Optimization (Bergstra and Bengio, 2012)](http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf).
|
||||||
|
|
||||||
|
### Core Concepts and Classes in Arbiter for Hyperparameter Optimization
|
||||||
|
|
||||||
|
In order to conduct hyperparameter optimization in Arbiter, it is necessary for the user to understand and define the following:
|
||||||
|
|
||||||
|
- **Parameter Space**: A ```ParameterSpace<P>``` specifies the type and allowable values of hyperparameters for a model configuration of type ```P```. For example, ```P``` could be a MultiLayerConfiguration for DL4J
|
||||||
|
- **Candidate Generator**: A ```CandidateGenerator<C>``` is used to generate candidate models configurations of some type ```C```. The following implementations are defined in arbiter-core:
|
||||||
|
- ```RandomSearchCandidateGenerator```
|
||||||
|
- ```GridSearchCandidateGenerator```
|
||||||
|
- **Score Function**: A ```ScoreFunction<M,D>``` is used to score a model of type ```M``` given data of type ```D```. For example, in DL4J a score function might be used to calculate the classification accuracy from a DataSetIterator
|
||||||
|
- A key concept here is that they score is a single numerical (double precision) value that we either want to minimize or maximize - this is the goal of hyperparameter optimization
|
||||||
|
- **Termination Conditions**: One or more ```TerminationCondition``` instances must be provided to the ```OptimizationConfiguration```. ```TerminationCondition``` instances are used to control when hyperparameter optimization should be stopped. Some built-in termination conditions:
|
||||||
|
- ```MaxCandidatesCondition```: Terminate if more than the specified number of candidate hyperparameter configurations have been executed
|
||||||
|
- ```MaxTimeCondition```: Terminate after a specified amount of time has elapsed since starting the optimization
|
||||||
|
- **Result Saver**: The ```ResultSaver<C,M,A>``` interface is used to specify how the results of each hyperparameter optimization run should be saved. For example, whether saving should be done to local disk, to a database, to HDFS, or simply stored in memory.
|
||||||
|
- Note that ```ResultSaver.saveModel``` method returns a ```ResultReference``` object, which provides a mechanism for re-loading both the model and score from wherever it may be saved.
|
||||||
|
- **Optimization Configuration**: An ```OptimizationConfiguration<C,M,D,A>``` ties together the above configuration options in a fluent (builder) pattern.
|
||||||
|
- **Candidate Executor**: The ```CandidateExecutor<C,M,D,A>``` interface provides a layer of abstraction between the configuration and execution of each instance of learning. Currently, the only option is the ```LocalCandidateExecutor```, which is used to execute learning on a single machine (in the current JVM). In principle, other execution methods (for example, on Spark or cloud computing machines) could be implemented.
|
||||||
|
- **Optimization Runner**: The ```OptimizationRunner``` uses an ```OptimizationConfiguration``` and a ```CandidateExecutor``` to actually run the optimization, and save the results.
|
||||||
|
|
||||||
|
|
||||||
|
### Optimization of DeepLearning4J Models
|
||||||
|
|
||||||
|
(This section: forthcoming)
|
|
@ -0,0 +1,97 @@
|
||||||
|
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
~
|
||||||
|
~ This program and the accompanying materials are made available under the
|
||||||
|
~ terms of the Apache License, Version 2.0 which is available at
|
||||||
|
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
~
|
||||||
|
~ Unless required by applicable law or agreed to in writing, software
|
||||||
|
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
~ License for the specific language governing permissions and limitations
|
||||||
|
~ under the License.
|
||||||
|
~
|
||||||
|
~ SPDX-License-Identifier: Apache-2.0
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||||
|
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<parent>
|
||||||
|
<artifactId>arbiter</artifactId>
|
||||||
|
<groupId>net.brutex.ai</groupId>
|
||||||
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
</parent>
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<artifactId>arbiter-core</artifactId>
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
|
<name>arbiter-core</name>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>net.brutex.ai</groupId>
|
||||||
|
<artifactId>nd4j-api</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<exclusions>
|
||||||
|
<exclusion>
|
||||||
|
<groupId>com.google.code.findbugs</groupId>
|
||||||
|
<artifactId>*</artifactId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.google.guava</groupId>
|
||||||
|
<artifactId>guava</artifactId>
|
||||||
|
<version>${guava.jre.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-lang3</artifactId>
|
||||||
|
<version>${commons.lang.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-math3</artifactId>
|
||||||
|
<version>${commons.math.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.slf4j</groupId>
|
||||||
|
<artifactId>slf4j-api</artifactId>
|
||||||
|
<version>${slf4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>joda-time</groupId>
|
||||||
|
<artifactId>joda-time</artifactId>
|
||||||
|
<version>${jodatime.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.core</groupId>
|
||||||
|
<artifactId>jackson-annotations</artifactId>
|
||||||
|
<version>${jackson.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>net.brutex.ai</groupId>
|
||||||
|
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||||
|
<artifactId>jackson-datatype-joda</artifactId>
|
||||||
|
<version>${jackson.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>net.brutex.ai</groupId>
|
||||||
|
<artifactId>nd4j-native</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
<classifier>windows-x86_64</classifier>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
</project>
|
|
@ -0,0 +1,91 @@
|
||||||
|
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
~
|
||||||
|
~ This program and the accompanying materials are made available under the
|
||||||
|
~ terms of the Apache License, Version 2.0 which is available at
|
||||||
|
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
~
|
||||||
|
~ Unless required by applicable law or agreed to in writing, software
|
||||||
|
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
~ License for the specific language governing permissions and limitations
|
||||||
|
~ under the License.
|
||||||
|
~
|
||||||
|
~ SPDX-License-Identifier: Apache-2.0
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||||
|
|
||||||
|
<assembly>
|
||||||
|
<id>bin</id>
|
||||||
|
<!-- START SNIPPET: formats -->
|
||||||
|
<formats>
|
||||||
|
<format>tar.gz</format>
|
||||||
|
<!--
|
||||||
|
<format>tar.bz2</format>
|
||||||
|
<format>zip</format>
|
||||||
|
-->
|
||||||
|
</formats>
|
||||||
|
<!-- END SNIPPET: formats -->
|
||||||
|
|
||||||
|
<dependencySets>
|
||||||
|
<dependencySet>
|
||||||
|
<outputDirectory>lib</outputDirectory>
|
||||||
|
<includes>
|
||||||
|
<include>*:jar:*</include>
|
||||||
|
</includes>
|
||||||
|
<excludes>
|
||||||
|
<exclude>*:sources</exclude>
|
||||||
|
</excludes>
|
||||||
|
</dependencySet>
|
||||||
|
</dependencySets>
|
||||||
|
|
||||||
|
<!-- START SNIPPET: fileSets -->
|
||||||
|
<fileSets>
|
||||||
|
<fileSet>
|
||||||
|
<includes>
|
||||||
|
<include>readme.txt</include>
|
||||||
|
</includes>
|
||||||
|
</fileSet>
|
||||||
|
|
||||||
|
<fileSet>
|
||||||
|
<directory>src/main/resources/bin/</directory>
|
||||||
|
<outputDirectory>bin</outputDirectory>
|
||||||
|
<includes>
|
||||||
|
<include>arbiter</include>
|
||||||
|
</includes>
|
||||||
|
<lineEnding>unix</lineEnding>
|
||||||
|
<fileMode>0755</fileMode>
|
||||||
|
</fileSet>
|
||||||
|
|
||||||
|
<fileSet>
|
||||||
|
<directory>examples</directory>
|
||||||
|
<outputDirectory>examples</outputDirectory>
|
||||||
|
<!--
|
||||||
|
<lineEnding>unix</lineEnding>
|
||||||
|
https://stackoverflow.com/questions/2958282/stranges-files-in-my-assembly-since-switching-to-lineendingunix-lineending
|
||||||
|
-->
|
||||||
|
</fileSet>
|
||||||
|
|
||||||
|
|
||||||
|
<!--
|
||||||
|
<fileSet>
|
||||||
|
<directory>src/bin</directory>
|
||||||
|
<outputDirectory>bin</outputDirectory>
|
||||||
|
<includes>
|
||||||
|
<include>hello</include>
|
||||||
|
</includes>
|
||||||
|
<lineEnding>unix</lineEnding>
|
||||||
|
<fileMode>0755</fileMode>
|
||||||
|
</fileSet>
|
||||||
|
-->
|
||||||
|
|
||||||
|
<fileSet>
|
||||||
|
<directory>target</directory>
|
||||||
|
<outputDirectory>./</outputDirectory>
|
||||||
|
<includes>
|
||||||
|
<include>*.jar</include>
|
||||||
|
</includes>
|
||||||
|
</fileSet>
|
||||||
|
|
||||||
|
</fileSets>
|
||||||
|
<!-- END SNIPPET: fileSets -->
|
||||||
|
</assembly>
|
|
@ -0,0 +1,74 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Created by Alex on 23/07/2017.
|
||||||
|
*/
|
||||||
|
public abstract class AbstractParameterSpace<T> implements ParameterSpace<T> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||||
|
Map<String, ParameterSpace> m = new LinkedHashMap<>();
|
||||||
|
|
||||||
|
//Need to manually build and walk the class heirarchy...
|
||||||
|
Class<?> currClass = this.getClass();
|
||||||
|
List<Class<?>> classHeirarchy = new ArrayList<>();
|
||||||
|
while (currClass != Object.class) {
|
||||||
|
classHeirarchy.add(currClass);
|
||||||
|
currClass = currClass.getSuperclass();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = classHeirarchy.size() - 1; i >= 0; i--) {
|
||||||
|
//Use reflection here to avoid a mass of boilerplate code...
|
||||||
|
Field[] allFields = classHeirarchy.get(i).getDeclaredFields();
|
||||||
|
|
||||||
|
for (Field f : allFields) {
|
||||||
|
|
||||||
|
String name = f.getName();
|
||||||
|
Class<?> fieldClass = f.getType();
|
||||||
|
boolean isParamSpacefield = ParameterSpace.class.isAssignableFrom(fieldClass);
|
||||||
|
|
||||||
|
if (!isParamSpacefield) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
f.setAccessible(true);
|
||||||
|
|
||||||
|
ParameterSpace<?> p;
|
||||||
|
try {
|
||||||
|
p = (ParameterSpace<?>) f.get(this);
|
||||||
|
} catch (IllegalAccessException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (p != null) {
|
||||||
|
m.put(name, p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return m;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.util.SerializedSupplier;
|
||||||
|
import org.nd4j.common.function.Supplier;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Candidate: a proposed hyperparameter configuration.
|
||||||
|
* Also includes a map for data parameters, to configure things like data preprocessing, etc.
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class Candidate<C> implements Serializable {
|
||||||
|
|
||||||
|
private Supplier<C> supplier;
|
||||||
|
private int index;
|
||||||
|
private double[] flatParameters;
|
||||||
|
private Map<String, Object> dataParameters;
|
||||||
|
private Exception exception;
|
||||||
|
|
||||||
|
public Candidate(C value, int index, double[] flatParameters, Map<String,Object> dataParameters, Exception e) {
|
||||||
|
this(new SerializedSupplier<C>(value), index, flatParameters, dataParameters, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Candidate(C value, int index, double[] flatParameters) {
|
||||||
|
this(new SerializedSupplier<C>(value), index, flatParameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Candidate(Supplier<C> value, int index, double[] flatParameters) {
|
||||||
|
this(value, index, flatParameters, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public C getValue(){
|
||||||
|
return supplier.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A CandidateGenerator proposes candidates (i.e., hyperparameter configurations) for evaluation.
|
||||||
|
* This abstraction allows for different ways of generating the next configuration to test; for example,
|
||||||
|
* random search, grid search, Bayesian optimization methods, etc.
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
public interface CandidateGenerator {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Is this candidate generator able to generate more candidates? This will always return true in some
|
||||||
|
* cases, but some search strategies have a limit (grid search, for example)
|
||||||
|
*/
|
||||||
|
boolean hasMoreCandidates();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a candidate hyperparameter configuration
|
||||||
|
*/
|
||||||
|
Candidate getCandidate();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Report results for the candidate generator.
|
||||||
|
*
|
||||||
|
* @param result The results to report
|
||||||
|
*/
|
||||||
|
void reportResults(OptimizationResult result);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Get the parameter space for this candidate generator
|
||||||
|
*/
|
||||||
|
ParameterSpace<?> getParameterSpace();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param rngSeed Set the random number generator seed for the candidate generator
|
||||||
|
*/
|
||||||
|
void setRngSeed(long rngSeed);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The type (class) of the generated candidates
|
||||||
|
*/
|
||||||
|
Class<?> getCandidateType();
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An optimization result represents the results of an optimization run, including the canditate configuration, the
|
||||||
|
* trained model, the score for that model, and index of the model
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
@JsonIgnoreProperties({"resultReference"})
|
||||||
|
public class OptimizationResult implements Serializable {
|
||||||
|
@JsonProperty
|
||||||
|
private Candidate candidate;
|
||||||
|
@JsonProperty
|
||||||
|
private Double score;
|
||||||
|
@JsonProperty
|
||||||
|
private int index;
|
||||||
|
@JsonProperty
|
||||||
|
private Object modelSpecificResults;
|
||||||
|
@JsonProperty
|
||||||
|
private CandidateInfo candidateInfo;
|
||||||
|
private ResultReference resultReference;
|
||||||
|
|
||||||
|
|
||||||
|
public OptimizationResult(Candidate candidate, Double score, int index, Object modelSpecificResults,
|
||||||
|
CandidateInfo candidateInfo, ResultReference resultReference) {
|
||||||
|
this.candidate = candidate;
|
||||||
|
this.score = score;
|
||||||
|
this.index = index;
|
||||||
|
this.modelSpecificResults = modelSpecificResults;
|
||||||
|
this.candidateInfo = candidateInfo;
|
||||||
|
this.resultReference = resultReference;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ParameterSpace: defines the acceptable ranges of values a given parameter may take.
|
||||||
|
* Note that parameter spaces can be simple (like {@code ParameterSpace<Double>}) or complicated, including
|
||||||
|
* multiple nested ParameterSpaces
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
public interface ParameterSpace<P> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a candidate given a set of values. These values are then mapped to a specific candidate, using some
|
||||||
|
* mapping function (such as the prior probability distribution)
|
||||||
|
*
|
||||||
|
* @param parameterValues A set of values, each in the range [0,1], of length {@link #numParameters()}
|
||||||
|
*/
|
||||||
|
P getValue(double[] parameterValues);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the total number of parameters (hyperparameters) to be optimized. This includes optional parameters from
|
||||||
|
* different parameter subpaces. (Thus, not every parameter may be used in every candidate)
|
||||||
|
*
|
||||||
|
* @return Number of hyperparameters to be optimized
|
||||||
|
*/
|
||||||
|
int numParameters();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Collect a list of parameters, recursively. Note that leaf parameters are parameters that do not have any
|
||||||
|
* nested parameter spaces
|
||||||
|
*/
|
||||||
|
List<ParameterSpace> collectLeaves();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a list of nested parameter spaces by name. Note that the returned parameter spaces may in turn have further
|
||||||
|
* nested parameter spaces. The map should be empty for leaf parameter spaces
|
||||||
|
*
|
||||||
|
* @return A map of nested parameter spaces
|
||||||
|
*/
|
||||||
|
Map<String, ParameterSpace> getNestedSpaces();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Is this ParameterSpace a leaf? (i.e., does it contain other ParameterSpaces internally?)
|
||||||
|
*/
|
||||||
|
@JsonIgnore
|
||||||
|
boolean isLeaf();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* For leaf ParameterSpaces: set the indices of the leaf ParameterSpace.
|
||||||
|
* Expects input of length {@link #numParameters()}. Throws exception if {@link #isLeaf()} is false.
|
||||||
|
*
|
||||||
|
* @param indices Indices to set. Length should equal {@link #numParameters()}
|
||||||
|
*/
|
||||||
|
void setIndices(int... indices);
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Properties;
|
||||||
|
import java.util.concurrent.Callable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The TaskCreator is used to take a candidate configuration, data provider and score function, and create something
|
||||||
|
* that can be executed as a Callable
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public interface TaskCreator {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
|
||||||
|
*
|
||||||
|
* @param candidate Candidate (model) configuration to be trained
|
||||||
|
* @param dataProvider DataProvider, for the data
|
||||||
|
* @param scoreFunction Score function to be used to evaluate the model
|
||||||
|
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
|
||||||
|
* @return A callable that returns an OptimizationResult, once optimization is complete
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
Callable<OptimizationResult> create(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction,
|
||||||
|
List<StatusListener> statusListeners, IOptimizationRunner runner);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
|
||||||
|
*
|
||||||
|
* @param candidate Candidate (model) configuration to be trained
|
||||||
|
* @param dataSource Data source
|
||||||
|
* @param dataSourceProperties Properties (may be null) for the data source
|
||||||
|
* @param scoreFunction Score function to be used to evaluate the model
|
||||||
|
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
|
||||||
|
* @return A callable that returns an OptimizationResult, once optimization is complete
|
||||||
|
*/
|
||||||
|
Callable<OptimizationResult> create(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties,
|
||||||
|
ScoreFunction scoreFunction, List<StatusListener> statusListeners, IOptimizationRunner runner);
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class TaskCreatorProvider {
|
||||||
|
|
||||||
|
private static Map<Class<? extends ParameterSpace>, Class<? extends TaskCreator>> map = new HashMap<>();
|
||||||
|
|
||||||
|
public synchronized static TaskCreator defaultTaskCreatorFor(Class<? extends ParameterSpace> paramSpaceClass){
|
||||||
|
Class<? extends TaskCreator> c = map.get(paramSpaceClass);
|
||||||
|
try {
|
||||||
|
if(c == null){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return c.newInstance();
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new RuntimeException("Could not create new instance of task creator class: " + c + " - missing no-arg constructor?", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public synchronized static void registerDefaultTaskCreatorClass(Class<? extends ParameterSpace> spaceClass,
|
||||||
|
Class<? extends TaskCreator> creatorClass){
|
||||||
|
map.put(spaceClass, creatorClass);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,82 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.adapter;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An abstract class used for adapting one type into another. Subclasses of this need to merely implement 2 simple methods
|
||||||
|
*
|
||||||
|
* @param <F> Type to convert from
|
||||||
|
* @param <T> Type to convert to
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@AllArgsConstructor
|
||||||
|
public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
|
||||||
|
|
||||||
|
|
||||||
|
protected abstract T convertValue(F from);
|
||||||
|
|
||||||
|
protected abstract ParameterSpace<F> underlying();
|
||||||
|
|
||||||
|
protected abstract String underlyingName();
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public T getValue(double[] parameterValues) {
|
||||||
|
return convertValue(underlying().getValue(parameterValues));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int numParameters() {
|
||||||
|
return underlying().numParameters();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<ParameterSpace> collectLeaves() {
|
||||||
|
ParameterSpace p = underlying();
|
||||||
|
if(p.isLeaf()){
|
||||||
|
return Collections.singletonList(p);
|
||||||
|
}
|
||||||
|
return underlying().collectLeaves();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||||
|
return Collections.singletonMap(underlyingName(), (ParameterSpace)underlying());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isLeaf() {
|
||||||
|
return false; //Underlying may be a leaf, however
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setIndices(int... indices) {
|
||||||
|
underlying().setIndices(indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return underlying().toString();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,54 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.data;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DataProvider interface abstracts out the providing of data
|
||||||
|
* @deprecated Use {@link DataSource}
|
||||||
|
*/
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
@Deprecated
|
||||||
|
public interface DataProvider extends Serializable {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get training data given some parameters for the data.
|
||||||
|
* Data parameters map is used to specify things like batch
|
||||||
|
* size data preprocessing
|
||||||
|
*
|
||||||
|
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||||
|
* @return training data
|
||||||
|
*/
|
||||||
|
Object trainData(Map<String, Object> dataParameters);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get training data given some parameters for the data. Data parameters map is used to specify things like batch
|
||||||
|
* size data preprocessing
|
||||||
|
*
|
||||||
|
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||||
|
* @return training data
|
||||||
|
*/
|
||||||
|
Object testData(Map<String, Object> dataParameters);
|
||||||
|
|
||||||
|
Class<?> getDataType();
|
||||||
|
}
|
|
@ -0,0 +1,89 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.data;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is a {@link DataProvider} for
|
||||||
|
* an {@link DataSetIteratorFactory} which
|
||||||
|
* based on a key of {@link DataSetIteratorFactoryProvider#FACTORY_KEY}
|
||||||
|
* will create {@link org.nd4j.linalg.dataset.api.iterator.DataSetIterator}
|
||||||
|
* for use with arbiter.
|
||||||
|
*
|
||||||
|
* This {@link DataProvider} is mainly meant for use for command line driven
|
||||||
|
* applications.
|
||||||
|
*
|
||||||
|
* @author Adam Gibson
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class DataSetIteratorFactoryProvider implements DataProvider {
|
||||||
|
|
||||||
|
public final static String FACTORY_KEY = "org.deeplearning4j.arbiter.data.data.factory";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get training data given some parameters for the data.
|
||||||
|
* Data parameters map is used to specify things like batch
|
||||||
|
* size data preprocessing
|
||||||
|
*
|
||||||
|
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||||
|
* @return training data
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public DataSetIteratorFactory trainData(Map<String, Object> dataParameters) {
|
||||||
|
return create(dataParameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get training data given some parameters for the data. Data parameters map
|
||||||
|
* is used to specify things like batch
|
||||||
|
* size data preprocessing
|
||||||
|
*
|
||||||
|
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||||
|
* @return training data
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public DataSetIteratorFactory testData(Map<String, Object> dataParameters) {
|
||||||
|
return create(dataParameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Class<?> getDataType() {
|
||||||
|
return DataSetIteratorFactory.class;
|
||||||
|
}
|
||||||
|
|
||||||
|
private DataSetIteratorFactory create(Map<String, Object> dataParameters) {
|
||||||
|
if (dataParameters == null)
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Data parameters is null. Please specify a class name to create a dataset iterator.");
|
||||||
|
if (!dataParameters.containsKey(FACTORY_KEY))
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"No data set iterator factory class found. Please specify a class name with key "
|
||||||
|
+ FACTORY_KEY);
|
||||||
|
String value = dataParameters.get(FACTORY_KEY).toString();
|
||||||
|
try {
|
||||||
|
Class<? extends DataSetIteratorFactory> clazz =
|
||||||
|
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
|
||||||
|
return clazz.newInstance();
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.data;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.Properties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DataSource: defines where the data should come from for training and testing.
|
||||||
|
* Note that implementations must have a no-argument contsructor
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public interface DataSource extends Serializable {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure the current data source with the specified properties
|
||||||
|
* Note: These properties are fixed for the training instance, and are optionally provided by the user
|
||||||
|
* at the configuration stage.
|
||||||
|
* The properties could be anything - and are usually specific to each DataSource implementation.
|
||||||
|
* For example, values such as batch size could be set using these properties
|
||||||
|
* @param properties Properties to apply to the data source instance
|
||||||
|
*/
|
||||||
|
void configure(Properties properties);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
|
||||||
|
*/
|
||||||
|
Object trainData();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
|
||||||
|
*/
|
||||||
|
Object testData();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The type of data returned by {@link #trainData()} and {@link #testData()}.
|
||||||
|
* Usually DataSetIterator or MultiDataSetIterator
|
||||||
|
* @return Class of the objects returned by trainData and testData
|
||||||
|
*/
|
||||||
|
Class<?> getDataType();
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.evaluation;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ModelEvaluator: Used to conduct additional evaluation.
|
||||||
|
* For example, this may be classification performance on a test set or similar
|
||||||
|
*/
|
||||||
|
public interface ModelEvaluator extends Serializable {
|
||||||
|
Object evaluateModel(Object model, DataProvider dataProvider);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The model types supported by this class
|
||||||
|
*/
|
||||||
|
List<Class<?>> getSupportedModelTypes();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The datatypes supported by this class
|
||||||
|
*/
|
||||||
|
List<Class<?>> getSupportedDataTypes();
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.saving;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple class to store optimization results in-memory.
|
||||||
|
* Not recommended for large (or a large number of) models.
|
||||||
|
*/
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class InMemoryResultSaver implements ResultSaver {
|
||||||
|
@Override
|
||||||
|
public ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException {
|
||||||
|
return new InMemoryResult(result, modelResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Class<?>> getSupportedCandidateTypes() {
|
||||||
|
return Collections.<Class<?>>singletonList(Object.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Class<?>> getSupportedModelTypes() {
|
||||||
|
return Collections.<Class<?>>singletonList(Object.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
private static class InMemoryResult implements ResultReference {
|
||||||
|
private OptimizationResult result;
|
||||||
|
private Object modelResult;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public OptimizationResult getResult() throws IOException {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object getResultModel() throws IOException {
|
||||||
|
return modelResult;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.saving;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Idea: We can't store all results in memory in general (might have thousands of candidates with millions of
|
||||||
|
* parameters each)
|
||||||
|
* So instead: return a reference to the saved result. Idea is that the result may be saved to disk or a database,
|
||||||
|
* and we can easily load it back into memory (if/when required) using the getResult() method
|
||||||
|
*/
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
public interface ResultReference {
|
||||||
|
|
||||||
|
OptimizationResult getResult() throws IOException;
|
||||||
|
|
||||||
|
Object getResultModel() throws IOException;
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.saving;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The ResultSaver interface provides a means of saving models in such a way that they can be loaded back into memory later,
|
||||||
|
* regardless of where/how they are saved.
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
public interface ResultSaver {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save the model (including configuration and any additional evaluation/results)
|
||||||
|
*
|
||||||
|
* @param result Optimization result for the model to save
|
||||||
|
* @param modelResult Model result to save
|
||||||
|
* @return ResultReference, such that the result can be loaded back into memory
|
||||||
|
* @throws IOException If IO error occurs during model saving
|
||||||
|
*/
|
||||||
|
ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The candidate types supported by this class
|
||||||
|
*/
|
||||||
|
List<Class<?>> getSupportedCandidateTypes();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The model types supported by this class
|
||||||
|
*/
|
||||||
|
List<Class<?>> getSupportedModelTypes();
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.score;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Properties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ScoreFunction defines the objective of hyperparameter optimization.
|
||||||
|
* Specifically, it is used to calculate a score for a given model, relative to the data set provided
|
||||||
|
* in the configuration.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
public interface ScoreFunction extends Serializable {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate and return the score, for the given model and data provider
|
||||||
|
*
|
||||||
|
* @param model Model to score
|
||||||
|
* @param dataProvider Data provider - data to use
|
||||||
|
* @param dataParameters Parameters for data
|
||||||
|
* @return Calculated score
|
||||||
|
*/
|
||||||
|
double score(Object model, DataProvider dataProvider, Map<String, Object> dataParameters);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate and return the score, for the given model and data provider
|
||||||
|
*
|
||||||
|
* @param model Model to score
|
||||||
|
* @param dataSource Data source
|
||||||
|
* @param dataSourceProperties data source properties
|
||||||
|
* @return Calculated score
|
||||||
|
*/
|
||||||
|
double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Should this score function be minimized or maximized?
|
||||||
|
*
|
||||||
|
* @return true if score should be minimized, false if score should be maximized
|
||||||
|
*/
|
||||||
|
boolean minimize();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The model types supported by this class
|
||||||
|
*/
|
||||||
|
List<Class<?>> getSupportedModelTypes();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The data types supported by this class
|
||||||
|
*/
|
||||||
|
List<Class<?>> getSupportedDataTypes();
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.termination;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Terminate hyperparameter search when the number of candidates exceeds a specified value.
|
||||||
|
* Note that this is counted as number of completed candidates, plus number of failed candidates.
|
||||||
|
*/
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
|
@Data
|
||||||
|
public class MaxCandidatesCondition implements TerminationCondition {
|
||||||
|
@JsonProperty
|
||||||
|
private int maxCandidates;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void initialize(IOptimizationRunner optimizationRunner) {
|
||||||
|
//No op
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean terminate(IOptimizationRunner optimizationRunner) {
|
||||||
|
return optimizationRunner.numCandidatesCompleted() + optimizationRunner.numCandidatesFailed() >= maxCandidates;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "MaxCandidatesCondition(" + maxCandidates + ")";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.termination;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||||
|
import org.joda.time.format.DateTimeFormat;
|
||||||
|
import org.joda.time.format.DateTimeFormatter;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Terminate hyperparameter optimization after
|
||||||
|
* a fixed amount of time has passed
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@NoArgsConstructor
|
||||||
|
@Data
|
||||||
|
public class MaxTimeCondition implements TerminationCondition {
|
||||||
|
private static final DateTimeFormatter formatter = DateTimeFormat.forPattern("dd-MMM HH:mm ZZ");
|
||||||
|
|
||||||
|
private long duration;
|
||||||
|
private TimeUnit timeUnit;
|
||||||
|
private long startTime;
|
||||||
|
private long endTime;
|
||||||
|
|
||||||
|
|
||||||
|
private MaxTimeCondition(@JsonProperty("duration") long duration, @JsonProperty("timeUnit") TimeUnit timeUnit,
|
||||||
|
@JsonProperty("startTime") long startTime, @JsonProperty("endTime") long endTime) {
|
||||||
|
this.duration = duration;
|
||||||
|
this.timeUnit = timeUnit;
|
||||||
|
this.startTime = startTime;
|
||||||
|
this.endTime = endTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param duration Duration of time
|
||||||
|
* @param timeUnit Unit that the duration is specified in
|
||||||
|
*/
|
||||||
|
public MaxTimeCondition(long duration, TimeUnit timeUnit) {
|
||||||
|
this.duration = duration;
|
||||||
|
this.timeUnit = timeUnit;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void initialize(IOptimizationRunner optimizationRunner) {
|
||||||
|
startTime = System.currentTimeMillis();
|
||||||
|
this.endTime = startTime + timeUnit.toMillis(duration);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean terminate(IOptimizationRunner optimizationRunner) {
|
||||||
|
return System.currentTimeMillis() >= endTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
if (startTime > 0) {
|
||||||
|
return "MaxTimeCondition(" + duration + "," + timeUnit + ",start=\"" + formatter.print(startTime)
|
||||||
|
+ "\",end=\"" + formatter.print(endTime) + "\")";
|
||||||
|
} else {
|
||||||
|
return "MaxTimeCondition(" + duration + "," + timeUnit + "\")";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.api.termination;
|
||||||
|
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Global termination condition for conducting hyperparameter optimization.
|
||||||
|
* Termination conditions are used to determine if/when the optimization should stop.
|
||||||
|
*/
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
public interface TerminationCondition {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize the termination condition (such as starting timers, etc).
|
||||||
|
*/
|
||||||
|
void initialize(IOptimizationRunner optimizationRunner);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Determine whether optimization should be terminated
|
||||||
|
*
|
||||||
|
* @param optimizationRunner Optimization runner
|
||||||
|
* @return true if learning should be terminated, false otherwise
|
||||||
|
*/
|
||||||
|
boolean terminate(IOptimizationRunner optimizationRunner);
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,226 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.config;
|
||||||
|
|
||||||
|
import lombok.*;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.lang.reflect.Constructor;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Properties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OptimizationConfiguration ties together all of the various
|
||||||
|
* components (such as data, score functions, result saving etc)
|
||||||
|
* required to execute hyperparameter optimization.
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(exclude = {"dataProvider", "terminationConditions", "candidateGenerator", "resultSaver"})
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
public class OptimizationConfiguration {
|
||||||
|
@JsonSerialize
|
||||||
|
private DataProvider dataProvider;
|
||||||
|
@JsonSerialize
|
||||||
|
private Class<? extends DataSource> dataSource;
|
||||||
|
@JsonSerialize
|
||||||
|
private Properties dataSourceProperties;
|
||||||
|
@JsonSerialize
|
||||||
|
private CandidateGenerator candidateGenerator;
|
||||||
|
@JsonSerialize
|
||||||
|
private ResultSaver resultSaver;
|
||||||
|
@JsonSerialize
|
||||||
|
private ScoreFunction scoreFunction;
|
||||||
|
@JsonSerialize
|
||||||
|
private List<TerminationCondition> terminationConditions;
|
||||||
|
@JsonSerialize
|
||||||
|
private Long rngSeed;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
private long executionStartTime;
|
||||||
|
|
||||||
|
|
||||||
|
private OptimizationConfiguration(Builder builder) {
|
||||||
|
this.dataProvider = builder.dataProvider;
|
||||||
|
this.dataSource = builder.dataSource;
|
||||||
|
this.dataSourceProperties = builder.dataSourceProperties;
|
||||||
|
this.candidateGenerator = builder.candidateGenerator;
|
||||||
|
this.resultSaver = builder.resultSaver;
|
||||||
|
this.scoreFunction = builder.scoreFunction;
|
||||||
|
this.terminationConditions = builder.terminationConditions;
|
||||||
|
this.rngSeed = builder.rngSeed;
|
||||||
|
|
||||||
|
if (rngSeed != null)
|
||||||
|
candidateGenerator.setRngSeed(rngSeed);
|
||||||
|
|
||||||
|
//Validate the configuration: data types, score types, etc
|
||||||
|
//TODO
|
||||||
|
|
||||||
|
//Validate that the dataSource has a no-arg constructor
|
||||||
|
if (dataSource != null) {
|
||||||
|
try {
|
||||||
|
dataSource.getConstructor();
|
||||||
|
} catch (NoSuchMethodException e) {
|
||||||
|
throw new IllegalStateException("Data source class " + dataSource.getName() + " does not have a public no-argument constructor");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
|
||||||
|
private DataProvider dataProvider;
|
||||||
|
private Class<? extends DataSource> dataSource;
|
||||||
|
private Properties dataSourceProperties;
|
||||||
|
private CandidateGenerator candidateGenerator;
|
||||||
|
private ResultSaver resultSaver;
|
||||||
|
private ScoreFunction scoreFunction;
|
||||||
|
private List<TerminationCondition> terminationConditions;
|
||||||
|
private Long rngSeed;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Use {@link #dataSource(Class, Properties)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public Builder dataProvider(DataProvider dataProvider) {
|
||||||
|
this.dataProvider = dataProvider;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DataSource: defines where the data should come from for training and testing.
|
||||||
|
* Note that implementations must have a no-argument contsructor
|
||||||
|
*
|
||||||
|
* @param dataSource Class for the data source
|
||||||
|
* @param dataSourceProperties May be null. Properties for configuring the data source
|
||||||
|
*/
|
||||||
|
public Builder dataSource(Class<? extends DataSource> dataSource, Properties dataSourceProperties) {
|
||||||
|
this.dataSource = dataSource;
|
||||||
|
this.dataSourceProperties = dataSourceProperties;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder candidateGenerator(CandidateGenerator candidateGenerator) {
|
||||||
|
this.candidateGenerator = candidateGenerator;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder modelSaver(ResultSaver resultSaver) {
|
||||||
|
this.resultSaver = resultSaver;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder scoreFunction(ScoreFunction scoreFunction) {
|
||||||
|
this.scoreFunction = scoreFunction;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Termination conditions to use
|
||||||
|
*
|
||||||
|
* @param conditions
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Builder terminationConditions(TerminationCondition... conditions) {
|
||||||
|
terminationConditions = Arrays.asList(conditions);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder terminationConditions(List<TerminationCondition> terminationConditions) {
|
||||||
|
this.terminationConditions = terminationConditions;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder rngSeed(long rngSeed) {
|
||||||
|
this.rngSeed = rngSeed;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OptimizationConfiguration build() {
|
||||||
|
return new OptimizationConfiguration(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create an optimization configuration from the json
|
||||||
|
*
|
||||||
|
* @param json the json to create the config from
|
||||||
|
* For type definitions
|
||||||
|
* @see OptimizationConfiguration
|
||||||
|
*/
|
||||||
|
public static OptimizationConfiguration fromYaml(String json) {
|
||||||
|
try {
|
||||||
|
return JsonMapper.getYamlMapper().readValue(json, OptimizationConfiguration.class);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create an optimization configuration from the json
|
||||||
|
*
|
||||||
|
* @param json the json to create the config from
|
||||||
|
* @see OptimizationConfiguration
|
||||||
|
*/
|
||||||
|
public static OptimizationConfiguration fromJson(String json) {
|
||||||
|
try {
|
||||||
|
return JsonMapper.getMapper().readValue(json, OptimizationConfiguration.class);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return a json configuration of this optimization configuration
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public String toJson() {
|
||||||
|
try {
|
||||||
|
return JsonMapper.getMapper().writeValueAsString(this);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return a yaml configuration of this optimization configuration
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public String toYaml() {
|
||||||
|
try {
|
||||||
|
return JsonMapper.getYamlMapper().writeValueAsString(this);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,96 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||||
|
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||||
|
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Degenerate distribution: i.e., integer "distribution" that is just a fixed value
|
||||||
|
*/
|
||||||
|
public class DegenerateIntegerDistribution implements IntegerDistribution {
|
||||||
|
private int value;
|
||||||
|
|
||||||
|
public DegenerateIntegerDistribution(int value) {
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double probability(int x) {
|
||||||
|
return (x == value ? 1.0 : 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double cumulativeProbability(int x) {
|
||||||
|
return (x >= value ? 1.0 : 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double cumulativeProbability(int x0, int x1) throws NumberIsTooLargeException {
|
||||||
|
return (value >= x0 && value <= x1 ? 1.0 : 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int inverseCumulativeProbability(double p) throws OutOfRangeException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getNumericalMean() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getNumericalVariance() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getSupportLowerBound() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getSupportUpperBound() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSupportConnected() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reseedRandomGenerator(long seed) {
|
||||||
|
//no op
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int sample() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int[] sample(int sampleSize) {
|
||||||
|
int[] out = new int[sampleSize];
|
||||||
|
for (int i = 0; i < out.length; i++)
|
||||||
|
out[i] = value;
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,149 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.distribution.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Distribution utils for Apache Commons math distributions - which don't provide equals, hashcode, toString methods,
|
||||||
|
* don't implement serializable etc.
|
||||||
|
* Which makes unit testing etc quite difficult.
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class DistributionUtils {
|
||||||
|
|
||||||
|
private DistributionUtils() {}
|
||||||
|
|
||||||
|
|
||||||
|
public static boolean distributionsEqual(RealDistribution a, RealDistribution b) {
|
||||||
|
if (a.getClass() != b.getClass())
|
||||||
|
return false;
|
||||||
|
Class<?> c = a.getClass();
|
||||||
|
if (c == BetaDistribution.class) {
|
||||||
|
BetaDistribution ba = (BetaDistribution) a;
|
||||||
|
BetaDistribution bb = (BetaDistribution) b;
|
||||||
|
|
||||||
|
return ba.getAlpha() == bb.getAlpha() && ba.getBeta() == bb.getBeta();
|
||||||
|
} else if (c == CauchyDistribution.class) {
|
||||||
|
CauchyDistribution ca = (CauchyDistribution) a;
|
||||||
|
CauchyDistribution cb = (CauchyDistribution) b;
|
||||||
|
return ca.getMedian() == cb.getMedian() && ca.getScale() == cb.getScale();
|
||||||
|
} else if (c == ChiSquaredDistribution.class) {
|
||||||
|
ChiSquaredDistribution ca = (ChiSquaredDistribution) a;
|
||||||
|
ChiSquaredDistribution cb = (ChiSquaredDistribution) b;
|
||||||
|
return ca.getDegreesOfFreedom() == cb.getDegreesOfFreedom();
|
||||||
|
} else if (c == ExponentialDistribution.class) {
|
||||||
|
ExponentialDistribution ea = (ExponentialDistribution) a;
|
||||||
|
ExponentialDistribution eb = (ExponentialDistribution) b;
|
||||||
|
return ea.getMean() == eb.getMean();
|
||||||
|
} else if (c == FDistribution.class) {
|
||||||
|
FDistribution fa = (FDistribution) a;
|
||||||
|
FDistribution fb = (FDistribution) b;
|
||||||
|
return fa.getNumeratorDegreesOfFreedom() == fb.getNumeratorDegreesOfFreedom()
|
||||||
|
&& fa.getDenominatorDegreesOfFreedom() == fb.getDenominatorDegreesOfFreedom();
|
||||||
|
} else if (c == GammaDistribution.class) {
|
||||||
|
GammaDistribution ga = (GammaDistribution) a;
|
||||||
|
GammaDistribution gb = (GammaDistribution) b;
|
||||||
|
return ga.getShape() == gb.getShape() && ga.getScale() == gb.getScale();
|
||||||
|
} else if (c == LevyDistribution.class) {
|
||||||
|
LevyDistribution la = (LevyDistribution) a;
|
||||||
|
LevyDistribution lb = (LevyDistribution) b;
|
||||||
|
return la.getLocation() == lb.getLocation() && la.getScale() == lb.getScale();
|
||||||
|
} else if (c == LogNormalDistribution.class) {
|
||||||
|
LogNormalDistribution la = (LogNormalDistribution) a;
|
||||||
|
LogNormalDistribution lb = (LogNormalDistribution) b;
|
||||||
|
return la.getScale() == lb.getScale() && la.getShape() == lb.getShape();
|
||||||
|
} else if (c == NormalDistribution.class) {
|
||||||
|
NormalDistribution na = (NormalDistribution) a;
|
||||||
|
NormalDistribution nb = (NormalDistribution) b;
|
||||||
|
return na.getMean() == nb.getMean() && na.getStandardDeviation() == nb.getStandardDeviation();
|
||||||
|
} else if (c == ParetoDistribution.class) {
|
||||||
|
ParetoDistribution pa = (ParetoDistribution) a;
|
||||||
|
ParetoDistribution pb = (ParetoDistribution) b;
|
||||||
|
return pa.getScale() == pb.getScale() && pa.getShape() == pb.getShape();
|
||||||
|
} else if (c == TDistribution.class) {
|
||||||
|
TDistribution ta = (TDistribution) a;
|
||||||
|
TDistribution tb = (TDistribution) b;
|
||||||
|
return ta.getDegreesOfFreedom() == tb.getDegreesOfFreedom();
|
||||||
|
} else if (c == TriangularDistribution.class) {
|
||||||
|
TriangularDistribution ta = (TriangularDistribution) a;
|
||||||
|
TriangularDistribution tb = (TriangularDistribution) b;
|
||||||
|
return ta.getSupportLowerBound() == tb.getSupportLowerBound()
|
||||||
|
&& ta.getSupportUpperBound() == tb.getSupportUpperBound() && ta.getMode() == tb.getMode();
|
||||||
|
} else if (c == UniformRealDistribution.class) {
|
||||||
|
UniformRealDistribution ua = (UniformRealDistribution) a;
|
||||||
|
UniformRealDistribution ub = (UniformRealDistribution) b;
|
||||||
|
return ua.getSupportLowerBound() == ub.getSupportLowerBound()
|
||||||
|
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
|
||||||
|
} else if (c == WeibullDistribution.class) {
|
||||||
|
WeibullDistribution wa = (WeibullDistribution) a;
|
||||||
|
WeibullDistribution wb = (WeibullDistribution) b;
|
||||||
|
return wa.getShape() == wb.getShape() && wa.getScale() == wb.getScale();
|
||||||
|
} else if (c == LogUniformDistribution.class ){
|
||||||
|
LogUniformDistribution lu_a = (LogUniformDistribution)a;
|
||||||
|
LogUniformDistribution lu_b = (LogUniformDistribution)b;
|
||||||
|
return lu_a.getMin() == lu_b.getMin() && lu_a.getMax() == lu_b.getMax();
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException("Unknown or not supported RealDistribution: " + c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean distributionEquals(IntegerDistribution a, IntegerDistribution b) {
|
||||||
|
if (a.getClass() != b.getClass())
|
||||||
|
return false;
|
||||||
|
Class<?> c = a.getClass();
|
||||||
|
|
||||||
|
if (c == BinomialDistribution.class) {
|
||||||
|
BinomialDistribution ba = (BinomialDistribution) a;
|
||||||
|
BinomialDistribution bb = (BinomialDistribution) b;
|
||||||
|
return ba.getNumberOfTrials() == bb.getNumberOfTrials()
|
||||||
|
&& ba.getProbabilityOfSuccess() == bb.getProbabilityOfSuccess();
|
||||||
|
} else if (c == GeometricDistribution.class) {
|
||||||
|
GeometricDistribution ga = (GeometricDistribution) a;
|
||||||
|
GeometricDistribution gb = (GeometricDistribution) b;
|
||||||
|
return ga.getProbabilityOfSuccess() == gb.getProbabilityOfSuccess();
|
||||||
|
} else if (c == HypergeometricDistribution.class) {
|
||||||
|
HypergeometricDistribution ha = (HypergeometricDistribution) a;
|
||||||
|
HypergeometricDistribution hb = (HypergeometricDistribution) b;
|
||||||
|
return ha.getPopulationSize() == hb.getPopulationSize()
|
||||||
|
&& ha.getNumberOfSuccesses() == hb.getNumberOfSuccesses()
|
||||||
|
&& ha.getSampleSize() == hb.getSampleSize();
|
||||||
|
} else if (c == PascalDistribution.class) {
|
||||||
|
PascalDistribution pa = (PascalDistribution) a;
|
||||||
|
PascalDistribution pb = (PascalDistribution) b;
|
||||||
|
return pa.getNumberOfSuccesses() == pb.getNumberOfSuccesses()
|
||||||
|
&& pa.getProbabilityOfSuccess() == pb.getProbabilityOfSuccess();
|
||||||
|
} else if (c == PoissonDistribution.class) {
|
||||||
|
PoissonDistribution pa = (PoissonDistribution) a;
|
||||||
|
PoissonDistribution pb = (PoissonDistribution) b;
|
||||||
|
return pa.getMean() == pb.getMean();
|
||||||
|
} else if (c == UniformIntegerDistribution.class) {
|
||||||
|
UniformIntegerDistribution ua = (UniformIntegerDistribution) a;
|
||||||
|
UniformIntegerDistribution ub = (UniformIntegerDistribution) b;
|
||||||
|
return ua.getSupportUpperBound() == ub.getSupportUpperBound()
|
||||||
|
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
|
||||||
|
} else if (c == ZipfDistribution.class) {
|
||||||
|
ZipfDistribution za = (ZipfDistribution) a;
|
||||||
|
ZipfDistribution zb = (ZipfDistribution) b;
|
||||||
|
return za.getNumberOfElements() == zb.getNumberOfElements() && za.getExponent() == zb.getNumberOfElements();
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException("Unknown or not supported IntegerDistribution: " + c);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,155 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||||
|
|
||||||
|
import com.google.common.base.Preconditions;
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.apache.commons.math3.distribution.RealDistribution;
|
||||||
|
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||||
|
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||||
|
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Log uniform distribution, with support in range [min, max] for min > 0
|
||||||
|
*
|
||||||
|
* Reference: <a href="https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php">https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php</a>
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class LogUniformDistribution implements RealDistribution {
|
||||||
|
|
||||||
|
@Getter private final double min;
|
||||||
|
@Getter private final double max;
|
||||||
|
|
||||||
|
private final double logMin;
|
||||||
|
private final double logMax;
|
||||||
|
|
||||||
|
private transient Random rng = new Random();
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param min Minimum value
|
||||||
|
* @param max Maximum value
|
||||||
|
*/
|
||||||
|
public LogUniformDistribution(double min, double max) {
|
||||||
|
Preconditions.checkArgument(min > 0, "Minimum must be > 0. Got: " + min);
|
||||||
|
Preconditions.checkArgument(max > min, "Maximum must be > min. Got: (min, max)=("
|
||||||
|
+ min + "," + max + ")");
|
||||||
|
this.min = min;
|
||||||
|
this.max = max;
|
||||||
|
|
||||||
|
this.logMin = Math.log(min);
|
||||||
|
this.logMax = Math.log(max);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double probability(double x) {
|
||||||
|
if(x < min || x > max){
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1.0 / (x * (logMax - logMin));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double density(double x) {
|
||||||
|
return probability(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double cumulativeProbability(double x) {
|
||||||
|
if(x <= min){
|
||||||
|
return 0.0;
|
||||||
|
} else if(x >= max){
|
||||||
|
return 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (Math.log(x)-logMin)/(logMax-logMin);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException {
|
||||||
|
return cumulativeProbability(x1) - cumulativeProbability(x0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double inverseCumulativeProbability(double p) throws OutOfRangeException {
|
||||||
|
Preconditions.checkArgument(p >= 0 && p <= 1, "Invalid input: " + p);
|
||||||
|
return Math.exp(p * (logMax-logMin) + logMin);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getNumericalMean() {
|
||||||
|
return (max-min)/(logMax-logMin);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getNumericalVariance() {
|
||||||
|
double d1 = (logMax-logMin)*(max*max - min*min) - 2*(max-min)*(max-min);
|
||||||
|
return d1 / (2*Math.pow(logMax-logMin, 2.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getSupportLowerBound() {
|
||||||
|
return min;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getSupportUpperBound() {
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSupportLowerBoundInclusive() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSupportUpperBoundInclusive() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isSupportConnected() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reseedRandomGenerator(long seed) {
|
||||||
|
rng.setSeed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double sample() {
|
||||||
|
return inverseCumulativeProbability(rng.nextDouble());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double[] sample(int sampleSize) {
|
||||||
|
double[] d = new double[sampleSize];
|
||||||
|
for( int i=0; i<sampleSize; i++ ){
|
||||||
|
d[i] = sample();
|
||||||
|
}
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
return "LogUniformDistribution(min=" + min + ",max=" + max + ")";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||||
|
import org.deeplearning4j.arbiter.util.LeafUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* BaseCandidateGenerator: abstract class upon which {@link RandomSearchGenerator},
|
||||||
|
* {@link GridSearchCandidateGenerator} and {@link GeneticSearchCandidateGenerator}
|
||||||
|
* are built.
|
||||||
|
*
|
||||||
|
* @param <T> Type of candidates to generate
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@EqualsAndHashCode(exclude = {"rng", "candidateCounter"})
|
||||||
|
public abstract class BaseCandidateGenerator<T> implements CandidateGenerator {
|
||||||
|
protected ParameterSpace<T> parameterSpace;
|
||||||
|
protected AtomicInteger candidateCounter = new AtomicInteger(0);
|
||||||
|
protected SynchronizedRandomGenerator rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||||
|
protected Map<String, Object> dataParameters;
|
||||||
|
protected boolean initDone = false;
|
||||||
|
|
||||||
|
public BaseCandidateGenerator(ParameterSpace<T> parameterSpace, Map<String, Object> dataParameters,
|
||||||
|
boolean initDone) {
|
||||||
|
this.parameterSpace = parameterSpace;
|
||||||
|
this.dataParameters = dataParameters;
|
||||||
|
this.initDone = initDone;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void initialize() {
|
||||||
|
if(!initDone) {
|
||||||
|
//First: collect leaf parameter spaces objects and remove duplicates
|
||||||
|
List<ParameterSpace> noDuplicatesList = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
|
||||||
|
|
||||||
|
//Second: assign each a number
|
||||||
|
int i = 0;
|
||||||
|
for (ParameterSpace ps : noDuplicatesList) {
|
||||||
|
int np = ps.numParameters();
|
||||||
|
if (np == 1) {
|
||||||
|
ps.setIndices(i++);
|
||||||
|
} else {
|
||||||
|
int[] values = new int[np];
|
||||||
|
for (int j = 0; j < np; j++)
|
||||||
|
values[j] = i++;
|
||||||
|
ps.setIndices(values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
initDone = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ParameterSpace<T> getParameterSpace() {
|
||||||
|
return parameterSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reportResults(OptimizationResult result) {
|
||||||
|
//No op
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setRngSeed(long rngSeed) {
|
||||||
|
rng.setSeed(rngSeed);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,187 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.EmptyPopulationInitializer;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.GeneticSelectionOperator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Uses a genetic algorithm to generate candidates.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class GeneticSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
protected final PopulationModel populationModel;
|
||||||
|
|
||||||
|
protected final ChromosomeFactory chromosomeFactory;
|
||||||
|
protected final SelectionOperator selectionOperator;
|
||||||
|
|
||||||
|
protected boolean hasMoreCandidates = true;
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
protected final ParameterSpace<?> parameterSpace;
|
||||||
|
|
||||||
|
protected Map<String, Object> dataParameters;
|
||||||
|
protected boolean initDone;
|
||||||
|
protected boolean minimizeScore;
|
||||||
|
protected PopulationModel populationModel;
|
||||||
|
protected ChromosomeFactory chromosomeFactory;
|
||||||
|
protected SelectionOperator selectionOperator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param parameterSpace ParameterSpace from which to generate candidates
|
||||||
|
* @param scoreFunction The score function that will be used in the OptimizationConfiguration
|
||||||
|
*/
|
||||||
|
public Builder(ParameterSpace<?> parameterSpace, ScoreFunction scoreFunction) {
|
||||||
|
this.parameterSpace = parameterSpace;
|
||||||
|
this.minimizeScore = scoreFunction.minimize();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param populationModel The PopulationModel instance to use.
|
||||||
|
*/
|
||||||
|
public Builder populationModel(PopulationModel populationModel) {
|
||||||
|
this.populationModel = populationModel;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param selectionOperator The SelectionOperator to use. Default is GeneticSelectionOperator
|
||||||
|
*/
|
||||||
|
public Builder selectionOperator(SelectionOperator selectionOperator) {
|
||||||
|
this.selectionOperator = selectionOperator;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder dataParameters(Map<String, Object> dataParameters) {
|
||||||
|
|
||||||
|
this.dataParameters = dataParameters;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public GeneticSearchCandidateGenerator.Builder initDone(boolean initDone) {
|
||||||
|
this.initDone = initDone;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param chromosomeFactory The ChromosomeFactory to use
|
||||||
|
*/
|
||||||
|
public Builder chromosomeFactory(ChromosomeFactory chromosomeFactory) {
|
||||||
|
this.chromosomeFactory = chromosomeFactory;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public GeneticSearchCandidateGenerator build() {
|
||||||
|
if (populationModel == null) {
|
||||||
|
PopulationInitializer defaultPopulationInitializer = new EmptyPopulationInitializer();
|
||||||
|
populationModel = new PopulationModel.Builder().populationInitializer(defaultPopulationInitializer)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chromosomeFactory == null) {
|
||||||
|
chromosomeFactory = new ChromosomeFactory();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selectionOperator == null) {
|
||||||
|
selectionOperator = new GeneticSelectionOperator.Builder().build();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new GeneticSearchCandidateGenerator(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private GeneticSearchCandidateGenerator(Builder builder) {
|
||||||
|
super(builder.parameterSpace, builder.dataParameters, builder.initDone);
|
||||||
|
|
||||||
|
initialize();
|
||||||
|
|
||||||
|
chromosomeFactory = builder.chromosomeFactory;
|
||||||
|
populationModel = builder.populationModel;
|
||||||
|
selectionOperator = builder.selectionOperator;
|
||||||
|
|
||||||
|
chromosomeFactory.initializeInstance(builder.parameterSpace.numParameters());
|
||||||
|
populationModel.initializeInstance(builder.minimizeScore);
|
||||||
|
selectionOperator.initializeInstance(populationModel, chromosomeFactory);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasMoreCandidates() {
|
||||||
|
return hasMoreCandidates;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Candidate getCandidate() {
|
||||||
|
|
||||||
|
double[] values = null;
|
||||||
|
Object value = null;
|
||||||
|
Exception e = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
values = selectionOperator.buildNextGenes();
|
||||||
|
value = parameterSpace.getValue(values);
|
||||||
|
} catch (GeneticGenerationException e2) {
|
||||||
|
log.warn("Error generating candidate", e2);
|
||||||
|
e = e2;
|
||||||
|
hasMoreCandidates = false;
|
||||||
|
} catch (Exception e2) {
|
||||||
|
log.warn("Error getting configuration for candidate", e2);
|
||||||
|
e = e2;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Class<?> getCandidateType() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "GeneticSearchCandidateGenerator";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reportResults(OptimizationResult result) {
|
||||||
|
if (result.getScore() == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Chromosome newChromosome = chromosomeFactory.createChromosome(result.getCandidate().getFlatParameters(),
|
||||||
|
result.getScore());
|
||||||
|
populationModel.add(newChromosome);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,232 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator;
|
||||||
|
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.math3.random.RandomAdaptor;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
||||||
|
import org.deeplearning4j.arbiter.util.LeafUtils;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* GridSearchCandidateGenerator: generates candidates in an exhaustive grid search manner.<br>
|
||||||
|
* Note that:<br>
|
||||||
|
* - For discrete parameters: the grid size (# values to check per hyperparameter) is equal to the number of values for
|
||||||
|
* that hyperparameter<br>
|
||||||
|
* - For integer parameters: the grid size is equal to {@code min(discretizationCount,max-min+1)}. Some integer ranges can
|
||||||
|
* be large, and we don't necessarily want to exhaustively search them. {@code discretizationCount} is a constructor argument<br>
|
||||||
|
* - For continuous parameters: the grid size is equal to {@code discretizationCount}.<br>
|
||||||
|
* In all cases, the minimum, maximum and gridSize-2 values between the min/max will be generated.<br>
|
||||||
|
* Also note that: if a probability distribution is provided for continuous hyperparameters, this will be taken into account
|
||||||
|
* when generating candidates. This allows the grid for a hyperparameter to be non-linear: i.e., for example, linear in log space
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@EqualsAndHashCode(exclude = {"order"}, callSuper = true)
|
||||||
|
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
|
||||||
|
public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In what order should candidates be generated?<br>
|
||||||
|
* <b>Sequential</b>: generate candidates in order. The first hyperparameter will be changed most rapidly, and the last
|
||||||
|
* will be changed least rapidly.<br>
|
||||||
|
* <b>RandomOrder</b>: generate candidates in a random order<br>
|
||||||
|
* In both cases, the same candidates will be generated; only the order of generation is different
|
||||||
|
*/
|
||||||
|
public enum Mode {
|
||||||
|
Sequential, RandomOrder
|
||||||
|
}
|
||||||
|
|
||||||
|
private final int discretizationCount;
|
||||||
|
private final Mode mode;
|
||||||
|
|
||||||
|
private int[] numValuesPerParam;
|
||||||
|
@Getter
|
||||||
|
private int totalNumCandidates;
|
||||||
|
private Queue<Integer> order;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param parameterSpace ParameterSpace from which to generate candidates
|
||||||
|
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
|
||||||
|
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
|
||||||
|
* do [0.0, 0.5, 1.0]. Note that if all values
|
||||||
|
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
|
||||||
|
* in which candidates should be generated.
|
||||||
|
*/
|
||||||
|
public GridSearchCandidateGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
|
||||||
|
@JsonProperty("discretizationCount") int discretizationCount, @JsonProperty("mode") Mode mode,
|
||||||
|
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
|
||||||
|
@JsonProperty("initDone") boolean initDone) {
|
||||||
|
super(parameterSpace, dataParameters, initDone);
|
||||||
|
this.discretizationCount = discretizationCount;
|
||||||
|
this.mode = mode;
|
||||||
|
initialize();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param parameterSpace ParameterSpace from which to generate candidates
|
||||||
|
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
|
||||||
|
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
|
||||||
|
* do [0.0, 0.5, 1.0]. Note that if all values
|
||||||
|
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
|
||||||
|
* in which candidates should be generated.
|
||||||
|
*/
|
||||||
|
public GridSearchCandidateGenerator(ParameterSpace<?> parameterSpace, int discretizationCount, Mode mode,
|
||||||
|
Map<String, Object> dataParameters){
|
||||||
|
this(parameterSpace, discretizationCount, mode, dataParameters, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void initialize() {
|
||||||
|
super.initialize();
|
||||||
|
|
||||||
|
List<ParameterSpace> leaves = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
|
||||||
|
int nParams = leaves.size();
|
||||||
|
|
||||||
|
//Work out for each parameter: is it continuous or discrete?
|
||||||
|
// for grid search: discrete values are grid-searchable as-is
|
||||||
|
// continuous values: discretize using 'discretizationCount' bins
|
||||||
|
// integer values: use min(max-min+1, discretizationCount) values. i.e., discretize if necessary
|
||||||
|
numValuesPerParam = new int[nParams];
|
||||||
|
long searchSize = 1;
|
||||||
|
for (int i = 0; i < nParams; i++) {
|
||||||
|
ParameterSpace ps = leaves.get(i);
|
||||||
|
if (ps instanceof DiscreteParameterSpace) {
|
||||||
|
DiscreteParameterSpace dps = (DiscreteParameterSpace) ps;
|
||||||
|
numValuesPerParam[i] = dps.numValues();
|
||||||
|
} else if (ps instanceof IntegerParameterSpace) {
|
||||||
|
IntegerParameterSpace ips = (IntegerParameterSpace) ps;
|
||||||
|
int min = ips.getMin();
|
||||||
|
int max = ips.getMax();
|
||||||
|
//Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000)
|
||||||
|
numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount);
|
||||||
|
} else if (ps instanceof FixedValue){
|
||||||
|
numValuesPerParam[i] = 1;
|
||||||
|
} else {
|
||||||
|
numValuesPerParam[i] = discretizationCount;
|
||||||
|
}
|
||||||
|
searchSize *= numValuesPerParam[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (searchSize >= Integer.MAX_VALUE)
|
||||||
|
throw new IllegalStateException("Invalid search: cannot process search with " + searchSize
|
||||||
|
+ " candidates > Integer.MAX_VALUE"); //TODO find a more reasonable upper bound?
|
||||||
|
|
||||||
|
order = new ConcurrentLinkedQueue<>();
|
||||||
|
|
||||||
|
totalNumCandidates = (int) searchSize;
|
||||||
|
switch (mode) {
|
||||||
|
case Sequential:
|
||||||
|
for (int i = 0; i < totalNumCandidates; i++) {
|
||||||
|
order.add(i);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case RandomOrder:
|
||||||
|
List<Integer> tempList = new ArrayList<>(totalNumCandidates);
|
||||||
|
for (int i = 0; i < totalNumCandidates; i++) {
|
||||||
|
tempList.add(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
Collections.shuffle(tempList, new RandomAdaptor(rng));
|
||||||
|
order.addAll(tempList);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasMoreCandidates() {
|
||||||
|
return !order.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Candidate getCandidate() {
|
||||||
|
int next = order.remove();
|
||||||
|
|
||||||
|
//Next: max integer (candidate number) to values
|
||||||
|
double[] values = indexToValues(numValuesPerParam, next, totalNumCandidates);
|
||||||
|
|
||||||
|
Object value = null;
|
||||||
|
Exception e = null;
|
||||||
|
try {
|
||||||
|
value = parameterSpace.getValue(values);
|
||||||
|
} catch (Exception e2) {
|
||||||
|
log.warn("Error getting configuration for candidate", e2);
|
||||||
|
e = e2;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Class<?> getCandidateType() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double[] indexToValues(int[] numValuesPerParam, int candidateIdx, int product) {
|
||||||
|
//How? first map to index of num possible values. Then: to double values in range 0 to 1
|
||||||
|
// 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc
|
||||||
|
//Based on: Nd4j Shape.ind2sub
|
||||||
|
|
||||||
|
int countNon1 = 0;
|
||||||
|
for( int i : numValuesPerParam)
|
||||||
|
if(i > 1)
|
||||||
|
countNon1++;
|
||||||
|
|
||||||
|
int denom = product;
|
||||||
|
int num = candidateIdx;
|
||||||
|
int[] index = new int[numValuesPerParam.length];
|
||||||
|
|
||||||
|
for (int i = index.length - 1; i >= 0; i--) {
|
||||||
|
denom /= numValuesPerParam[i];
|
||||||
|
index[i] = num / denom;
|
||||||
|
num %= denom;
|
||||||
|
}
|
||||||
|
|
||||||
|
//Now: convert indexes to values in range [0,1]
|
||||||
|
//min value -> 0
|
||||||
|
//max value -> 1
|
||||||
|
double[] out = new double[countNon1];
|
||||||
|
int outIdx = 0;
|
||||||
|
for (int i = 0; i < numValuesPerParam.length; i++) {
|
||||||
|
if (numValuesPerParam[i] > 1){
|
||||||
|
out[outIdx++] = index[i] / ((double) (numValuesPerParam[i] - 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "GridSearchCandidateGenerator(mode=" + mode + ")";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,93 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator;
|
||||||
|
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RandomSearchGenerator: generates candidates at random.<br>
|
||||||
|
* Note: if a probability distribution is provided for continuous hyperparameters,
|
||||||
|
* this will be taken into account
|
||||||
|
* when generating candidates. This allows the search to be weighted more towards
|
||||||
|
* certain values according to a probability
|
||||||
|
* density. For example: generate samples for learning rate according to log uniform distribution
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
|
||||||
|
public class RandomSearchGenerator extends BaseCandidateGenerator {
|
||||||
|
|
||||||
|
@JsonCreator
|
||||||
|
public RandomSearchGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
|
||||||
|
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
|
||||||
|
@JsonProperty("initDone") boolean initDone) {
|
||||||
|
super(parameterSpace, dataParameters, initDone);
|
||||||
|
initialize();
|
||||||
|
}
|
||||||
|
|
||||||
|
public RandomSearchGenerator(ParameterSpace<?> parameterSpace, Map<String,Object> dataParameters){
|
||||||
|
this(parameterSpace, dataParameters, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public RandomSearchGenerator(ParameterSpace<?> parameterSpace){
|
||||||
|
this(parameterSpace, null, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasMoreCandidates() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Candidate getCandidate() {
|
||||||
|
double[] randomValues = new double[parameterSpace.numParameters()];
|
||||||
|
for (int i = 0; i < randomValues.length; i++)
|
||||||
|
randomValues[i] = rng.nextDouble();
|
||||||
|
|
||||||
|
Object value = null;
|
||||||
|
Exception e = null;
|
||||||
|
try {
|
||||||
|
value = parameterSpace.getValue(randomValues);
|
||||||
|
} catch (Exception e2) {
|
||||||
|
log.warn("Error getting configuration for candidate", e2);
|
||||||
|
e = e2;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Candidate(value, candidateCounter.getAndIncrement(), randomValues, dataParameters, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Class<?> getCandidateType() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "RandomSearchGenerator";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Candidates are stored as Chromosome in the population model
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class Chromosome {
|
||||||
|
/**
|
||||||
|
* The fitness score of the genes.
|
||||||
|
*/
|
||||||
|
protected final double fitness;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The genes.
|
||||||
|
*/
|
||||||
|
protected final double[] genes;
|
||||||
|
|
||||||
|
public Chromosome(double[] genes, double fitness) {
|
||||||
|
this.genes = genes;
|
||||||
|
this.fitness = fitness;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A factory that builds new chromosomes. Used by the GeneticSearchCandidateGenerator.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class ChromosomeFactory {
|
||||||
|
private int chromosomeLength;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called by the GeneticSearchCandidateGenerator.
|
||||||
|
*/
|
||||||
|
public void initializeInstance(int chromosomeLength) {
|
||||||
|
this.chromosomeLength = chromosomeLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new instance of a Chromosome
|
||||||
|
*
|
||||||
|
* @param genes The genes
|
||||||
|
* @param fitness The fitness score
|
||||||
|
* @return A new instance of Chromosome
|
||||||
|
*/
|
||||||
|
public Chromosome createChromosome(double[] genes, double fitness) {
|
||||||
|
return new Chromosome(genes, fitness);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The number of genes in a chromosome
|
||||||
|
*/
|
||||||
|
public int getChromosomeLength() {
|
||||||
|
return chromosomeLength;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,120 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A crossover operator that linearly combines the genes of two parents. <br>
|
||||||
|
* When a crossover is generated (with a of probability <i>crossover rate</i>), each genes is a linear combination of the corresponding genes of the parents.
|
||||||
|
* <p>
|
||||||
|
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i>
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class ArithmeticCrossover extends TwoParentsCrossoverOperator {
|
||||||
|
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||||
|
|
||||||
|
private final double crossoverRate;
|
||||||
|
private final RandomGenerator rng;
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||||
|
private RandomGenerator rng;
|
||||||
|
private TwoParentSelection parentSelection;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The probability that the operator generates a crossover (default 0.85).
|
||||||
|
*
|
||||||
|
* @param rate A value between 0.0 and 1.0
|
||||||
|
*/
|
||||||
|
public Builder crossoverRate(double rate) {
|
||||||
|
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||||
|
|
||||||
|
this.crossoverRate = rate;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use a supplied RandomGenerator
|
||||||
|
*
|
||||||
|
* @param rng An instance of RandomGenerator
|
||||||
|
*/
|
||||||
|
public Builder randomGenerator(RandomGenerator rng) {
|
||||||
|
this.rng = rng;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The parent selection behavior. Default is random parent selection.
|
||||||
|
*
|
||||||
|
* @param parentSelection An instance of TwoParentSelection
|
||||||
|
*/
|
||||||
|
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||||
|
this.parentSelection = parentSelection;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArithmeticCrossover build() {
|
||||||
|
if (rng == null) {
|
||||||
|
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parentSelection == null) {
|
||||||
|
parentSelection = new RandomTwoParentSelection();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new ArithmeticCrossover(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ArithmeticCrossover(ArithmeticCrossover.Builder builder) {
|
||||||
|
super(builder.parentSelection);
|
||||||
|
|
||||||
|
this.crossoverRate = builder.crossoverRate;
|
||||||
|
this.rng = builder.rng;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Has a probability <i>crossoverRate</i> of performing the crossover where each gene is a linear combination of:<br>
|
||||||
|
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i><br>
|
||||||
|
* Otherwise, returns the genes of a random parent.
|
||||||
|
*
|
||||||
|
* @return The crossover result. See {@link CrossoverResult}.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public CrossoverResult crossover() {
|
||||||
|
double[][] parents = parentSelection.selectParents();
|
||||||
|
|
||||||
|
double[] offspringValues = new double[parents[0].length];
|
||||||
|
|
||||||
|
if (rng.nextDouble() < crossoverRate) {
|
||||||
|
for (int i = 0; i < offspringValues.length; ++i) {
|
||||||
|
double t = rng.nextDouble();
|
||||||
|
offspringValues[i] = t * parents[0][i] + (1.0 - t) * parents[1][i];
|
||||||
|
}
|
||||||
|
return new CrossoverResult(true, offspringValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CrossoverResult(false, parents[0]);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract class for all crossover operators
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public abstract class CrossoverOperator {
|
||||||
|
protected PopulationModel populationModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Will be called by the selection operator once the population model is instantiated.
|
||||||
|
*/
|
||||||
|
public void initializeInstance(PopulationModel populationModel) {
|
||||||
|
this.populationModel = populationModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs the crossover
|
||||||
|
*
|
||||||
|
* @return The crossover result. See {@link CrossoverResult}.
|
||||||
|
*/
|
||||||
|
public abstract CrossoverResult crossover();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returned by a crossover operator
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class CrossoverResult {
|
||||||
|
/**
|
||||||
|
* If false, there was no crossover and the operator simply returned the genes of a random parent.
|
||||||
|
* If true, the genes are the result of a crossover.
|
||||||
|
*/
|
||||||
|
private final boolean isModified;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The genes returned by the operator.
|
||||||
|
*/
|
||||||
|
private final double[] genes;
|
||||||
|
|
||||||
|
public CrossoverResult(boolean isModified, double[] genes) {
|
||||||
|
this.isModified = isModified;
|
||||||
|
this.genes = genes;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,178 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
|
||||||
|
import java.util.Deque;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The K-Point crossover will select at random multiple crossover points.<br>
|
||||||
|
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched.
|
||||||
|
*/
|
||||||
|
public class KPointCrossover extends TwoParentsCrossoverOperator {
|
||||||
|
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||||
|
private static final int DEFAULT_MIN_CROSSOVER = 1;
|
||||||
|
private static final int DEFAULT_MAX_CROSSOVER = 4;
|
||||||
|
|
||||||
|
private final double crossoverRate;
|
||||||
|
private final int minCrossovers;
|
||||||
|
private final int maxCrossovers;
|
||||||
|
|
||||||
|
private final RandomGenerator rng;
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||||
|
private int minCrossovers = DEFAULT_MIN_CROSSOVER;
|
||||||
|
private int maxCrossovers = DEFAULT_MAX_CROSSOVER;
|
||||||
|
private RandomGenerator rng;
|
||||||
|
private TwoParentSelection parentSelection;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The probability that the operator generates a crossover (default 0.85).
|
||||||
|
*
|
||||||
|
* @param rate A value between 0.0 and 1.0
|
||||||
|
*/
|
||||||
|
public Builder crossoverRate(double rate) {
|
||||||
|
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||||
|
|
||||||
|
this.crossoverRate = rate;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of crossovers points (default is min 1, max 4)
|
||||||
|
*
|
||||||
|
* @param min The minimum number
|
||||||
|
* @param max The maximum number
|
||||||
|
*/
|
||||||
|
public Builder numCrossovers(int min, int max) {
|
||||||
|
Preconditions.checkState(max >= 0 && min >= 0, "Min and max must be positive");
|
||||||
|
Preconditions.checkState(max >= min, "Max must be greater or equal to min");
|
||||||
|
|
||||||
|
this.minCrossovers = min;
|
||||||
|
this.maxCrossovers = max;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use a fixed number of crossover points
|
||||||
|
*
|
||||||
|
* @param num The number of crossovers
|
||||||
|
*/
|
||||||
|
public Builder numCrossovers(int num) {
|
||||||
|
Preconditions.checkState(num >= 0, "Num must be positive");
|
||||||
|
|
||||||
|
this.minCrossovers = num;
|
||||||
|
this.maxCrossovers = num;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use a supplied RandomGenerator
|
||||||
|
*
|
||||||
|
* @param rng An instance of RandomGenerator
|
||||||
|
*/
|
||||||
|
public Builder randomGenerator(RandomGenerator rng) {
|
||||||
|
this.rng = rng;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The parent selection behavior. Default is random parent selection.
|
||||||
|
*
|
||||||
|
* @param parentSelection An instance of TwoParentSelection
|
||||||
|
*/
|
||||||
|
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||||
|
this.parentSelection = parentSelection;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public KPointCrossover build() {
|
||||||
|
if (rng == null) {
|
||||||
|
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parentSelection == null) {
|
||||||
|
parentSelection = new RandomTwoParentSelection();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new KPointCrossover(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private CrossoverPointsGenerator crossoverPointsGenerator;
|
||||||
|
|
||||||
|
private KPointCrossover(KPointCrossover.Builder builder) {
|
||||||
|
super(builder.parentSelection);
|
||||||
|
|
||||||
|
this.crossoverRate = builder.crossoverRate;
|
||||||
|
this.maxCrossovers = builder.maxCrossovers;
|
||||||
|
this.minCrossovers = builder.minCrossovers;
|
||||||
|
this.rng = builder.rng;
|
||||||
|
}
|
||||||
|
|
||||||
|
private CrossoverPointsGenerator getCrossoverPointsGenerator(int chromosomeLength) {
|
||||||
|
if (crossoverPointsGenerator == null) {
|
||||||
|
crossoverPointsGenerator =
|
||||||
|
new CrossoverPointsGenerator(chromosomeLength, minCrossovers, maxCrossovers, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
return crossoverPointsGenerator;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select at random multiple crossover points.<br>
|
||||||
|
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched. <br>
|
||||||
|
* Otherwise, returns the genes of a random parent.
|
||||||
|
*
|
||||||
|
* @return The crossover result. See {@link CrossoverResult}.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public CrossoverResult crossover() {
|
||||||
|
double[][] parents = parentSelection.selectParents();
|
||||||
|
|
||||||
|
boolean isModified = false;
|
||||||
|
double[] resultGenes = parents[0];
|
||||||
|
|
||||||
|
if (rng.nextDouble() < crossoverRate) {
|
||||||
|
// Select crossover points
|
||||||
|
Deque<Integer> crossoverPoints = getCrossoverPointsGenerator(parents[0].length).getCrossoverPoints();
|
||||||
|
|
||||||
|
// Crossover
|
||||||
|
resultGenes = new double[parents[0].length];
|
||||||
|
int currentParent = 0;
|
||||||
|
int nextCrossover = crossoverPoints.pop();
|
||||||
|
for (int i = 0; i < resultGenes.length; ++i) {
|
||||||
|
if (i == nextCrossover) {
|
||||||
|
currentParent = currentParent == 0 ? 1 : 0;
|
||||||
|
nextCrossover = crossoverPoints.pop();
|
||||||
|
}
|
||||||
|
resultGenes[i] = parents[currentParent][i];
|
||||||
|
}
|
||||||
|
isModified = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CrossoverResult(isModified, resultGenes);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,123 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The single point crossover will select a random point where every genes before that point comes from one parent
|
||||||
|
* and after which every genes comes from the other parent.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class SinglePointCrossover extends TwoParentsCrossoverOperator {
|
||||||
|
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||||
|
|
||||||
|
private final RandomGenerator rng;
|
||||||
|
private final double crossoverRate;
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||||
|
private RandomGenerator rng;
|
||||||
|
private TwoParentSelection parentSelection;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The probability that the operator generates a crossover (default 0.85).
|
||||||
|
*
|
||||||
|
* @param rate A value between 0.0 and 1.0
|
||||||
|
*/
|
||||||
|
public Builder crossoverRate(double rate) {
|
||||||
|
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||||
|
|
||||||
|
this.crossoverRate = rate;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use a supplied RandomGenerator
|
||||||
|
*
|
||||||
|
* @param rng An instance of RandomGenerator
|
||||||
|
*/
|
||||||
|
public Builder randomGenerator(RandomGenerator rng) {
|
||||||
|
this.rng = rng;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The parent selection behavior. Default is random parent selection.
|
||||||
|
*
|
||||||
|
* @param parentSelection An instance of TwoParentSelection
|
||||||
|
*/
|
||||||
|
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||||
|
this.parentSelection = parentSelection;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public SinglePointCrossover build() {
|
||||||
|
if (rng == null) {
|
||||||
|
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parentSelection == null) {
|
||||||
|
parentSelection = new RandomTwoParentSelection();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new SinglePointCrossover(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private SinglePointCrossover(SinglePointCrossover.Builder builder) {
|
||||||
|
super(builder.parentSelection);
|
||||||
|
|
||||||
|
this.crossoverRate = builder.crossoverRate;
|
||||||
|
this.rng = builder.rng;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select a random crossover point.<br>
|
||||||
|
* Each gene before this point comes from one of the two parents and each gene at or after this point comes from the other parent.
|
||||||
|
* Otherwise, returns the genes of a random parent.
|
||||||
|
*
|
||||||
|
* @return The crossover result. See {@link CrossoverResult}.
|
||||||
|
*/
|
||||||
|
public CrossoverResult crossover() {
|
||||||
|
double[][] parents = parentSelection.selectParents();
|
||||||
|
|
||||||
|
boolean isModified = false;
|
||||||
|
double[] resultGenes = parents[0];
|
||||||
|
|
||||||
|
if (rng.nextDouble() < crossoverRate) {
|
||||||
|
int chromosomeLength = parents[0].length;
|
||||||
|
|
||||||
|
// Crossover
|
||||||
|
resultGenes = new double[chromosomeLength];
|
||||||
|
|
||||||
|
int crossoverPoint = rng.nextInt(chromosomeLength);
|
||||||
|
for (int i = 0; i < resultGenes.length; ++i) {
|
||||||
|
resultGenes[i] = ((i < crossoverPoint) ? parents[0] : parents[1])[i];
|
||||||
|
}
|
||||||
|
isModified = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CrossoverResult(isModified, resultGenes);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,46 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract class for all crossover operators that applies to two parents.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public abstract class TwoParentsCrossoverOperator extends CrossoverOperator {
|
||||||
|
|
||||||
|
protected final TwoParentSelection parentSelection;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param parentSelection A parent selection that selects two parents.
|
||||||
|
*/
|
||||||
|
protected TwoParentsCrossoverOperator(TwoParentSelection parentSelection) {
|
||||||
|
this.parentSelection = parentSelection;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Will be called by the selection operator once the population model is instantiated.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void initializeInstance(PopulationModel populationModel) {
|
||||||
|
super.initializeInstance(populationModel);
|
||||||
|
parentSelection.initializeInstance(populationModel.getPopulation());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,136 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The uniform crossover will, for each gene, randomly select the parent that donates the gene.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class UniformCrossover extends TwoParentsCrossoverOperator {
|
||||||
|
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||||
|
private static final double DEFAULT_PARENT_BIAS_FACTOR = 0.5;
|
||||||
|
|
||||||
|
private final double crossoverRate;
|
||||||
|
private final double parentBiasFactor;
|
||||||
|
private final RandomGenerator rng;
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||||
|
private double parentBiasFactor = DEFAULT_PARENT_BIAS_FACTOR;
|
||||||
|
private RandomGenerator rng;
|
||||||
|
private TwoParentSelection parentSelection;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The probability that the operator generates a crossover (default 0.85).
|
||||||
|
*
|
||||||
|
* @param rate A value between 0.0 and 1.0
|
||||||
|
*/
|
||||||
|
public Builder crossoverRate(double rate) {
|
||||||
|
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||||
|
|
||||||
|
this.crossoverRate = rate;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A factor that will introduce a bias in the parent selection.<br>
|
||||||
|
*
|
||||||
|
* @param factor In the range [0, 1]. 0 will only select the first parent while 1 only select the second one. The default is 0.5; no bias.
|
||||||
|
*/
|
||||||
|
public Builder parentBiasFactor(double factor) {
|
||||||
|
Preconditions.checkState(factor >= 0.0 && factor <= 1.0, "Factor must be between 0.0 and 1.0, got %s",
|
||||||
|
factor);
|
||||||
|
|
||||||
|
this.parentBiasFactor = factor;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use a supplied RandomGenerator
|
||||||
|
*
|
||||||
|
* @param rng An instance of RandomGenerator
|
||||||
|
*/
|
||||||
|
public Builder randomGenerator(RandomGenerator rng) {
|
||||||
|
this.rng = rng;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The parent selection behavior. Default is random parent selection.
|
||||||
|
*
|
||||||
|
* @param parentSelection An instance of TwoParentSelection
|
||||||
|
*/
|
||||||
|
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||||
|
this.parentSelection = parentSelection;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public UniformCrossover build() {
|
||||||
|
if (rng == null) {
|
||||||
|
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||||
|
}
|
||||||
|
if (parentSelection == null) {
|
||||||
|
parentSelection = new RandomTwoParentSelection();
|
||||||
|
}
|
||||||
|
return new UniformCrossover(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private UniformCrossover(UniformCrossover.Builder builder) {
|
||||||
|
super(builder.parentSelection);
|
||||||
|
|
||||||
|
this.crossoverRate = builder.crossoverRate;
|
||||||
|
this.parentBiasFactor = builder.parentBiasFactor;
|
||||||
|
this.rng = builder.rng;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select randomly which parent donates the gene.<br>
|
||||||
|
* One of the parent may be favored if the bias is different than 0.5
|
||||||
|
* Otherwise, returns the genes of a random parent.
|
||||||
|
*
|
||||||
|
* @return The crossover result. See {@link CrossoverResult}.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public CrossoverResult crossover() {
|
||||||
|
// select the parents
|
||||||
|
double[][] parents = parentSelection.selectParents();
|
||||||
|
|
||||||
|
double[] resultGenes = parents[0];
|
||||||
|
boolean isModified = false;
|
||||||
|
|
||||||
|
if (rng.nextDouble() < crossoverRate) {
|
||||||
|
// Crossover
|
||||||
|
resultGenes = new double[parents[0].length];
|
||||||
|
|
||||||
|
for (int i = 0; i < resultGenes.length; ++i) {
|
||||||
|
resultGenes[i] = ((rng.nextDouble() < parentBiasFactor) ? parents[0] : parents[1])[i];
|
||||||
|
}
|
||||||
|
isModified = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CrossoverResult(isModified, resultGenes);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract class for all parent selection behaviors
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public abstract class ParentSelection {
|
||||||
|
protected List<Chromosome> population;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Will be called by the crossover operator once the population model is instantiated.
|
||||||
|
*/
|
||||||
|
public void initializeInstance(List<Chromosome> population) {
|
||||||
|
this.population = population;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs the parent selection
|
||||||
|
*
|
||||||
|
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
|
||||||
|
*/
|
||||||
|
public abstract double[][] selectParents();
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A parent selection behavior that returns two random parents.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class RandomTwoParentSelection extends TwoParentSelection {
|
||||||
|
|
||||||
|
private final RandomGenerator rng;
|
||||||
|
|
||||||
|
public RandomTwoParentSelection() {
|
||||||
|
this(new SynchronizedRandomGenerator(new JDKRandomGenerator()));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use a supplied RandomGenerator
|
||||||
|
*
|
||||||
|
* @param rng An instance of RandomGenerator
|
||||||
|
*/
|
||||||
|
public RandomTwoParentSelection(RandomGenerator rng) {
|
||||||
|
this.rng = rng;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Selects two random parents
|
||||||
|
*
|
||||||
|
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public double[][] selectParents() {
|
||||||
|
double[][] parents = new double[2][];
|
||||||
|
|
||||||
|
int parent1Idx = rng.nextInt(population.size());
|
||||||
|
int parent2Idx;
|
||||||
|
do {
|
||||||
|
parent2Idx = rng.nextInt(population.size());
|
||||||
|
} while (parent1Idx == parent2Idx);
|
||||||
|
|
||||||
|
parents[0] = population.get(parent1Idx).getGenes();
|
||||||
|
parents[1] = population.get(parent2Idx).getGenes();
|
||||||
|
|
||||||
|
return parents;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract class for all parent selection behaviors that selects two parents.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public abstract class TwoParentSelection extends ParentSelection {
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A helper class used by {@link KPointCrossover} to generate the crossover points
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class CrossoverPointsGenerator {
|
||||||
|
private final int minCrossovers;
|
||||||
|
private final int maxCrossovers;
|
||||||
|
private final RandomGenerator rng;
|
||||||
|
private List<Integer> parameterIndexes;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor
|
||||||
|
*
|
||||||
|
* @param chromosomeLength The number of genes
|
||||||
|
* @param minCrossovers The minimum number of crossover points to generate
|
||||||
|
* @param maxCrossovers The maximum number of crossover points to generate
|
||||||
|
* @param rng A RandomGenerator instance
|
||||||
|
*/
|
||||||
|
public CrossoverPointsGenerator(int chromosomeLength, int minCrossovers, int maxCrossovers, RandomGenerator rng) {
|
||||||
|
this.minCrossovers = minCrossovers;
|
||||||
|
this.maxCrossovers = maxCrossovers;
|
||||||
|
this.rng = rng;
|
||||||
|
parameterIndexes = new ArrayList<Integer>();
|
||||||
|
for (int i = 0; i < chromosomeLength; ++i) {
|
||||||
|
parameterIndexes.add(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a list of crossover points.
|
||||||
|
*
|
||||||
|
* @return An ordered list of crossover point indexes and with Integer.MAX_VALUE as the last element
|
||||||
|
*/
|
||||||
|
public Deque<Integer> getCrossoverPoints() {
|
||||||
|
Collections.shuffle(parameterIndexes);
|
||||||
|
List<Integer> crossoverPointLists =
|
||||||
|
parameterIndexes.subList(0, rng.nextInt(maxCrossovers - minCrossovers) + minCrossovers);
|
||||||
|
Collections.sort(crossoverPointLists);
|
||||||
|
Deque<Integer> crossoverPoints = new ArrayDeque<Integer>(crossoverPointLists);
|
||||||
|
crossoverPoints.add(Integer.MAX_VALUE);
|
||||||
|
|
||||||
|
return crossoverPoints;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The cull operator will remove from the population the least desirables chromosomes.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface CullOperator {
|
||||||
|
/**
|
||||||
|
* Will be called by the population model once created.
|
||||||
|
*/
|
||||||
|
void initializeInstance(PopulationModel populationModel);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cull the population to the culled size.
|
||||||
|
*/
|
||||||
|
void cullPopulation();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The target population size after culling.
|
||||||
|
*/
|
||||||
|
int getCulledSize();
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An elitist cull operator that discards the chromosomes with the worst fitness while keeping the best ones.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class LeastFitCullOperator extends RatioCullOperator {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The default cull ratio is 1/3.
|
||||||
|
*/
|
||||||
|
public LeastFitCullOperator() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param cullRatio The ratio of the maximum population size to be culled.<br>
|
||||||
|
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
|
||||||
|
*/
|
||||||
|
public LeastFitCullOperator(double cullRatio) {
|
||||||
|
super(cullRatio);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Will discard the chromosomes with the worst fitness until the population size fall back at the culled size.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void cullPopulation() {
|
||||||
|
while (population.size() > culledSize) {
|
||||||
|
population.remove(population.size() - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An abstract base for cull operators that culls back the population to a ratio of its maximum size.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public abstract class RatioCullOperator implements CullOperator {
|
||||||
|
private static final double DEFAULT_CULL_RATIO = 1.0 / 3.0;
|
||||||
|
protected int culledSize;
|
||||||
|
protected List<Chromosome> population;
|
||||||
|
protected final double cullRatio;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param cullRatio The ratio of the maximum population size to be culled.<br>
|
||||||
|
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
|
||||||
|
*/
|
||||||
|
public RatioCullOperator(double cullRatio) {
|
||||||
|
Preconditions.checkState(cullRatio >= 0.0 && cullRatio <= 1.0, "Cull ratio must be between 0.0 and 1.0, got %s",
|
||||||
|
cullRatio);
|
||||||
|
|
||||||
|
this.cullRatio = cullRatio;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The default cull ratio is 1/3
|
||||||
|
*/
|
||||||
|
public RatioCullOperator() {
|
||||||
|
this(DEFAULT_CULL_RATIO);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Will be called by the population model once created.
|
||||||
|
*/
|
||||||
|
public void initializeInstance(PopulationModel populationModel) {
|
||||||
|
this.population = populationModel.getPopulation();
|
||||||
|
culledSize = (int) (populationModel.getPopulationSize() * (1.0 - cullRatio) + 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The target population size after culling.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public int getCulledSize() {
|
||||||
|
return culledSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions;
|
||||||
|
|
||||||
|
public class GeneticGenerationException extends RuntimeException {
|
||||||
|
public GeneticGenerationException(String message) {
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The mutation operator will apply a mutation to the given genes.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface MutationOperator {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs a mutation.
|
||||||
|
*
|
||||||
|
* @param genes The genes to be mutated
|
||||||
|
* @return True if the genes were mutated, otherwise false.
|
||||||
|
*/
|
||||||
|
boolean mutate(double[] genes);
|
||||||
|
}
|
|
@ -0,0 +1,93 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A mutation operator where each gene has a chance of being mutated with a <i>mutation rate</i> probability.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class RandomMutationOperator implements MutationOperator {
|
||||||
|
private static final double DEFAULT_MUTATION_RATE = 0.005;
|
||||||
|
|
||||||
|
private final double mutationRate;
|
||||||
|
private final RandomGenerator rng;
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private double mutationRate = DEFAULT_MUTATION_RATE;
|
||||||
|
private RandomGenerator rng;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each gene will have this probability of being mutated.
|
||||||
|
*
|
||||||
|
* @param rate The mutation rate. (default 0.005)
|
||||||
|
*/
|
||||||
|
public Builder mutationRate(double rate) {
|
||||||
|
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||||
|
|
||||||
|
this.mutationRate = rate;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use a supplied RandomGenerator
|
||||||
|
*
|
||||||
|
* @param rng An instance of RandomGenerator
|
||||||
|
*/
|
||||||
|
public Builder randomGenerator(RandomGenerator rng) {
|
||||||
|
this.rng = rng;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public RandomMutationOperator build() {
|
||||||
|
if (rng == null) {
|
||||||
|
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||||
|
}
|
||||||
|
return new RandomMutationOperator(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private RandomMutationOperator(RandomMutationOperator.Builder builder) {
|
||||||
|
this.mutationRate = builder.mutationRate;
|
||||||
|
this.rng = builder.rng;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs the mutation. Each gene has a <i>mutation rate</i> probability of being mutated.
|
||||||
|
*
|
||||||
|
* @param genes The genes to be mutated
|
||||||
|
* @return True if the genes were mutated, otherwise false.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public boolean mutate(double[] genes) {
|
||||||
|
boolean hasMutation = false;
|
||||||
|
|
||||||
|
for (int i = 0; i < genes.length; ++i) {
|
||||||
|
if (rng.nextDouble() < mutationRate) {
|
||||||
|
genes[i] = rng.nextDouble();
|
||||||
|
hasMutation = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return hasMutation;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A population initializer that build an empty population.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class EmptyPopulationInitializer implements PopulationInitializer {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize an empty population
|
||||||
|
*
|
||||||
|
* @param size The maximum size of the population.
|
||||||
|
* @return The initialized population.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public List<Chromosome> getInitializedPopulation(int size) {
|
||||||
|
return new ArrayList<>(size);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An initializer that construct the population used by the population model.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface PopulationInitializer {
|
||||||
|
/**
|
||||||
|
* Called by the population model to construct the population
|
||||||
|
*
|
||||||
|
* @param size The maximum size of the population
|
||||||
|
* @return An initialized population
|
||||||
|
*/
|
||||||
|
List<Chromosome> getInitializedPopulation(int size);
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A listener that is called when the population changes.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface PopulationListener {
|
||||||
|
/**
|
||||||
|
* Called after the population has changed.
|
||||||
|
*
|
||||||
|
* @param population The population after it has changed.
|
||||||
|
*/
|
||||||
|
void onChanged(List<Chromosome> population);
|
||||||
|
}
|
|
@ -0,0 +1,182 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The population model handles all aspects of the population (initialization, additions and culling)
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class PopulationModel {
|
||||||
|
private static final int DEFAULT_POPULATION_SIZE = 30;
|
||||||
|
|
||||||
|
private final CullOperator cullOperator;
|
||||||
|
private final List<PopulationListener> populationListeners = new ArrayList<>();
|
||||||
|
private Comparator<Chromosome> chromosomeComparator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The maximum population size
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
private final int populationSize;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The population
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
public final List<Chromosome> population;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A comparator used when higher fitness value is better
|
||||||
|
*/
|
||||||
|
public static class MaximizeScoreComparator implements Comparator<Chromosome> {
|
||||||
|
@Override
|
||||||
|
public int compare(Chromosome lhs, Chromosome rhs) {
|
||||||
|
return -Double.compare(lhs.getFitness(), rhs.getFitness());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A comparator used when lower fitness value is better
|
||||||
|
*/
|
||||||
|
public static class MinimizeScoreComparator implements Comparator<Chromosome> {
|
||||||
|
@Override
|
||||||
|
public int compare(Chromosome lhs, Chromosome rhs) {
|
||||||
|
return Double.compare(lhs.getFitness(), rhs.getFitness());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private int populationSize = DEFAULT_POPULATION_SIZE;
|
||||||
|
private PopulationInitializer populationInitializer;
|
||||||
|
private CullOperator cullOperator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use an alternate population initialization behavior. Default is empty population.
|
||||||
|
*
|
||||||
|
* @param populationInitializer An instance of PopulationInitializer
|
||||||
|
*/
|
||||||
|
public Builder populationInitializer(PopulationInitializer populationInitializer) {
|
||||||
|
this.populationInitializer = populationInitializer;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The maximum population size. <br>
|
||||||
|
* If using a ratio based culling, using a population with culled size of around 1.5 to 2 times the number of genes generally gives good results.
|
||||||
|
* (e.g. For a chromosome having 10 genes, the culled size should be between 15 and 20. And with a cull ratio of 1/3 we should set the population size to 23 to 30. (15 / (1 - 1/3)), rounded up)
|
||||||
|
*
|
||||||
|
* @param size The maximum size of the population
|
||||||
|
*/
|
||||||
|
public Builder populationSize(int size) {
|
||||||
|
populationSize = size;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use an alternate cull operator behavior. Default is least fit culling.
|
||||||
|
*
|
||||||
|
* @param cullOperator An instance of a CullOperator
|
||||||
|
*/
|
||||||
|
public Builder cullOperator(CullOperator cullOperator) {
|
||||||
|
this.cullOperator = cullOperator;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PopulationModel build() {
|
||||||
|
if (cullOperator == null) {
|
||||||
|
cullOperator = new LeastFitCullOperator();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (populationInitializer == null) {
|
||||||
|
populationInitializer = new EmptyPopulationInitializer();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new PopulationModel(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public PopulationModel(PopulationModel.Builder builder) {
|
||||||
|
populationSize = builder.populationSize;
|
||||||
|
population = new ArrayList<>(builder.populationSize);
|
||||||
|
PopulationInitializer populationInitializer = builder.populationInitializer;
|
||||||
|
|
||||||
|
List<Chromosome> initializedPopulation = populationInitializer.getInitializedPopulation(populationSize);
|
||||||
|
population.clear();
|
||||||
|
population.addAll(initializedPopulation);
|
||||||
|
|
||||||
|
cullOperator = builder.cullOperator;
|
||||||
|
cullOperator.initializeInstance(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called by the GeneticSearchCandidateGenerator
|
||||||
|
*/
|
||||||
|
public void initializeInstance(boolean minimizeScore) {
|
||||||
|
chromosomeComparator = minimizeScore ? new MinimizeScoreComparator() : new MaximizeScoreComparator();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a PopulationListener to the list of change listeners
|
||||||
|
* @param listener A PopulationListener instance
|
||||||
|
*/
|
||||||
|
public void addListener(PopulationListener listener) {
|
||||||
|
populationListeners.add(listener);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a Chromosome to the population and call the PopulationListeners. Culling may be triggered.
|
||||||
|
*
|
||||||
|
* @param element The chromosome to be added
|
||||||
|
*/
|
||||||
|
public void add(Chromosome element) {
|
||||||
|
if (population.size() == populationSize) {
|
||||||
|
cullOperator.cullPopulation();
|
||||||
|
}
|
||||||
|
|
||||||
|
population.add(element);
|
||||||
|
|
||||||
|
Collections.sort(population, chromosomeComparator);
|
||||||
|
|
||||||
|
triggerPopulationChangedListeners(population);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Return false when the population is below the culled size, otherwise true. <br>
|
||||||
|
* Used by the selection operator to know if the population is still too small and should generate random genes.
|
||||||
|
*/
|
||||||
|
public boolean isReadyToBreed() {
|
||||||
|
return population.size() >= cullOperator.getCulledSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void triggerPopulationChangedListeners(List<Chromosome> population) {
|
||||||
|
for (PopulationListener listener : populationListeners) {
|
||||||
|
listener.onChanged(population);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,197 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.MutationOperator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A selection operator that will generate random genes initially. Once the population has reached the culled size,
|
||||||
|
* will start to generate offsprings of parents selected in the population.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class GeneticSelectionOperator extends SelectionOperator {
|
||||||
|
|
||||||
|
private final static int PREVIOUS_GENES_TO_KEEP = 100;
|
||||||
|
private final static int MAX_NUM_GENERATION_ATTEMPTS = 1024;
|
||||||
|
|
||||||
|
private final CrossoverOperator crossoverOperator;
|
||||||
|
private final MutationOperator mutationOperator;
|
||||||
|
private final RandomGenerator rng;
|
||||||
|
private double[][] previousGenes = new double[PREVIOUS_GENES_TO_KEEP][];
|
||||||
|
private int previousGenesIdx = 0;
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private ChromosomeFactory chromosomeFactory;
|
||||||
|
private PopulationModel populationModel;
|
||||||
|
private CrossoverOperator crossoverOperator;
|
||||||
|
private MutationOperator mutationOperator;
|
||||||
|
private RandomGenerator rng;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use an alternate crossover behavior. Default is SinglePointCrossover.
|
||||||
|
*
|
||||||
|
* @param crossoverOperator An instance of CrossoverOperator
|
||||||
|
*/
|
||||||
|
public Builder crossoverOperator(CrossoverOperator crossoverOperator) {
|
||||||
|
this.crossoverOperator = crossoverOperator;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use an alternate mutation behavior. Default is RandomMutationOperator.
|
||||||
|
*
|
||||||
|
* @param mutationOperator An instance of MutationOperator
|
||||||
|
*/
|
||||||
|
public Builder mutationOperator(MutationOperator mutationOperator) {
|
||||||
|
this.mutationOperator = mutationOperator;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use a supplied RandomGenerator
|
||||||
|
*
|
||||||
|
* @param rng An instance of RandomGenerator
|
||||||
|
*/
|
||||||
|
public Builder randomGenerator(RandomGenerator rng) {
|
||||||
|
this.rng = rng;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public GeneticSelectionOperator build() {
|
||||||
|
if (crossoverOperator == null) {
|
||||||
|
crossoverOperator = new SinglePointCrossover.Builder().build();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mutationOperator == null) {
|
||||||
|
mutationOperator = new RandomMutationOperator.Builder().build();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rng == null) {
|
||||||
|
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||||
|
}
|
||||||
|
|
||||||
|
return new GeneticSelectionOperator(crossoverOperator, mutationOperator, rng);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private GeneticSelectionOperator(CrossoverOperator crossoverOperator, MutationOperator mutationOperator,
|
||||||
|
RandomGenerator rng) {
|
||||||
|
this.crossoverOperator = crossoverOperator;
|
||||||
|
this.mutationOperator = mutationOperator;
|
||||||
|
this.rng = rng;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called by GeneticSearchCandidateGenerator
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
|
||||||
|
super.initializeInstance(populationModel, chromosomeFactory);
|
||||||
|
crossoverOperator.initializeInstance(populationModel);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a new set of genes. Has two distinct modes of operation
|
||||||
|
* <ul>
|
||||||
|
* <li>Before the population has reached the culled size: will return a random set of genes.</li>
|
||||||
|
* <li>After: Parents will be selected among the population, a crossover will be applied followed by a mutation.</li>
|
||||||
|
* </ul>
|
||||||
|
* @return Returns the generated set of genes
|
||||||
|
* @throws GeneticGenerationException If buildNextGenes() can't generate a set that has not already been tried,
|
||||||
|
* or if the crossover and the mutation operators can't generate a set,
|
||||||
|
* this exception is thrown.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public double[] buildNextGenes() {
|
||||||
|
double[] result;
|
||||||
|
|
||||||
|
boolean hasAlreadyBeenTried;
|
||||||
|
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
|
||||||
|
do {
|
||||||
|
if (populationModel.isReadyToBreed()) {
|
||||||
|
result = buildOffspring();
|
||||||
|
} else {
|
||||||
|
result = buildRandomGenes();
|
||||||
|
}
|
||||||
|
|
||||||
|
hasAlreadyBeenTried = hasAlreadyBeenTried(result);
|
||||||
|
if (hasAlreadyBeenTried && --attemptsRemaining == 0) {
|
||||||
|
throw new GeneticGenerationException("Failed to generate a set of genes not already tried.");
|
||||||
|
}
|
||||||
|
} while (hasAlreadyBeenTried);
|
||||||
|
|
||||||
|
previousGenes[previousGenesIdx] = result;
|
||||||
|
previousGenesIdx = ++previousGenesIdx % previousGenes.length;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean hasAlreadyBeenTried(double[] genes) {
|
||||||
|
for (int i = 0; i < previousGenes.length; ++i) {
|
||||||
|
double[] current = previousGenes[i];
|
||||||
|
if (current != null && Arrays.equals(current, genes)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private double[] buildOffspring() {
|
||||||
|
double[] offspringValues;
|
||||||
|
|
||||||
|
boolean isModified;
|
||||||
|
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
|
||||||
|
do {
|
||||||
|
CrossoverResult crossoverResult = crossoverOperator.crossover();
|
||||||
|
offspringValues = crossoverResult.getGenes();
|
||||||
|
isModified = crossoverResult.isModified();
|
||||||
|
isModified |= mutationOperator.mutate(offspringValues);
|
||||||
|
|
||||||
|
if (!isModified && --attemptsRemaining == 0) {
|
||||||
|
throw new GeneticGenerationException(
|
||||||
|
String.format("Crossover and mutation operators failed to generate a new set of genes after %s attempts.",
|
||||||
|
MAX_NUM_GENERATION_ATTEMPTS));
|
||||||
|
}
|
||||||
|
} while (!isModified);
|
||||||
|
|
||||||
|
return offspringValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
private double[] buildRandomGenes() {
|
||||||
|
double[] randomValues = new double[chromosomeFactory.getChromosomeLength()];
|
||||||
|
for (int i = 0; i < randomValues.length; ++i) {
|
||||||
|
randomValues[i] = rng.nextDouble();
|
||||||
|
}
|
||||||
|
|
||||||
|
return randomValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
|
||||||
|
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An abstract class for all selection operators. Used by the GeneticSearchCandidateGenerator to generate new candidates.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public abstract class SelectionOperator {
|
||||||
|
protected PopulationModel populationModel;
|
||||||
|
protected ChromosomeFactory chromosomeFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called by GeneticSearchCandidateGenerator
|
||||||
|
*/
|
||||||
|
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
|
||||||
|
|
||||||
|
this.populationModel = populationModel;
|
||||||
|
this.chromosomeFactory = chromosomeFactory;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a new set of genes.
|
||||||
|
*/
|
||||||
|
public abstract double[] buildNextGenes();
|
||||||
|
}
|
|
@ -0,0 +1,46 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.arbiter.optimize.generator.util;
|
||||||
|
|
||||||
|
import org.nd4j.common.function.Supplier;
|
||||||
|
|
||||||
|
import java.io.*;
|
||||||
|
|
||||||
|
public class SerializedSupplier<T> implements Serializable, Supplier<T> {
|
||||||
|
|
||||||
|
private byte[] asBytes;
|
||||||
|
|
||||||
|
public SerializedSupplier(T obj){
|
||||||
|
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
|
||||||
|
oos.writeObject(obj);
|
||||||
|
oos.flush();
|
||||||
|
oos.close();
|
||||||
|
asBytes = baos.toByteArray();
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new RuntimeException("Error serializing object - must be serializable",e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public T get() {
|
||||||
|
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(asBytes))){
|
||||||
|
return (T)ois.readObject();
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new RuntimeException("Error deserializing object",e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue