Compare commits
No commits in common. "1c3496ad844392afb1411b78234e9720a023d972" and "1749b9e4afa92f11f03889e73f671237e4dc771b" have entirely different histories.
1c3496ad84
...
1749b9e4af
|
@ -1,15 +1,8 @@
|
|||
FROM nvidia/cuda:11.4.3-cudnn8-devel-ubuntu20.04
|
||||
FROM nvidia/cuda:11.4.0-cudnn8-devel-ubuntu20.04
|
||||
|
||||
RUN apt-get update && \
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git
|
||||
#Build cmake version from source \
|
||||
#RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \
|
||||
# tar -xvf cmake-3.24.2.tar.gz && cd cmake-3.24.2 && \
|
||||
# ./bootstrap && make && make install
|
||||
RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2-linux-x86_64.sh && \
|
||||
mkdir /opt/cmake && sh ./cmake-3.24.2-linux-x86_64.sh --skip-license --prefix=/opt/cmake && ln -s /opt/cmake/bin/cmake /usr/bin/cmake && \
|
||||
rm cmake-3.24.2-linux-x86_64.sh
|
||||
|
||||
|
||||
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
|
||||
RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \
|
||||
tar -xvf cmake-3.24.2.tar.gz && cd cmake-3.24.2 && \
|
||||
./bootstrap && make && make install
|
||||
|
||||
|
|
|
@ -36,8 +36,6 @@ pom.xml.versionsBackup
|
|||
pom.xml.next
|
||||
release.properties
|
||||
*dependency-reduced-pom.xml
|
||||
**/build/*
|
||||
.gradle/*
|
||||
|
||||
# Specific for Nd4j
|
||||
*.md5
|
||||
|
@ -52,12 +50,12 @@ release.properties
|
|||
*.dylib
|
||||
.vs/
|
||||
.vscode/
|
||||
.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/resources/bin
|
||||
.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/writeNumpy.csv
|
||||
.old/nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/data-all*
|
||||
.old/nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/checkpoint
|
||||
.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/onnx/
|
||||
.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/tensorflow/
|
||||
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/resources/bin
|
||||
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/writeNumpy.csv
|
||||
nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/data-all*
|
||||
nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/checkpoint
|
||||
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/onnx/
|
||||
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/tensorflow/
|
||||
|
||||
doc_sources/
|
||||
doc_sources_*
|
||||
|
@ -69,8 +67,8 @@ venv/
|
|||
venv2/
|
||||
|
||||
# Ignore the nd4j files that are created by javacpp at build to stop merge conflicts
|
||||
.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
|
||||
.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
|
||||
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
|
||||
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
|
||||
|
||||
# Ignore meld temp files
|
||||
*.orig
|
||||
|
@ -84,15 +82,3 @@ bruai4j-native-common/cmake*
|
|||
*.dll
|
||||
/bruai4j-native/bruai4j-native-common/blasbuild/
|
||||
/bruai4j-native/bruai4j-native-common/build/
|
||||
/cavis-native/cavis-native-lib/blasbuild/
|
||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/classes/org.deeplearning4j.gradientcheck.AttentionLayerTest.html
|
||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/base-style.css
|
||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/css/style.css
|
||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/js/report.js
|
||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/packages/org.deeplearning4j.gradientcheck.html
|
||||
/cavis-dnn/cavis-dnn-core/build/reports/tests/cudaTest/index.html
|
||||
/cavis-dnn/cavis-dnn-core/build/resources/main/iris.dat
|
||||
/cavis-dnn/cavis-dnn-core/build/resources/test/junit-platform.properties
|
||||
/cavis-dnn/cavis-dnn-core/build/resources/test/logback-test.xml
|
||||
/cavis-dnn/cavis-dnn-core/build/test-results/cudaTest/TEST-org.deeplearning4j.gradientcheck.AttentionLayerTest.xml
|
||||
/cavis-dnn/cavis-dnn-core/build/tmp/jar/MANIFEST.MF
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
pipeline {
|
||||
agent {
|
||||
label 'linux'
|
||||
}
|
||||
|
||||
stages {
|
||||
stage('prep-build-environment-linux-cpu') {
|
||||
steps {
|
||||
checkout scm
|
||||
sh 'gcc --version'
|
||||
sh 'cmake --version'
|
||||
sh 'sh ./gradlew --version'
|
||||
}
|
||||
}
|
||||
stage('build-linux-cpu') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal_Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cpu \
|
||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||
}
|
||||
}
|
||||
/*stage('test-linux-cpu') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -PCAVIS_CHIP=cpu \
|
||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||
}
|
||||
}*/
|
||||
stage('publish-linux-cpu') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
sh 'sh ./gradlew publish --stacktrace -PCAVIS_CHIP=cpu \
|
||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,62 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
pipeline {
|
||||
agent {
|
||||
/* dockerfile {
|
||||
filename 'Dockerfile'
|
||||
dir '.docker'
|
||||
label 'linux && cuda'
|
||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||
//args '--gpus all' --needed for test only, you can build without GPU
|
||||
}
|
||||
*/
|
||||
label 'linux && cuda'
|
||||
}
|
||||
|
||||
stages {
|
||||
stage('prep-build-environment-linux-cuda') {
|
||||
steps {
|
||||
checkout scm
|
||||
//sh 'nvidia-smi'
|
||||
sh 'nvcc --version'
|
||||
sh 'gcc --version'
|
||||
sh 'cmake --version'
|
||||
sh 'sh ./gradlew --version'
|
||||
}
|
||||
}
|
||||
stage('build-linux-cuda') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal_Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
sh 'sh ./gradlew build --stacktrace -PCAVIS_CHIP=cuda \
|
||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
pipeline {
|
||||
agent {
|
||||
dockerfile {
|
||||
filename 'Dockerfile'
|
||||
dir '.docker'
|
||||
label 'linux && docker && cuda'
|
||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||
//args '--gpus all' --needed for test only, you can build without GPU
|
||||
}
|
||||
}
|
||||
|
||||
stages {
|
||||
stage("Build all chip") {
|
||||
parallel {
|
||||
|
||||
stage('prep-build-environment-linux-cuda') {
|
||||
|
||||
steps {
|
||||
checkout scm
|
||||
//sh 'nvidia-smi'
|
||||
sh 'nvcc --version'
|
||||
sh 'gcc --version'
|
||||
sh 'cmake --version'
|
||||
sh 'sh ./gradlew --version'
|
||||
}
|
||||
}
|
||||
stage('build-linux-cuda') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal_Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \
|
||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,88 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
pipeline {
|
||||
agent {
|
||||
dockerfile {
|
||||
filename 'Dockerfile'
|
||||
dir '.docker'
|
||||
label 'linux && docker'
|
||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||
//args '--gpus all'
|
||||
}
|
||||
}
|
||||
|
||||
stages {
|
||||
stage('prep-build-environment-linux-cpu') {
|
||||
steps {
|
||||
checkout scm
|
||||
sh 'gcc --version'
|
||||
sh 'cmake --version'
|
||||
sh 'sh ./gradlew --version'
|
||||
}
|
||||
}
|
||||
stage('build-linux-cpu') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal_Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cpu \
|
||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||
}
|
||||
}
|
||||
stage('test-linux-cpu') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
//sh 'sh ./gradlew test --stacktrace -PCAVIS_CHIP=cpu \
|
||||
// -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
// -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||
}
|
||||
}
|
||||
stage('publish-linux-cpu') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
sh 'sh ./gradlew publish --stacktrace -PCAVIS_CHIP=cpu \
|
||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
pipeline {
|
||||
agent {
|
||||
dockerfile {
|
||||
filename 'Dockerfile'
|
||||
dir '.docker'
|
||||
label 'linux && docker'
|
||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||
//args '--gpus all'
|
||||
}
|
||||
}
|
||||
|
||||
stages {
|
||||
stage('publish-linux-cpu') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal_Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
sh 'sh ./gradlew publish -x test -PCAVIS_CHIP=cpu \
|
||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,13 +20,14 @@
|
|||
*/
|
||||
|
||||
pipeline {
|
||||
|
||||
agent {
|
||||
dockerfile {
|
||||
filename 'Dockerfile'
|
||||
dir '.docker'
|
||||
label 'linux && docker && cuda'
|
||||
label 'linuxdocker'
|
||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||
//args '--gpus all' --needed for test only, you can build without GPU
|
||||
args '--gpus all'
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,7 +35,7 @@ pipeline {
|
|||
stage('prep-build-environment-linux-cuda') {
|
||||
steps {
|
||||
checkout scm
|
||||
//sh 'nvidia-smi'
|
||||
sh 'nvidia-smi'
|
||||
sh 'nvcc --version'
|
||||
sh 'gcc --version'
|
||||
sh 'cmake --version'
|
||||
|
@ -43,33 +44,19 @@ pipeline {
|
|||
}
|
||||
stage('build-linux-cuda') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal_Archiva')
|
||||
MAVEN = credentials('Internal Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \
|
||||
sh 'sh ./gradlew publish --stacktrace -x test -PCAVIS_CHIP=cuda \
|
||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||
}
|
||||
}
|
||||
stage('test-linux-cuda') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal_Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -Pskip-native=true -PCAVIS_CHIP=cuda \
|
||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
pipeline {
|
||||
agent {
|
||||
dockerfile {
|
||||
filename 'Dockerfile'
|
||||
dir '.docker'
|
||||
label 'WSL-docker'
|
||||
//additionalBuildArgs '--build-arg version=1.0.2'
|
||||
//args '--gpus all'
|
||||
}
|
||||
}
|
||||
|
||||
stages {
|
||||
stage('prep-build-environment-linux-cpu') {
|
||||
steps {
|
||||
checkout scm
|
||||
sh 'gcc --version'
|
||||
sh 'cmake --version'
|
||||
sh 'sh ./gradlew --version'
|
||||
}
|
||||
}
|
||||
stage('build-linux-cpu') {
|
||||
environment {
|
||||
MAVEN = credentials('Internal_Archiva')
|
||||
OSSRH = credentials('OSSRH')
|
||||
}
|
||||
|
||||
steps {
|
||||
withGradle {
|
||||
sh 'sh ./gradlew publish --stacktrace -x test -PCAVIS_CHIP=cpu \
|
||||
-Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \
|
||||
-PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW'
|
||||
}
|
||||
//stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -48,12 +48,12 @@ Deeplearning4J offers a very high level API for defining even complex neural net
|
|||
you how LeNet, a convolutional neural network, is defined in DL4J.
|
||||
|
||||
```java
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.seed(seed)
|
||||
.l2(0.0005)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.updater(new Adam(1e-3))
|
||||
|
||||
.list()
|
||||
.layer(new ConvolutionLayer.Builder(5, 5)
|
||||
.stride(1,1)
|
||||
.nOut(20)
|
||||
|
@ -78,7 +78,7 @@ NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
|||
.nOut(outputNum)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.build())
|
||||
.inputType(InputType.convolutionalFlat(28,28,1))
|
||||
.setInputType(InputType.convolutionalFlat(28,28,1))
|
||||
.build();
|
||||
|
||||
```
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
branches:
|
||||
only:
|
||||
- master
|
||||
notifications:
|
||||
email: false
|
||||
dist: trusty
|
||||
sudo: false
|
||||
cache:
|
||||
directories:
|
||||
- $HOME/.m2
|
||||
language: java
|
||||
jdk:
|
||||
- openjdk8
|
||||
matrix:
|
||||
include:
|
||||
- os: linux
|
||||
env: OS=linux-x86_64 SCALA=2.10
|
||||
install: true
|
||||
script: bash ./ci/build-linux-x86_64.sh
|
||||
- os: linux
|
||||
env: OS=linux-x86_64 SCALA=2.11
|
||||
install: true
|
||||
script: bash ./ci/build-linux-x86_64.sh
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# Arbiter
|
||||
|
||||
A tool dedicated to tuning (hyperparameter optimization) of machine learning models. Part of the DL4J Suite of Machine Learning / Deep Learning tools for the enterprise.
|
||||
|
||||
|
||||
## Modules
|
||||
Arbiter contains the following modules:
|
||||
|
||||
- arbiter-core: Defines the API and core functionality, and also contains functionality for the Arbiter UI
|
||||
- arbiter-deeplearning4j: For hyperparameter optimization of DL4J models (MultiLayerNetwork and ComputationGraph networks)
|
||||
|
||||
|
||||
## Hyperparameter Optimization Functionality
|
||||
|
||||
The open-source version of Arbiter currently defines two methods of hyperparameter optimization:
|
||||
|
||||
- Grid search
|
||||
- Random search
|
||||
|
||||
For optimization of complex models such as neural networks (those with more than a few hyperparameters), random search is superior to grid search, though Bayesian hyperparameter optimization schemes
|
||||
For a comparison of random and grid search methods, see [Random Search for Hyper-parameter Optimization (Bergstra and Bengio, 2012)](http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf).
|
||||
|
||||
### Core Concepts and Classes in Arbiter for Hyperparameter Optimization
|
||||
|
||||
In order to conduct hyperparameter optimization in Arbiter, it is necessary for the user to understand and define the following:
|
||||
|
||||
- **Parameter Space**: A ```ParameterSpace<P>``` specifies the type and allowable values of hyperparameters for a model configuration of type ```P```. For example, ```P``` could be a MultiLayerConfiguration for DL4J
|
||||
- **Candidate Generator**: A ```CandidateGenerator<C>``` is used to generate candidate models configurations of some type ```C```. The following implementations are defined in arbiter-core:
|
||||
- ```RandomSearchCandidateGenerator```
|
||||
- ```GridSearchCandidateGenerator```
|
||||
- **Score Function**: A ```ScoreFunction<M,D>``` is used to score a model of type ```M``` given data of type ```D```. For example, in DL4J a score function might be used to calculate the classification accuracy from a DataSetIterator
|
||||
- A key concept here is that they score is a single numerical (double precision) value that we either want to minimize or maximize - this is the goal of hyperparameter optimization
|
||||
- **Termination Conditions**: One or more ```TerminationCondition``` instances must be provided to the ```OptimizationConfiguration```. ```TerminationCondition``` instances are used to control when hyperparameter optimization should be stopped. Some built-in termination conditions:
|
||||
- ```MaxCandidatesCondition```: Terminate if more than the specified number of candidate hyperparameter configurations have been executed
|
||||
- ```MaxTimeCondition```: Terminate after a specified amount of time has elapsed since starting the optimization
|
||||
- **Result Saver**: The ```ResultSaver<C,M,A>``` interface is used to specify how the results of each hyperparameter optimization run should be saved. For example, whether saving should be done to local disk, to a database, to HDFS, or simply stored in memory.
|
||||
- Note that ```ResultSaver.saveModel``` method returns a ```ResultReference``` object, which provides a mechanism for re-loading both the model and score from wherever it may be saved.
|
||||
- **Optimization Configuration**: An ```OptimizationConfiguration<C,M,D,A>``` ties together the above configuration options in a fluent (builder) pattern.
|
||||
- **Candidate Executor**: The ```CandidateExecutor<C,M,D,A>``` interface provides a layer of abstraction between the configuration and execution of each instance of learning. Currently, the only option is the ```LocalCandidateExecutor```, which is used to execute learning on a single machine (in the current JVM). In principle, other execution methods (for example, on Spark or cloud computing machines) could be implemented.
|
||||
- **Optimization Runner**: The ```OptimizationRunner``` uses an ```OptimizationConfiguration``` and a ```CandidateExecutor``` to actually run the optimization, and save the results.
|
||||
|
||||
|
||||
### Optimization of DeepLearning4J Models
|
||||
|
||||
(This section: forthcoming)
|
|
@ -0,0 +1,97 @@
|
|||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ License for the specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
~
|
||||
~ SPDX-License-Identifier: Apache-2.0
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<parent>
|
||||
<artifactId>arbiter</artifactId>
|
||||
<groupId>net.brutex.ai</groupId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>arbiter-core</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>arbiter-core</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>net.brutex.ai</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>com.google.code.findbugs</groupId>
|
||||
<artifactId>*</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>${guava.jre.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>${commons.lang.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-math3</artifactId>
|
||||
<version>${commons.math.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>${slf4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>joda-time</groupId>
|
||||
<artifactId>joda-time</artifactId>
|
||||
<version>${jodatime.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-annotations</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>net.brutex.ai</groupId>
|
||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-joda</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>net.brutex.ai</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
<classifier>windows-x86_64</classifier>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
|
@ -0,0 +1,91 @@
|
|||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ License for the specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
~
|
||||
~ SPDX-License-Identifier: Apache-2.0
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||
|
||||
<assembly>
|
||||
<id>bin</id>
|
||||
<!-- START SNIPPET: formats -->
|
||||
<formats>
|
||||
<format>tar.gz</format>
|
||||
<!--
|
||||
<format>tar.bz2</format>
|
||||
<format>zip</format>
|
||||
-->
|
||||
</formats>
|
||||
<!-- END SNIPPET: formats -->
|
||||
|
||||
<dependencySets>
|
||||
<dependencySet>
|
||||
<outputDirectory>lib</outputDirectory>
|
||||
<includes>
|
||||
<include>*:jar:*</include>
|
||||
</includes>
|
||||
<excludes>
|
||||
<exclude>*:sources</exclude>
|
||||
</excludes>
|
||||
</dependencySet>
|
||||
</dependencySets>
|
||||
|
||||
<!-- START SNIPPET: fileSets -->
|
||||
<fileSets>
|
||||
<fileSet>
|
||||
<includes>
|
||||
<include>readme.txt</include>
|
||||
</includes>
|
||||
</fileSet>
|
||||
|
||||
<fileSet>
|
||||
<directory>src/main/resources/bin/</directory>
|
||||
<outputDirectory>bin</outputDirectory>
|
||||
<includes>
|
||||
<include>arbiter</include>
|
||||
</includes>
|
||||
<lineEnding>unix</lineEnding>
|
||||
<fileMode>0755</fileMode>
|
||||
</fileSet>
|
||||
|
||||
<fileSet>
|
||||
<directory>examples</directory>
|
||||
<outputDirectory>examples</outputDirectory>
|
||||
<!--
|
||||
<lineEnding>unix</lineEnding>
|
||||
https://stackoverflow.com/questions/2958282/stranges-files-in-my-assembly-since-switching-to-lineendingunix-lineending
|
||||
-->
|
||||
</fileSet>
|
||||
|
||||
|
||||
<!--
|
||||
<fileSet>
|
||||
<directory>src/bin</directory>
|
||||
<outputDirectory>bin</outputDirectory>
|
||||
<includes>
|
||||
<include>hello</include>
|
||||
</includes>
|
||||
<lineEnding>unix</lineEnding>
|
||||
<fileMode>0755</fileMode>
|
||||
</fileSet>
|
||||
-->
|
||||
|
||||
<fileSet>
|
||||
<directory>target</directory>
|
||||
<outputDirectory>./</outputDirectory>
|
||||
<includes>
|
||||
<include>*.jar</include>
|
||||
</includes>
|
||||
</fileSet>
|
||||
|
||||
</fileSets>
|
||||
<!-- END SNIPPET: fileSets -->
|
||||
</assembly>
|
|
@ -0,0 +1,74 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Created by Alex on 23/07/2017.
|
||||
*/
|
||||
public abstract class AbstractParameterSpace<T> implements ParameterSpace<T> {
|
||||
|
||||
@Override
|
||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||
Map<String, ParameterSpace> m = new LinkedHashMap<>();
|
||||
|
||||
//Need to manually build and walk the class heirarchy...
|
||||
Class<?> currClass = this.getClass();
|
||||
List<Class<?>> classHeirarchy = new ArrayList<>();
|
||||
while (currClass != Object.class) {
|
||||
classHeirarchy.add(currClass);
|
||||
currClass = currClass.getSuperclass();
|
||||
}
|
||||
|
||||
for (int i = classHeirarchy.size() - 1; i >= 0; i--) {
|
||||
//Use reflection here to avoid a mass of boilerplate code...
|
||||
Field[] allFields = classHeirarchy.get(i).getDeclaredFields();
|
||||
|
||||
for (Field f : allFields) {
|
||||
|
||||
String name = f.getName();
|
||||
Class<?> fieldClass = f.getType();
|
||||
boolean isParamSpacefield = ParameterSpace.class.isAssignableFrom(fieldClass);
|
||||
|
||||
if (!isParamSpacefield) {
|
||||
continue;
|
||||
}
|
||||
|
||||
f.setAccessible(true);
|
||||
|
||||
ParameterSpace<?> p;
|
||||
try {
|
||||
p = (ParameterSpace<?>) f.get(this);
|
||||
} catch (IllegalAccessException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
if (p != null) {
|
||||
m.put(name, p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.util.SerializedSupplier;
|
||||
import org.nd4j.common.function.Supplier;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Candidate: a proposed hyperparameter configuration.
|
||||
* Also includes a map for data parameters, to configure things like data preprocessing, etc.
|
||||
*/
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class Candidate<C> implements Serializable {
|
||||
|
||||
private Supplier<C> supplier;
|
||||
private int index;
|
||||
private double[] flatParameters;
|
||||
private Map<String, Object> dataParameters;
|
||||
private Exception exception;
|
||||
|
||||
public Candidate(C value, int index, double[] flatParameters, Map<String,Object> dataParameters, Exception e) {
|
||||
this(new SerializedSupplier<C>(value), index, flatParameters, dataParameters, e);
|
||||
}
|
||||
|
||||
public Candidate(C value, int index, double[] flatParameters) {
|
||||
this(new SerializedSupplier<C>(value), index, flatParameters);
|
||||
}
|
||||
|
||||
public Candidate(Supplier<C> value, int index, double[] flatParameters) {
|
||||
this(value, index, flatParameters, null, null);
|
||||
}
|
||||
|
||||
public C getValue(){
|
||||
return supplier.get();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
/**
|
||||
* A CandidateGenerator proposes candidates (i.e., hyperparameter configurations) for evaluation.
|
||||
* This abstraction allows for different ways of generating the next configuration to test; for example,
|
||||
* random search, grid search, Bayesian optimization methods, etc.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface CandidateGenerator {
|
||||
|
||||
/**
|
||||
* Is this candidate generator able to generate more candidates? This will always return true in some
|
||||
* cases, but some search strategies have a limit (grid search, for example)
|
||||
*/
|
||||
boolean hasMoreCandidates();
|
||||
|
||||
/**
|
||||
* Generate a candidate hyperparameter configuration
|
||||
*/
|
||||
Candidate getCandidate();
|
||||
|
||||
/**
|
||||
* Report results for the candidate generator.
|
||||
*
|
||||
* @param result The results to report
|
||||
*/
|
||||
void reportResults(OptimizationResult result);
|
||||
|
||||
/**
|
||||
* @return Get the parameter space for this candidate generator
|
||||
*/
|
||||
ParameterSpace<?> getParameterSpace();
|
||||
|
||||
/**
|
||||
* @param rngSeed Set the random number generator seed for the candidate generator
|
||||
*/
|
||||
void setRngSeed(long rngSeed);
|
||||
|
||||
/**
|
||||
* @return The type (class) of the generated candidates
|
||||
*/
|
||||
Class<?> getCandidateType();
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import lombok.Data;
|
||||
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* An optimization result represents the results of an optimization run, including the canditate configuration, the
|
||||
* trained model, the score for that model, and index of the model
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Data
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
@JsonIgnoreProperties({"resultReference"})
|
||||
public class OptimizationResult implements Serializable {
|
||||
@JsonProperty
|
||||
private Candidate candidate;
|
||||
@JsonProperty
|
||||
private Double score;
|
||||
@JsonProperty
|
||||
private int index;
|
||||
@JsonProperty
|
||||
private Object modelSpecificResults;
|
||||
@JsonProperty
|
||||
private CandidateInfo candidateInfo;
|
||||
private ResultReference resultReference;
|
||||
|
||||
|
||||
public OptimizationResult(Candidate candidate, Double score, int index, Object modelSpecificResults,
|
||||
CandidateInfo candidateInfo, ResultReference resultReference) {
|
||||
this.candidate = candidate;
|
||||
this.score = score;
|
||||
this.index = index;
|
||||
this.modelSpecificResults = modelSpecificResults;
|
||||
this.candidateInfo = candidateInfo;
|
||||
this.resultReference = resultReference;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* ParameterSpace: defines the acceptable ranges of values a given parameter may take.
|
||||
* Note that parameter spaces can be simple (like {@code ParameterSpace<Double>}) or complicated, including
|
||||
* multiple nested ParameterSpaces
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ParameterSpace<P> {
|
||||
|
||||
/**
|
||||
* Generate a candidate given a set of values. These values are then mapped to a specific candidate, using some
|
||||
* mapping function (such as the prior probability distribution)
|
||||
*
|
||||
* @param parameterValues A set of values, each in the range [0,1], of length {@link #numParameters()}
|
||||
*/
|
||||
P getValue(double[] parameterValues);
|
||||
|
||||
/**
|
||||
* Get the total number of parameters (hyperparameters) to be optimized. This includes optional parameters from
|
||||
* different parameter subpaces. (Thus, not every parameter may be used in every candidate)
|
||||
*
|
||||
* @return Number of hyperparameters to be optimized
|
||||
*/
|
||||
int numParameters();
|
||||
|
||||
/**
|
||||
* Collect a list of parameters, recursively. Note that leaf parameters are parameters that do not have any
|
||||
* nested parameter spaces
|
||||
*/
|
||||
List<ParameterSpace> collectLeaves();
|
||||
|
||||
/**
|
||||
* Get a list of nested parameter spaces by name. Note that the returned parameter spaces may in turn have further
|
||||
* nested parameter spaces. The map should be empty for leaf parameter spaces
|
||||
*
|
||||
* @return A map of nested parameter spaces
|
||||
*/
|
||||
Map<String, ParameterSpace> getNestedSpaces();
|
||||
|
||||
/**
|
||||
* Is this ParameterSpace a leaf? (i.e., does it contain other ParameterSpaces internally?)
|
||||
*/
|
||||
@JsonIgnore
|
||||
boolean isLeaf();
|
||||
|
||||
/**
|
||||
* For leaf ParameterSpaces: set the indices of the leaf ParameterSpace.
|
||||
* Expects input of length {@link #numParameters()}. Throws exception if {@link #isLeaf()} is false.
|
||||
*
|
||||
* @param indices Indices to set. Length should equal {@link #numParameters()}
|
||||
*/
|
||||
void setIndices(int... indices);
|
||||
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
/**
|
||||
* The TaskCreator is used to take a candidate configuration, data provider and score function, and create something
|
||||
* that can be executed as a Callable
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public interface TaskCreator {
|
||||
|
||||
/**
|
||||
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
|
||||
*
|
||||
* @param candidate Candidate (model) configuration to be trained
|
||||
* @param dataProvider DataProvider, for the data
|
||||
* @param scoreFunction Score function to be used to evaluate the model
|
||||
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
|
||||
* @return A callable that returns an OptimizationResult, once optimization is complete
|
||||
*/
|
||||
@Deprecated
|
||||
Callable<OptimizationResult> create(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction,
|
||||
List<StatusListener> statusListeners, IOptimizationRunner runner);
|
||||
|
||||
/**
|
||||
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
|
||||
*
|
||||
* @param candidate Candidate (model) configuration to be trained
|
||||
* @param dataSource Data source
|
||||
* @param dataSourceProperties Properties (may be null) for the data source
|
||||
* @param scoreFunction Score function to be used to evaluate the model
|
||||
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
|
||||
* @return A callable that returns an OptimizationResult, once optimization is complete
|
||||
*/
|
||||
Callable<OptimizationResult> create(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties,
|
||||
ScoreFunction scoreFunction, List<StatusListener> statusListeners, IOptimizationRunner runner);
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class TaskCreatorProvider {
|
||||
|
||||
private static Map<Class<? extends ParameterSpace>, Class<? extends TaskCreator>> map = new HashMap<>();
|
||||
|
||||
public synchronized static TaskCreator defaultTaskCreatorFor(Class<? extends ParameterSpace> paramSpaceClass){
|
||||
Class<? extends TaskCreator> c = map.get(paramSpaceClass);
|
||||
try {
|
||||
if(c == null){
|
||||
return null;
|
||||
}
|
||||
return c.newInstance();
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException("Could not create new instance of task creator class: " + c + " - missing no-arg constructor?", e);
|
||||
}
|
||||
}
|
||||
|
||||
public synchronized static void registerDefaultTaskCreatorClass(Class<? extends ParameterSpace> spaceClass,
|
||||
Class<? extends TaskCreator> creatorClass){
|
||||
map.put(spaceClass, creatorClass);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.adapter;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* An abstract class used for adapting one type into another. Subclasses of this need to merely implement 2 simple methods
|
||||
*
|
||||
* @param <F> Type to convert from
|
||||
* @param <T> Type to convert to
|
||||
* @author Alex Black
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
|
||||
|
||||
|
||||
protected abstract T convertValue(F from);
|
||||
|
||||
protected abstract ParameterSpace<F> underlying();
|
||||
|
||||
protected abstract String underlyingName();
|
||||
|
||||
|
||||
@Override
|
||||
public T getValue(double[] parameterValues) {
|
||||
return convertValue(underlying().getValue(parameterValues));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numParameters() {
|
||||
return underlying().numParameters();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
ParameterSpace p = underlying();
|
||||
if(p.isLeaf()){
|
||||
return Collections.singletonList(p);
|
||||
}
|
||||
return underlying().collectLeaves();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||
return Collections.singletonMap(underlyingName(), (ParameterSpace)underlying());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return false; //Underlying may be a leaf, however
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIndices(int... indices) {
|
||||
underlying().setIndices(indices);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return underlying().toString();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.data;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* DataProvider interface abstracts out the providing of data
|
||||
* @deprecated Use {@link DataSource}
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
@Deprecated
|
||||
public interface DataProvider extends Serializable {
|
||||
|
||||
/**
|
||||
* Get training data given some parameters for the data.
|
||||
* Data parameters map is used to specify things like batch
|
||||
* size data preprocessing
|
||||
*
|
||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||
* @return training data
|
||||
*/
|
||||
Object trainData(Map<String, Object> dataParameters);
|
||||
|
||||
/**
|
||||
* Get training data given some parameters for the data. Data parameters map is used to specify things like batch
|
||||
* size data preprocessing
|
||||
*
|
||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||
* @return training data
|
||||
*/
|
||||
Object testData(Map<String, Object> dataParameters);
|
||||
|
||||
Class<?> getDataType();
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.data;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* This is a {@link DataProvider} for
|
||||
* an {@link DataSetIteratorFactory} which
|
||||
* based on a key of {@link DataSetIteratorFactoryProvider#FACTORY_KEY}
|
||||
* will create {@link org.nd4j.linalg.dataset.api.iterator.DataSetIterator}
|
||||
* for use with arbiter.
|
||||
*
|
||||
* This {@link DataProvider} is mainly meant for use for command line driven
|
||||
* applications.
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@Data
|
||||
public class DataSetIteratorFactoryProvider implements DataProvider {
|
||||
|
||||
public final static String FACTORY_KEY = "org.deeplearning4j.arbiter.data.data.factory";
|
||||
|
||||
/**
|
||||
* Get training data given some parameters for the data.
|
||||
* Data parameters map is used to specify things like batch
|
||||
* size data preprocessing
|
||||
*
|
||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||
* @return training data
|
||||
*/
|
||||
@Override
|
||||
public DataSetIteratorFactory trainData(Map<String, Object> dataParameters) {
|
||||
return create(dataParameters);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get training data given some parameters for the data. Data parameters map
|
||||
* is used to specify things like batch
|
||||
* size data preprocessing
|
||||
*
|
||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||
* @return training data
|
||||
*/
|
||||
@Override
|
||||
public DataSetIteratorFactory testData(Map<String, Object> dataParameters) {
|
||||
return create(dataParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<?> getDataType() {
|
||||
return DataSetIteratorFactory.class;
|
||||
}
|
||||
|
||||
private DataSetIteratorFactory create(Map<String, Object> dataParameters) {
|
||||
if (dataParameters == null)
|
||||
throw new IllegalArgumentException(
|
||||
"Data parameters is null. Please specify a class name to create a dataset iterator.");
|
||||
if (!dataParameters.containsKey(FACTORY_KEY))
|
||||
throw new IllegalArgumentException(
|
||||
"No data set iterator factory class found. Please specify a class name with key "
|
||||
+ FACTORY_KEY);
|
||||
String value = dataParameters.get(FACTORY_KEY).toString();
|
||||
try {
|
||||
Class<? extends DataSetIteratorFactory> clazz =
|
||||
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
|
||||
return clazz.newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Properties;
|
||||
|
||||
/**
|
||||
* DataSource: defines where the data should come from for training and testing.
|
||||
* Note that implementations must have a no-argument contsructor
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public interface DataSource extends Serializable {
|
||||
|
||||
/**
|
||||
* Configure the current data source with the specified properties
|
||||
* Note: These properties are fixed for the training instance, and are optionally provided by the user
|
||||
* at the configuration stage.
|
||||
* The properties could be anything - and are usually specific to each DataSource implementation.
|
||||
* For example, values such as batch size could be set using these properties
|
||||
* @param properties Properties to apply to the data source instance
|
||||
*/
|
||||
void configure(Properties properties);
|
||||
|
||||
/**
|
||||
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
|
||||
*/
|
||||
Object trainData();
|
||||
|
||||
/**
|
||||
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
|
||||
*/
|
||||
Object testData();
|
||||
|
||||
/**
|
||||
* The type of data returned by {@link #trainData()} and {@link #testData()}.
|
||||
* Usually DataSetIterator or MultiDataSetIterator
|
||||
* @return Class of the objects returned by trainData and testData
|
||||
*/
|
||||
Class<?> getDataType();
|
||||
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.evaluation;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* ModelEvaluator: Used to conduct additional evaluation.
|
||||
* For example, this may be classification performance on a test set or similar
|
||||
*/
|
||||
public interface ModelEvaluator extends Serializable {
|
||||
Object evaluateModel(Object model, DataProvider dataProvider);
|
||||
|
||||
/**
|
||||
* @return The model types supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedModelTypes();
|
||||
|
||||
/**
|
||||
* @return The datatypes supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedDataTypes();
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.saving;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A simple class to store optimization results in-memory.
|
||||
* Not recommended for large (or a large number of) models.
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class InMemoryResultSaver implements ResultSaver {
|
||||
@Override
|
||||
public ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException {
|
||||
return new InMemoryResult(result, modelResult);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Class<?>> getSupportedCandidateTypes() {
|
||||
return Collections.<Class<?>>singletonList(Object.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Class<?>> getSupportedModelTypes() {
|
||||
return Collections.<Class<?>>singletonList(Object.class);
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
private static class InMemoryResult implements ResultReference {
|
||||
private OptimizationResult result;
|
||||
private Object modelResult;
|
||||
|
||||
@Override
|
||||
public OptimizationResult getResult() throws IOException {
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getResultModel() throws IOException {
|
||||
return modelResult;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.saving;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Idea: We can't store all results in memory in general (might have thousands of candidates with millions of
|
||||
* parameters each)
|
||||
* So instead: return a reference to the saved result. Idea is that the result may be saved to disk or a database,
|
||||
* and we can easily load it back into memory (if/when required) using the getResult() method
|
||||
*/
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ResultReference {
|
||||
|
||||
OptimizationResult getResult() throws IOException;
|
||||
|
||||
Object getResultModel() throws IOException;
|
||||
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.saving;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* The ResultSaver interface provides a means of saving models in such a way that they can be loaded back into memory later,
|
||||
* regardless of where/how they are saved.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ResultSaver {
|
||||
|
||||
/**
|
||||
* Save the model (including configuration and any additional evaluation/results)
|
||||
*
|
||||
* @param result Optimization result for the model to save
|
||||
* @param modelResult Model result to save
|
||||
* @return ResultReference, such that the result can be loaded back into memory
|
||||
* @throws IOException If IO error occurs during model saving
|
||||
*/
|
||||
ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException;
|
||||
|
||||
/**
|
||||
* @return The candidate types supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedCandidateTypes();
|
||||
|
||||
/**
|
||||
* @return The model types supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedModelTypes();
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.score;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Properties;
|
||||
|
||||
/**
|
||||
* ScoreFunction defines the objective of hyperparameter optimization.
|
||||
* Specifically, it is used to calculate a score for a given model, relative to the data set provided
|
||||
* in the configuration.
|
||||
*
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ScoreFunction extends Serializable {
|
||||
|
||||
/**
|
||||
* Calculate and return the score, for the given model and data provider
|
||||
*
|
||||
* @param model Model to score
|
||||
* @param dataProvider Data provider - data to use
|
||||
* @param dataParameters Parameters for data
|
||||
* @return Calculated score
|
||||
*/
|
||||
double score(Object model, DataProvider dataProvider, Map<String, Object> dataParameters);
|
||||
|
||||
/**
|
||||
* Calculate and return the score, for the given model and data provider
|
||||
*
|
||||
* @param model Model to score
|
||||
* @param dataSource Data source
|
||||
* @param dataSourceProperties data source properties
|
||||
* @return Calculated score
|
||||
*/
|
||||
double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties);
|
||||
|
||||
/**
|
||||
* Should this score function be minimized or maximized?
|
||||
*
|
||||
* @return true if score should be minimized, false if score should be maximized
|
||||
*/
|
||||
boolean minimize();
|
||||
|
||||
/**
|
||||
* @return The model types supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedModelTypes();
|
||||
|
||||
/**
|
||||
* @return The data types supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedDataTypes();
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.termination;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
/**
|
||||
* Terminate hyperparameter search when the number of candidates exceeds a specified value.
|
||||
* Note that this is counted as number of completed candidates, plus number of failed candidates.
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public class MaxCandidatesCondition implements TerminationCondition {
|
||||
@JsonProperty
|
||||
private int maxCandidates;
|
||||
|
||||
@Override
|
||||
public void initialize(IOptimizationRunner optimizationRunner) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean terminate(IOptimizationRunner optimizationRunner) {
|
||||
return optimizationRunner.numCandidatesCompleted() + optimizationRunner.numCandidatesFailed() >= maxCandidates;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "MaxCandidatesCondition(" + maxCandidates + ")";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.termination;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import org.joda.time.format.DateTimeFormat;
|
||||
import org.joda.time.format.DateTimeFormatter;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Terminate hyperparameter optimization after
|
||||
* a fixed amount of time has passed
|
||||
* @author Alex Black
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public class MaxTimeCondition implements TerminationCondition {
|
||||
private static final DateTimeFormatter formatter = DateTimeFormat.forPattern("dd-MMM HH:mm ZZ");
|
||||
|
||||
private long duration;
|
||||
private TimeUnit timeUnit;
|
||||
private long startTime;
|
||||
private long endTime;
|
||||
|
||||
|
||||
private MaxTimeCondition(@JsonProperty("duration") long duration, @JsonProperty("timeUnit") TimeUnit timeUnit,
|
||||
@JsonProperty("startTime") long startTime, @JsonProperty("endTime") long endTime) {
|
||||
this.duration = duration;
|
||||
this.timeUnit = timeUnit;
|
||||
this.startTime = startTime;
|
||||
this.endTime = endTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param duration Duration of time
|
||||
* @param timeUnit Unit that the duration is specified in
|
||||
*/
|
||||
public MaxTimeCondition(long duration, TimeUnit timeUnit) {
|
||||
this.duration = duration;
|
||||
this.timeUnit = timeUnit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initialize(IOptimizationRunner optimizationRunner) {
|
||||
startTime = System.currentTimeMillis();
|
||||
this.endTime = startTime + timeUnit.toMillis(duration);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean terminate(IOptimizationRunner optimizationRunner) {
|
||||
return System.currentTimeMillis() >= endTime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
if (startTime > 0) {
|
||||
return "MaxTimeCondition(" + duration + "," + timeUnit + ",start=\"" + formatter.print(startTime)
|
||||
+ "\",end=\"" + formatter.print(endTime) + "\")";
|
||||
} else {
|
||||
return "MaxTimeCondition(" + duration + "," + timeUnit + "\")";
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.termination;
|
||||
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
/**
|
||||
* Global termination condition for conducting hyperparameter optimization.
|
||||
* Termination conditions are used to determine if/when the optimization should stop.
|
||||
*/
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public interface TerminationCondition {
|
||||
|
||||
/**
|
||||
* Initialize the termination condition (such as starting timers, etc).
|
||||
*/
|
||||
void initialize(IOptimizationRunner optimizationRunner);
|
||||
|
||||
/**
|
||||
* Determine whether optimization should be terminated
|
||||
*
|
||||
* @param optimizationRunner Optimization runner
|
||||
* @return true if learning should be terminated, false otherwise
|
||||
*/
|
||||
boolean terminate(IOptimizationRunner optimizationRunner);
|
||||
|
||||
}
|
|
@ -0,0 +1,226 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.config;
|
||||
|
||||
import lombok.*;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
|
||||
/**
|
||||
* OptimizationConfiguration ties together all of the various
|
||||
* components (such as data, score functions, result saving etc)
|
||||
* required to execute hyperparameter optimization.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(exclude = {"dataProvider", "terminationConditions", "candidateGenerator", "resultSaver"})
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public class OptimizationConfiguration {
|
||||
@JsonSerialize
|
||||
private DataProvider dataProvider;
|
||||
@JsonSerialize
|
||||
private Class<? extends DataSource> dataSource;
|
||||
@JsonSerialize
|
||||
private Properties dataSourceProperties;
|
||||
@JsonSerialize
|
||||
private CandidateGenerator candidateGenerator;
|
||||
@JsonSerialize
|
||||
private ResultSaver resultSaver;
|
||||
@JsonSerialize
|
||||
private ScoreFunction scoreFunction;
|
||||
@JsonSerialize
|
||||
private List<TerminationCondition> terminationConditions;
|
||||
@JsonSerialize
|
||||
private Long rngSeed;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
private long executionStartTime;
|
||||
|
||||
|
||||
private OptimizationConfiguration(Builder builder) {
|
||||
this.dataProvider = builder.dataProvider;
|
||||
this.dataSource = builder.dataSource;
|
||||
this.dataSourceProperties = builder.dataSourceProperties;
|
||||
this.candidateGenerator = builder.candidateGenerator;
|
||||
this.resultSaver = builder.resultSaver;
|
||||
this.scoreFunction = builder.scoreFunction;
|
||||
this.terminationConditions = builder.terminationConditions;
|
||||
this.rngSeed = builder.rngSeed;
|
||||
|
||||
if (rngSeed != null)
|
||||
candidateGenerator.setRngSeed(rngSeed);
|
||||
|
||||
//Validate the configuration: data types, score types, etc
|
||||
//TODO
|
||||
|
||||
//Validate that the dataSource has a no-arg constructor
|
||||
if (dataSource != null) {
|
||||
try {
|
||||
dataSource.getConstructor();
|
||||
} catch (NoSuchMethodException e) {
|
||||
throw new IllegalStateException("Data source class " + dataSource.getName() + " does not have a public no-argument constructor");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private DataProvider dataProvider;
|
||||
private Class<? extends DataSource> dataSource;
|
||||
private Properties dataSourceProperties;
|
||||
private CandidateGenerator candidateGenerator;
|
||||
private ResultSaver resultSaver;
|
||||
private ScoreFunction scoreFunction;
|
||||
private List<TerminationCondition> terminationConditions;
|
||||
private Long rngSeed;
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #dataSource(Class, Properties)}
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder dataProvider(DataProvider dataProvider) {
|
||||
this.dataProvider = dataProvider;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* DataSource: defines where the data should come from for training and testing.
|
||||
* Note that implementations must have a no-argument contsructor
|
||||
*
|
||||
* @param dataSource Class for the data source
|
||||
* @param dataSourceProperties May be null. Properties for configuring the data source
|
||||
*/
|
||||
public Builder dataSource(Class<? extends DataSource> dataSource, Properties dataSourceProperties) {
|
||||
this.dataSource = dataSource;
|
||||
this.dataSourceProperties = dataSourceProperties;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder candidateGenerator(CandidateGenerator candidateGenerator) {
|
||||
this.candidateGenerator = candidateGenerator;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder modelSaver(ResultSaver resultSaver) {
|
||||
this.resultSaver = resultSaver;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder scoreFunction(ScoreFunction scoreFunction) {
|
||||
this.scoreFunction = scoreFunction;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Termination conditions to use
|
||||
*
|
||||
* @param conditions
|
||||
* @return
|
||||
*/
|
||||
public Builder terminationConditions(TerminationCondition... conditions) {
|
||||
terminationConditions = Arrays.asList(conditions);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder terminationConditions(List<TerminationCondition> terminationConditions) {
|
||||
this.terminationConditions = terminationConditions;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder rngSeed(long rngSeed) {
|
||||
this.rngSeed = rngSeed;
|
||||
return this;
|
||||
}
|
||||
|
||||
public OptimizationConfiguration build() {
|
||||
return new OptimizationConfiguration(this);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Create an optimization configuration from the json
|
||||
*
|
||||
* @param json the json to create the config from
|
||||
* For type definitions
|
||||
* @see OptimizationConfiguration
|
||||
*/
|
||||
public static OptimizationConfiguration fromYaml(String json) {
|
||||
try {
|
||||
return JsonMapper.getYamlMapper().readValue(json, OptimizationConfiguration.class);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an optimization configuration from the json
|
||||
*
|
||||
* @param json the json to create the config from
|
||||
* @see OptimizationConfiguration
|
||||
*/
|
||||
public static OptimizationConfiguration fromJson(String json) {
|
||||
try {
|
||||
return JsonMapper.getMapper().readValue(json, OptimizationConfiguration.class);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a json configuration of this optimization configuration
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public String toJson() {
|
||||
try {
|
||||
return JsonMapper.getMapper().writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a yaml configuration of this optimization configuration
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public String toYaml() {
|
||||
try {
|
||||
return JsonMapper.getYamlMapper().writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||
|
||||
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
|
||||
/**
|
||||
* Degenerate distribution: i.e., integer "distribution" that is just a fixed value
|
||||
*/
|
||||
public class DegenerateIntegerDistribution implements IntegerDistribution {
|
||||
private int value;
|
||||
|
||||
public DegenerateIntegerDistribution(int value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public double probability(int x) {
|
||||
return (x == value ? 1.0 : 0.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cumulativeProbability(int x) {
|
||||
return (x >= value ? 1.0 : 0.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cumulativeProbability(int x0, int x1) throws NumberIsTooLargeException {
|
||||
return (value >= x0 && value <= x1 ? 1.0 : 0.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int inverseCumulativeProbability(double p) throws OutOfRangeException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getNumericalMean() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getNumericalVariance() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSupportLowerBound() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSupportUpperBound() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupportConnected() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reseedRandomGenerator(long seed) {
|
||||
//no op
|
||||
}
|
||||
|
||||
@Override
|
||||
public int sample() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int[] sample(int sampleSize) {
|
||||
int[] out = new int[sampleSize];
|
||||
for (int i = 0; i < out.length; i++)
|
||||
out[i] = value;
|
||||
return out;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||
|
||||
import org.apache.commons.math3.distribution.*;
|
||||
|
||||
/**
|
||||
* Distribution utils for Apache Commons math distributions - which don't provide equals, hashcode, toString methods,
|
||||
* don't implement serializable etc.
|
||||
* Which makes unit testing etc quite difficult.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class DistributionUtils {
|
||||
|
||||
private DistributionUtils() {}
|
||||
|
||||
|
||||
public static boolean distributionsEqual(RealDistribution a, RealDistribution b) {
|
||||
if (a.getClass() != b.getClass())
|
||||
return false;
|
||||
Class<?> c = a.getClass();
|
||||
if (c == BetaDistribution.class) {
|
||||
BetaDistribution ba = (BetaDistribution) a;
|
||||
BetaDistribution bb = (BetaDistribution) b;
|
||||
|
||||
return ba.getAlpha() == bb.getAlpha() && ba.getBeta() == bb.getBeta();
|
||||
} else if (c == CauchyDistribution.class) {
|
||||
CauchyDistribution ca = (CauchyDistribution) a;
|
||||
CauchyDistribution cb = (CauchyDistribution) b;
|
||||
return ca.getMedian() == cb.getMedian() && ca.getScale() == cb.getScale();
|
||||
} else if (c == ChiSquaredDistribution.class) {
|
||||
ChiSquaredDistribution ca = (ChiSquaredDistribution) a;
|
||||
ChiSquaredDistribution cb = (ChiSquaredDistribution) b;
|
||||
return ca.getDegreesOfFreedom() == cb.getDegreesOfFreedom();
|
||||
} else if (c == ExponentialDistribution.class) {
|
||||
ExponentialDistribution ea = (ExponentialDistribution) a;
|
||||
ExponentialDistribution eb = (ExponentialDistribution) b;
|
||||
return ea.getMean() == eb.getMean();
|
||||
} else if (c == FDistribution.class) {
|
||||
FDistribution fa = (FDistribution) a;
|
||||
FDistribution fb = (FDistribution) b;
|
||||
return fa.getNumeratorDegreesOfFreedom() == fb.getNumeratorDegreesOfFreedom()
|
||||
&& fa.getDenominatorDegreesOfFreedom() == fb.getDenominatorDegreesOfFreedom();
|
||||
} else if (c == GammaDistribution.class) {
|
||||
GammaDistribution ga = (GammaDistribution) a;
|
||||
GammaDistribution gb = (GammaDistribution) b;
|
||||
return ga.getShape() == gb.getShape() && ga.getScale() == gb.getScale();
|
||||
} else if (c == LevyDistribution.class) {
|
||||
LevyDistribution la = (LevyDistribution) a;
|
||||
LevyDistribution lb = (LevyDistribution) b;
|
||||
return la.getLocation() == lb.getLocation() && la.getScale() == lb.getScale();
|
||||
} else if (c == LogNormalDistribution.class) {
|
||||
LogNormalDistribution la = (LogNormalDistribution) a;
|
||||
LogNormalDistribution lb = (LogNormalDistribution) b;
|
||||
return la.getScale() == lb.getScale() && la.getShape() == lb.getShape();
|
||||
} else if (c == NormalDistribution.class) {
|
||||
NormalDistribution na = (NormalDistribution) a;
|
||||
NormalDistribution nb = (NormalDistribution) b;
|
||||
return na.getMean() == nb.getMean() && na.getStandardDeviation() == nb.getStandardDeviation();
|
||||
} else if (c == ParetoDistribution.class) {
|
||||
ParetoDistribution pa = (ParetoDistribution) a;
|
||||
ParetoDistribution pb = (ParetoDistribution) b;
|
||||
return pa.getScale() == pb.getScale() && pa.getShape() == pb.getShape();
|
||||
} else if (c == TDistribution.class) {
|
||||
TDistribution ta = (TDistribution) a;
|
||||
TDistribution tb = (TDistribution) b;
|
||||
return ta.getDegreesOfFreedom() == tb.getDegreesOfFreedom();
|
||||
} else if (c == TriangularDistribution.class) {
|
||||
TriangularDistribution ta = (TriangularDistribution) a;
|
||||
TriangularDistribution tb = (TriangularDistribution) b;
|
||||
return ta.getSupportLowerBound() == tb.getSupportLowerBound()
|
||||
&& ta.getSupportUpperBound() == tb.getSupportUpperBound() && ta.getMode() == tb.getMode();
|
||||
} else if (c == UniformRealDistribution.class) {
|
||||
UniformRealDistribution ua = (UniformRealDistribution) a;
|
||||
UniformRealDistribution ub = (UniformRealDistribution) b;
|
||||
return ua.getSupportLowerBound() == ub.getSupportLowerBound()
|
||||
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
|
||||
} else if (c == WeibullDistribution.class) {
|
||||
WeibullDistribution wa = (WeibullDistribution) a;
|
||||
WeibullDistribution wb = (WeibullDistribution) b;
|
||||
return wa.getShape() == wb.getShape() && wa.getScale() == wb.getScale();
|
||||
} else if (c == LogUniformDistribution.class ){
|
||||
LogUniformDistribution lu_a = (LogUniformDistribution)a;
|
||||
LogUniformDistribution lu_b = (LogUniformDistribution)b;
|
||||
return lu_a.getMin() == lu_b.getMin() && lu_a.getMax() == lu_b.getMax();
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unknown or not supported RealDistribution: " + c);
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean distributionEquals(IntegerDistribution a, IntegerDistribution b) {
|
||||
if (a.getClass() != b.getClass())
|
||||
return false;
|
||||
Class<?> c = a.getClass();
|
||||
|
||||
if (c == BinomialDistribution.class) {
|
||||
BinomialDistribution ba = (BinomialDistribution) a;
|
||||
BinomialDistribution bb = (BinomialDistribution) b;
|
||||
return ba.getNumberOfTrials() == bb.getNumberOfTrials()
|
||||
&& ba.getProbabilityOfSuccess() == bb.getProbabilityOfSuccess();
|
||||
} else if (c == GeometricDistribution.class) {
|
||||
GeometricDistribution ga = (GeometricDistribution) a;
|
||||
GeometricDistribution gb = (GeometricDistribution) b;
|
||||
return ga.getProbabilityOfSuccess() == gb.getProbabilityOfSuccess();
|
||||
} else if (c == HypergeometricDistribution.class) {
|
||||
HypergeometricDistribution ha = (HypergeometricDistribution) a;
|
||||
HypergeometricDistribution hb = (HypergeometricDistribution) b;
|
||||
return ha.getPopulationSize() == hb.getPopulationSize()
|
||||
&& ha.getNumberOfSuccesses() == hb.getNumberOfSuccesses()
|
||||
&& ha.getSampleSize() == hb.getSampleSize();
|
||||
} else if (c == PascalDistribution.class) {
|
||||
PascalDistribution pa = (PascalDistribution) a;
|
||||
PascalDistribution pb = (PascalDistribution) b;
|
||||
return pa.getNumberOfSuccesses() == pb.getNumberOfSuccesses()
|
||||
&& pa.getProbabilityOfSuccess() == pb.getProbabilityOfSuccess();
|
||||
} else if (c == PoissonDistribution.class) {
|
||||
PoissonDistribution pa = (PoissonDistribution) a;
|
||||
PoissonDistribution pb = (PoissonDistribution) b;
|
||||
return pa.getMean() == pb.getMean();
|
||||
} else if (c == UniformIntegerDistribution.class) {
|
||||
UniformIntegerDistribution ua = (UniformIntegerDistribution) a;
|
||||
UniformIntegerDistribution ub = (UniformIntegerDistribution) b;
|
||||
return ua.getSupportUpperBound() == ub.getSupportUpperBound()
|
||||
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
|
||||
} else if (c == ZipfDistribution.class) {
|
||||
ZipfDistribution za = (ZipfDistribution) a;
|
||||
ZipfDistribution zb = (ZipfDistribution) b;
|
||||
return za.getNumberOfElements() == zb.getNumberOfElements() && za.getExponent() == zb.getNumberOfElements();
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unknown or not supported IntegerDistribution: " + c);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,155 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import lombok.Getter;
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* Log uniform distribution, with support in range [min, max] for min > 0
|
||||
*
|
||||
* Reference: <a href="https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php">https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php</a>
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class LogUniformDistribution implements RealDistribution {
|
||||
|
||||
@Getter private final double min;
|
||||
@Getter private final double max;
|
||||
|
||||
private final double logMin;
|
||||
private final double logMax;
|
||||
|
||||
private transient Random rng = new Random();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param min Minimum value
|
||||
* @param max Maximum value
|
||||
*/
|
||||
public LogUniformDistribution(double min, double max) {
|
||||
Preconditions.checkArgument(min > 0, "Minimum must be > 0. Got: " + min);
|
||||
Preconditions.checkArgument(max > min, "Maximum must be > min. Got: (min, max)=("
|
||||
+ min + "," + max + ")");
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
|
||||
this.logMin = Math.log(min);
|
||||
this.logMax = Math.log(max);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double probability(double x) {
|
||||
if(x < min || x > max){
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 1.0 / (x * (logMax - logMin));
|
||||
}
|
||||
|
||||
@Override
|
||||
public double density(double x) {
|
||||
return probability(x);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cumulativeProbability(double x) {
|
||||
if(x <= min){
|
||||
return 0.0;
|
||||
} else if(x >= max){
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
return (Math.log(x)-logMin)/(logMax-logMin);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException {
|
||||
return cumulativeProbability(x1) - cumulativeProbability(x0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double inverseCumulativeProbability(double p) throws OutOfRangeException {
|
||||
Preconditions.checkArgument(p >= 0 && p <= 1, "Invalid input: " + p);
|
||||
return Math.exp(p * (logMax-logMin) + logMin);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getNumericalMean() {
|
||||
return (max-min)/(logMax-logMin);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getNumericalVariance() {
|
||||
double d1 = (logMax-logMin)*(max*max - min*min) - 2*(max-min)*(max-min);
|
||||
return d1 / (2*Math.pow(logMax-logMin, 2.0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getSupportLowerBound() {
|
||||
return min;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getSupportUpperBound() {
|
||||
return max;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupportLowerBoundInclusive() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupportUpperBoundInclusive() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupportConnected() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reseedRandomGenerator(long seed) {
|
||||
rng.setSeed(seed);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double sample() {
|
||||
return inverseCumulativeProbability(rng.nextDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] sample(int sampleSize) {
|
||||
double[] d = new double[sampleSize];
|
||||
for( int i=0; i<sampleSize; i++ ){
|
||||
d[i] = sample();
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
return "LogUniformDistribution(min=" + min + ",max=" + max + ")";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.util.LeafUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
/**
|
||||
* BaseCandidateGenerator: abstract class upon which {@link RandomSearchGenerator},
|
||||
* {@link GridSearchCandidateGenerator} and {@link GeneticSearchCandidateGenerator}
|
||||
* are built.
|
||||
*
|
||||
* @param <T> Type of candidates to generate
|
||||
*/
|
||||
@Data
|
||||
@EqualsAndHashCode(exclude = {"rng", "candidateCounter"})
|
||||
public abstract class BaseCandidateGenerator<T> implements CandidateGenerator {
|
||||
protected ParameterSpace<T> parameterSpace;
|
||||
protected AtomicInteger candidateCounter = new AtomicInteger(0);
|
||||
protected SynchronizedRandomGenerator rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
protected Map<String, Object> dataParameters;
|
||||
protected boolean initDone = false;
|
||||
|
||||
public BaseCandidateGenerator(ParameterSpace<T> parameterSpace, Map<String, Object> dataParameters,
|
||||
boolean initDone) {
|
||||
this.parameterSpace = parameterSpace;
|
||||
this.dataParameters = dataParameters;
|
||||
this.initDone = initDone;
|
||||
}
|
||||
|
||||
protected void initialize() {
|
||||
if(!initDone) {
|
||||
//First: collect leaf parameter spaces objects and remove duplicates
|
||||
List<ParameterSpace> noDuplicatesList = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
|
||||
|
||||
//Second: assign each a number
|
||||
int i = 0;
|
||||
for (ParameterSpace ps : noDuplicatesList) {
|
||||
int np = ps.numParameters();
|
||||
if (np == 1) {
|
||||
ps.setIndices(i++);
|
||||
} else {
|
||||
int[] values = new int[np];
|
||||
for (int j = 0; j < np; j++)
|
||||
values[j] = i++;
|
||||
ps.setIndices(values);
|
||||
}
|
||||
}
|
||||
initDone = true;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParameterSpace<T> getParameterSpace() {
|
||||
return parameterSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reportResults(OptimizationResult result) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setRngSeed(long rngSeed) {
|
||||
rng.setSeed(rngSeed);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.EmptyPopulationInitializer;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.GeneticSelectionOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Uses a genetic algorithm to generate candidates.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
@Slf4j
|
||||
public class GeneticSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||
|
||||
@Getter
|
||||
protected final PopulationModel populationModel;
|
||||
|
||||
protected final ChromosomeFactory chromosomeFactory;
|
||||
protected final SelectionOperator selectionOperator;
|
||||
|
||||
protected boolean hasMoreCandidates = true;
|
||||
|
||||
public static class Builder {
|
||||
protected final ParameterSpace<?> parameterSpace;
|
||||
|
||||
protected Map<String, Object> dataParameters;
|
||||
protected boolean initDone;
|
||||
protected boolean minimizeScore;
|
||||
protected PopulationModel populationModel;
|
||||
protected ChromosomeFactory chromosomeFactory;
|
||||
protected SelectionOperator selectionOperator;
|
||||
|
||||
/**
|
||||
* @param parameterSpace ParameterSpace from which to generate candidates
|
||||
* @param scoreFunction The score function that will be used in the OptimizationConfiguration
|
||||
*/
|
||||
public Builder(ParameterSpace<?> parameterSpace, ScoreFunction scoreFunction) {
|
||||
this.parameterSpace = parameterSpace;
|
||||
this.minimizeScore = scoreFunction.minimize();
|
||||
}
|
||||
|
||||
/**
|
||||
* @param populationModel The PopulationModel instance to use.
|
||||
*/
|
||||
public Builder populationModel(PopulationModel populationModel) {
|
||||
this.populationModel = populationModel;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param selectionOperator The SelectionOperator to use. Default is GeneticSelectionOperator
|
||||
*/
|
||||
public Builder selectionOperator(SelectionOperator selectionOperator) {
|
||||
this.selectionOperator = selectionOperator;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder dataParameters(Map<String, Object> dataParameters) {
|
||||
|
||||
this.dataParameters = dataParameters;
|
||||
return this;
|
||||
}
|
||||
|
||||
public GeneticSearchCandidateGenerator.Builder initDone(boolean initDone) {
|
||||
this.initDone = initDone;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param chromosomeFactory The ChromosomeFactory to use
|
||||
*/
|
||||
public Builder chromosomeFactory(ChromosomeFactory chromosomeFactory) {
|
||||
this.chromosomeFactory = chromosomeFactory;
|
||||
return this;
|
||||
}
|
||||
|
||||
public GeneticSearchCandidateGenerator build() {
|
||||
if (populationModel == null) {
|
||||
PopulationInitializer defaultPopulationInitializer = new EmptyPopulationInitializer();
|
||||
populationModel = new PopulationModel.Builder().populationInitializer(defaultPopulationInitializer)
|
||||
.build();
|
||||
}
|
||||
|
||||
if (chromosomeFactory == null) {
|
||||
chromosomeFactory = new ChromosomeFactory();
|
||||
}
|
||||
|
||||
if (selectionOperator == null) {
|
||||
selectionOperator = new GeneticSelectionOperator.Builder().build();
|
||||
}
|
||||
|
||||
return new GeneticSearchCandidateGenerator(this);
|
||||
}
|
||||
}
|
||||
|
||||
private GeneticSearchCandidateGenerator(Builder builder) {
|
||||
super(builder.parameterSpace, builder.dataParameters, builder.initDone);
|
||||
|
||||
initialize();
|
||||
|
||||
chromosomeFactory = builder.chromosomeFactory;
|
||||
populationModel = builder.populationModel;
|
||||
selectionOperator = builder.selectionOperator;
|
||||
|
||||
chromosomeFactory.initializeInstance(builder.parameterSpace.numParameters());
|
||||
populationModel.initializeInstance(builder.minimizeScore);
|
||||
selectionOperator.initializeInstance(populationModel, chromosomeFactory);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasMoreCandidates() {
|
||||
return hasMoreCandidates;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Candidate getCandidate() {
|
||||
|
||||
double[] values = null;
|
||||
Object value = null;
|
||||
Exception e = null;
|
||||
|
||||
try {
|
||||
values = selectionOperator.buildNextGenes();
|
||||
value = parameterSpace.getValue(values);
|
||||
} catch (GeneticGenerationException e2) {
|
||||
log.warn("Error generating candidate", e2);
|
||||
e = e2;
|
||||
hasMoreCandidates = false;
|
||||
} catch (Exception e2) {
|
||||
log.warn("Error getting configuration for candidate", e2);
|
||||
e = e2;
|
||||
}
|
||||
|
||||
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<?> getCandidateType() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "GeneticSearchCandidateGenerator";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reportResults(OptimizationResult result) {
|
||||
if (result.getScore() == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
Chromosome newChromosome = chromosomeFactory.createChromosome(result.getCandidate().getFlatParameters(),
|
||||
result.getScore());
|
||||
populationModel.add(newChromosome);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,232 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.math3.random.RandomAdaptor;
|
||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
||||
import org.deeplearning4j.arbiter.util.LeafUtils;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||
|
||||
|
||||
/**
|
||||
* GridSearchCandidateGenerator: generates candidates in an exhaustive grid search manner.<br>
|
||||
* Note that:<br>
|
||||
* - For discrete parameters: the grid size (# values to check per hyperparameter) is equal to the number of values for
|
||||
* that hyperparameter<br>
|
||||
* - For integer parameters: the grid size is equal to {@code min(discretizationCount,max-min+1)}. Some integer ranges can
|
||||
* be large, and we don't necessarily want to exhaustively search them. {@code discretizationCount} is a constructor argument<br>
|
||||
* - For continuous parameters: the grid size is equal to {@code discretizationCount}.<br>
|
||||
* In all cases, the minimum, maximum and gridSize-2 values between the min/max will be generated.<br>
|
||||
* Also note that: if a probability distribution is provided for continuous hyperparameters, this will be taken into account
|
||||
* when generating candidates. This allows the grid for a hyperparameter to be non-linear: i.e., for example, linear in log space
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Slf4j
|
||||
@EqualsAndHashCode(exclude = {"order"}, callSuper = true)
|
||||
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
|
||||
public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||
|
||||
/**
|
||||
* In what order should candidates be generated?<br>
|
||||
* <b>Sequential</b>: generate candidates in order. The first hyperparameter will be changed most rapidly, and the last
|
||||
* will be changed least rapidly.<br>
|
||||
* <b>RandomOrder</b>: generate candidates in a random order<br>
|
||||
* In both cases, the same candidates will be generated; only the order of generation is different
|
||||
*/
|
||||
public enum Mode {
|
||||
Sequential, RandomOrder
|
||||
}
|
||||
|
||||
private final int discretizationCount;
|
||||
private final Mode mode;
|
||||
|
||||
private int[] numValuesPerParam;
|
||||
@Getter
|
||||
private int totalNumCandidates;
|
||||
private Queue<Integer> order;
|
||||
|
||||
/**
|
||||
* @param parameterSpace ParameterSpace from which to generate candidates
|
||||
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
|
||||
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
|
||||
* do [0.0, 0.5, 1.0]. Note that if all values
|
||||
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
|
||||
* in which candidates should be generated.
|
||||
*/
|
||||
public GridSearchCandidateGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
|
||||
@JsonProperty("discretizationCount") int discretizationCount, @JsonProperty("mode") Mode mode,
|
||||
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
|
||||
@JsonProperty("initDone") boolean initDone) {
|
||||
super(parameterSpace, dataParameters, initDone);
|
||||
this.discretizationCount = discretizationCount;
|
||||
this.mode = mode;
|
||||
initialize();
|
||||
}
|
||||
|
||||
/**
|
||||
* @param parameterSpace ParameterSpace from which to generate candidates
|
||||
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
|
||||
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
|
||||
* do [0.0, 0.5, 1.0]. Note that if all values
|
||||
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
|
||||
* in which candidates should be generated.
|
||||
*/
|
||||
public GridSearchCandidateGenerator(ParameterSpace<?> parameterSpace, int discretizationCount, Mode mode,
|
||||
Map<String, Object> dataParameters){
|
||||
this(parameterSpace, discretizationCount, mode, dataParameters, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void initialize() {
|
||||
super.initialize();
|
||||
|
||||
List<ParameterSpace> leaves = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
|
||||
int nParams = leaves.size();
|
||||
|
||||
//Work out for each parameter: is it continuous or discrete?
|
||||
// for grid search: discrete values are grid-searchable as-is
|
||||
// continuous values: discretize using 'discretizationCount' bins
|
||||
// integer values: use min(max-min+1, discretizationCount) values. i.e., discretize if necessary
|
||||
numValuesPerParam = new int[nParams];
|
||||
long searchSize = 1;
|
||||
for (int i = 0; i < nParams; i++) {
|
||||
ParameterSpace ps = leaves.get(i);
|
||||
if (ps instanceof DiscreteParameterSpace) {
|
||||
DiscreteParameterSpace dps = (DiscreteParameterSpace) ps;
|
||||
numValuesPerParam[i] = dps.numValues();
|
||||
} else if (ps instanceof IntegerParameterSpace) {
|
||||
IntegerParameterSpace ips = (IntegerParameterSpace) ps;
|
||||
int min = ips.getMin();
|
||||
int max = ips.getMax();
|
||||
//Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000)
|
||||
numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount);
|
||||
} else if (ps instanceof FixedValue){
|
||||
numValuesPerParam[i] = 1;
|
||||
} else {
|
||||
numValuesPerParam[i] = discretizationCount;
|
||||
}
|
||||
searchSize *= numValuesPerParam[i];
|
||||
}
|
||||
|
||||
if (searchSize >= Integer.MAX_VALUE)
|
||||
throw new IllegalStateException("Invalid search: cannot process search with " + searchSize
|
||||
+ " candidates > Integer.MAX_VALUE"); //TODO find a more reasonable upper bound?
|
||||
|
||||
order = new ConcurrentLinkedQueue<>();
|
||||
|
||||
totalNumCandidates = (int) searchSize;
|
||||
switch (mode) {
|
||||
case Sequential:
|
||||
for (int i = 0; i < totalNumCandidates; i++) {
|
||||
order.add(i);
|
||||
}
|
||||
break;
|
||||
case RandomOrder:
|
||||
List<Integer> tempList = new ArrayList<>(totalNumCandidates);
|
||||
for (int i = 0; i < totalNumCandidates; i++) {
|
||||
tempList.add(i);
|
||||
}
|
||||
|
||||
Collections.shuffle(tempList, new RandomAdaptor(rng));
|
||||
order.addAll(tempList);
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasMoreCandidates() {
|
||||
return !order.isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Candidate getCandidate() {
|
||||
int next = order.remove();
|
||||
|
||||
//Next: max integer (candidate number) to values
|
||||
double[] values = indexToValues(numValuesPerParam, next, totalNumCandidates);
|
||||
|
||||
Object value = null;
|
||||
Exception e = null;
|
||||
try {
|
||||
value = parameterSpace.getValue(values);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Error getting configuration for candidate", e2);
|
||||
e = e2;
|
||||
}
|
||||
|
||||
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<?> getCandidateType() {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static double[] indexToValues(int[] numValuesPerParam, int candidateIdx, int product) {
|
||||
//How? first map to index of num possible values. Then: to double values in range 0 to 1
|
||||
// 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc
|
||||
//Based on: Nd4j Shape.ind2sub
|
||||
|
||||
int countNon1 = 0;
|
||||
for( int i : numValuesPerParam)
|
||||
if(i > 1)
|
||||
countNon1++;
|
||||
|
||||
int denom = product;
|
||||
int num = candidateIdx;
|
||||
int[] index = new int[numValuesPerParam.length];
|
||||
|
||||
for (int i = index.length - 1; i >= 0; i--) {
|
||||
denom /= numValuesPerParam[i];
|
||||
index[i] = num / denom;
|
||||
num %= denom;
|
||||
}
|
||||
|
||||
//Now: convert indexes to values in range [0,1]
|
||||
//min value -> 0
|
||||
//max value -> 1
|
||||
double[] out = new double[countNon1];
|
||||
int outIdx = 0;
|
||||
for (int i = 0; i < numValuesPerParam.length; i++) {
|
||||
if (numValuesPerParam[i] > 1){
|
||||
out[outIdx++] = index[i] / ((double) (numValuesPerParam[i] - 1));
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "GridSearchCandidateGenerator(mode=" + mode + ")";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* RandomSearchGenerator: generates candidates at random.<br>
|
||||
* Note: if a probability distribution is provided for continuous hyperparameters,
|
||||
* this will be taken into account
|
||||
* when generating candidates. This allows the search to be weighted more towards
|
||||
* certain values according to a probability
|
||||
* density. For example: generate samples for learning rate according to log uniform distribution
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Slf4j
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
|
||||
public class RandomSearchGenerator extends BaseCandidateGenerator {
|
||||
|
||||
@JsonCreator
|
||||
public RandomSearchGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
|
||||
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
|
||||
@JsonProperty("initDone") boolean initDone) {
|
||||
super(parameterSpace, dataParameters, initDone);
|
||||
initialize();
|
||||
}
|
||||
|
||||
public RandomSearchGenerator(ParameterSpace<?> parameterSpace, Map<String,Object> dataParameters){
|
||||
this(parameterSpace, dataParameters, false);
|
||||
}
|
||||
|
||||
public RandomSearchGenerator(ParameterSpace<?> parameterSpace){
|
||||
this(parameterSpace, null, false);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean hasMoreCandidates() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Candidate getCandidate() {
|
||||
double[] randomValues = new double[parameterSpace.numParameters()];
|
||||
for (int i = 0; i < randomValues.length; i++)
|
||||
randomValues[i] = rng.nextDouble();
|
||||
|
||||
Object value = null;
|
||||
Exception e = null;
|
||||
try {
|
||||
value = parameterSpace.getValue(randomValues);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Error getting configuration for candidate", e2);
|
||||
e = e2;
|
||||
}
|
||||
|
||||
return new Candidate(value, candidateCounter.getAndIncrement(), randomValues, dataParameters, e);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<?> getCandidateType() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "RandomSearchGenerator";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* Candidates are stored as Chromosome in the population model
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
@Data
|
||||
public class Chromosome {
|
||||
/**
|
||||
* The fitness score of the genes.
|
||||
*/
|
||||
protected final double fitness;
|
||||
|
||||
/**
|
||||
* The genes.
|
||||
*/
|
||||
protected final double[] genes;
|
||||
|
||||
public Chromosome(double[] genes, double fitness) {
|
||||
this.genes = genes;
|
||||
this.fitness = fitness;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic;
|
||||
|
||||
/**
|
||||
* A factory that builds new chromosomes. Used by the GeneticSearchCandidateGenerator.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class ChromosomeFactory {
|
||||
private int chromosomeLength;
|
||||
|
||||
/**
|
||||
* Called by the GeneticSearchCandidateGenerator.
|
||||
*/
|
||||
public void initializeInstance(int chromosomeLength) {
|
||||
this.chromosomeLength = chromosomeLength;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new instance of a Chromosome
|
||||
*
|
||||
* @param genes The genes
|
||||
* @param fitness The fitness score
|
||||
* @return A new instance of Chromosome
|
||||
*/
|
||||
public Chromosome createChromosome(double[] genes, double fitness) {
|
||||
return new Chromosome(genes, fitness);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The number of genes in a chromosome
|
||||
*/
|
||||
public int getChromosomeLength() {
|
||||
return chromosomeLength;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
/**
|
||||
* A crossover operator that linearly combines the genes of two parents. <br>
|
||||
* When a crossover is generated (with a of probability <i>crossover rate</i>), each genes is a linear combination of the corresponding genes of the parents.
|
||||
* <p>
|
||||
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i>
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class ArithmeticCrossover extends TwoParentsCrossoverOperator {
|
||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||
|
||||
private final double crossoverRate;
|
||||
private final RandomGenerator rng;
|
||||
|
||||
public static class Builder {
|
||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||
private RandomGenerator rng;
|
||||
private TwoParentSelection parentSelection;
|
||||
|
||||
/**
|
||||
* The probability that the operator generates a crossover (default 0.85).
|
||||
*
|
||||
* @param rate A value between 0.0 and 1.0
|
||||
*/
|
||||
public Builder crossoverRate(double rate) {
|
||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||
|
||||
this.crossoverRate = rate;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The parent selection behavior. Default is random parent selection.
|
||||
*
|
||||
* @param parentSelection An instance of TwoParentSelection
|
||||
*/
|
||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||
this.parentSelection = parentSelection;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ArithmeticCrossover build() {
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
|
||||
if (parentSelection == null) {
|
||||
parentSelection = new RandomTwoParentSelection();
|
||||
}
|
||||
|
||||
return new ArithmeticCrossover(this);
|
||||
}
|
||||
}
|
||||
|
||||
private ArithmeticCrossover(ArithmeticCrossover.Builder builder) {
|
||||
super(builder.parentSelection);
|
||||
|
||||
this.crossoverRate = builder.crossoverRate;
|
||||
this.rng = builder.rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Has a probability <i>crossoverRate</i> of performing the crossover where each gene is a linear combination of:<br>
|
||||
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i><br>
|
||||
* Otherwise, returns the genes of a random parent.
|
||||
*
|
||||
* @return The crossover result. See {@link CrossoverResult}.
|
||||
*/
|
||||
@Override
|
||||
public CrossoverResult crossover() {
|
||||
double[][] parents = parentSelection.selectParents();
|
||||
|
||||
double[] offspringValues = new double[parents[0].length];
|
||||
|
||||
if (rng.nextDouble() < crossoverRate) {
|
||||
for (int i = 0; i < offspringValues.length; ++i) {
|
||||
double t = rng.nextDouble();
|
||||
offspringValues[i] = t * parents[0][i] + (1.0 - t) * parents[1][i];
|
||||
}
|
||||
return new CrossoverResult(true, offspringValues);
|
||||
}
|
||||
|
||||
return new CrossoverResult(false, parents[0]);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
||||
/**
|
||||
* Abstract class for all crossover operators
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class CrossoverOperator {
|
||||
protected PopulationModel populationModel;
|
||||
|
||||
/**
|
||||
* Will be called by the selection operator once the population model is instantiated.
|
||||
*/
|
||||
public void initializeInstance(PopulationModel populationModel) {
|
||||
this.populationModel = populationModel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the crossover
|
||||
*
|
||||
* @return The crossover result. See {@link CrossoverResult}.
|
||||
*/
|
||||
public abstract CrossoverResult crossover();
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* Returned by a crossover operator
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
@Data
|
||||
public class CrossoverResult {
|
||||
/**
|
||||
* If false, there was no crossover and the operator simply returned the genes of a random parent.
|
||||
* If true, the genes are the result of a crossover.
|
||||
*/
|
||||
private final boolean isModified;
|
||||
|
||||
/**
|
||||
* The genes returned by the operator.
|
||||
*/
|
||||
private final double[] genes;
|
||||
|
||||
public CrossoverResult(boolean isModified, double[] genes) {
|
||||
this.isModified = isModified;
|
||||
this.genes = genes;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,178 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
import java.util.Deque;
|
||||
|
||||
/**
|
||||
* The K-Point crossover will select at random multiple crossover points.<br>
|
||||
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched.
|
||||
*/
|
||||
public class KPointCrossover extends TwoParentsCrossoverOperator {
|
||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||
private static final int DEFAULT_MIN_CROSSOVER = 1;
|
||||
private static final int DEFAULT_MAX_CROSSOVER = 4;
|
||||
|
||||
private final double crossoverRate;
|
||||
private final int minCrossovers;
|
||||
private final int maxCrossovers;
|
||||
|
||||
private final RandomGenerator rng;
|
||||
|
||||
public static class Builder {
|
||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||
private int minCrossovers = DEFAULT_MIN_CROSSOVER;
|
||||
private int maxCrossovers = DEFAULT_MAX_CROSSOVER;
|
||||
private RandomGenerator rng;
|
||||
private TwoParentSelection parentSelection;
|
||||
|
||||
/**
|
||||
* The probability that the operator generates a crossover (default 0.85).
|
||||
*
|
||||
* @param rate A value between 0.0 and 1.0
|
||||
*/
|
||||
public Builder crossoverRate(double rate) {
|
||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||
|
||||
this.crossoverRate = rate;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of crossovers points (default is min 1, max 4)
|
||||
*
|
||||
* @param min The minimum number
|
||||
* @param max The maximum number
|
||||
*/
|
||||
public Builder numCrossovers(int min, int max) {
|
||||
Preconditions.checkState(max >= 0 && min >= 0, "Min and max must be positive");
|
||||
Preconditions.checkState(max >= min, "Max must be greater or equal to min");
|
||||
|
||||
this.minCrossovers = min;
|
||||
this.maxCrossovers = max;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a fixed number of crossover points
|
||||
*
|
||||
* @param num The number of crossovers
|
||||
*/
|
||||
public Builder numCrossovers(int num) {
|
||||
Preconditions.checkState(num >= 0, "Num must be positive");
|
||||
|
||||
this.minCrossovers = num;
|
||||
this.maxCrossovers = num;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The parent selection behavior. Default is random parent selection.
|
||||
*
|
||||
* @param parentSelection An instance of TwoParentSelection
|
||||
*/
|
||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||
this.parentSelection = parentSelection;
|
||||
return this;
|
||||
}
|
||||
|
||||
public KPointCrossover build() {
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
|
||||
if (parentSelection == null) {
|
||||
parentSelection = new RandomTwoParentSelection();
|
||||
}
|
||||
|
||||
return new KPointCrossover(this);
|
||||
}
|
||||
}
|
||||
|
||||
private CrossoverPointsGenerator crossoverPointsGenerator;
|
||||
|
||||
private KPointCrossover(KPointCrossover.Builder builder) {
|
||||
super(builder.parentSelection);
|
||||
|
||||
this.crossoverRate = builder.crossoverRate;
|
||||
this.maxCrossovers = builder.maxCrossovers;
|
||||
this.minCrossovers = builder.minCrossovers;
|
||||
this.rng = builder.rng;
|
||||
}
|
||||
|
||||
private CrossoverPointsGenerator getCrossoverPointsGenerator(int chromosomeLength) {
|
||||
if (crossoverPointsGenerator == null) {
|
||||
crossoverPointsGenerator =
|
||||
new CrossoverPointsGenerator(chromosomeLength, minCrossovers, maxCrossovers, rng);
|
||||
}
|
||||
|
||||
return crossoverPointsGenerator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select at random multiple crossover points.<br>
|
||||
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched. <br>
|
||||
* Otherwise, returns the genes of a random parent.
|
||||
*
|
||||
* @return The crossover result. See {@link CrossoverResult}.
|
||||
*/
|
||||
@Override
|
||||
public CrossoverResult crossover() {
|
||||
double[][] parents = parentSelection.selectParents();
|
||||
|
||||
boolean isModified = false;
|
||||
double[] resultGenes = parents[0];
|
||||
|
||||
if (rng.nextDouble() < crossoverRate) {
|
||||
// Select crossover points
|
||||
Deque<Integer> crossoverPoints = getCrossoverPointsGenerator(parents[0].length).getCrossoverPoints();
|
||||
|
||||
// Crossover
|
||||
resultGenes = new double[parents[0].length];
|
||||
int currentParent = 0;
|
||||
int nextCrossover = crossoverPoints.pop();
|
||||
for (int i = 0; i < resultGenes.length; ++i) {
|
||||
if (i == nextCrossover) {
|
||||
currentParent = currentParent == 0 ? 1 : 0;
|
||||
nextCrossover = crossoverPoints.pop();
|
||||
}
|
||||
resultGenes[i] = parents[currentParent][i];
|
||||
}
|
||||
isModified = true;
|
||||
}
|
||||
|
||||
return new CrossoverResult(isModified, resultGenes);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
/**
|
||||
* The single point crossover will select a random point where every genes before that point comes from one parent
|
||||
* and after which every genes comes from the other parent.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class SinglePointCrossover extends TwoParentsCrossoverOperator {
|
||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||
|
||||
private final RandomGenerator rng;
|
||||
private final double crossoverRate;
|
||||
|
||||
public static class Builder {
|
||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||
private RandomGenerator rng;
|
||||
private TwoParentSelection parentSelection;
|
||||
|
||||
/**
|
||||
* The probability that the operator generates a crossover (default 0.85).
|
||||
*
|
||||
* @param rate A value between 0.0 and 1.0
|
||||
*/
|
||||
public Builder crossoverRate(double rate) {
|
||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||
|
||||
this.crossoverRate = rate;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The parent selection behavior. Default is random parent selection.
|
||||
*
|
||||
* @param parentSelection An instance of TwoParentSelection
|
||||
*/
|
||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||
this.parentSelection = parentSelection;
|
||||
return this;
|
||||
}
|
||||
|
||||
public SinglePointCrossover build() {
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
|
||||
if (parentSelection == null) {
|
||||
parentSelection = new RandomTwoParentSelection();
|
||||
}
|
||||
|
||||
return new SinglePointCrossover(this);
|
||||
}
|
||||
}
|
||||
|
||||
private SinglePointCrossover(SinglePointCrossover.Builder builder) {
|
||||
super(builder.parentSelection);
|
||||
|
||||
this.crossoverRate = builder.crossoverRate;
|
||||
this.rng = builder.rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select a random crossover point.<br>
|
||||
* Each gene before this point comes from one of the two parents and each gene at or after this point comes from the other parent.
|
||||
* Otherwise, returns the genes of a random parent.
|
||||
*
|
||||
* @return The crossover result. See {@link CrossoverResult}.
|
||||
*/
|
||||
public CrossoverResult crossover() {
|
||||
double[][] parents = parentSelection.selectParents();
|
||||
|
||||
boolean isModified = false;
|
||||
double[] resultGenes = parents[0];
|
||||
|
||||
if (rng.nextDouble() < crossoverRate) {
|
||||
int chromosomeLength = parents[0].length;
|
||||
|
||||
// Crossover
|
||||
resultGenes = new double[chromosomeLength];
|
||||
|
||||
int crossoverPoint = rng.nextInt(chromosomeLength);
|
||||
for (int i = 0; i < resultGenes.length; ++i) {
|
||||
resultGenes[i] = ((i < crossoverPoint) ? parents[0] : parents[1])[i];
|
||||
}
|
||||
isModified = true;
|
||||
}
|
||||
|
||||
return new CrossoverResult(isModified, resultGenes);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
||||
/**
|
||||
* Abstract class for all crossover operators that applies to two parents.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class TwoParentsCrossoverOperator extends CrossoverOperator {
|
||||
|
||||
protected final TwoParentSelection parentSelection;
|
||||
|
||||
/**
|
||||
* @param parentSelection A parent selection that selects two parents.
|
||||
*/
|
||||
protected TwoParentsCrossoverOperator(TwoParentSelection parentSelection) {
|
||||
this.parentSelection = parentSelection;
|
||||
}
|
||||
|
||||
/**
|
||||
* Will be called by the selection operator once the population model is instantiated.
|
||||
*/
|
||||
@Override
|
||||
public void initializeInstance(PopulationModel populationModel) {
|
||||
super.initializeInstance(populationModel);
|
||||
parentSelection.initializeInstance(populationModel.getPopulation());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
/**
|
||||
* The uniform crossover will, for each gene, randomly select the parent that donates the gene.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class UniformCrossover extends TwoParentsCrossoverOperator {
|
||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||
private static final double DEFAULT_PARENT_BIAS_FACTOR = 0.5;
|
||||
|
||||
private final double crossoverRate;
|
||||
private final double parentBiasFactor;
|
||||
private final RandomGenerator rng;
|
||||
|
||||
public static class Builder {
|
||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||
private double parentBiasFactor = DEFAULT_PARENT_BIAS_FACTOR;
|
||||
private RandomGenerator rng;
|
||||
private TwoParentSelection parentSelection;
|
||||
|
||||
/**
|
||||
* The probability that the operator generates a crossover (default 0.85).
|
||||
*
|
||||
* @param rate A value between 0.0 and 1.0
|
||||
*/
|
||||
public Builder crossoverRate(double rate) {
|
||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||
|
||||
this.crossoverRate = rate;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* A factor that will introduce a bias in the parent selection.<br>
|
||||
*
|
||||
* @param factor In the range [0, 1]. 0 will only select the first parent while 1 only select the second one. The default is 0.5; no bias.
|
||||
*/
|
||||
public Builder parentBiasFactor(double factor) {
|
||||
Preconditions.checkState(factor >= 0.0 && factor <= 1.0, "Factor must be between 0.0 and 1.0, got %s",
|
||||
factor);
|
||||
|
||||
this.parentBiasFactor = factor;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The parent selection behavior. Default is random parent selection.
|
||||
*
|
||||
* @param parentSelection An instance of TwoParentSelection
|
||||
*/
|
||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||
this.parentSelection = parentSelection;
|
||||
return this;
|
||||
}
|
||||
|
||||
public UniformCrossover build() {
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
if (parentSelection == null) {
|
||||
parentSelection = new RandomTwoParentSelection();
|
||||
}
|
||||
return new UniformCrossover(this);
|
||||
}
|
||||
}
|
||||
|
||||
private UniformCrossover(UniformCrossover.Builder builder) {
|
||||
super(builder.parentSelection);
|
||||
|
||||
this.crossoverRate = builder.crossoverRate;
|
||||
this.parentBiasFactor = builder.parentBiasFactor;
|
||||
this.rng = builder.rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select randomly which parent donates the gene.<br>
|
||||
* One of the parent may be favored if the bias is different than 0.5
|
||||
* Otherwise, returns the genes of a random parent.
|
||||
*
|
||||
* @return The crossover result. See {@link CrossoverResult}.
|
||||
*/
|
||||
@Override
|
||||
public CrossoverResult crossover() {
|
||||
// select the parents
|
||||
double[][] parents = parentSelection.selectParents();
|
||||
|
||||
double[] resultGenes = parents[0];
|
||||
boolean isModified = false;
|
||||
|
||||
if (rng.nextDouble() < crossoverRate) {
|
||||
// Crossover
|
||||
resultGenes = new double[parents[0].length];
|
||||
|
||||
for (int i = 0; i < resultGenes.length; ++i) {
|
||||
resultGenes[i] = ((rng.nextDouble() < parentBiasFactor) ? parents[0] : parents[1])[i];
|
||||
}
|
||||
isModified = true;
|
||||
}
|
||||
|
||||
return new CrossoverResult(isModified, resultGenes);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Abstract class for all parent selection behaviors
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class ParentSelection {
|
||||
protected List<Chromosome> population;
|
||||
|
||||
/**
|
||||
* Will be called by the crossover operator once the population model is instantiated.
|
||||
*/
|
||||
public void initializeInstance(List<Chromosome> population) {
|
||||
this.population = population;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the parent selection
|
||||
*
|
||||
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
|
||||
*/
|
||||
public abstract double[][] selectParents();
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
|
||||
/**
|
||||
* A parent selection behavior that returns two random parents.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class RandomTwoParentSelection extends TwoParentSelection {
|
||||
|
||||
private final RandomGenerator rng;
|
||||
|
||||
public RandomTwoParentSelection() {
|
||||
this(new SynchronizedRandomGenerator(new JDKRandomGenerator()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public RandomTwoParentSelection(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects two random parents
|
||||
*
|
||||
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
|
||||
*/
|
||||
@Override
|
||||
public double[][] selectParents() {
|
||||
double[][] parents = new double[2][];
|
||||
|
||||
int parent1Idx = rng.nextInt(population.size());
|
||||
int parent2Idx;
|
||||
do {
|
||||
parent2Idx = rng.nextInt(population.size());
|
||||
} while (parent1Idx == parent2Idx);
|
||||
|
||||
parents[0] = population.get(parent1Idx).getGenes();
|
||||
parents[1] = population.get(parent2Idx).getGenes();
|
||||
|
||||
return parents;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
||||
|
||||
/**
|
||||
* Abstract class for all parent selection behaviors that selects two parents.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class TwoParentSelection extends ParentSelection {
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* A helper class used by {@link KPointCrossover} to generate the crossover points
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class CrossoverPointsGenerator {
|
||||
private final int minCrossovers;
|
||||
private final int maxCrossovers;
|
||||
private final RandomGenerator rng;
|
||||
private List<Integer> parameterIndexes;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param chromosomeLength The number of genes
|
||||
* @param minCrossovers The minimum number of crossover points to generate
|
||||
* @param maxCrossovers The maximum number of crossover points to generate
|
||||
* @param rng A RandomGenerator instance
|
||||
*/
|
||||
public CrossoverPointsGenerator(int chromosomeLength, int minCrossovers, int maxCrossovers, RandomGenerator rng) {
|
||||
this.minCrossovers = minCrossovers;
|
||||
this.maxCrossovers = maxCrossovers;
|
||||
this.rng = rng;
|
||||
parameterIndexes = new ArrayList<Integer>();
|
||||
for (int i = 0; i < chromosomeLength; ++i) {
|
||||
parameterIndexes.add(i);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a list of crossover points.
|
||||
*
|
||||
* @return An ordered list of crossover point indexes and with Integer.MAX_VALUE as the last element
|
||||
*/
|
||||
public Deque<Integer> getCrossoverPoints() {
|
||||
Collections.shuffle(parameterIndexes);
|
||||
List<Integer> crossoverPointLists =
|
||||
parameterIndexes.subList(0, rng.nextInt(maxCrossovers - minCrossovers) + minCrossovers);
|
||||
Collections.sort(crossoverPointLists);
|
||||
Deque<Integer> crossoverPoints = new ArrayDeque<Integer>(crossoverPointLists);
|
||||
crossoverPoints.add(Integer.MAX_VALUE);
|
||||
|
||||
return crossoverPoints;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
||||
/**
|
||||
* The cull operator will remove from the population the least desirables chromosomes.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface CullOperator {
|
||||
/**
|
||||
* Will be called by the population model once created.
|
||||
*/
|
||||
void initializeInstance(PopulationModel populationModel);
|
||||
|
||||
/**
|
||||
* Cull the population to the culled size.
|
||||
*/
|
||||
void cullPopulation();
|
||||
|
||||
/**
|
||||
* @return The target population size after culling.
|
||||
*/
|
||||
int getCulledSize();
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
||||
|
||||
/**
|
||||
* An elitist cull operator that discards the chromosomes with the worst fitness while keeping the best ones.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class LeastFitCullOperator extends RatioCullOperator {
|
||||
|
||||
/**
|
||||
* The default cull ratio is 1/3.
|
||||
*/
|
||||
public LeastFitCullOperator() {
|
||||
super();
|
||||
}
|
||||
|
||||
/**
|
||||
* @param cullRatio The ratio of the maximum population size to be culled.<br>
|
||||
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
|
||||
*/
|
||||
public LeastFitCullOperator(double cullRatio) {
|
||||
super(cullRatio);
|
||||
}
|
||||
|
||||
/**
|
||||
* Will discard the chromosomes with the worst fitness until the population size fall back at the culled size.
|
||||
*/
|
||||
@Override
|
||||
public void cullPopulation() {
|
||||
while (population.size() > culledSize) {
|
||||
population.remove(population.size() - 1);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* An abstract base for cull operators that culls back the population to a ratio of its maximum size.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class RatioCullOperator implements CullOperator {
|
||||
private static final double DEFAULT_CULL_RATIO = 1.0 / 3.0;
|
||||
protected int culledSize;
|
||||
protected List<Chromosome> population;
|
||||
protected final double cullRatio;
|
||||
|
||||
/**
|
||||
* @param cullRatio The ratio of the maximum population size to be culled.<br>
|
||||
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
|
||||
*/
|
||||
public RatioCullOperator(double cullRatio) {
|
||||
Preconditions.checkState(cullRatio >= 0.0 && cullRatio <= 1.0, "Cull ratio must be between 0.0 and 1.0, got %s",
|
||||
cullRatio);
|
||||
|
||||
this.cullRatio = cullRatio;
|
||||
}
|
||||
|
||||
/**
|
||||
* The default cull ratio is 1/3
|
||||
*/
|
||||
public RatioCullOperator() {
|
||||
this(DEFAULT_CULL_RATIO);
|
||||
}
|
||||
|
||||
/**
|
||||
* Will be called by the population model once created.
|
||||
*/
|
||||
public void initializeInstance(PopulationModel populationModel) {
|
||||
this.population = populationModel.getPopulation();
|
||||
culledSize = (int) (populationModel.getPopulationSize() * (1.0 - cullRatio) + 0.5);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The target population size after culling.
|
||||
*/
|
||||
@Override
|
||||
public int getCulledSize() {
|
||||
return culledSize;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions;
|
||||
|
||||
public class GeneticGenerationException extends RuntimeException {
|
||||
public GeneticGenerationException(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
|
||||
|
||||
/**
|
||||
* The mutation operator will apply a mutation to the given genes.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface MutationOperator {
|
||||
|
||||
/**
|
||||
* Performs a mutation.
|
||||
*
|
||||
* @param genes The genes to be mutated
|
||||
* @return True if the genes were mutated, otherwise false.
|
||||
*/
|
||||
boolean mutate(double[] genes);
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
/**
|
||||
* A mutation operator where each gene has a chance of being mutated with a <i>mutation rate</i> probability.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class RandomMutationOperator implements MutationOperator {
|
||||
private static final double DEFAULT_MUTATION_RATE = 0.005;
|
||||
|
||||
private final double mutationRate;
|
||||
private final RandomGenerator rng;
|
||||
|
||||
public static class Builder {
|
||||
private double mutationRate = DEFAULT_MUTATION_RATE;
|
||||
private RandomGenerator rng;
|
||||
|
||||
/**
|
||||
* Each gene will have this probability of being mutated.
|
||||
*
|
||||
* @param rate The mutation rate. (default 0.005)
|
||||
*/
|
||||
public Builder mutationRate(double rate) {
|
||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||
|
||||
this.mutationRate = rate;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
public RandomMutationOperator build() {
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
return new RandomMutationOperator(this);
|
||||
}
|
||||
}
|
||||
|
||||
private RandomMutationOperator(RandomMutationOperator.Builder builder) {
|
||||
this.mutationRate = builder.mutationRate;
|
||||
this.rng = builder.rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the mutation. Each gene has a <i>mutation rate</i> probability of being mutated.
|
||||
*
|
||||
* @param genes The genes to be mutated
|
||||
* @return True if the genes were mutated, otherwise false.
|
||||
*/
|
||||
@Override
|
||||
public boolean mutate(double[] genes) {
|
||||
boolean hasMutation = false;
|
||||
|
||||
for (int i = 0; i < genes.length; ++i) {
|
||||
if (rng.nextDouble() < mutationRate) {
|
||||
genes[i] = rng.nextDouble();
|
||||
hasMutation = true;
|
||||
}
|
||||
}
|
||||
|
||||
return hasMutation;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A population initializer that build an empty population.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class EmptyPopulationInitializer implements PopulationInitializer {
|
||||
|
||||
/**
|
||||
* Initialize an empty population
|
||||
*
|
||||
* @param size The maximum size of the population.
|
||||
* @return The initialized population.
|
||||
*/
|
||||
@Override
|
||||
public List<Chromosome> getInitializedPopulation(int size) {
|
||||
return new ArrayList<>(size);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* An initializer that construct the population used by the population model.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface PopulationInitializer {
|
||||
/**
|
||||
* Called by the population model to construct the population
|
||||
*
|
||||
* @param size The maximum size of the population
|
||||
* @return An initialized population
|
||||
*/
|
||||
List<Chromosome> getInitializedPopulation(int size);
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A listener that is called when the population changes.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface PopulationListener {
|
||||
/**
|
||||
* Called after the population has changed.
|
||||
*
|
||||
* @param population The population after it has changed.
|
||||
*/
|
||||
void onChanged(List<Chromosome> population);
|
||||
}
|
|
@ -0,0 +1,182 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* The population model handles all aspects of the population (initialization, additions and culling)
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class PopulationModel {
|
||||
private static final int DEFAULT_POPULATION_SIZE = 30;
|
||||
|
||||
private final CullOperator cullOperator;
|
||||
private final List<PopulationListener> populationListeners = new ArrayList<>();
|
||||
private Comparator<Chromosome> chromosomeComparator;
|
||||
|
||||
/**
|
||||
* The maximum population size
|
||||
*/
|
||||
@Getter
|
||||
private final int populationSize;
|
||||
|
||||
/**
|
||||
* The population
|
||||
*/
|
||||
@Getter
|
||||
public final List<Chromosome> population;
|
||||
|
||||
/**
|
||||
* A comparator used when higher fitness value is better
|
||||
*/
|
||||
public static class MaximizeScoreComparator implements Comparator<Chromosome> {
|
||||
@Override
|
||||
public int compare(Chromosome lhs, Chromosome rhs) {
|
||||
return -Double.compare(lhs.getFitness(), rhs.getFitness());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A comparator used when lower fitness value is better
|
||||
*/
|
||||
public static class MinimizeScoreComparator implements Comparator<Chromosome> {
|
||||
@Override
|
||||
public int compare(Chromosome lhs, Chromosome rhs) {
|
||||
return Double.compare(lhs.getFitness(), rhs.getFitness());
|
||||
}
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private int populationSize = DEFAULT_POPULATION_SIZE;
|
||||
private PopulationInitializer populationInitializer;
|
||||
private CullOperator cullOperator;
|
||||
|
||||
/**
|
||||
* Use an alternate population initialization behavior. Default is empty population.
|
||||
*
|
||||
* @param populationInitializer An instance of PopulationInitializer
|
||||
*/
|
||||
public Builder populationInitializer(PopulationInitializer populationInitializer) {
|
||||
this.populationInitializer = populationInitializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The maximum population size. <br>
|
||||
* If using a ratio based culling, using a population with culled size of around 1.5 to 2 times the number of genes generally gives good results.
|
||||
* (e.g. For a chromosome having 10 genes, the culled size should be between 15 and 20. And with a cull ratio of 1/3 we should set the population size to 23 to 30. (15 / (1 - 1/3)), rounded up)
|
||||
*
|
||||
* @param size The maximum size of the population
|
||||
*/
|
||||
public Builder populationSize(int size) {
|
||||
populationSize = size;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use an alternate cull operator behavior. Default is least fit culling.
|
||||
*
|
||||
* @param cullOperator An instance of a CullOperator
|
||||
*/
|
||||
public Builder cullOperator(CullOperator cullOperator) {
|
||||
this.cullOperator = cullOperator;
|
||||
return this;
|
||||
}
|
||||
|
||||
public PopulationModel build() {
|
||||
if (cullOperator == null) {
|
||||
cullOperator = new LeastFitCullOperator();
|
||||
}
|
||||
|
||||
if (populationInitializer == null) {
|
||||
populationInitializer = new EmptyPopulationInitializer();
|
||||
}
|
||||
|
||||
return new PopulationModel(this);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public PopulationModel(PopulationModel.Builder builder) {
|
||||
populationSize = builder.populationSize;
|
||||
population = new ArrayList<>(builder.populationSize);
|
||||
PopulationInitializer populationInitializer = builder.populationInitializer;
|
||||
|
||||
List<Chromosome> initializedPopulation = populationInitializer.getInitializedPopulation(populationSize);
|
||||
population.clear();
|
||||
population.addAll(initializedPopulation);
|
||||
|
||||
cullOperator = builder.cullOperator;
|
||||
cullOperator.initializeInstance(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Called by the GeneticSearchCandidateGenerator
|
||||
*/
|
||||
public void initializeInstance(boolean minimizeScore) {
|
||||
chromosomeComparator = minimizeScore ? new MinimizeScoreComparator() : new MaximizeScoreComparator();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a PopulationListener to the list of change listeners
|
||||
* @param listener A PopulationListener instance
|
||||
*/
|
||||
public void addListener(PopulationListener listener) {
|
||||
populationListeners.add(listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a Chromosome to the population and call the PopulationListeners. Culling may be triggered.
|
||||
*
|
||||
* @param element The chromosome to be added
|
||||
*/
|
||||
public void add(Chromosome element) {
|
||||
if (population.size() == populationSize) {
|
||||
cullOperator.cullPopulation();
|
||||
}
|
||||
|
||||
population.add(element);
|
||||
|
||||
Collections.sort(population, chromosomeComparator);
|
||||
|
||||
triggerPopulationChangedListeners(population);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Return false when the population is below the culled size, otherwise true. <br>
|
||||
* Used by the selection operator to know if the population is still too small and should generate random genes.
|
||||
*/
|
||||
public boolean isReadyToBreed() {
|
||||
return population.size() >= cullOperator.getCulledSize();
|
||||
}
|
||||
|
||||
private void triggerPopulationChangedListeners(List<Chromosome> population) {
|
||||
for (PopulationListener listener : populationListeners) {
|
||||
listener.onChanged(population);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,197 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.MutationOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* A selection operator that will generate random genes initially. Once the population has reached the culled size,
|
||||
* will start to generate offsprings of parents selected in the population.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class GeneticSelectionOperator extends SelectionOperator {
|
||||
|
||||
private final static int PREVIOUS_GENES_TO_KEEP = 100;
|
||||
private final static int MAX_NUM_GENERATION_ATTEMPTS = 1024;
|
||||
|
||||
private final CrossoverOperator crossoverOperator;
|
||||
private final MutationOperator mutationOperator;
|
||||
private final RandomGenerator rng;
|
||||
private double[][] previousGenes = new double[PREVIOUS_GENES_TO_KEEP][];
|
||||
private int previousGenesIdx = 0;
|
||||
|
||||
public static class Builder {
|
||||
private ChromosomeFactory chromosomeFactory;
|
||||
private PopulationModel populationModel;
|
||||
private CrossoverOperator crossoverOperator;
|
||||
private MutationOperator mutationOperator;
|
||||
private RandomGenerator rng;
|
||||
|
||||
/**
|
||||
* Use an alternate crossover behavior. Default is SinglePointCrossover.
|
||||
*
|
||||
* @param crossoverOperator An instance of CrossoverOperator
|
||||
*/
|
||||
public Builder crossoverOperator(CrossoverOperator crossoverOperator) {
|
||||
this.crossoverOperator = crossoverOperator;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use an alternate mutation behavior. Default is RandomMutationOperator.
|
||||
*
|
||||
* @param mutationOperator An instance of MutationOperator
|
||||
*/
|
||||
public Builder mutationOperator(MutationOperator mutationOperator) {
|
||||
this.mutationOperator = mutationOperator;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
public GeneticSelectionOperator build() {
|
||||
if (crossoverOperator == null) {
|
||||
crossoverOperator = new SinglePointCrossover.Builder().build();
|
||||
}
|
||||
|
||||
if (mutationOperator == null) {
|
||||
mutationOperator = new RandomMutationOperator.Builder().build();
|
||||
}
|
||||
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
|
||||
return new GeneticSelectionOperator(crossoverOperator, mutationOperator, rng);
|
||||
}
|
||||
}
|
||||
|
||||
private GeneticSelectionOperator(CrossoverOperator crossoverOperator, MutationOperator mutationOperator,
|
||||
RandomGenerator rng) {
|
||||
this.crossoverOperator = crossoverOperator;
|
||||
this.mutationOperator = mutationOperator;
|
||||
this.rng = rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Called by GeneticSearchCandidateGenerator
|
||||
*/
|
||||
@Override
|
||||
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
|
||||
super.initializeInstance(populationModel, chromosomeFactory);
|
||||
crossoverOperator.initializeInstance(populationModel);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a new set of genes. Has two distinct modes of operation
|
||||
* <ul>
|
||||
* <li>Before the population has reached the culled size: will return a random set of genes.</li>
|
||||
* <li>After: Parents will be selected among the population, a crossover will be applied followed by a mutation.</li>
|
||||
* </ul>
|
||||
* @return Returns the generated set of genes
|
||||
* @throws GeneticGenerationException If buildNextGenes() can't generate a set that has not already been tried,
|
||||
* or if the crossover and the mutation operators can't generate a set,
|
||||
* this exception is thrown.
|
||||
*/
|
||||
@Override
|
||||
public double[] buildNextGenes() {
|
||||
double[] result;
|
||||
|
||||
boolean hasAlreadyBeenTried;
|
||||
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
|
||||
do {
|
||||
if (populationModel.isReadyToBreed()) {
|
||||
result = buildOffspring();
|
||||
} else {
|
||||
result = buildRandomGenes();
|
||||
}
|
||||
|
||||
hasAlreadyBeenTried = hasAlreadyBeenTried(result);
|
||||
if (hasAlreadyBeenTried && --attemptsRemaining == 0) {
|
||||
throw new GeneticGenerationException("Failed to generate a set of genes not already tried.");
|
||||
}
|
||||
} while (hasAlreadyBeenTried);
|
||||
|
||||
previousGenes[previousGenesIdx] = result;
|
||||
previousGenesIdx = ++previousGenesIdx % previousGenes.length;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private boolean hasAlreadyBeenTried(double[] genes) {
|
||||
for (int i = 0; i < previousGenes.length; ++i) {
|
||||
double[] current = previousGenes[i];
|
||||
if (current != null && Arrays.equals(current, genes)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private double[] buildOffspring() {
|
||||
double[] offspringValues;
|
||||
|
||||
boolean isModified;
|
||||
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
|
||||
do {
|
||||
CrossoverResult crossoverResult = crossoverOperator.crossover();
|
||||
offspringValues = crossoverResult.getGenes();
|
||||
isModified = crossoverResult.isModified();
|
||||
isModified |= mutationOperator.mutate(offspringValues);
|
||||
|
||||
if (!isModified && --attemptsRemaining == 0) {
|
||||
throw new GeneticGenerationException(
|
||||
String.format("Crossover and mutation operators failed to generate a new set of genes after %s attempts.",
|
||||
MAX_NUM_GENERATION_ATTEMPTS));
|
||||
}
|
||||
} while (!isModified);
|
||||
|
||||
return offspringValues;
|
||||
}
|
||||
|
||||
private double[] buildRandomGenes() {
|
||||
double[] randomValues = new double[chromosomeFactory.getChromosomeLength()];
|
||||
for (int i = 0; i < randomValues.length; ++i) {
|
||||
randomValues[i] = rng.nextDouble();
|
||||
}
|
||||
|
||||
return randomValues;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
||||
/**
|
||||
* An abstract class for all selection operators. Used by the GeneticSearchCandidateGenerator to generate new candidates.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class SelectionOperator {
|
||||
protected PopulationModel populationModel;
|
||||
protected ChromosomeFactory chromosomeFactory;
|
||||
|
||||
/**
|
||||
* Called by GeneticSearchCandidateGenerator
|
||||
*/
|
||||
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
|
||||
|
||||
this.populationModel = populationModel;
|
||||
this.chromosomeFactory = chromosomeFactory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new set of genes.
|
||||
*/
|
||||
public abstract double[] buildNextGenes();
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.util;
|
||||
|
||||
import org.nd4j.common.function.Supplier;
|
||||
|
||||
import java.io.*;
|
||||
|
||||
public class SerializedSupplier<T> implements Serializable, Supplier<T> {
|
||||
|
||||
private byte[] asBytes;
|
||||
|
||||
public SerializedSupplier(T obj){
|
||||
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
|
||||
oos.writeObject(obj);
|
||||
oos.flush();
|
||||
oos.close();
|
||||
asBytes = baos.toByteArray();
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException("Error serializing object - must be serializable",e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public T get() {
|
||||
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(asBytes))){
|
||||
return (T)ois.readObject();
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException("Error deserializing object",e);
|
||||
}
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue