Update copyrights remove attic and relocate elsewhere
parent
5bd386a4f9
commit
46dbd0b203
|
@ -1,45 +0,0 @@
|
|||
# Arbiter
|
||||
|
||||
A tool dedicated to tuning (hyperparameter optimization) of machine learning models. Part of the DL4J Suite of Machine Learning / Deep Learning tools for the enterprise.
|
||||
|
||||
|
||||
## Modules
|
||||
Arbiter contains the following modules:
|
||||
|
||||
- arbiter-core: Defines the API and core functionality, and also contains functionality for the Arbiter UI
|
||||
- arbiter-deeplearning4j: For hyperparameter optimization of DL4J models (MultiLayerNetwork and ComputationGraph networks)
|
||||
|
||||
|
||||
## Hyperparameter Optimization Functionality
|
||||
|
||||
The open-source version of Arbiter currently defines two methods of hyperparameter optimization:
|
||||
|
||||
- Grid search
|
||||
- Random search
|
||||
|
||||
For optimization of complex models such as neural networks (those with more than a few hyperparameters), random search is superior to grid search, though Bayesian hyperparameter optimization schemes
|
||||
For a comparison of random and grid search methods, see [Random Search for Hyper-parameter Optimization (Bergstra and Bengio, 2012)](http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf).
|
||||
|
||||
### Core Concepts and Classes in Arbiter for Hyperparameter Optimization
|
||||
|
||||
In order to conduct hyperparameter optimization in Arbiter, it is necessary for the user to understand and define the following:
|
||||
|
||||
- **Parameter Space**: A ```ParameterSpace<P>``` specifies the type and allowable values of hyperparameters for a model configuration of type ```P```. For example, ```P``` could be a MultiLayerConfiguration for DL4J
|
||||
- **Candidate Generator**: A ```CandidateGenerator<C>``` is used to generate candidate models configurations of some type ```C```. The following implementations are defined in arbiter-core:
|
||||
- ```RandomSearchCandidateGenerator```
|
||||
- ```GridSearchCandidateGenerator```
|
||||
- **Score Function**: A ```ScoreFunction<M,D>``` is used to score a model of type ```M``` given data of type ```D```. For example, in DL4J a score function might be used to calculate the classification accuracy from a DataSetIterator
|
||||
- A key concept here is that they score is a single numerical (double precision) value that we either want to minimize or maximize - this is the goal of hyperparameter optimization
|
||||
- **Termination Conditions**: One or more ```TerminationCondition``` instances must be provided to the ```OptimizationConfiguration```. ```TerminationCondition``` instances are used to control when hyperparameter optimization should be stopped. Some built-in termination conditions:
|
||||
- ```MaxCandidatesCondition```: Terminate if more than the specified number of candidate hyperparameter configurations have been executed
|
||||
- ```MaxTimeCondition```: Terminate after a specified amount of time has elapsed since starting the optimization
|
||||
- **Result Saver**: The ```ResultSaver<C,M,A>``` interface is used to specify how the results of each hyperparameter optimization run should be saved. For example, whether saving should be done to local disk, to a database, to HDFS, or simply stored in memory.
|
||||
- Note that ```ResultSaver.saveModel``` method returns a ```ResultReference``` object, which provides a mechanism for re-loading both the model and score from wherever it may be saved.
|
||||
- **Optimization Configuration**: An ```OptimizationConfiguration<C,M,D,A>``` ties together the above configuration options in a fluent (builder) pattern.
|
||||
- **Candidate Executor**: The ```CandidateExecutor<C,M,D,A>``` interface provides a layer of abstraction between the configuration and execution of each instance of learning. Currently, the only option is the ```LocalCandidateExecutor```, which is used to execute learning on a single machine (in the current JVM). In principle, other execution methods (for example, on Spark or cloud computing machines) could be implemented.
|
||||
- **Optimization Runner**: The ```OptimizationRunner``` uses an ```OptimizationConfiguration``` and a ```CandidateExecutor``` to actually run the optimization, and save the results.
|
||||
|
||||
|
||||
### Optimization of DeepLearning4J Models
|
||||
|
||||
(This section: forthcoming)
|
|
@ -1,89 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>arbiter</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>arbiter-core</artifactId>
|
||||
|
||||
<name>arbiter-core</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>com.google.code.findbugs</groupId>
|
||||
<artifactId>*</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-math3</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>joda-time</groupId>
|
||||
<artifactId>joda-time</artifactId>
|
||||
</dependency>
|
||||
<!-- ND4J Shaded Jackson Dependency -->
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>jackson</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-codec</groupId>
|
||||
<artifactId>commons-codec</artifactId>
|
||||
<version>${commons-codec.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,95 +0,0 @@
|
|||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<assembly>
|
||||
<id>bin</id>
|
||||
<!-- START SNIPPET: formats -->
|
||||
<formats>
|
||||
<format>tar.gz</format>
|
||||
<!--
|
||||
<format>tar.bz2</format>
|
||||
<format>zip</format>
|
||||
-->
|
||||
</formats>
|
||||
<!-- END SNIPPET: formats -->
|
||||
|
||||
<dependencySets>
|
||||
<dependencySet>
|
||||
<outputDirectory>lib</outputDirectory>
|
||||
<includes>
|
||||
<include>*:jar:*</include>
|
||||
</includes>
|
||||
<excludes>
|
||||
<exclude>*:sources</exclude>
|
||||
</excludes>
|
||||
</dependencySet>
|
||||
</dependencySets>
|
||||
|
||||
<!-- START SNIPPET: fileSets -->
|
||||
<fileSets>
|
||||
<fileSet>
|
||||
<includes>
|
||||
<include>readme.txt</include>
|
||||
</includes>
|
||||
</fileSet>
|
||||
|
||||
<fileSet>
|
||||
<directory>src/main/resources/bin/</directory>
|
||||
<outputDirectory>bin</outputDirectory>
|
||||
<includes>
|
||||
<include>arbiter</include>
|
||||
</includes>
|
||||
<lineEnding>unix</lineEnding>
|
||||
<fileMode>0755</fileMode>
|
||||
</fileSet>
|
||||
|
||||
<fileSet>
|
||||
<directory>examples</directory>
|
||||
<outputDirectory>examples</outputDirectory>
|
||||
<!--
|
||||
<lineEnding>unix</lineEnding>
|
||||
https://stackoverflow.com/questions/2958282/stranges-files-in-my-assembly-since-switching-to-lineendingunix-lineending
|
||||
-->
|
||||
</fileSet>
|
||||
|
||||
|
||||
<!--
|
||||
<fileSet>
|
||||
<directory>src/bin</directory>
|
||||
<outputDirectory>bin</outputDirectory>
|
||||
<includes>
|
||||
<include>hello</include>
|
||||
</includes>
|
||||
<lineEnding>unix</lineEnding>
|
||||
<fileMode>0755</fileMode>
|
||||
</fileSet>
|
||||
-->
|
||||
|
||||
<fileSet>
|
||||
<directory>target</directory>
|
||||
<outputDirectory>./</outputDirectory>
|
||||
<includes>
|
||||
<include>*.jar</include>
|
||||
</includes>
|
||||
</fileSet>
|
||||
|
||||
</fileSets>
|
||||
<!-- END SNIPPET: fileSets -->
|
||||
</assembly>
|
|
@ -1,76 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Created by Alex on 23/07/2017.
|
||||
*/
|
||||
public abstract class AbstractParameterSpace<T> implements ParameterSpace<T> {
|
||||
|
||||
@Override
|
||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||
Map<String, ParameterSpace> m = new LinkedHashMap<>();
|
||||
|
||||
//Need to manually build and walk the class hierarchy...
|
||||
Class<?> currClass = this.getClass();
|
||||
List<Class<?>> classHierarchy = new ArrayList<>();
|
||||
while (currClass != Object.class) {
|
||||
classHierarchy.add(currClass);
|
||||
currClass = currClass.getSuperclass();
|
||||
}
|
||||
|
||||
for (int i = classHierarchy.size() - 1; i >= 0; i--) {
|
||||
//Use reflection here to avoid a mass of boilerplate code...
|
||||
Field[] allFields = classHierarchy.get(i).getDeclaredFields();
|
||||
|
||||
for (Field f : allFields) {
|
||||
|
||||
String name = f.getName();
|
||||
Class<?> fieldClass = f.getType();
|
||||
boolean isParamSpacefield = ParameterSpace.class.isAssignableFrom(fieldClass);
|
||||
|
||||
if (!isParamSpacefield) {
|
||||
continue;
|
||||
}
|
||||
|
||||
f.setAccessible(true);
|
||||
|
||||
ParameterSpace<?> p;
|
||||
try {
|
||||
p = (ParameterSpace<?>) f.get(this);
|
||||
} catch (IllegalAccessException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
if (p != null) {
|
||||
m.put(name, p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.util.SerializedSupplier;
|
||||
import org.nd4j.common.function.Supplier;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Candidate: a proposed hyperparameter configuration.
|
||||
* Also includes a map for data parameters, to configure things like data preprocessing, etc.
|
||||
*/
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class Candidate<C> implements Serializable {
|
||||
|
||||
private Supplier<C> supplier;
|
||||
private int index;
|
||||
private double[] flatParameters;
|
||||
private Map<String, Object> dataParameters;
|
||||
private Exception exception;
|
||||
|
||||
public Candidate(C value, int index, double[] flatParameters, Map<String,Object> dataParameters, Exception e) {
|
||||
this(new SerializedSupplier<C>(value), index, flatParameters, dataParameters, e);
|
||||
}
|
||||
|
||||
public Candidate(C value, int index, double[] flatParameters) {
|
||||
this(new SerializedSupplier<C>(value), index, flatParameters);
|
||||
}
|
||||
|
||||
public Candidate(Supplier<C> value, int index, double[] flatParameters) {
|
||||
this(value, index, flatParameters, null, null);
|
||||
}
|
||||
|
||||
public C getValue(){
|
||||
return supplier.get();
|
||||
}
|
||||
|
||||
}
|
|
@ -1,67 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
/**
|
||||
* A CandidateGenerator proposes candidates (i.e., hyperparameter configurations) for evaluation.
|
||||
* This abstraction allows for different ways of generating the next configuration to test; for example,
|
||||
* random search, grid search, Bayesian optimization methods, etc.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface CandidateGenerator {
|
||||
|
||||
/**
|
||||
* Is this candidate generator able to generate more candidates? This will always return true in some
|
||||
* cases, but some search strategies have a limit (grid search, for example)
|
||||
*/
|
||||
boolean hasMoreCandidates();
|
||||
|
||||
/**
|
||||
* Generate a candidate hyperparameter configuration
|
||||
*/
|
||||
Candidate getCandidate();
|
||||
|
||||
/**
|
||||
* Report results for the candidate generator.
|
||||
*
|
||||
* @param result The results to report
|
||||
*/
|
||||
void reportResults(OptimizationResult result);
|
||||
|
||||
/**
|
||||
* @return Get the parameter space for this candidate generator
|
||||
*/
|
||||
ParameterSpace<?> getParameterSpace();
|
||||
|
||||
/**
|
||||
* @param rngSeed Set the random number generator seed for the candidate generator
|
||||
*/
|
||||
void setRngSeed(long rngSeed);
|
||||
|
||||
/**
|
||||
* @return The type (class) of the generated candidates
|
||||
*/
|
||||
Class<?> getCandidateType();
|
||||
}
|
|
@ -1,62 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import lombok.Data;
|
||||
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* An optimization result represents the results of an optimization run, including the candidate configuration, the
|
||||
* trained model, the score for that model, and index of the model
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Data
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
@JsonIgnoreProperties({"resultReference"})
|
||||
public class OptimizationResult implements Serializable {
|
||||
@JsonProperty
|
||||
private Candidate candidate;
|
||||
@JsonProperty
|
||||
private Double score;
|
||||
@JsonProperty
|
||||
private int index;
|
||||
@JsonProperty
|
||||
private Object modelSpecificResults;
|
||||
@JsonProperty
|
||||
private CandidateInfo candidateInfo;
|
||||
private ResultReference resultReference;
|
||||
|
||||
|
||||
public OptimizationResult(Candidate candidate, Double score, int index, Object modelSpecificResults,
|
||||
CandidateInfo candidateInfo, ResultReference resultReference) {
|
||||
this.candidate = candidate;
|
||||
this.score = score;
|
||||
this.index = index;
|
||||
this.modelSpecificResults = modelSpecificResults;
|
||||
this.candidateInfo = candidateInfo;
|
||||
this.resultReference = resultReference;
|
||||
}
|
||||
}
|
|
@ -1,83 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnore;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* ParameterSpace: defines the acceptable ranges of values a given parameter may take.
|
||||
* Note that parameter spaces can be simple (like {@code ParameterSpace<Double>}) or complicated, including
|
||||
* multiple nested ParameterSpaces
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ParameterSpace<P> {
|
||||
|
||||
/**
|
||||
* Generate a candidate given a set of values. These values are then mapped to a specific candidate, using some
|
||||
* mapping function (such as the prior probability distribution)
|
||||
*
|
||||
* @param parameterValues A set of values, each in the range [0,1], of length {@link #numParameters()}
|
||||
*/
|
||||
P getValue(double[] parameterValues);
|
||||
|
||||
/**
|
||||
* Get the total number of parameters (hyperparameters) to be optimized. This includes optional parameters from
|
||||
* different parameter subpaces. (Thus, not every parameter may be used in every candidate)
|
||||
*
|
||||
* @return Number of hyperparameters to be optimized
|
||||
*/
|
||||
int numParameters();
|
||||
|
||||
/**
|
||||
* Collect a list of parameters, recursively. Note that leaf parameters are parameters that do not have any
|
||||
* nested parameter spaces
|
||||
*/
|
||||
List<ParameterSpace> collectLeaves();
|
||||
|
||||
/**
|
||||
* Get a list of nested parameter spaces by name. Note that the returned parameter spaces may in turn have further
|
||||
* nested parameter spaces. The map should be empty for leaf parameter spaces
|
||||
*
|
||||
* @return A map of nested parameter spaces
|
||||
*/
|
||||
Map<String, ParameterSpace> getNestedSpaces();
|
||||
|
||||
/**
|
||||
* Is this ParameterSpace a leaf? (i.e., does it contain other ParameterSpaces internally?)
|
||||
*/
|
||||
@JsonIgnore
|
||||
boolean isLeaf();
|
||||
|
||||
/**
|
||||
* For leaf ParameterSpaces: set the indices of the leaf ParameterSpace.
|
||||
* Expects input of length {@link #numParameters()}. Throws exception if {@link #isLeaf()} is false.
|
||||
*
|
||||
* @param indices Indices to set. Length should equal {@link #numParameters()}
|
||||
*/
|
||||
void setIndices(int... indices);
|
||||
|
||||
}
|
|
@ -1,64 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
/**
|
||||
* The TaskCreator is used to take a candidate configuration, data provider and score function, and create something
|
||||
* that can be executed as a Callable
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public interface TaskCreator {
|
||||
|
||||
/**
|
||||
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
|
||||
*
|
||||
* @param candidate Candidate (model) configuration to be trained
|
||||
* @param dataProvider DataProvider, for the data
|
||||
* @param scoreFunction Score function to be used to evaluate the model
|
||||
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
|
||||
* @return A callable that returns an OptimizationResult, once optimization is complete
|
||||
*/
|
||||
@Deprecated
|
||||
Callable<OptimizationResult> create(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction,
|
||||
List<StatusListener> statusListeners, IOptimizationRunner runner);
|
||||
|
||||
/**
|
||||
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
|
||||
*
|
||||
* @param candidate Candidate (model) configuration to be trained
|
||||
* @param dataSource Data source
|
||||
* @param dataSourceProperties Properties (may be null) for the data source
|
||||
* @param scoreFunction Score function to be used to evaluate the model
|
||||
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
|
||||
* @return A callable that returns an OptimizationResult, once optimization is complete
|
||||
*/
|
||||
Callable<OptimizationResult> create(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties,
|
||||
ScoreFunction scoreFunction, List<StatusListener> statusListeners, IOptimizationRunner runner);
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class TaskCreatorProvider {
|
||||
|
||||
private static Map<Class<? extends ParameterSpace>, Class<? extends TaskCreator>> map = new HashMap<>();
|
||||
|
||||
public synchronized static TaskCreator defaultTaskCreatorFor(Class<? extends ParameterSpace> paramSpaceClass){
|
||||
Class<? extends TaskCreator> c = map.get(paramSpaceClass);
|
||||
try {
|
||||
if(c == null){
|
||||
return null;
|
||||
}
|
||||
return c.newInstance();
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException("Could not create new instance of task creator class: " + c + " - missing no-arg constructor?", e);
|
||||
}
|
||||
}
|
||||
|
||||
public synchronized static void registerDefaultTaskCreatorClass(Class<? extends ParameterSpace> spaceClass,
|
||||
Class<? extends TaskCreator> creatorClass){
|
||||
map.put(spaceClass, creatorClass);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,84 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.adapter;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* An abstract class used for adapting one type into another. Subclasses of this need to merely implement 2 simple methods
|
||||
*
|
||||
* @param <F> Type to convert from
|
||||
* @param <T> Type to convert to
|
||||
* @author Alex Black
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
|
||||
|
||||
|
||||
protected abstract T convertValue(F from);
|
||||
|
||||
protected abstract ParameterSpace<F> underlying();
|
||||
|
||||
protected abstract String underlyingName();
|
||||
|
||||
|
||||
@Override
|
||||
public T getValue(double[] parameterValues) {
|
||||
return convertValue(underlying().getValue(parameterValues));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numParameters() {
|
||||
return underlying().numParameters();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
ParameterSpace p = underlying();
|
||||
if(p.isLeaf()){
|
||||
return Collections.singletonList(p);
|
||||
}
|
||||
return underlying().collectLeaves();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||
return Collections.singletonMap(underlyingName(), (ParameterSpace)underlying());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return false; //Underlying may be a leaf, however
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIndices(int... indices) {
|
||||
underlying().setIndices(indices);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return underlying().toString();
|
||||
}
|
||||
}
|
|
@ -1,56 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.data;
|
||||
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* DataProvider interface abstracts out the providing of data
|
||||
* @deprecated Use {@link DataSource}
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
@Deprecated
|
||||
public interface DataProvider extends Serializable {
|
||||
|
||||
/**
|
||||
* Get training data given some parameters for the data.
|
||||
* Data parameters map is used to specify things like batch
|
||||
* size data preprocessing
|
||||
*
|
||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||
* @return training data
|
||||
*/
|
||||
Object trainData(Map<String, Object> dataParameters);
|
||||
|
||||
/**
|
||||
* Get training data given some parameters for the data. Data parameters map is used to specify things like batch
|
||||
* size data preprocessing
|
||||
*
|
||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||
* @return training data
|
||||
*/
|
||||
Object testData(Map<String, Object> dataParameters);
|
||||
|
||||
Class<?> getDataType();
|
||||
}
|
|
@ -1,91 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.data;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* This is a {@link DataProvider} for
|
||||
* an {@link DataSetIteratorFactory} which
|
||||
* based on a key of {@link DataSetIteratorFactoryProvider#FACTORY_KEY}
|
||||
* will create {@link org.nd4j.linalg.dataset.api.iterator.DataSetIterator}
|
||||
* for use with arbiter.
|
||||
*
|
||||
* This {@link DataProvider} is mainly meant for use for command line driven
|
||||
* applications.
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@Data
|
||||
public class DataSetIteratorFactoryProvider implements DataProvider {
|
||||
|
||||
public final static String FACTORY_KEY = "org.deeplearning4j.arbiter.data.data.factory";
|
||||
|
||||
/**
|
||||
* Get training data given some parameters for the data.
|
||||
* Data parameters map is used to specify things like batch
|
||||
* size data preprocessing
|
||||
*
|
||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||
* @return training data
|
||||
*/
|
||||
@Override
|
||||
public DataSetIteratorFactory trainData(Map<String, Object> dataParameters) {
|
||||
return create(dataParameters);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get training data given some parameters for the data. Data parameters map
|
||||
* is used to specify things like batch
|
||||
* size data preprocessing
|
||||
*
|
||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
||||
* @return training data
|
||||
*/
|
||||
@Override
|
||||
public DataSetIteratorFactory testData(Map<String, Object> dataParameters) {
|
||||
return create(dataParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<?> getDataType() {
|
||||
return DataSetIteratorFactory.class;
|
||||
}
|
||||
|
||||
private DataSetIteratorFactory create(Map<String, Object> dataParameters) {
|
||||
if (dataParameters == null)
|
||||
throw new IllegalArgumentException(
|
||||
"Data parameters is null. Please specify a class name to create a dataset iterator.");
|
||||
if (!dataParameters.containsKey(FACTORY_KEY))
|
||||
throw new IllegalArgumentException(
|
||||
"No data set iterator factory class found. Please specify a class name with key "
|
||||
+ FACTORY_KEY);
|
||||
String value = dataParameters.get(FACTORY_KEY).toString();
|
||||
try {
|
||||
Class<? extends DataSetIteratorFactory> clazz =
|
||||
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
|
||||
return clazz.newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Properties;
|
||||
|
||||
/**
|
||||
* DataSource: defines where the data should come from for training and testing.
|
||||
* Note that implementations must have a no-argument contsructor
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public interface DataSource extends Serializable {
|
||||
|
||||
/**
|
||||
* Configure the current data source with the specified properties
|
||||
* Note: These properties are fixed for the training instance, and are optionally provided by the user
|
||||
* at the configuration stage.
|
||||
* The properties could be anything - and are usually specific to each DataSource implementation.
|
||||
* For example, values such as batch size could be set using these properties
|
||||
* @param properties Properties to apply to the data source instance
|
||||
*/
|
||||
void configure(Properties properties);
|
||||
|
||||
/**
|
||||
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
|
||||
*/
|
||||
Object trainData();
|
||||
|
||||
/**
|
||||
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
|
||||
*/
|
||||
Object testData();
|
||||
|
||||
/**
|
||||
* The type of data returned by {@link #trainData()} and {@link #testData()}.
|
||||
* Usually DataSetIterator or MultiDataSetIterator
|
||||
* @return Class of the objects returned by trainData and testData
|
||||
*/
|
||||
Class<?> getDataType();
|
||||
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.evaluation;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* ModelEvaluator: Used to conduct additional evaluation.
|
||||
* For example, this may be classification performance on a test set or similar
|
||||
*/
|
||||
public interface ModelEvaluator extends Serializable {
|
||||
Object evaluateModel(Object model, DataProvider dataProvider);
|
||||
|
||||
/**
|
||||
* @return The model types supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedModelTypes();
|
||||
|
||||
/**
|
||||
* @return The datatypes supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedDataTypes();
|
||||
}
|
|
@ -1,65 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.saving;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A simple class to store optimization results in-memory.
|
||||
* Not recommended for large (or a large number of) models.
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class InMemoryResultSaver implements ResultSaver {
|
||||
@Override
|
||||
public ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException {
|
||||
return new InMemoryResult(result, modelResult);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Class<?>> getSupportedCandidateTypes() {
|
||||
return Collections.<Class<?>>singletonList(Object.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Class<?>> getSupportedModelTypes() {
|
||||
return Collections.<Class<?>>singletonList(Object.class);
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
private static class InMemoryResult implements ResultReference {
|
||||
private OptimizationResult result;
|
||||
private Object modelResult;
|
||||
|
||||
@Override
|
||||
public OptimizationResult getResult() throws IOException {
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getResultModel() throws IOException {
|
||||
return modelResult;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.saving;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Idea: We can't store all results in memory in general (might have thousands of candidates with millions of
|
||||
* parameters each)
|
||||
* So instead: return a reference to the saved result. Idea is that the result may be saved to disk or a database,
|
||||
* and we can easily load it back into memory (if/when required) using the getResult() method
|
||||
*/
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ResultReference {
|
||||
|
||||
OptimizationResult getResult() throws IOException;
|
||||
|
||||
Object getResultModel() throws IOException;
|
||||
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.saving;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* The ResultSaver interface provides a means of saving models in such a way that they can be loaded back into memory later,
|
||||
* regardless of where/how they are saved.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ResultSaver {
|
||||
|
||||
/**
|
||||
* Save the model (including configuration and any additional evaluation/results)
|
||||
*
|
||||
* @param result Optimization result for the model to save
|
||||
* @param modelResult Model result to save
|
||||
* @return ResultReference, such that the result can be loaded back into memory
|
||||
* @throws IOException If IO error occurs during model saving
|
||||
*/
|
||||
ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException;
|
||||
|
||||
/**
|
||||
* @return The candidate types supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedCandidateTypes();
|
||||
|
||||
/**
|
||||
* @return The model types supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedModelTypes();
|
||||
|
||||
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.score;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Properties;
|
||||
|
||||
/**
|
||||
* ScoreFunction defines the objective of hyperparameter optimization.
|
||||
* Specifically, it is used to calculate a score for a given model, relative to the data set provided
|
||||
* in the configuration.
|
||||
*
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ScoreFunction extends Serializable {
|
||||
|
||||
/**
|
||||
* Calculate and return the score, for the given model and data provider
|
||||
*
|
||||
* @param model Model to score
|
||||
* @param dataProvider Data provider - data to use
|
||||
* @param dataParameters Parameters for data
|
||||
* @return Calculated score
|
||||
*/
|
||||
double score(Object model, DataProvider dataProvider, Map<String, Object> dataParameters);
|
||||
|
||||
/**
|
||||
* Calculate and return the score, for the given model and data provider
|
||||
*
|
||||
* @param model Model to score
|
||||
* @param dataSource Data source
|
||||
* @param dataSourceProperties data source properties
|
||||
* @return Calculated score
|
||||
*/
|
||||
double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties);
|
||||
|
||||
/**
|
||||
* Should this score function be minimized or maximized?
|
||||
*
|
||||
* @return true if score should be minimized, false if score should be maximized
|
||||
*/
|
||||
boolean minimize();
|
||||
|
||||
/**
|
||||
* @return The model types supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedModelTypes();
|
||||
|
||||
/**
|
||||
* @return The data types supported by this class
|
||||
*/
|
||||
List<Class<?>> getSupportedDataTypes();
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.termination;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
/**
|
||||
* Terminate hyperparameter search when the number of candidates exceeds a specified value.
|
||||
* Note that this is counted as number of completed candidates, plus number of failed candidates.
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public class MaxCandidatesCondition implements TerminationCondition {
|
||||
@JsonProperty
|
||||
private int maxCandidates;
|
||||
|
||||
@Override
|
||||
public void initialize(IOptimizationRunner optimizationRunner) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean terminate(IOptimizationRunner optimizationRunner) {
|
||||
return optimizationRunner.numCandidatesCompleted() + optimizationRunner.numCandidatesFailed() >= maxCandidates;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "MaxCandidatesCondition(" + maxCandidates + ")";
|
||||
}
|
||||
}
|
|
@ -1,83 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.termination;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import org.joda.time.format.DateTimeFormat;
|
||||
import org.joda.time.format.DateTimeFormatter;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Terminate hyperparameter optimization after
|
||||
* a fixed amount of time has passed
|
||||
* @author Alex Black
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public class MaxTimeCondition implements TerminationCondition {
|
||||
private static final DateTimeFormatter formatter = DateTimeFormat.forPattern("dd-MMM HH:mm ZZ");
|
||||
|
||||
private long duration;
|
||||
private TimeUnit timeUnit;
|
||||
private long startTime;
|
||||
private long endTime;
|
||||
|
||||
|
||||
private MaxTimeCondition(@JsonProperty("duration") long duration, @JsonProperty("timeUnit") TimeUnit timeUnit,
|
||||
@JsonProperty("startTime") long startTime, @JsonProperty("endTime") long endTime) {
|
||||
this.duration = duration;
|
||||
this.timeUnit = timeUnit;
|
||||
this.startTime = startTime;
|
||||
this.endTime = endTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param duration Duration of time
|
||||
* @param timeUnit Unit that the duration is specified in
|
||||
*/
|
||||
public MaxTimeCondition(long duration, TimeUnit timeUnit) {
|
||||
this.duration = duration;
|
||||
this.timeUnit = timeUnit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initialize(IOptimizationRunner optimizationRunner) {
|
||||
startTime = System.currentTimeMillis();
|
||||
this.endTime = startTime + timeUnit.toMillis(duration);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean terminate(IOptimizationRunner optimizationRunner) {
|
||||
return System.currentTimeMillis() >= endTime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
if (startTime > 0) {
|
||||
return "MaxTimeCondition(" + duration + "," + timeUnit + ",start=\"" + formatter.print(startTime)
|
||||
+ "\",end=\"" + formatter.print(endTime) + "\")";
|
||||
} else {
|
||||
return "MaxTimeCondition(" + duration + "," + timeUnit + "\")";
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.api.termination;
|
||||
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
/**
|
||||
* Global termination condition for conducting hyperparameter optimization.
|
||||
* Termination conditions are used to determine if/when the optimization should stop.
|
||||
*/
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public interface TerminationCondition {
|
||||
|
||||
/**
|
||||
* Initialize the termination condition (such as starting timers, etc).
|
||||
*/
|
||||
void initialize(IOptimizationRunner optimizationRunner);
|
||||
|
||||
/**
|
||||
* Determine whether optimization should be terminated
|
||||
*
|
||||
* @param optimizationRunner Optimization runner
|
||||
* @return true if learning should be terminated, false otherwise
|
||||
*/
|
||||
boolean terminate(IOptimizationRunner optimizationRunner);
|
||||
|
||||
}
|
|
@ -1,223 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.config;
|
||||
|
||||
import lombok.*;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
|
||||
/**
|
||||
* OptimizationConfiguration ties together all of the various
|
||||
* components (such as data, score functions, result saving etc)
|
||||
* required to execute hyperparameter optimization.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(exclude = {"dataProvider", "terminationConditions", "candidateGenerator", "resultSaver"})
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public class OptimizationConfiguration {
|
||||
@JsonSerialize
|
||||
private DataProvider dataProvider;
|
||||
@JsonSerialize
|
||||
private Class<? extends DataSource> dataSource;
|
||||
@JsonSerialize
|
||||
private Properties dataSourceProperties;
|
||||
@JsonSerialize
|
||||
private CandidateGenerator candidateGenerator;
|
||||
@JsonSerialize
|
||||
private ResultSaver resultSaver;
|
||||
@JsonSerialize
|
||||
private ScoreFunction scoreFunction;
|
||||
@JsonSerialize
|
||||
private List<TerminationCondition> terminationConditions;
|
||||
@JsonSerialize
|
||||
private Long rngSeed;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
private long executionStartTime;
|
||||
|
||||
|
||||
private OptimizationConfiguration(Builder builder) {
|
||||
this.dataProvider = builder.dataProvider;
|
||||
this.dataSource = builder.dataSource;
|
||||
this.dataSourceProperties = builder.dataSourceProperties;
|
||||
this.candidateGenerator = builder.candidateGenerator;
|
||||
this.resultSaver = builder.resultSaver;
|
||||
this.scoreFunction = builder.scoreFunction;
|
||||
this.terminationConditions = builder.terminationConditions;
|
||||
this.rngSeed = builder.rngSeed;
|
||||
|
||||
if (rngSeed != null)
|
||||
candidateGenerator.setRngSeed(rngSeed);
|
||||
|
||||
//Validate the configuration: data types, score types, etc
|
||||
//TODO
|
||||
|
||||
//Validate that the dataSource has a no-arg constructor
|
||||
if(dataSource != null){
|
||||
try{
|
||||
dataSource.getConstructor();
|
||||
} catch (NoSuchMethodException e){
|
||||
throw new IllegalStateException("Data source class " + dataSource.getName() + " does not have a public no-argument constructor");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private DataProvider dataProvider;
|
||||
private Class<? extends DataSource> dataSource;
|
||||
private Properties dataSourceProperties;
|
||||
private CandidateGenerator candidateGenerator;
|
||||
private ResultSaver resultSaver;
|
||||
private ScoreFunction scoreFunction;
|
||||
private List<TerminationCondition> terminationConditions;
|
||||
private Long rngSeed;
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #dataSource(Class, Properties)}
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder dataProvider(DataProvider dataProvider) {
|
||||
this.dataProvider = dataProvider;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* DataSource: defines where the data should come from for training and testing.
|
||||
* Note that implementations must have a no-argument constructor
|
||||
* @param dataSource Class for the data source
|
||||
* @param dataSourceProperties May be null. Properties for configuring the data source
|
||||
*/
|
||||
public Builder dataSource(Class<? extends DataSource> dataSource, Properties dataSourceProperties){
|
||||
this.dataSource = dataSource;
|
||||
this.dataSourceProperties = dataSourceProperties;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder candidateGenerator(CandidateGenerator candidateGenerator) {
|
||||
this.candidateGenerator = candidateGenerator;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder modelSaver(ResultSaver resultSaver) {
|
||||
this.resultSaver = resultSaver;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder scoreFunction(ScoreFunction scoreFunction) {
|
||||
this.scoreFunction = scoreFunction;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Termination conditions to use
|
||||
* @param conditions
|
||||
* @return
|
||||
*/
|
||||
public Builder terminationConditions(TerminationCondition... conditions) {
|
||||
terminationConditions = Arrays.asList(conditions);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder terminationConditions(List<TerminationCondition> terminationConditions) {
|
||||
this.terminationConditions = terminationConditions;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder rngSeed(long rngSeed) {
|
||||
this.rngSeed = rngSeed;
|
||||
return this;
|
||||
}
|
||||
|
||||
public OptimizationConfiguration build() {
|
||||
return new OptimizationConfiguration(this);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Create an optimization configuration from the json
|
||||
* @param json the json to create the config from
|
||||
* For type definitions
|
||||
* @see OptimizationConfiguration
|
||||
*/
|
||||
public static OptimizationConfiguration fromYaml(String json) {
|
||||
try {
|
||||
return JsonMapper.getYamlMapper().readValue(json, OptimizationConfiguration.class);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an optimization configuration from the json
|
||||
* @param json the json to create the config from
|
||||
* @see OptimizationConfiguration
|
||||
*/
|
||||
public static OptimizationConfiguration fromJson(String json) {
|
||||
try {
|
||||
return JsonMapper.getMapper().readValue(json, OptimizationConfiguration.class);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a json configuration of this optimization configuration
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public String toJson() {
|
||||
try {
|
||||
return JsonMapper.getMapper().writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a yaml configuration of this optimization configuration
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public String toYaml() {
|
||||
try {
|
||||
return JsonMapper.getYamlMapper().writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,99 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||
|
||||
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* Degenerate distribution: i.e., integer "distribution" that is just a fixed value
|
||||
*/
|
||||
public class DegenerateIntegerDistribution implements IntegerDistribution {
|
||||
private int value;
|
||||
|
||||
public DegenerateIntegerDistribution(int value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public double probability(int x) {
|
||||
return (x == value ? 1.0 : 0.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cumulativeProbability(int x) {
|
||||
return (x >= value ? 1.0 : 0.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cumulativeProbability(int x0, int x1) throws NumberIsTooLargeException {
|
||||
return (value >= x0 && value <= x1 ? 1.0 : 0.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int inverseCumulativeProbability(double p) throws OutOfRangeException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getNumericalMean() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getNumericalVariance() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSupportLowerBound() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSupportUpperBound() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupportConnected() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reseedRandomGenerator(long seed) {
|
||||
//no op
|
||||
}
|
||||
|
||||
@Override
|
||||
public int sample() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int[] sample(int sampleSize) {
|
||||
int[] out = new int[sampleSize];
|
||||
Arrays.fill(out, value);
|
||||
return out;
|
||||
}
|
||||
}
|
|
@ -1,151 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||
|
||||
import org.apache.commons.math3.distribution.*;
|
||||
|
||||
/**
|
||||
* Distribution utils for Apache Commons math distributions - which don't provide equals, hashcode, toString methods,
|
||||
* don't implement serializable etc.
|
||||
* Which makes unit testing etc quite difficult.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class DistributionUtils {
|
||||
|
||||
private DistributionUtils() {}
|
||||
|
||||
|
||||
public static boolean distributionsEqual(RealDistribution a, RealDistribution b) {
|
||||
if (a.getClass() != b.getClass())
|
||||
return false;
|
||||
Class<?> c = a.getClass();
|
||||
if (c == BetaDistribution.class) {
|
||||
BetaDistribution ba = (BetaDistribution) a;
|
||||
BetaDistribution bb = (BetaDistribution) b;
|
||||
|
||||
return ba.getAlpha() == bb.getAlpha() && ba.getBeta() == bb.getBeta();
|
||||
} else if (c == CauchyDistribution.class) {
|
||||
CauchyDistribution ca = (CauchyDistribution) a;
|
||||
CauchyDistribution cb = (CauchyDistribution) b;
|
||||
return ca.getMedian() == cb.getMedian() && ca.getScale() == cb.getScale();
|
||||
} else if (c == ChiSquaredDistribution.class) {
|
||||
ChiSquaredDistribution ca = (ChiSquaredDistribution) a;
|
||||
ChiSquaredDistribution cb = (ChiSquaredDistribution) b;
|
||||
return ca.getDegreesOfFreedom() == cb.getDegreesOfFreedom();
|
||||
} else if (c == ExponentialDistribution.class) {
|
||||
ExponentialDistribution ea = (ExponentialDistribution) a;
|
||||
ExponentialDistribution eb = (ExponentialDistribution) b;
|
||||
return ea.getMean() == eb.getMean();
|
||||
} else if (c == FDistribution.class) {
|
||||
FDistribution fa = (FDistribution) a;
|
||||
FDistribution fb = (FDistribution) b;
|
||||
return fa.getNumeratorDegreesOfFreedom() == fb.getNumeratorDegreesOfFreedom()
|
||||
&& fa.getDenominatorDegreesOfFreedom() == fb.getDenominatorDegreesOfFreedom();
|
||||
} else if (c == GammaDistribution.class) {
|
||||
GammaDistribution ga = (GammaDistribution) a;
|
||||
GammaDistribution gb = (GammaDistribution) b;
|
||||
return ga.getShape() == gb.getShape() && ga.getScale() == gb.getScale();
|
||||
} else if (c == LevyDistribution.class) {
|
||||
LevyDistribution la = (LevyDistribution) a;
|
||||
LevyDistribution lb = (LevyDistribution) b;
|
||||
return la.getLocation() == lb.getLocation() && la.getScale() == lb.getScale();
|
||||
} else if (c == LogNormalDistribution.class) {
|
||||
LogNormalDistribution la = (LogNormalDistribution) a;
|
||||
LogNormalDistribution lb = (LogNormalDistribution) b;
|
||||
return la.getScale() == lb.getScale() && la.getShape() == lb.getShape();
|
||||
} else if (c == NormalDistribution.class) {
|
||||
NormalDistribution na = (NormalDistribution) a;
|
||||
NormalDistribution nb = (NormalDistribution) b;
|
||||
return na.getMean() == nb.getMean() && na.getStandardDeviation() == nb.getStandardDeviation();
|
||||
} else if (c == ParetoDistribution.class) {
|
||||
ParetoDistribution pa = (ParetoDistribution) a;
|
||||
ParetoDistribution pb = (ParetoDistribution) b;
|
||||
return pa.getScale() == pb.getScale() && pa.getShape() == pb.getShape();
|
||||
} else if (c == TDistribution.class) {
|
||||
TDistribution ta = (TDistribution) a;
|
||||
TDistribution tb = (TDistribution) b;
|
||||
return ta.getDegreesOfFreedom() == tb.getDegreesOfFreedom();
|
||||
} else if (c == TriangularDistribution.class) {
|
||||
TriangularDistribution ta = (TriangularDistribution) a;
|
||||
TriangularDistribution tb = (TriangularDistribution) b;
|
||||
return ta.getSupportLowerBound() == tb.getSupportLowerBound()
|
||||
&& ta.getSupportUpperBound() == tb.getSupportUpperBound() && ta.getMode() == tb.getMode();
|
||||
} else if (c == UniformRealDistribution.class) {
|
||||
UniformRealDistribution ua = (UniformRealDistribution) a;
|
||||
UniformRealDistribution ub = (UniformRealDistribution) b;
|
||||
return ua.getSupportLowerBound() == ub.getSupportLowerBound()
|
||||
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
|
||||
} else if (c == WeibullDistribution.class) {
|
||||
WeibullDistribution wa = (WeibullDistribution) a;
|
||||
WeibullDistribution wb = (WeibullDistribution) b;
|
||||
return wa.getShape() == wb.getShape() && wa.getScale() == wb.getScale();
|
||||
} else if (c == LogUniformDistribution.class ){
|
||||
LogUniformDistribution lu_a = (LogUniformDistribution)a;
|
||||
LogUniformDistribution lu_b = (LogUniformDistribution)b;
|
||||
return lu_a.getMin() == lu_b.getMin() && lu_a.getMax() == lu_b.getMax();
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unknown or not supported RealDistribution: " + c);
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean distributionEquals(IntegerDistribution a, IntegerDistribution b) {
|
||||
if (a.getClass() != b.getClass())
|
||||
return false;
|
||||
Class<?> c = a.getClass();
|
||||
|
||||
if (c == BinomialDistribution.class) {
|
||||
BinomialDistribution ba = (BinomialDistribution) a;
|
||||
BinomialDistribution bb = (BinomialDistribution) b;
|
||||
return ba.getNumberOfTrials() == bb.getNumberOfTrials()
|
||||
&& ba.getProbabilityOfSuccess() == bb.getProbabilityOfSuccess();
|
||||
} else if (c == GeometricDistribution.class) {
|
||||
GeometricDistribution ga = (GeometricDistribution) a;
|
||||
GeometricDistribution gb = (GeometricDistribution) b;
|
||||
return ga.getProbabilityOfSuccess() == gb.getProbabilityOfSuccess();
|
||||
} else if (c == HypergeometricDistribution.class) {
|
||||
HypergeometricDistribution ha = (HypergeometricDistribution) a;
|
||||
HypergeometricDistribution hb = (HypergeometricDistribution) b;
|
||||
return ha.getPopulationSize() == hb.getPopulationSize()
|
||||
&& ha.getNumberOfSuccesses() == hb.getNumberOfSuccesses()
|
||||
&& ha.getSampleSize() == hb.getSampleSize();
|
||||
} else if (c == PascalDistribution.class) {
|
||||
PascalDistribution pa = (PascalDistribution) a;
|
||||
PascalDistribution pb = (PascalDistribution) b;
|
||||
return pa.getNumberOfSuccesses() == pb.getNumberOfSuccesses()
|
||||
&& pa.getProbabilityOfSuccess() == pb.getProbabilityOfSuccess();
|
||||
} else if (c == PoissonDistribution.class) {
|
||||
PoissonDistribution pa = (PoissonDistribution) a;
|
||||
PoissonDistribution pb = (PoissonDistribution) b;
|
||||
return pa.getMean() == pb.getMean();
|
||||
} else if (c == UniformIntegerDistribution.class) {
|
||||
UniformIntegerDistribution ua = (UniformIntegerDistribution) a;
|
||||
UniformIntegerDistribution ub = (UniformIntegerDistribution) b;
|
||||
return ua.getSupportUpperBound() == ub.getSupportUpperBound()
|
||||
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
|
||||
} else if (c == ZipfDistribution.class) {
|
||||
ZipfDistribution za = (ZipfDistribution) a;
|
||||
ZipfDistribution zb = (ZipfDistribution) b;
|
||||
return za.getNumberOfElements() == zb.getNumberOfElements() && za.getExponent() == zb.getNumberOfElements();
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unknown or not supported IntegerDistribution: " + c);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -1,157 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||
|
||||
import org.nd4j.shade.guava.base.Preconditions;
|
||||
import lombok.Getter;
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* Log uniform distribution, with support in range [min, max] for min > 0
|
||||
*
|
||||
* Reference: <a href="https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php">https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php</a>
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class LogUniformDistribution implements RealDistribution {
|
||||
|
||||
@Getter private final double min;
|
||||
@Getter private final double max;
|
||||
|
||||
private final double logMin;
|
||||
private final double logMax;
|
||||
|
||||
private transient Random rng = new Random();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param min Minimum value
|
||||
* @param max Maximum value
|
||||
*/
|
||||
public LogUniformDistribution(double min, double max) {
|
||||
Preconditions.checkArgument(min > 0, "Minimum must be > 0. Got: " + min);
|
||||
Preconditions.checkArgument(max > min, "Maximum must be > min. Got: (min, max)=("
|
||||
+ min + "," + max + ")");
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
|
||||
this.logMin = Math.log(min);
|
||||
this.logMax = Math.log(max);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double probability(double x) {
|
||||
if(x < min || x > max){
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 1.0 / (x * (logMax - logMin));
|
||||
}
|
||||
|
||||
@Override
|
||||
public double density(double x) {
|
||||
return probability(x);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cumulativeProbability(double x) {
|
||||
if(x <= min){
|
||||
return 0.0;
|
||||
} else if(x >= max){
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
return (Math.log(x)-logMin)/(logMax-logMin);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException {
|
||||
return cumulativeProbability(x1) - cumulativeProbability(x0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double inverseCumulativeProbability(double p) throws OutOfRangeException {
|
||||
Preconditions.checkArgument(p >= 0 && p <= 1, "Invalid input: " + p);
|
||||
return Math.exp(p * (logMax-logMin) + logMin);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getNumericalMean() {
|
||||
return (max-min)/(logMax-logMin);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getNumericalVariance() {
|
||||
double d1 = (logMax-logMin)*(max*max - min*min) - 2*(max-min)*(max-min);
|
||||
return d1 / (2*Math.pow(logMax-logMin, 2.0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getSupportLowerBound() {
|
||||
return min;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getSupportUpperBound() {
|
||||
return max;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupportLowerBoundInclusive() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupportUpperBoundInclusive() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupportConnected() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reseedRandomGenerator(long seed) {
|
||||
rng.setSeed(seed);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double sample() {
|
||||
return inverseCumulativeProbability(rng.nextDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] sample(int sampleSize) {
|
||||
double[] d = new double[sampleSize];
|
||||
for( int i=0; i<sampleSize; i++ ){
|
||||
d[i] = sample();
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
return "LogUniformDistribution(min=" + min + ",max=" + max + ")";
|
||||
}
|
||||
}
|
|
@ -1,93 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.util.LeafUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
/**
|
||||
* BaseCandidateGenerator: abstract class upon which {@link RandomSearchGenerator},
|
||||
* {@link GridSearchCandidateGenerator} and {@link GeneticSearchCandidateGenerator}
|
||||
* are built.
|
||||
*
|
||||
* @param <T> Type of candidates to generate
|
||||
*/
|
||||
@Data
|
||||
@EqualsAndHashCode(exclude = {"rng", "candidateCounter"})
|
||||
public abstract class BaseCandidateGenerator<T> implements CandidateGenerator {
|
||||
protected ParameterSpace<T> parameterSpace;
|
||||
protected AtomicInteger candidateCounter = new AtomicInteger(0);
|
||||
protected SynchronizedRandomGenerator rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
protected Map<String, Object> dataParameters;
|
||||
protected boolean initDone = false;
|
||||
|
||||
public BaseCandidateGenerator(ParameterSpace<T> parameterSpace, Map<String, Object> dataParameters,
|
||||
boolean initDone) {
|
||||
this.parameterSpace = parameterSpace;
|
||||
this.dataParameters = dataParameters;
|
||||
this.initDone = initDone;
|
||||
}
|
||||
|
||||
protected void initialize() {
|
||||
if(!initDone) {
|
||||
//First: collect leaf parameter spaces objects and remove duplicates
|
||||
List<ParameterSpace> noDuplicatesList = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
|
||||
|
||||
//Second: assign each a number
|
||||
int i = 0;
|
||||
for (ParameterSpace ps : noDuplicatesList) {
|
||||
int np = ps.numParameters();
|
||||
if (np == 1) {
|
||||
ps.setIndices(i++);
|
||||
} else {
|
||||
int[] values = new int[np];
|
||||
for (int j = 0; j < np; j++)
|
||||
values[j] = i++;
|
||||
ps.setIndices(values);
|
||||
}
|
||||
}
|
||||
initDone = true;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParameterSpace<T> getParameterSpace() {
|
||||
return parameterSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reportResults(OptimizationResult result) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setRngSeed(long rngSeed) {
|
||||
rng.setSeed(rngSeed);
|
||||
}
|
||||
}
|
|
@ -1,189 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.EmptyPopulationInitializer;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.GeneticSelectionOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Uses a genetic algorithm to generate candidates.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
@Slf4j
|
||||
public class GeneticSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||
|
||||
@Getter
|
||||
protected final PopulationModel populationModel;
|
||||
|
||||
protected final ChromosomeFactory chromosomeFactory;
|
||||
protected final SelectionOperator selectionOperator;
|
||||
|
||||
protected boolean hasMoreCandidates = true;
|
||||
|
||||
public static class Builder {
|
||||
protected final ParameterSpace<?> parameterSpace;
|
||||
|
||||
protected Map<String, Object> dataParameters;
|
||||
protected boolean initDone;
|
||||
protected boolean minimizeScore;
|
||||
protected PopulationModel populationModel;
|
||||
protected ChromosomeFactory chromosomeFactory;
|
||||
protected SelectionOperator selectionOperator;
|
||||
|
||||
/**
|
||||
* @param parameterSpace ParameterSpace from which to generate candidates
|
||||
* @param scoreFunction The score function that will be used in the OptimizationConfiguration
|
||||
*/
|
||||
public Builder(ParameterSpace<?> parameterSpace, ScoreFunction scoreFunction) {
|
||||
this.parameterSpace = parameterSpace;
|
||||
this.minimizeScore = scoreFunction.minimize();
|
||||
}
|
||||
|
||||
/**
|
||||
* @param populationModel The PopulationModel instance to use.
|
||||
*/
|
||||
public Builder populationModel(PopulationModel populationModel) {
|
||||
this.populationModel = populationModel;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param selectionOperator The SelectionOperator to use. Default is GeneticSelectionOperator
|
||||
*/
|
||||
public Builder selectionOperator(SelectionOperator selectionOperator) {
|
||||
this.selectionOperator = selectionOperator;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder dataParameters(Map<String, Object> dataParameters) {
|
||||
|
||||
this.dataParameters = dataParameters;
|
||||
return this;
|
||||
}
|
||||
|
||||
public GeneticSearchCandidateGenerator.Builder initDone(boolean initDone) {
|
||||
this.initDone = initDone;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param chromosomeFactory The ChromosomeFactory to use
|
||||
*/
|
||||
public Builder chromosomeFactory(ChromosomeFactory chromosomeFactory) {
|
||||
this.chromosomeFactory = chromosomeFactory;
|
||||
return this;
|
||||
}
|
||||
|
||||
public GeneticSearchCandidateGenerator build() {
|
||||
if (populationModel == null) {
|
||||
PopulationInitializer defaultPopulationInitializer = new EmptyPopulationInitializer();
|
||||
populationModel = new PopulationModel.Builder().populationInitializer(defaultPopulationInitializer)
|
||||
.build();
|
||||
}
|
||||
|
||||
if (chromosomeFactory == null) {
|
||||
chromosomeFactory = new ChromosomeFactory();
|
||||
}
|
||||
|
||||
if (selectionOperator == null) {
|
||||
selectionOperator = new GeneticSelectionOperator.Builder().build();
|
||||
}
|
||||
|
||||
return new GeneticSearchCandidateGenerator(this);
|
||||
}
|
||||
}
|
||||
|
||||
private GeneticSearchCandidateGenerator(Builder builder) {
|
||||
super(builder.parameterSpace, builder.dataParameters, builder.initDone);
|
||||
|
||||
initialize();
|
||||
|
||||
chromosomeFactory = builder.chromosomeFactory;
|
||||
populationModel = builder.populationModel;
|
||||
selectionOperator = builder.selectionOperator;
|
||||
|
||||
chromosomeFactory.initializeInstance(builder.parameterSpace.numParameters());
|
||||
populationModel.initializeInstance(builder.minimizeScore);
|
||||
selectionOperator.initializeInstance(populationModel, chromosomeFactory);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasMoreCandidates() {
|
||||
return hasMoreCandidates;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Candidate getCandidate() {
|
||||
|
||||
double[] values = null;
|
||||
Object value = null;
|
||||
Exception e = null;
|
||||
|
||||
try {
|
||||
values = selectionOperator.buildNextGenes();
|
||||
value = parameterSpace.getValue(values);
|
||||
} catch (GeneticGenerationException e2) {
|
||||
log.warn("Error generating candidate", e2);
|
||||
e = e2;
|
||||
hasMoreCandidates = false;
|
||||
} catch (Exception e2) {
|
||||
log.warn("Error getting configuration for candidate", e2);
|
||||
e = e2;
|
||||
}
|
||||
|
||||
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<?> getCandidateType() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "GeneticSearchCandidateGenerator";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reportResults(OptimizationResult result) {
|
||||
if (result.getScore() == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
Chromosome newChromosome = chromosomeFactory.createChromosome(result.getCandidate().getFlatParameters(),
|
||||
result.getScore());
|
||||
populationModel.add(newChromosome);
|
||||
}
|
||||
}
|
|
@ -1,234 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.math3.random.RandomAdaptor;
|
||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
||||
import org.deeplearning4j.arbiter.util.LeafUtils;
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||
|
||||
|
||||
/**
|
||||
* GridSearchCandidateGenerator: generates candidates in an exhaustive grid search manner.<br>
|
||||
* Note that:<br>
|
||||
* - For discrete parameters: the grid size (# values to check per hyperparameter) is equal to the number of values for
|
||||
* that hyperparameter<br>
|
||||
* - For integer parameters: the grid size is equal to {@code min(discretizationCount,max-min+1)}. Some integer ranges can
|
||||
* be large, and we don't necessarily want to exhaustively search them. {@code discretizationCount} is a constructor argument<br>
|
||||
* - For continuous parameters: the grid size is equal to {@code discretizationCount}.<br>
|
||||
* In all cases, the minimum, maximum and gridSize-2 values between the min/max will be generated.<br>
|
||||
* Also note that: if a probability distribution is provided for continuous hyperparameters, this will be taken into account
|
||||
* when generating candidates. This allows the grid for a hyperparameter to be non-linear: i.e., for example, linear in log space
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Slf4j
|
||||
@EqualsAndHashCode(exclude = {"order"}, callSuper = true)
|
||||
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
|
||||
public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
||||
|
||||
/**
|
||||
* In what order should candidates be generated?<br>
|
||||
* <b>Sequential</b>: generate candidates in order. The first hyperparameter will be changed most rapidly, and the last
|
||||
* will be changed least rapidly.<br>
|
||||
* <b>RandomOrder</b>: generate candidates in a random order<br>
|
||||
* In both cases, the same candidates will be generated; only the order of generation is different
|
||||
*/
|
||||
public enum Mode {
|
||||
Sequential, RandomOrder
|
||||
}
|
||||
|
||||
private final int discretizationCount;
|
||||
private final Mode mode;
|
||||
|
||||
private int[] numValuesPerParam;
|
||||
@Getter
|
||||
private int totalNumCandidates;
|
||||
private Queue<Integer> order;
|
||||
|
||||
/**
|
||||
* @param parameterSpace ParameterSpace from which to generate candidates
|
||||
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
|
||||
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
|
||||
* do [0.0, 0.5, 1.0]. Note that if all values
|
||||
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
|
||||
* in which candidates should be generated.
|
||||
*/
|
||||
public GridSearchCandidateGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
|
||||
@JsonProperty("discretizationCount") int discretizationCount, @JsonProperty("mode") Mode mode,
|
||||
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
|
||||
@JsonProperty("initDone") boolean initDone) {
|
||||
super(parameterSpace, dataParameters, initDone);
|
||||
this.discretizationCount = discretizationCount;
|
||||
this.mode = mode;
|
||||
initialize();
|
||||
}
|
||||
|
||||
/**
|
||||
* @param parameterSpace ParameterSpace from which to generate candidates
|
||||
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
|
||||
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
|
||||
* do [0.0, 0.5, 1.0]. Note that if all values
|
||||
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
|
||||
* in which candidates should be generated.
|
||||
*/
|
||||
public GridSearchCandidateGenerator(ParameterSpace<?> parameterSpace, int discretizationCount, Mode mode,
|
||||
Map<String, Object> dataParameters){
|
||||
this(parameterSpace, discretizationCount, mode, dataParameters, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void initialize() {
|
||||
super.initialize();
|
||||
|
||||
List<ParameterSpace> leaves = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
|
||||
int nParams = leaves.size();
|
||||
|
||||
//Work out for each parameter: is it continuous or discrete?
|
||||
// for grid search: discrete values are grid-searchable as-is
|
||||
// continuous values: discretize using 'discretizationCount' bins
|
||||
// integer values: use min(max-min+1, discretizationCount) values. i.e., discretize if necessary
|
||||
numValuesPerParam = new int[nParams];
|
||||
long searchSize = 1;
|
||||
for (int i = 0; i < nParams; i++) {
|
||||
ParameterSpace ps = leaves.get(i);
|
||||
if (ps instanceof DiscreteParameterSpace) {
|
||||
DiscreteParameterSpace dps = (DiscreteParameterSpace) ps;
|
||||
numValuesPerParam[i] = dps.numValues();
|
||||
} else if (ps instanceof IntegerParameterSpace) {
|
||||
IntegerParameterSpace ips = (IntegerParameterSpace) ps;
|
||||
int min = ips.getMin();
|
||||
int max = ips.getMax();
|
||||
//Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000)
|
||||
numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount);
|
||||
} else if (ps instanceof FixedValue){
|
||||
numValuesPerParam[i] = 1;
|
||||
} else {
|
||||
numValuesPerParam[i] = discretizationCount;
|
||||
}
|
||||
searchSize *= numValuesPerParam[i];
|
||||
}
|
||||
|
||||
if (searchSize >= Integer.MAX_VALUE)
|
||||
throw new IllegalStateException("Invalid search: cannot process search with " + searchSize
|
||||
+ " candidates > Integer.MAX_VALUE"); //TODO find a more reasonable upper bound?
|
||||
|
||||
order = new ConcurrentLinkedQueue<>();
|
||||
|
||||
totalNumCandidates = (int) searchSize;
|
||||
switch (mode) {
|
||||
case Sequential:
|
||||
for (int i = 0; i < totalNumCandidates; i++) {
|
||||
order.add(i);
|
||||
}
|
||||
break;
|
||||
case RandomOrder:
|
||||
List<Integer> tempList = new ArrayList<>(totalNumCandidates);
|
||||
for (int i = 0; i < totalNumCandidates; i++) {
|
||||
tempList.add(i);
|
||||
}
|
||||
|
||||
Collections.shuffle(tempList, new RandomAdaptor(rng));
|
||||
order.addAll(tempList);
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasMoreCandidates() {
|
||||
return !order.isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Candidate getCandidate() {
|
||||
int next = order.remove();
|
||||
|
||||
//Next: max integer (candidate number) to values
|
||||
double[] values = indexToValues(numValuesPerParam, next, totalNumCandidates);
|
||||
|
||||
Object value = null;
|
||||
Exception e = null;
|
||||
try {
|
||||
value = parameterSpace.getValue(values);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Error getting configuration for candidate", e2);
|
||||
e = e2;
|
||||
}
|
||||
|
||||
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<?> getCandidateType() {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static double[] indexToValues(int[] numValuesPerParam, int candidateIdx, int product) {
|
||||
//How? first map to index of num possible values. Then: to double values in range 0 to 1
|
||||
// 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc
|
||||
//Based on: Nd4j Shape.ind2sub
|
||||
|
||||
int countNon1 = 0;
|
||||
for( int i : numValuesPerParam)
|
||||
if(i > 1)
|
||||
countNon1++;
|
||||
|
||||
int denom = product;
|
||||
int num = candidateIdx;
|
||||
int[] index = new int[numValuesPerParam.length];
|
||||
|
||||
for (int i = index.length - 1; i >= 0; i--) {
|
||||
denom /= numValuesPerParam[i];
|
||||
index[i] = num / denom;
|
||||
num %= denom;
|
||||
}
|
||||
|
||||
//Now: convert indexes to values in range [0,1]
|
||||
//min value -> 0
|
||||
//max value -> 1
|
||||
double[] out = new double[countNon1];
|
||||
int outIdx = 0;
|
||||
for (int i = 0; i < numValuesPerParam.length; i++) {
|
||||
if (numValuesPerParam[i] > 1){
|
||||
out[outIdx++] = index[i] / ((double) (numValuesPerParam[i] - 1));
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "GridSearchCandidateGenerator(mode=" + mode + ")";
|
||||
}
|
||||
}
|
|
@ -1,95 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* RandomSearchGenerator: generates candidates at random.<br>
|
||||
* Note: if a probability distribution is provided for continuous hyperparameters,
|
||||
* this will be taken into account
|
||||
* when generating candidates. This allows the search to be weighted more towards
|
||||
* certain values according to a probability
|
||||
* density. For example: generate samples for learning rate according to log uniform distribution
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Slf4j
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
|
||||
public class RandomSearchGenerator extends BaseCandidateGenerator {
|
||||
|
||||
@JsonCreator
|
||||
public RandomSearchGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
|
||||
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
|
||||
@JsonProperty("initDone") boolean initDone) {
|
||||
super(parameterSpace, dataParameters, initDone);
|
||||
initialize();
|
||||
}
|
||||
|
||||
public RandomSearchGenerator(ParameterSpace<?> parameterSpace, Map<String,Object> dataParameters){
|
||||
this(parameterSpace, dataParameters, false);
|
||||
}
|
||||
|
||||
public RandomSearchGenerator(ParameterSpace<?> parameterSpace){
|
||||
this(parameterSpace, null, false);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean hasMoreCandidates() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Candidate getCandidate() {
|
||||
double[] randomValues = new double[parameterSpace.numParameters()];
|
||||
for (int i = 0; i < randomValues.length; i++)
|
||||
randomValues[i] = rng.nextDouble();
|
||||
|
||||
Object value = null;
|
||||
Exception e = null;
|
||||
try {
|
||||
value = parameterSpace.getValue(randomValues);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Error getting configuration for candidate", e2);
|
||||
e = e2;
|
||||
}
|
||||
|
||||
return new Candidate(value, candidateCounter.getAndIncrement(), randomValues, dataParameters, e);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<?> getCandidateType() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "RandomSearchGenerator";
|
||||
}
|
||||
}
|
|
@ -1,44 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* Candidates are stored as Chromosome in the population model
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
@Data
|
||||
public class Chromosome {
|
||||
/**
|
||||
* The fitness score of the genes.
|
||||
*/
|
||||
protected final double fitness;
|
||||
|
||||
/**
|
||||
* The genes.
|
||||
*/
|
||||
protected final double[] genes;
|
||||
|
||||
public Chromosome(double[] genes, double fitness) {
|
||||
this.genes = genes;
|
||||
this.fitness = fitness;
|
||||
}
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic;
|
||||
|
||||
/**
|
||||
* A factory that builds new chromosomes. Used by the GeneticSearchCandidateGenerator.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class ChromosomeFactory {
|
||||
private int chromosomeLength;
|
||||
|
||||
/**
|
||||
* Called by the GeneticSearchCandidateGenerator.
|
||||
*/
|
||||
public void initializeInstance(int chromosomeLength) {
|
||||
this.chromosomeLength = chromosomeLength;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new instance of a Chromosome
|
||||
*
|
||||
* @param genes The genes
|
||||
* @param fitness The fitness score
|
||||
* @return A new instance of Chromosome
|
||||
*/
|
||||
public Chromosome createChromosome(double[] genes, double fitness) {
|
||||
return new Chromosome(genes, fitness);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The number of genes in a chromosome
|
||||
*/
|
||||
public int getChromosomeLength() {
|
||||
return chromosomeLength;
|
||||
}
|
||||
}
|
|
@ -1,122 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
/**
|
||||
* A crossover operator that linearly combines the genes of two parents. <br>
|
||||
* When a crossover is generated (with a of probability <i>crossover rate</i>), each genes is a linear combination of the corresponding genes of the parents.
|
||||
* <p>
|
||||
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i>
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class ArithmeticCrossover extends TwoParentsCrossoverOperator {
|
||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||
|
||||
private final double crossoverRate;
|
||||
private final RandomGenerator rng;
|
||||
|
||||
public static class Builder {
|
||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||
private RandomGenerator rng;
|
||||
private TwoParentSelection parentSelection;
|
||||
|
||||
/**
|
||||
* The probability that the operator generates a crossover (default 0.85).
|
||||
*
|
||||
* @param rate A value between 0.0 and 1.0
|
||||
*/
|
||||
public Builder crossoverRate(double rate) {
|
||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||
|
||||
this.crossoverRate = rate;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The parent selection behavior. Default is random parent selection.
|
||||
*
|
||||
* @param parentSelection An instance of TwoParentSelection
|
||||
*/
|
||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||
this.parentSelection = parentSelection;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ArithmeticCrossover build() {
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
|
||||
if (parentSelection == null) {
|
||||
parentSelection = new RandomTwoParentSelection();
|
||||
}
|
||||
|
||||
return new ArithmeticCrossover(this);
|
||||
}
|
||||
}
|
||||
|
||||
private ArithmeticCrossover(ArithmeticCrossover.Builder builder) {
|
||||
super(builder.parentSelection);
|
||||
|
||||
this.crossoverRate = builder.crossoverRate;
|
||||
this.rng = builder.rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Has a probability <i>crossoverRate</i> of performing the crossover where each gene is a linear combination of:<br>
|
||||
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i><br>
|
||||
* Otherwise, returns the genes of a random parent.
|
||||
*
|
||||
* @return The crossover result. See {@link CrossoverResult}.
|
||||
*/
|
||||
@Override
|
||||
public CrossoverResult crossover() {
|
||||
double[][] parents = parentSelection.selectParents();
|
||||
|
||||
double[] offspringValues = new double[parents[0].length];
|
||||
|
||||
if (rng.nextDouble() < crossoverRate) {
|
||||
for (int i = 0; i < offspringValues.length; ++i) {
|
||||
double t = rng.nextDouble();
|
||||
offspringValues[i] = t * parents[0][i] + (1.0 - t) * parents[1][i];
|
||||
}
|
||||
return new CrossoverResult(true, offspringValues);
|
||||
}
|
||||
|
||||
return new CrossoverResult(false, parents[0]);
|
||||
}
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
||||
/**
|
||||
* Abstract class for all crossover operators
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class CrossoverOperator {
|
||||
protected PopulationModel populationModel;
|
||||
|
||||
/**
|
||||
* Will be called by the selection operator once the population model is instantiated.
|
||||
*/
|
||||
public void initializeInstance(PopulationModel populationModel) {
|
||||
this.populationModel = populationModel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the crossover
|
||||
*
|
||||
* @return The crossover result. See {@link CrossoverResult}.
|
||||
*/
|
||||
public abstract CrossoverResult crossover();
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* Returned by a crossover operator
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
@Data
|
||||
public class CrossoverResult {
|
||||
/**
|
||||
* If false, there was no crossover and the operator simply returned the genes of a random parent.
|
||||
* If true, the genes are the result of a crossover.
|
||||
*/
|
||||
private final boolean isModified;
|
||||
|
||||
/**
|
||||
* The genes returned by the operator.
|
||||
*/
|
||||
private final double[] genes;
|
||||
|
||||
public CrossoverResult(boolean isModified, double[] genes) {
|
||||
this.isModified = isModified;
|
||||
this.genes = genes;
|
||||
}
|
||||
}
|
|
@ -1,180 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
import java.util.Deque;
|
||||
|
||||
/**
|
||||
* The K-Point crossover will select at random multiple crossover points.<br>
|
||||
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched.
|
||||
*/
|
||||
public class KPointCrossover extends TwoParentsCrossoverOperator {
|
||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||
private static final int DEFAULT_MIN_CROSSOVER = 1;
|
||||
private static final int DEFAULT_MAX_CROSSOVER = 4;
|
||||
|
||||
private final double crossoverRate;
|
||||
private final int minCrossovers;
|
||||
private final int maxCrossovers;
|
||||
|
||||
private final RandomGenerator rng;
|
||||
|
||||
public static class Builder {
|
||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||
private int minCrossovers = DEFAULT_MIN_CROSSOVER;
|
||||
private int maxCrossovers = DEFAULT_MAX_CROSSOVER;
|
||||
private RandomGenerator rng;
|
||||
private TwoParentSelection parentSelection;
|
||||
|
||||
/**
|
||||
* The probability that the operator generates a crossover (default 0.85).
|
||||
*
|
||||
* @param rate A value between 0.0 and 1.0
|
||||
*/
|
||||
public Builder crossoverRate(double rate) {
|
||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||
|
||||
this.crossoverRate = rate;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of crossovers points (default is min 1, max 4)
|
||||
*
|
||||
* @param min The minimum number
|
||||
* @param max The maximum number
|
||||
*/
|
||||
public Builder numCrossovers(int min, int max) {
|
||||
Preconditions.checkState(max >= 0 && min >= 0, "Min and max must be positive");
|
||||
Preconditions.checkState(max >= min, "Max must be greater or equal to min");
|
||||
|
||||
this.minCrossovers = min;
|
||||
this.maxCrossovers = max;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a fixed number of crossover points
|
||||
*
|
||||
* @param num The number of crossovers
|
||||
*/
|
||||
public Builder numCrossovers(int num) {
|
||||
Preconditions.checkState(num >= 0, "Num must be positive");
|
||||
|
||||
this.minCrossovers = num;
|
||||
this.maxCrossovers = num;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The parent selection behavior. Default is random parent selection.
|
||||
*
|
||||
* @param parentSelection An instance of TwoParentSelection
|
||||
*/
|
||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||
this.parentSelection = parentSelection;
|
||||
return this;
|
||||
}
|
||||
|
||||
public KPointCrossover build() {
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
|
||||
if (parentSelection == null) {
|
||||
parentSelection = new RandomTwoParentSelection();
|
||||
}
|
||||
|
||||
return new KPointCrossover(this);
|
||||
}
|
||||
}
|
||||
|
||||
private CrossoverPointsGenerator crossoverPointsGenerator;
|
||||
|
||||
private KPointCrossover(KPointCrossover.Builder builder) {
|
||||
super(builder.parentSelection);
|
||||
|
||||
this.crossoverRate = builder.crossoverRate;
|
||||
this.maxCrossovers = builder.maxCrossovers;
|
||||
this.minCrossovers = builder.minCrossovers;
|
||||
this.rng = builder.rng;
|
||||
}
|
||||
|
||||
private CrossoverPointsGenerator getCrossoverPointsGenerator(int chromosomeLength) {
|
||||
if (crossoverPointsGenerator == null) {
|
||||
crossoverPointsGenerator =
|
||||
new CrossoverPointsGenerator(chromosomeLength, minCrossovers, maxCrossovers, rng);
|
||||
}
|
||||
|
||||
return crossoverPointsGenerator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select at random multiple crossover points.<br>
|
||||
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched. <br>
|
||||
* Otherwise, returns the genes of a random parent.
|
||||
*
|
||||
* @return The crossover result. See {@link CrossoverResult}.
|
||||
*/
|
||||
@Override
|
||||
public CrossoverResult crossover() {
|
||||
double[][] parents = parentSelection.selectParents();
|
||||
|
||||
boolean isModified = false;
|
||||
double[] resultGenes = parents[0];
|
||||
|
||||
if (rng.nextDouble() < crossoverRate) {
|
||||
// Select crossover points
|
||||
Deque<Integer> crossoverPoints = getCrossoverPointsGenerator(parents[0].length).getCrossoverPoints();
|
||||
|
||||
// Crossover
|
||||
resultGenes = new double[parents[0].length];
|
||||
int currentParent = 0;
|
||||
int nextCrossover = crossoverPoints.pop();
|
||||
for (int i = 0; i < resultGenes.length; ++i) {
|
||||
if (i == nextCrossover) {
|
||||
currentParent = currentParent == 0 ? 1 : 0;
|
||||
nextCrossover = crossoverPoints.pop();
|
||||
}
|
||||
resultGenes[i] = parents[currentParent][i];
|
||||
}
|
||||
isModified = true;
|
||||
}
|
||||
|
||||
return new CrossoverResult(isModified, resultGenes);
|
||||
}
|
||||
}
|
|
@ -1,125 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
/**
|
||||
* The single point crossover will select a random point where every genes before that point comes from one parent
|
||||
* and after which every genes comes from the other parent.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class SinglePointCrossover extends TwoParentsCrossoverOperator {
|
||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||
|
||||
private final RandomGenerator rng;
|
||||
private final double crossoverRate;
|
||||
|
||||
public static class Builder {
|
||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||
private RandomGenerator rng;
|
||||
private TwoParentSelection parentSelection;
|
||||
|
||||
/**
|
||||
* The probability that the operator generates a crossover (default 0.85).
|
||||
*
|
||||
* @param rate A value between 0.0 and 1.0
|
||||
*/
|
||||
public Builder crossoverRate(double rate) {
|
||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||
|
||||
this.crossoverRate = rate;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The parent selection behavior. Default is random parent selection.
|
||||
*
|
||||
* @param parentSelection An instance of TwoParentSelection
|
||||
*/
|
||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||
this.parentSelection = parentSelection;
|
||||
return this;
|
||||
}
|
||||
|
||||
public SinglePointCrossover build() {
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
|
||||
if (parentSelection == null) {
|
||||
parentSelection = new RandomTwoParentSelection();
|
||||
}
|
||||
|
||||
return new SinglePointCrossover(this);
|
||||
}
|
||||
}
|
||||
|
||||
private SinglePointCrossover(SinglePointCrossover.Builder builder) {
|
||||
super(builder.parentSelection);
|
||||
|
||||
this.crossoverRate = builder.crossoverRate;
|
||||
this.rng = builder.rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select a random crossover point.<br>
|
||||
* Each gene before this point comes from one of the two parents and each gene at or after this point comes from the other parent.
|
||||
* Otherwise, returns the genes of a random parent.
|
||||
*
|
||||
* @return The crossover result. See {@link CrossoverResult}.
|
||||
*/
|
||||
public CrossoverResult crossover() {
|
||||
double[][] parents = parentSelection.selectParents();
|
||||
|
||||
boolean isModified = false;
|
||||
double[] resultGenes = parents[0];
|
||||
|
||||
if (rng.nextDouble() < crossoverRate) {
|
||||
int chromosomeLength = parents[0].length;
|
||||
|
||||
// Crossover
|
||||
resultGenes = new double[chromosomeLength];
|
||||
|
||||
int crossoverPoint = rng.nextInt(chromosomeLength);
|
||||
for (int i = 0; i < resultGenes.length; ++i) {
|
||||
resultGenes[i] = ((i < crossoverPoint) ? parents[0] : parents[1])[i];
|
||||
}
|
||||
isModified = true;
|
||||
}
|
||||
|
||||
return new CrossoverResult(isModified, resultGenes);
|
||||
}
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
||||
/**
|
||||
* Abstract class for all crossover operators that applies to two parents.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class TwoParentsCrossoverOperator extends CrossoverOperator {
|
||||
|
||||
protected final TwoParentSelection parentSelection;
|
||||
|
||||
/**
|
||||
* @param parentSelection A parent selection that selects two parents.
|
||||
*/
|
||||
protected TwoParentsCrossoverOperator(TwoParentSelection parentSelection) {
|
||||
this.parentSelection = parentSelection;
|
||||
}
|
||||
|
||||
/**
|
||||
* Will be called by the selection operator once the population model is instantiated.
|
||||
*/
|
||||
@Override
|
||||
public void initializeInstance(PopulationModel populationModel) {
|
||||
super.initializeInstance(populationModel);
|
||||
parentSelection.initializeInstance(populationModel.getPopulation());
|
||||
}
|
||||
}
|
|
@ -1,138 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
/**
|
||||
* The uniform crossover will, for each gene, randomly select the parent that donates the gene.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class UniformCrossover extends TwoParentsCrossoverOperator {
|
||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
||||
private static final double DEFAULT_PARENT_BIAS_FACTOR = 0.5;
|
||||
|
||||
private final double crossoverRate;
|
||||
private final double parentBiasFactor;
|
||||
private final RandomGenerator rng;
|
||||
|
||||
public static class Builder {
|
||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
||||
private double parentBiasFactor = DEFAULT_PARENT_BIAS_FACTOR;
|
||||
private RandomGenerator rng;
|
||||
private TwoParentSelection parentSelection;
|
||||
|
||||
/**
|
||||
* The probability that the operator generates a crossover (default 0.85).
|
||||
*
|
||||
* @param rate A value between 0.0 and 1.0
|
||||
*/
|
||||
public Builder crossoverRate(double rate) {
|
||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||
|
||||
this.crossoverRate = rate;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* A factor that will introduce a bias in the parent selection.<br>
|
||||
*
|
||||
* @param factor In the range [0, 1]. 0 will only select the first parent while 1 only select the second one. The default is 0.5; no bias.
|
||||
*/
|
||||
public Builder parentBiasFactor(double factor) {
|
||||
Preconditions.checkState(factor >= 0.0 && factor <= 1.0, "Factor must be between 0.0 and 1.0, got %s",
|
||||
factor);
|
||||
|
||||
this.parentBiasFactor = factor;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The parent selection behavior. Default is random parent selection.
|
||||
*
|
||||
* @param parentSelection An instance of TwoParentSelection
|
||||
*/
|
||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
||||
this.parentSelection = parentSelection;
|
||||
return this;
|
||||
}
|
||||
|
||||
public UniformCrossover build() {
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
if (parentSelection == null) {
|
||||
parentSelection = new RandomTwoParentSelection();
|
||||
}
|
||||
return new UniformCrossover(this);
|
||||
}
|
||||
}
|
||||
|
||||
private UniformCrossover(UniformCrossover.Builder builder) {
|
||||
super(builder.parentSelection);
|
||||
|
||||
this.crossoverRate = builder.crossoverRate;
|
||||
this.parentBiasFactor = builder.parentBiasFactor;
|
||||
this.rng = builder.rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select randomly which parent donates the gene.<br>
|
||||
* One of the parent may be favored if the bias is different than 0.5
|
||||
* Otherwise, returns the genes of a random parent.
|
||||
*
|
||||
* @return The crossover result. See {@link CrossoverResult}.
|
||||
*/
|
||||
@Override
|
||||
public CrossoverResult crossover() {
|
||||
// select the parents
|
||||
double[][] parents = parentSelection.selectParents();
|
||||
|
||||
double[] resultGenes = parents[0];
|
||||
boolean isModified = false;
|
||||
|
||||
if (rng.nextDouble() < crossoverRate) {
|
||||
// Crossover
|
||||
resultGenes = new double[parents[0].length];
|
||||
|
||||
for (int i = 0; i < resultGenes.length; ++i) {
|
||||
resultGenes[i] = ((rng.nextDouble() < parentBiasFactor) ? parents[0] : parents[1])[i];
|
||||
}
|
||||
isModified = true;
|
||||
}
|
||||
|
||||
return new CrossoverResult(isModified, resultGenes);
|
||||
}
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Abstract class for all parent selection behaviors
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class ParentSelection {
|
||||
protected List<Chromosome> population;
|
||||
|
||||
/**
|
||||
* Will be called by the crossover operator once the population model is instantiated.
|
||||
*/
|
||||
public void initializeInstance(List<Chromosome> population) {
|
||||
this.population = population;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the parent selection
|
||||
*
|
||||
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
|
||||
*/
|
||||
public abstract double[][] selectParents();
|
||||
}
|
|
@ -1,67 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
|
||||
/**
|
||||
* A parent selection behavior that returns two random parents.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class RandomTwoParentSelection extends TwoParentSelection {
|
||||
|
||||
private final RandomGenerator rng;
|
||||
|
||||
public RandomTwoParentSelection() {
|
||||
this(new SynchronizedRandomGenerator(new JDKRandomGenerator()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public RandomTwoParentSelection(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects two random parents
|
||||
*
|
||||
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
|
||||
*/
|
||||
@Override
|
||||
public double[][] selectParents() {
|
||||
double[][] parents = new double[2][];
|
||||
|
||||
int parent1Idx = rng.nextInt(population.size());
|
||||
int parent2Idx;
|
||||
do {
|
||||
parent2Idx = rng.nextInt(population.size());
|
||||
} while (parent1Idx == parent2Idx);
|
||||
|
||||
parents[0] = population.get(parent1Idx).getGenes();
|
||||
parents[1] = population.get(parent2Idx).getGenes();
|
||||
|
||||
return parents;
|
||||
}
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
||||
|
||||
/**
|
||||
* Abstract class for all parent selection behaviors that selects two parents.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class TwoParentSelection extends ParentSelection {
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* A helper class used by {@link KPointCrossover} to generate the crossover points
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class CrossoverPointsGenerator {
|
||||
private final int minCrossovers;
|
||||
private final int maxCrossovers;
|
||||
private final RandomGenerator rng;
|
||||
private List<Integer> parameterIndexes;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param chromosomeLength The number of genes
|
||||
* @param minCrossovers The minimum number of crossover points to generate
|
||||
* @param maxCrossovers The maximum number of crossover points to generate
|
||||
* @param rng A RandomGenerator instance
|
||||
*/
|
||||
public CrossoverPointsGenerator(int chromosomeLength, int minCrossovers, int maxCrossovers, RandomGenerator rng) {
|
||||
this.minCrossovers = minCrossovers;
|
||||
this.maxCrossovers = maxCrossovers;
|
||||
this.rng = rng;
|
||||
parameterIndexes = new ArrayList<Integer>();
|
||||
for (int i = 0; i < chromosomeLength; ++i) {
|
||||
parameterIndexes.add(i);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a list of crossover points.
|
||||
*
|
||||
* @return An ordered list of crossover point indexes and with Integer.MAX_VALUE as the last element
|
||||
*/
|
||||
public Deque<Integer> getCrossoverPoints() {
|
||||
Collections.shuffle(parameterIndexes);
|
||||
List<Integer> crossoverPointLists =
|
||||
parameterIndexes.subList(0, rng.nextInt(maxCrossovers - minCrossovers) + minCrossovers);
|
||||
Collections.sort(crossoverPointLists);
|
||||
Deque<Integer> crossoverPoints = new ArrayDeque<Integer>(crossoverPointLists);
|
||||
crossoverPoints.add(Integer.MAX_VALUE);
|
||||
|
||||
return crossoverPoints;
|
||||
}
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
||||
/**
|
||||
* The cull operator will remove from the population the least desirables chromosomes.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface CullOperator {
|
||||
/**
|
||||
* Will be called by the population model once created.
|
||||
*/
|
||||
void initializeInstance(PopulationModel populationModel);
|
||||
|
||||
/**
|
||||
* Cull the population to the culled size.
|
||||
*/
|
||||
void cullPopulation();
|
||||
|
||||
/**
|
||||
* @return The target population size after culling.
|
||||
*/
|
||||
int getCulledSize();
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
||||
|
||||
/**
|
||||
* An elitist cull operator that discards the chromosomes with the worst fitness while keeping the best ones.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class LeastFitCullOperator extends RatioCullOperator {
|
||||
|
||||
/**
|
||||
* The default cull ratio is 1/3.
|
||||
*/
|
||||
public LeastFitCullOperator() {
|
||||
super();
|
||||
}
|
||||
|
||||
/**
|
||||
* @param cullRatio The ratio of the maximum population size to be culled.<br>
|
||||
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
|
||||
*/
|
||||
public LeastFitCullOperator(double cullRatio) {
|
||||
super(cullRatio);
|
||||
}
|
||||
|
||||
/**
|
||||
* Will discard the chromosomes with the worst fitness until the population size fall back at the culled size.
|
||||
*/
|
||||
@Override
|
||||
public void cullPopulation() {
|
||||
while (population.size() > culledSize) {
|
||||
population.remove(population.size() - 1);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,72 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* An abstract base for cull operators that culls back the population to a ratio of its maximum size.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class RatioCullOperator implements CullOperator {
|
||||
private static final double DEFAULT_CULL_RATIO = 1.0 / 3.0;
|
||||
protected int culledSize;
|
||||
protected List<Chromosome> population;
|
||||
protected final double cullRatio;
|
||||
|
||||
/**
|
||||
* @param cullRatio The ratio of the maximum population size to be culled.<br>
|
||||
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
|
||||
*/
|
||||
public RatioCullOperator(double cullRatio) {
|
||||
Preconditions.checkState(cullRatio >= 0.0 && cullRatio <= 1.0, "Cull ratio must be between 0.0 and 1.0, got %s",
|
||||
cullRatio);
|
||||
|
||||
this.cullRatio = cullRatio;
|
||||
}
|
||||
|
||||
/**
|
||||
* The default cull ratio is 1/3
|
||||
*/
|
||||
public RatioCullOperator() {
|
||||
this(DEFAULT_CULL_RATIO);
|
||||
}
|
||||
|
||||
/**
|
||||
* Will be called by the population model once created.
|
||||
*/
|
||||
public void initializeInstance(PopulationModel populationModel) {
|
||||
this.population = populationModel.getPopulation();
|
||||
culledSize = (int) (populationModel.getPopulationSize() * (1.0 - cullRatio) + 0.5);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The target population size after culling.
|
||||
*/
|
||||
@Override
|
||||
public int getCulledSize() {
|
||||
return culledSize;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions;
|
||||
|
||||
public class GeneticGenerationException extends RuntimeException {
|
||||
public GeneticGenerationException(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
|
||||
|
||||
/**
|
||||
* The mutation operator will apply a mutation to the given genes.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface MutationOperator {
|
||||
|
||||
/**
|
||||
* Performs a mutation.
|
||||
*
|
||||
* @param genes The genes to be mutated
|
||||
* @return True if the genes were mutated, otherwise false.
|
||||
*/
|
||||
boolean mutate(double[] genes);
|
||||
}
|
|
@ -1,95 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
||||
/**
|
||||
* A mutation operator where each gene has a chance of being mutated with a <i>mutation rate</i> probability.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class RandomMutationOperator implements MutationOperator {
|
||||
private static final double DEFAULT_MUTATION_RATE = 0.005;
|
||||
|
||||
private final double mutationRate;
|
||||
private final RandomGenerator rng;
|
||||
|
||||
public static class Builder {
|
||||
private double mutationRate = DEFAULT_MUTATION_RATE;
|
||||
private RandomGenerator rng;
|
||||
|
||||
/**
|
||||
* Each gene will have this probability of being mutated.
|
||||
*
|
||||
* @param rate The mutation rate. (default 0.005)
|
||||
*/
|
||||
public Builder mutationRate(double rate) {
|
||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
||||
|
||||
this.mutationRate = rate;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
public RandomMutationOperator build() {
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
return new RandomMutationOperator(this);
|
||||
}
|
||||
}
|
||||
|
||||
private RandomMutationOperator(RandomMutationOperator.Builder builder) {
|
||||
this.mutationRate = builder.mutationRate;
|
||||
this.rng = builder.rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the mutation. Each gene has a <i>mutation rate</i> probability of being mutated.
|
||||
*
|
||||
* @param genes The genes to be mutated
|
||||
* @return True if the genes were mutated, otherwise false.
|
||||
*/
|
||||
@Override
|
||||
public boolean mutate(double[] genes) {
|
||||
boolean hasMutation = false;
|
||||
|
||||
for (int i = 0; i < genes.length; ++i) {
|
||||
if (rng.nextDouble() < mutationRate) {
|
||||
genes[i] = rng.nextDouble();
|
||||
hasMutation = true;
|
||||
}
|
||||
}
|
||||
|
||||
return hasMutation;
|
||||
}
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A population initializer that build an empty population.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class EmptyPopulationInitializer implements PopulationInitializer {
|
||||
|
||||
/**
|
||||
* Initialize an empty population
|
||||
*
|
||||
* @param size The maximum size of the population.
|
||||
* @return The initialized population.
|
||||
*/
|
||||
@Override
|
||||
public List<Chromosome> getInitializedPopulation(int size) {
|
||||
return new ArrayList<>(size);
|
||||
}
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* An initializer that construct the population used by the population model.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface PopulationInitializer {
|
||||
/**
|
||||
* Called by the population model to construct the population
|
||||
*
|
||||
* @param size The maximum size of the population
|
||||
* @return An initialized population
|
||||
*/
|
||||
List<Chromosome> getInitializedPopulation(int size);
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A listener that is called when the population changes.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface PopulationListener {
|
||||
/**
|
||||
* Called after the population has changed.
|
||||
*
|
||||
* @param population The population after it has changed.
|
||||
*/
|
||||
void onChanged(List<Chromosome> population);
|
||||
}
|
|
@ -1,184 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* The population model handles all aspects of the population (initialization, additions and culling)
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class PopulationModel {
|
||||
private static final int DEFAULT_POPULATION_SIZE = 30;
|
||||
|
||||
private final CullOperator cullOperator;
|
||||
private final List<PopulationListener> populationListeners = new ArrayList<>();
|
||||
private Comparator<Chromosome> chromosomeComparator;
|
||||
|
||||
/**
|
||||
* The maximum population size
|
||||
*/
|
||||
@Getter
|
||||
private final int populationSize;
|
||||
|
||||
/**
|
||||
* The population
|
||||
*/
|
||||
@Getter
|
||||
public final List<Chromosome> population;
|
||||
|
||||
/**
|
||||
* A comparator used when higher fitness value is better
|
||||
*/
|
||||
public static class MaximizeScoreComparator implements Comparator<Chromosome> {
|
||||
@Override
|
||||
public int compare(Chromosome lhs, Chromosome rhs) {
|
||||
return -Double.compare(lhs.getFitness(), rhs.getFitness());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A comparator used when lower fitness value is better
|
||||
*/
|
||||
public static class MinimizeScoreComparator implements Comparator<Chromosome> {
|
||||
@Override
|
||||
public int compare(Chromosome lhs, Chromosome rhs) {
|
||||
return Double.compare(lhs.getFitness(), rhs.getFitness());
|
||||
}
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private int populationSize = DEFAULT_POPULATION_SIZE;
|
||||
private PopulationInitializer populationInitializer;
|
||||
private CullOperator cullOperator;
|
||||
|
||||
/**
|
||||
* Use an alternate population initialization behavior. Default is empty population.
|
||||
*
|
||||
* @param populationInitializer An instance of PopulationInitializer
|
||||
*/
|
||||
public Builder populationInitializer(PopulationInitializer populationInitializer) {
|
||||
this.populationInitializer = populationInitializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The maximum population size. <br>
|
||||
* If using a ratio based culling, using a population with culled size of around 1.5 to 2 times the number of genes generally gives good results.
|
||||
* (e.g. For a chromosome having 10 genes, the culled size should be between 15 and 20. And with a cull ratio of 1/3 we should set the population size to 23 to 30. (15 / (1 - 1/3)), rounded up)
|
||||
*
|
||||
* @param size The maximum size of the population
|
||||
*/
|
||||
public Builder populationSize(int size) {
|
||||
populationSize = size;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use an alternate cull operator behavior. Default is least fit culling.
|
||||
*
|
||||
* @param cullOperator An instance of a CullOperator
|
||||
*/
|
||||
public Builder cullOperator(CullOperator cullOperator) {
|
||||
this.cullOperator = cullOperator;
|
||||
return this;
|
||||
}
|
||||
|
||||
public PopulationModel build() {
|
||||
if (cullOperator == null) {
|
||||
cullOperator = new LeastFitCullOperator();
|
||||
}
|
||||
|
||||
if (populationInitializer == null) {
|
||||
populationInitializer = new EmptyPopulationInitializer();
|
||||
}
|
||||
|
||||
return new PopulationModel(this);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public PopulationModel(PopulationModel.Builder builder) {
|
||||
populationSize = builder.populationSize;
|
||||
population = new ArrayList<>(builder.populationSize);
|
||||
PopulationInitializer populationInitializer = builder.populationInitializer;
|
||||
|
||||
List<Chromosome> initializedPopulation = populationInitializer.getInitializedPopulation(populationSize);
|
||||
population.clear();
|
||||
population.addAll(initializedPopulation);
|
||||
|
||||
cullOperator = builder.cullOperator;
|
||||
cullOperator.initializeInstance(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Called by the GeneticSearchCandidateGenerator
|
||||
*/
|
||||
public void initializeInstance(boolean minimizeScore) {
|
||||
chromosomeComparator = minimizeScore ? new MinimizeScoreComparator() : new MaximizeScoreComparator();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a PopulationListener to the list of change listeners
|
||||
* @param listener A PopulationListener instance
|
||||
*/
|
||||
public void addListener(PopulationListener listener) {
|
||||
populationListeners.add(listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a Chromosome to the population and call the PopulationListeners. Culling may be triggered.
|
||||
*
|
||||
* @param element The chromosome to be added
|
||||
*/
|
||||
public void add(Chromosome element) {
|
||||
if (population.size() == populationSize) {
|
||||
cullOperator.cullPopulation();
|
||||
}
|
||||
|
||||
population.add(element);
|
||||
|
||||
Collections.sort(population, chromosomeComparator);
|
||||
|
||||
triggerPopulationChangedListeners(population);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Return false when the population is below the culled size, otherwise true. <br>
|
||||
* Used by the selection operator to know if the population is still too small and should generate random genes.
|
||||
*/
|
||||
public boolean isReadyToBreed() {
|
||||
return population.size() >= cullOperator.getCulledSize();
|
||||
}
|
||||
|
||||
private void triggerPopulationChangedListeners(List<Chromosome> population) {
|
||||
for (PopulationListener listener : populationListeners) {
|
||||
listener.onChanged(population);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,199 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
|
||||
|
||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.MutationOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* A selection operator that will generate random genes initially. Once the population has reached the culled size,
|
||||
* will start to generate offsprings of parents selected in the population.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class GeneticSelectionOperator extends SelectionOperator {
|
||||
|
||||
private final static int PREVIOUS_GENES_TO_KEEP = 100;
|
||||
private final static int MAX_NUM_GENERATION_ATTEMPTS = 1024;
|
||||
|
||||
private final CrossoverOperator crossoverOperator;
|
||||
private final MutationOperator mutationOperator;
|
||||
private final RandomGenerator rng;
|
||||
private double[][] previousGenes = new double[PREVIOUS_GENES_TO_KEEP][];
|
||||
private int previousGenesIdx = 0;
|
||||
|
||||
public static class Builder {
|
||||
private ChromosomeFactory chromosomeFactory;
|
||||
private PopulationModel populationModel;
|
||||
private CrossoverOperator crossoverOperator;
|
||||
private MutationOperator mutationOperator;
|
||||
private RandomGenerator rng;
|
||||
|
||||
/**
|
||||
* Use an alternate crossover behavior. Default is SinglePointCrossover.
|
||||
*
|
||||
* @param crossoverOperator An instance of CrossoverOperator
|
||||
*/
|
||||
public Builder crossoverOperator(CrossoverOperator crossoverOperator) {
|
||||
this.crossoverOperator = crossoverOperator;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use an alternate mutation behavior. Default is RandomMutationOperator.
|
||||
*
|
||||
* @param mutationOperator An instance of MutationOperator
|
||||
*/
|
||||
public Builder mutationOperator(MutationOperator mutationOperator) {
|
||||
this.mutationOperator = mutationOperator;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a supplied RandomGenerator
|
||||
*
|
||||
* @param rng An instance of RandomGenerator
|
||||
*/
|
||||
public Builder randomGenerator(RandomGenerator rng) {
|
||||
this.rng = rng;
|
||||
return this;
|
||||
}
|
||||
|
||||
public GeneticSelectionOperator build() {
|
||||
if (crossoverOperator == null) {
|
||||
crossoverOperator = new SinglePointCrossover.Builder().build();
|
||||
}
|
||||
|
||||
if (mutationOperator == null) {
|
||||
mutationOperator = new RandomMutationOperator.Builder().build();
|
||||
}
|
||||
|
||||
if (rng == null) {
|
||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
||||
}
|
||||
|
||||
return new GeneticSelectionOperator(crossoverOperator, mutationOperator, rng);
|
||||
}
|
||||
}
|
||||
|
||||
private GeneticSelectionOperator(CrossoverOperator crossoverOperator, MutationOperator mutationOperator,
|
||||
RandomGenerator rng) {
|
||||
this.crossoverOperator = crossoverOperator;
|
||||
this.mutationOperator = mutationOperator;
|
||||
this.rng = rng;
|
||||
}
|
||||
|
||||
/**
|
||||
* Called by GeneticSearchCandidateGenerator
|
||||
*/
|
||||
@Override
|
||||
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
|
||||
super.initializeInstance(populationModel, chromosomeFactory);
|
||||
crossoverOperator.initializeInstance(populationModel);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a new set of genes. Has two distinct modes of operation
|
||||
* <ul>
|
||||
* <li>Before the population has reached the culled size: will return a random set of genes.</li>
|
||||
* <li>After: Parents will be selected among the population, a crossover will be applied followed by a mutation.</li>
|
||||
* </ul>
|
||||
* @return Returns the generated set of genes
|
||||
* @throws GeneticGenerationException If buildNextGenes() can't generate a set that has not already been tried,
|
||||
* or if the crossover and the mutation operators can't generate a set,
|
||||
* this exception is thrown.
|
||||
*/
|
||||
@Override
|
||||
public double[] buildNextGenes() {
|
||||
double[] result;
|
||||
|
||||
boolean hasAlreadyBeenTried;
|
||||
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
|
||||
do {
|
||||
if (populationModel.isReadyToBreed()) {
|
||||
result = buildOffspring();
|
||||
} else {
|
||||
result = buildRandomGenes();
|
||||
}
|
||||
|
||||
hasAlreadyBeenTried = hasAlreadyBeenTried(result);
|
||||
if (hasAlreadyBeenTried && --attemptsRemaining == 0) {
|
||||
throw new GeneticGenerationException("Failed to generate a set of genes not already tried.");
|
||||
}
|
||||
} while (hasAlreadyBeenTried);
|
||||
|
||||
previousGenes[previousGenesIdx] = result;
|
||||
previousGenesIdx = ++previousGenesIdx % previousGenes.length;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private boolean hasAlreadyBeenTried(double[] genes) {
|
||||
for (int i = 0; i < previousGenes.length; ++i) {
|
||||
double[] current = previousGenes[i];
|
||||
if (current != null && Arrays.equals(current, genes)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private double[] buildOffspring() {
|
||||
double[] offspringValues;
|
||||
|
||||
boolean isModified;
|
||||
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
|
||||
do {
|
||||
CrossoverResult crossoverResult = crossoverOperator.crossover();
|
||||
offspringValues = crossoverResult.getGenes();
|
||||
isModified = crossoverResult.isModified();
|
||||
isModified |= mutationOperator.mutate(offspringValues);
|
||||
|
||||
if (!isModified && --attemptsRemaining == 0) {
|
||||
throw new GeneticGenerationException(
|
||||
String.format("Crossover and mutation operators failed to generate a new set of genes after %s attempts.",
|
||||
MAX_NUM_GENERATION_ATTEMPTS));
|
||||
}
|
||||
} while (!isModified);
|
||||
|
||||
return offspringValues;
|
||||
}
|
||||
|
||||
private double[] buildRandomGenes() {
|
||||
double[] randomValues = new double[chromosomeFactory.getChromosomeLength()];
|
||||
for (int i = 0; i < randomValues.length; ++i) {
|
||||
randomValues[i] = rng.nextDouble();
|
||||
}
|
||||
|
||||
return randomValues;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
||||
/**
|
||||
* An abstract class for all selection operators. Used by the GeneticSearchCandidateGenerator to generate new candidates.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class SelectionOperator {
|
||||
protected PopulationModel populationModel;
|
||||
protected ChromosomeFactory chromosomeFactory;
|
||||
|
||||
/**
|
||||
* Called by GeneticSearchCandidateGenerator
|
||||
*/
|
||||
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
|
||||
|
||||
this.populationModel = populationModel;
|
||||
this.chromosomeFactory = chromosomeFactory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new set of genes.
|
||||
*/
|
||||
public abstract double[] buildNextGenes();
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.generator.util;
|
||||
|
||||
import org.nd4j.common.function.Supplier;
|
||||
|
||||
import java.io.*;
|
||||
|
||||
public class SerializedSupplier<T> implements Serializable, Supplier<T> {
|
||||
|
||||
private byte[] asBytes;
|
||||
|
||||
public SerializedSupplier(T obj){
|
||||
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
|
||||
oos.writeObject(obj);
|
||||
oos.flush();
|
||||
oos.close();
|
||||
asBytes = baos.toByteArray();
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException("Error serializing object - must be serializable",e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public T get() {
|
||||
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(asBytes))){
|
||||
return (T)ois.readObject();
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException("Error deserializing object",e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,78 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.parameter;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* BooleanParameterSpace is a {@code ParameterSpace<Boolean>}; Defines {True, False} as a parameter space
|
||||
* If argument to setValue is less than or equal to 0.5 it will return True else False
|
||||
*
|
||||
* @author susaneraly
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
public class BooleanSpace implements ParameterSpace<Boolean> {
|
||||
private int index = -1;
|
||||
|
||||
@Override
|
||||
public Boolean getValue(double[] input) {
|
||||
if (index == -1) {
|
||||
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
|
||||
}
|
||||
if (input[index] <= 0.5) return Boolean.TRUE;
|
||||
else return Boolean.FALSE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numParameters() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
return Collections.singletonList((ParameterSpace) this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIndices(int... indices) {
|
||||
if (indices == null || indices.length != 1)
|
||||
throw new IllegalArgumentException("Invalid index");
|
||||
this.index = indices[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "BooleanSpace()";
|
||||
}
|
||||
}
|
|
@ -1,92 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.parameter;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueDeserializer;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueSerializer;
|
||||
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* FixedValue is a ParameterSpace that defines only a single fixed value
|
||||
*
|
||||
* @param <T> Type of (fixed) value
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@JsonSerialize(using = FixedValueSerializer.class)
|
||||
@JsonDeserialize(using = FixedValueDeserializer.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public class FixedValue<T> implements ParameterSpace<T> {
|
||||
@Getter
|
||||
private Object value;
|
||||
private int index;
|
||||
|
||||
@JsonCreator
|
||||
public FixedValue(@JsonProperty("value") T value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "FixedValue(" + ObjectUtils.valueToString(value) + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
public T getValue(double[] input) {
|
||||
return (T) value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numParameters() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIndices(int... indices) {
|
||||
if (indices != null && indices.length != 0)
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid call: FixedValue ParameterSpace " + "should not be given an index (0 params)");
|
||||
}
|
||||
}
|
|
@ -1,137 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.parameter.continuous;
|
||||
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.commons.math3.distribution.UniformRealDistribution;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.distribution.DistributionUtils;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionDeserializer;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionSerializer;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* ContinuousParametSpace is a {@code ParameterSpace<Double>} that (optionally) takes an Apache Commons
|
||||
* {@link RealDistribution} when used for random sampling (such as in a RandomSearchCandidateGenerator)
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class ContinuousParameterSpace implements ParameterSpace<Double> {
|
||||
|
||||
//Need to use custom serializers/deserializers for commons RealDistribution instances
|
||||
@JsonSerialize(using = RealDistributionSerializer.class)
|
||||
@JsonDeserialize(using = RealDistributionDeserializer.class)
|
||||
private RealDistribution distribution;
|
||||
private int index = -1;
|
||||
|
||||
/**
|
||||
* ContinuousParameterSpace with uniform distribution between the minimum and maximum values
|
||||
*
|
||||
* @param min Minimum value that can be generated
|
||||
* @param max Maximum value that can be generated
|
||||
*/
|
||||
public ContinuousParameterSpace(double min, double max) {
|
||||
this(new UniformRealDistribution(min, max));
|
||||
}
|
||||
|
||||
/**
|
||||
* ConditiousParameterSpcae wiht a specified probability distribution. The provided distribution defines the min/max
|
||||
* values, and (for random search, etc) will be used when generating random values
|
||||
*
|
||||
* @param distribution Distribution to sample from
|
||||
*/
|
||||
public ContinuousParameterSpace(@JsonProperty("distribution") RealDistribution distribution) {
|
||||
this.distribution = distribution;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Double getValue(double[] input) {
|
||||
if (index == -1) {
|
||||
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
|
||||
}
|
||||
return distribution.inverseCumulativeProbability(input[index]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numParameters() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
return Collections.singletonList((ParameterSpace) this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIndices(int... indices) {
|
||||
if (indices == null || indices.length != 1) {
|
||||
throw new IllegalArgumentException("Invalid index");
|
||||
}
|
||||
this.index = indices[0];
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
if (distribution instanceof UniformRealDistribution) {
|
||||
return "ContinuousParameterSpace(min=" + distribution.getSupportLowerBound() + ",max="
|
||||
+ distribution.getSupportUpperBound() + ")";
|
||||
} else {
|
||||
return "ContinuousParameterSpace(" + distribution + ")";
|
||||
}
|
||||
}
|
||||
|
||||
public boolean equals(Object o) {
|
||||
if (o == this)
|
||||
return true;
|
||||
if (!(o instanceof ContinuousParameterSpace))
|
||||
return false;
|
||||
final ContinuousParameterSpace other = (ContinuousParameterSpace) o;
|
||||
if (distribution == null ? other.distribution != null
|
||||
: !DistributionUtils.distributionsEqual(distribution, other.distribution))
|
||||
return false;
|
||||
|
||||
return this.index == other.index;
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
final int PRIME = 59;
|
||||
int result = 1;
|
||||
result = result * PRIME + (distribution == null ? 43 : distribution.getClass().hashCode());
|
||||
result = result * PRIME + this.index;
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -1,114 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.parameter.discrete;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* A DiscreteParameterSpace is used for a set of un-ordered values
|
||||
*
|
||||
* @param <P> Parameter type
|
||||
* @author Alex Black
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
public class DiscreteParameterSpace<P> implements ParameterSpace<P> {
|
||||
|
||||
@JsonSerialize
|
||||
private List<P> values;
|
||||
private int index = -1;
|
||||
|
||||
public DiscreteParameterSpace(@JsonProperty("values") P... values) {
|
||||
if (values != null)
|
||||
this.values = Arrays.asList(values);
|
||||
}
|
||||
|
||||
public DiscreteParameterSpace(Collection<P> values) {
|
||||
this.values = new ArrayList<>(values);
|
||||
}
|
||||
|
||||
public int numValues() {
|
||||
return values.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public P getValue(double[] input) {
|
||||
if (index == -1) {
|
||||
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
|
||||
}
|
||||
if (values == null)
|
||||
throw new IllegalStateException("Values are null.");
|
||||
//Map a value in range [0,1] to one of the list of values
|
||||
//First value: [0,width], second: (width,2*width], third: (3*width,4*width] etc
|
||||
int size = values.size();
|
||||
if (size == 1)
|
||||
return values.get(0);
|
||||
double width = 1.0 / size;
|
||||
int val = (int) (input[index] / width);
|
||||
return values.get(Math.min(val, size - 1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numParameters() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
return Collections.singletonList((ParameterSpace) this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIndices(int... indices) {
|
||||
if (indices == null || indices.length != 1) {
|
||||
throw new IllegalArgumentException("Invalid index");
|
||||
}
|
||||
this.index = indices[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("DiscreteParameterSpace(");
|
||||
int n = values.size();
|
||||
for (int i = 0; i < n; i++) {
|
||||
P value = values.get(i);
|
||||
sb.append(ObjectUtils.valueToString(value));
|
||||
sb.append((i == n - 1 ? ")" : ","));
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,151 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.parameter.integer;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.distribution.DistributionUtils;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.IntegerDistributionDeserializer;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.IntegerDistributionSerializer;
|
||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* IntegerParameterSpace is a {@code ParameterSpace<Integer>}; i.e., defines an ordered space of integers between
|
||||
* some minimum and maximum value
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@JsonIgnoreProperties({"min", "max"})
|
||||
@NoArgsConstructor
|
||||
public class IntegerParameterSpace implements ParameterSpace<Integer> {
|
||||
|
||||
@JsonSerialize(using = IntegerDistributionSerializer.class)
|
||||
@JsonDeserialize(using = IntegerDistributionDeserializer.class)
|
||||
private IntegerDistribution distribution;
|
||||
private int index = -1;
|
||||
|
||||
/**
|
||||
* Create an IntegerParameterSpace with a uniform distribution between the specified min/max (inclusive)
|
||||
*
|
||||
* @param min Min value, inclusive
|
||||
* @param max Max value, inclusive
|
||||
*/
|
||||
public IntegerParameterSpace(int min, int max) {
|
||||
this(new UniformIntegerDistribution(min, max));
|
||||
}
|
||||
|
||||
/**
|
||||
* Crate an IntegerParametSpace from the given IntegerDistribution
|
||||
*
|
||||
* @param distribution Distribution to use
|
||||
*/
|
||||
@JsonCreator
|
||||
public IntegerParameterSpace(@JsonProperty("distribution") IntegerDistribution distribution) {
|
||||
this.distribution = distribution;
|
||||
}
|
||||
|
||||
public int getMin() {
|
||||
return distribution.getSupportLowerBound();
|
||||
}
|
||||
|
||||
public int getMax() {
|
||||
return distribution.getSupportUpperBound();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getValue(double[] input) {
|
||||
if (index == -1) {
|
||||
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
|
||||
}
|
||||
return distribution.inverseCumulativeProbability(input[index]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numParameters() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
return Collections.singletonList((ParameterSpace) this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIndices(int... indices) {
|
||||
if (indices == null || indices.length != 1)
|
||||
throw new IllegalArgumentException("Invalid index");
|
||||
this.index = indices[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
if (distribution instanceof UniformIntegerDistribution) {
|
||||
return "IntegerParameterSpace(min=" + distribution.getSupportLowerBound() + ",max="
|
||||
+ distribution.getSupportUpperBound() + ")";
|
||||
} else {
|
||||
return "IntegerParameterSpace(" + distribution + ")";
|
||||
}
|
||||
}
|
||||
|
||||
public boolean equals(Object o) {
|
||||
if (o == this)
|
||||
return true;
|
||||
if (!(o instanceof IntegerParameterSpace))
|
||||
return false;
|
||||
final IntegerParameterSpace other = (IntegerParameterSpace) o;
|
||||
if (!other.canEqual(this))
|
||||
return false;
|
||||
if (distribution == null ? other.distribution != null
|
||||
: !DistributionUtils.distributionEquals(distribution, other.distribution))
|
||||
return false;
|
||||
return this.index == other.index;
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
final int PRIME = 59;
|
||||
int result = 1;
|
||||
result = result * PRIME + (distribution == null ? 43 : distribution.getClass().hashCode());
|
||||
result = result * PRIME + this.index;
|
||||
return result;
|
||||
}
|
||||
|
||||
protected boolean canEqual(Object other) {
|
||||
return other instanceof IntegerParameterSpace;
|
||||
}
|
||||
}
|
|
@ -1,71 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.parameter.math;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A simple parameter space that implements scalar mathematical operations on another parameter space. This allows you
|
||||
* to do things like Y = X * 2, where X is a parameter space. For example, a layer size hyperparameter could be set
|
||||
* using this to 2x the size of the previous layer
|
||||
*
|
||||
* @param <T> Type of the parameter space
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class MathOp<T extends Number> extends AbstractParameterSpace<T> {
|
||||
|
||||
private ParameterSpace<T> parameterSpace;
|
||||
private Op op;
|
||||
private T scalar;
|
||||
|
||||
public MathOp(ParameterSpace<T> parameterSpace, Op op, T scalar){
|
||||
this.parameterSpace = parameterSpace;
|
||||
this.op = op;
|
||||
this.scalar = scalar;
|
||||
}
|
||||
|
||||
@Override
|
||||
public T getValue(double[] parameterValues) {
|
||||
T u = parameterSpace.getValue(parameterValues);
|
||||
return op.doOp(u, scalar);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numParameters() {
|
||||
return parameterSpace.numParameters();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
return parameterSpace.collectLeaves();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIndices(int... indices) {
|
||||
parameterSpace.setIndices(indices);
|
||||
}
|
||||
}
|
|
@ -1,78 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.parameter.math;
|
||||
|
||||
public enum Op {
|
||||
ADD, SUB, MUL, DIV;
|
||||
|
||||
|
||||
//Package private
|
||||
<T extends Number> T doOp(T first, T second){
|
||||
if(first instanceof Integer || first instanceof Long){
|
||||
long result;
|
||||
switch (this){
|
||||
case ADD:
|
||||
result = Long.valueOf(first.longValue() + second.longValue());
|
||||
break;
|
||||
case SUB:
|
||||
result = Long.valueOf(first.longValue() - second.longValue());
|
||||
break;
|
||||
case MUL:
|
||||
result = Long.valueOf(first.longValue() * second.longValue());
|
||||
break;
|
||||
case DIV:
|
||||
result = Long.valueOf(first.longValue() / second.longValue());
|
||||
break;
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unknown op: " + this);
|
||||
}
|
||||
if(first instanceof Long){
|
||||
return (T)Long.valueOf(result);
|
||||
} else {
|
||||
return (T)Integer.valueOf((int)result);
|
||||
}
|
||||
} else if(first instanceof Double || first instanceof Float){
|
||||
double result;
|
||||
switch (this){
|
||||
case ADD:
|
||||
result = Double.valueOf(first.doubleValue() + second.doubleValue());
|
||||
break;
|
||||
case SUB:
|
||||
result = Double.valueOf(first.doubleValue() - second.doubleValue());
|
||||
break;
|
||||
case MUL:
|
||||
result = Double.valueOf(first.doubleValue() * second.doubleValue());
|
||||
break;
|
||||
case DIV:
|
||||
result = Double.valueOf(first.doubleValue() / second.doubleValue());
|
||||
break;
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unknown op: " + this);
|
||||
}
|
||||
if(first instanceof Double){
|
||||
return (T)Double.valueOf(result);
|
||||
} else {
|
||||
return (T)Float.valueOf((float)result);
|
||||
}
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Not supported type: only Integer, Long, Double, Float supported" +
|
||||
" here. Got type: " + first.getClass());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.parameter.math;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A simple parameter space that implements pairwise mathematical operations on another parameter space. This allows you
|
||||
* to do things like Z = X + Y, where X and Y are parameter spaces.
|
||||
*
|
||||
* @param <T> Type of the parameter space
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class PairMathOp<T extends Number> extends AbstractParameterSpace<T> {
|
||||
|
||||
private ParameterSpace<T> first;
|
||||
private ParameterSpace<T> second;
|
||||
private Op op;
|
||||
|
||||
public PairMathOp(ParameterSpace<T> first, ParameterSpace<T> second, Op op){
|
||||
this.first = first;
|
||||
this.second = second;
|
||||
this.op = op;
|
||||
}
|
||||
|
||||
@Override
|
||||
public T getValue(double[] parameterValues) {
|
||||
T f = first.getValue(parameterValues);
|
||||
T s = second.getValue(parameterValues);
|
||||
return op.doOp(f, s);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numParameters() {
|
||||
return first.numParameters() + second.numParameters();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
List<ParameterSpace> l = new ArrayList<>();
|
||||
l.addAll(first.collectLeaves());
|
||||
l.addAll(second.collectLeaves());
|
||||
return l;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIndices(int... indices) {
|
||||
int n1 = first.numParameters();
|
||||
int n2 = second.numParameters();
|
||||
int[] s1 = Arrays.copyOfRange(indices, 0, n1);
|
||||
int[] s2 = Arrays.copyOfRange(indices, n1, n1+n2);
|
||||
first.setIndices(s1);
|
||||
second.setIndices(s2);
|
||||
}
|
||||
}
|
|
@ -1,381 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.runner;
|
||||
|
||||
import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.exception.ExceptionUtils;
|
||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
|
||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
/**
|
||||
* BaseOptimization runner: responsible for scheduling tasks, saving results using the result saver, etc.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseOptimizationRunner implements IOptimizationRunner {
|
||||
private static final int POLLING_FREQUENCY = 1;
|
||||
private static final TimeUnit POLLING_FREQUENCY_UNIT = TimeUnit.SECONDS;
|
||||
|
||||
protected OptimizationConfiguration config;
|
||||
protected Queue<Future<OptimizationResult>> queuedFutures = new ConcurrentLinkedQueue<>();
|
||||
protected BlockingQueue<Future<OptimizationResult>> completedFutures = new LinkedBlockingQueue<>();
|
||||
protected AtomicInteger totalCandidateCount = new AtomicInteger();
|
||||
protected AtomicInteger numCandidatesCompleted = new AtomicInteger();
|
||||
protected AtomicInteger numCandidatesFailed = new AtomicInteger();
|
||||
protected Double bestScore = null;
|
||||
protected Long bestScoreTime = null;
|
||||
protected AtomicInteger bestScoreCandidateIndex = new AtomicInteger(-1);
|
||||
protected List<ResultReference> allResults = new ArrayList<>();
|
||||
|
||||
protected Map<Integer, CandidateInfo> currentStatus = new ConcurrentHashMap<>(); //TODO: better design possible?
|
||||
|
||||
protected ExecutorService futureListenerExecutor;
|
||||
|
||||
protected List<StatusListener> statusListeners = new ArrayList<>();
|
||||
|
||||
|
||||
protected BaseOptimizationRunner(OptimizationConfiguration config) {
|
||||
this.config = config;
|
||||
|
||||
if (config.getTerminationConditions() == null || config.getTerminationConditions().size() == 0) {
|
||||
throw new IllegalArgumentException("Cannot create BaseOptimizationRunner without TerminationConditions ("
|
||||
+ "termination conditions are null or empty)");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
protected void init() {
|
||||
futureListenerExecutor = Executors.newFixedThreadPool(maxConcurrentTasks(), new ThreadFactory() {
|
||||
private AtomicLong counter = new AtomicLong(0);
|
||||
|
||||
@Override
|
||||
public Thread newThread(Runnable r) {
|
||||
Thread t = Executors.defaultThreadFactory().newThread(r);
|
||||
t.setDaemon(true);
|
||||
t.setName("ArbiterOptimizationRunner-" + counter.getAndIncrement());
|
||||
return t;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@Override
|
||||
public void execute() {
|
||||
log.info("{}: execution started", this.getClass().getSimpleName());
|
||||
config.setExecutionStartTime(System.currentTimeMillis());
|
||||
for (StatusListener listener : statusListeners) {
|
||||
listener.onInitialization(this);
|
||||
}
|
||||
|
||||
//Initialize termination conditions (start timers, etc)
|
||||
for (TerminationCondition c : config.getTerminationConditions()) {
|
||||
c.initialize(this);
|
||||
}
|
||||
|
||||
//Queue initial tasks:
|
||||
List<Future<OptimizationResult>> tempList = new ArrayList<>(100);
|
||||
while (true) {
|
||||
//Otherwise: add tasks if required
|
||||
Future<OptimizationResult> future = null;
|
||||
try {
|
||||
future = completedFutures.poll(POLLING_FREQUENCY, POLLING_FREQUENCY_UNIT);
|
||||
} catch (InterruptedException e) {
|
||||
//No op?
|
||||
}
|
||||
if (future != null) {
|
||||
tempList.add(future);
|
||||
}
|
||||
completedFutures.drainTo(tempList);
|
||||
|
||||
//Process results (if any)
|
||||
for (Future<OptimizationResult> f : tempList) {
|
||||
queuedFutures.remove(f);
|
||||
processReturnedTask(f);
|
||||
}
|
||||
|
||||
if (tempList.size() > 0) {
|
||||
for (StatusListener sl : statusListeners) {
|
||||
sl.onRunnerStatusChange(this);
|
||||
}
|
||||
}
|
||||
tempList.clear();
|
||||
|
||||
//Check termination conditions:
|
||||
if (terminate()) {
|
||||
shutdown(true);
|
||||
break;
|
||||
}
|
||||
|
||||
//Add additional tasks
|
||||
while (config.getCandidateGenerator().hasMoreCandidates() && queuedFutures.size() < maxConcurrentTasks()) {
|
||||
Candidate candidate = config.getCandidateGenerator().getCandidate();
|
||||
CandidateInfo status;
|
||||
if (candidate.getException() != null) {
|
||||
//Failed on generation...
|
||||
status = processFailedCandidates(candidate);
|
||||
} else {
|
||||
long created = System.currentTimeMillis();
|
||||
ListenableFuture<OptimizationResult> f;
|
||||
if(config.getDataSource() != null){
|
||||
f = execute(candidate, config.getDataSource(), config.getDataSourceProperties(), config.getScoreFunction());
|
||||
} else {
|
||||
f = execute(candidate, config.getDataProvider(), config.getScoreFunction());
|
||||
}
|
||||
f.addListener(new OnCompletionListener(f), futureListenerExecutor);
|
||||
queuedFutures.add(f);
|
||||
totalCandidateCount.getAndIncrement();
|
||||
|
||||
status = new CandidateInfo(candidate.getIndex(), CandidateStatus.Created, null,
|
||||
created, null, null, candidate.getFlatParameters(), null);
|
||||
currentStatus.put(candidate.getIndex(), status);
|
||||
}
|
||||
|
||||
for (StatusListener listener : statusListeners) {
|
||||
listener.onCandidateStatusChange(status, this, null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Process any final (completed) tasks:
|
||||
completedFutures.drainTo(tempList);
|
||||
for (Future<OptimizationResult> f : tempList) {
|
||||
queuedFutures.remove(f);
|
||||
processReturnedTask(f);
|
||||
}
|
||||
tempList.clear();
|
||||
|
||||
log.info("Optimization runner: execution complete");
|
||||
for (StatusListener listener : statusListeners) {
|
||||
listener.onShutdown(this);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private CandidateInfo processFailedCandidates(Candidate<?> candidate) {
|
||||
//In case the candidate fails during the creation of the candidate
|
||||
|
||||
long time = System.currentTimeMillis();
|
||||
String stackTrace = ExceptionUtils.getStackTrace(candidate.getException());
|
||||
CandidateInfo newStatus = new CandidateInfo(candidate.getIndex(), CandidateStatus.Failed, null, time, time,
|
||||
time, candidate.getFlatParameters(), stackTrace);
|
||||
currentStatus.put(candidate.getIndex(), newStatus);
|
||||
|
||||
return newStatus;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process returned task (either completed or failed
|
||||
*/
|
||||
private void processReturnedTask(Future<OptimizationResult> future) {
|
||||
long currentTime = System.currentTimeMillis();
|
||||
OptimizationResult result;
|
||||
try {
|
||||
result = future.get(100, TimeUnit.MILLISECONDS);
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException("Unexpected InterruptedException thrown for task", e);
|
||||
} catch (ExecutionException e) {
|
||||
//Note that most of the time, an OptimizationResult is returned even for an exception
|
||||
//This is just to handle any that are missed there (or, by implementations that don't properly do this)
|
||||
log.warn("Task failed", e);
|
||||
|
||||
numCandidatesFailed.getAndIncrement();
|
||||
return;
|
||||
} catch (TimeoutException e) {
|
||||
throw new RuntimeException(e); //TODO
|
||||
}
|
||||
|
||||
//Update internal status:
|
||||
CandidateInfo status = currentStatus.get(result.getIndex());
|
||||
CandidateInfo newStatus = new CandidateInfo(result.getIndex(), result.getCandidateInfo().getCandidateStatus(),
|
||||
result.getScore(), status.getCreatedTime(), result.getCandidateInfo().getStartTime(),
|
||||
currentTime, status.getFlatParams(), result.getCandidateInfo().getExceptionStackTrace());
|
||||
currentStatus.put(result.getIndex(), newStatus);
|
||||
|
||||
//Listeners (on complete, etc) should be executed in underlying task
|
||||
|
||||
|
||||
if (result.getCandidateInfo().getCandidateStatus() == CandidateStatus.Failed) {
|
||||
log.info("Task {} failed during execution: {}", result.getIndex(), result.getCandidateInfo().getExceptionStackTrace());
|
||||
numCandidatesFailed.getAndIncrement();
|
||||
} else {
|
||||
|
||||
//Report completion to candidate generator
|
||||
config.getCandidateGenerator().reportResults(result);
|
||||
|
||||
Double score = result.getScore();
|
||||
log.info("Completed task {}, score = {}", result.getIndex(), result.getScore());
|
||||
|
||||
boolean minimize = config.getScoreFunction().minimize();
|
||||
if (score != null && (bestScore == null
|
||||
|| ((minimize && score < bestScore) || (!minimize && score > bestScore)))) {
|
||||
if (bestScore == null) {
|
||||
log.info("New best score: {} (first completed model)", score);
|
||||
} else {
|
||||
int idx = result.getIndex();
|
||||
int lastBestIdx = bestScoreCandidateIndex.get();
|
||||
log.info("New best score: {}, model {} (prev={}, model {})", score, idx, bestScore, lastBestIdx);
|
||||
}
|
||||
bestScore = score;
|
||||
bestScoreTime = System.currentTimeMillis();
|
||||
bestScoreCandidateIndex.set(result.getIndex());
|
||||
}
|
||||
numCandidatesCompleted.getAndIncrement();
|
||||
|
||||
//Model saving is done in the optimization tasks, to avoid CUDA threading issues
|
||||
ResultReference resultReference = result.getResultReference();
|
||||
|
||||
if (resultReference != null)
|
||||
allResults.add(resultReference);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numCandidatesTotal() {
|
||||
return totalCandidateCount.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numCandidatesCompleted() {
|
||||
return numCandidatesCompleted.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numCandidatesFailed() {
|
||||
return numCandidatesFailed.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numCandidatesQueued() {
|
||||
return queuedFutures.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double bestScore() {
|
||||
return bestScore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long bestScoreTime() {
|
||||
return bestScoreTime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int bestScoreCandidateIndex() {
|
||||
return bestScoreCandidateIndex.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ResultReference> getResults() {
|
||||
return new ArrayList<>(allResults);
|
||||
}
|
||||
|
||||
@Override
|
||||
public OptimizationConfiguration getConfiguration() {
|
||||
return config;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void addListeners(StatusListener... listeners) {
|
||||
for (StatusListener l : listeners) {
|
||||
if (!statusListeners.contains(l)) {
|
||||
statusListeners.add(l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeListeners(StatusListener... listeners) {
|
||||
for (StatusListener l : listeners) {
|
||||
statusListeners.remove(l);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeAllListeners() {
|
||||
statusListeners.clear();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<CandidateInfo> getCandidateStatus() {
|
||||
return new ArrayList<>(currentStatus.values());
|
||||
}
|
||||
|
||||
private boolean terminate() {
|
||||
for (TerminationCondition c : config.getTerminationConditions()) {
|
||||
if (c.terminate(this)) {
|
||||
log.info("BaseOptimizationRunner global termination condition hit: {}", c);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
private class FutureDetails {
|
||||
private final Future<OptimizationResult> future;
|
||||
private final long startTime;
|
||||
private final int index;
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
private class OnCompletionListener implements Runnable {
|
||||
private Future<OptimizationResult> future;
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
completedFutures.add(future);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
protected abstract int maxConcurrentTasks();
|
||||
|
||||
@Deprecated
|
||||
protected abstract ListenableFuture<OptimizationResult> execute(Candidate candidate, DataProvider dataProvider,
|
||||
ScoreFunction scoreFunction);
|
||||
@Deprecated
|
||||
protected abstract List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates,
|
||||
DataProvider dataProvider, ScoreFunction scoreFunction);
|
||||
|
||||
protected abstract ListenableFuture<OptimizationResult> execute(Candidate candidate, Class<? extends DataSource> dataSource,
|
||||
Properties dataSourceProperties, ScoreFunction scoreFunction);
|
||||
|
||||
protected abstract List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates, Class<? extends DataSource> dataSource,
|
||||
Properties dataSourceProperties, ScoreFunction scoreFunction);
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.runner;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* Simple helper class to store status of a candidate that is/has been/will be executed
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
public class CandidateInfo {
|
||||
|
||||
public CandidateInfo() {
|
||||
//No arg constructor for Jackson
|
||||
}
|
||||
|
||||
private int index;
|
||||
private CandidateStatus candidateStatus;
|
||||
private Double score;
|
||||
private long createdTime;
|
||||
private Long startTime;
|
||||
private Long endTime;
|
||||
private double[] flatParams; //Same as parameters in Candidate class
|
||||
private String exceptionStackTrace;
|
||||
}
|
|
@ -1,26 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.runner;
|
||||
|
||||
/**
|
||||
* Status for candidates
|
||||
*/
|
||||
public enum CandidateStatus {
|
||||
Created, Running, Complete, Failed, Cancelled
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.runner;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
|
||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface IOptimizationRunner {
|
||||
|
||||
void execute();
|
||||
|
||||
/** Total number of candidates: created (scheduled), completed and failed */
|
||||
int numCandidatesTotal();
|
||||
|
||||
int numCandidatesCompleted();
|
||||
|
||||
int numCandidatesFailed();
|
||||
|
||||
/** Number of candidates running or queued */
|
||||
int numCandidatesQueued();
|
||||
|
||||
/** Best score found so far */
|
||||
Double bestScore();
|
||||
|
||||
/** Time that the best score was found at, or 0 if no jobs have completed successfully */
|
||||
Long bestScoreTime();
|
||||
|
||||
/** Index of the best scoring candidate, or -1 if no candidate has scored yet*/
|
||||
int bestScoreCandidateIndex();
|
||||
|
||||
List<ResultReference> getResults();
|
||||
|
||||
OptimizationConfiguration getConfiguration();
|
||||
|
||||
void addListeners(StatusListener... listeners);
|
||||
|
||||
void removeListeners(StatusListener... listeners);
|
||||
|
||||
void removeAllListeners();
|
||||
|
||||
List<CandidateInfo> getCandidateStatus();
|
||||
|
||||
/**
|
||||
* @param awaitCompletion If true: await completion of currently scheduled tasks. If false: shutdown immediately,
|
||||
* cancelling any currently executing tasks
|
||||
*/
|
||||
void shutdown(boolean awaitCompletion);
|
||||
}
|
|
@ -1,152 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.runner;
|
||||
|
||||
import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
|
||||
import org.nd4j.shade.guava.util.concurrent.ListeningExecutorService;
|
||||
import org.nd4j.shade.guava.util.concurrent.MoreExecutors;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.arbiter.optimize.api.*;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
/**
|
||||
* LocalOptimizationRunner: execute hyperparameter optimization
|
||||
* locally (on current machine, in current JVM).
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class LocalOptimizationRunner extends BaseOptimizationRunner {
|
||||
|
||||
public static final int DEFAULT_MAX_CONCURRENT_TASKS = 1;
|
||||
|
||||
private final int maxConcurrentTasks;
|
||||
|
||||
private TaskCreator taskCreator;
|
||||
private ListeningExecutorService executor;
|
||||
@Setter
|
||||
private long shutdownMaxWaitMS = 2L * 24 * 60 * 60 * 1000;
|
||||
|
||||
public LocalOptimizationRunner(OptimizationConfiguration config){
|
||||
this(config, null);
|
||||
}
|
||||
|
||||
public LocalOptimizationRunner(OptimizationConfiguration config, TaskCreator taskCreator) {
|
||||
this(DEFAULT_MAX_CONCURRENT_TASKS, config, taskCreator);
|
||||
}
|
||||
|
||||
public LocalOptimizationRunner(int maxConcurrentTasks, OptimizationConfiguration config){
|
||||
this(maxConcurrentTasks, config, null);
|
||||
}
|
||||
|
||||
public LocalOptimizationRunner(int maxConcurrentTasks, OptimizationConfiguration config, TaskCreator taskCreator) {
|
||||
super(config);
|
||||
if (maxConcurrentTasks <= 0)
|
||||
throw new IllegalArgumentException("maxConcurrentTasks must be > 0 (got: " + maxConcurrentTasks + ")");
|
||||
this.maxConcurrentTasks = maxConcurrentTasks;
|
||||
|
||||
if(taskCreator == null){
|
||||
Class<? extends ParameterSpace> psClass = config.getCandidateGenerator().getParameterSpace().getClass();
|
||||
taskCreator = TaskCreatorProvider.defaultTaskCreatorFor(psClass);
|
||||
if(taskCreator == null){
|
||||
throw new IllegalStateException("No TaskCreator was provided and a default TaskCreator cannot be " +
|
||||
"inferred for ParameterSpace class " + psClass.getName() + ". Please provide a TaskCreator " +
|
||||
"via the LocalOptimizationRunner constructor");
|
||||
}
|
||||
}
|
||||
|
||||
this.taskCreator = taskCreator;
|
||||
|
||||
ExecutorService exec = Executors.newFixedThreadPool(maxConcurrentTasks, new ThreadFactory() {
|
||||
private AtomicLong counter = new AtomicLong(0);
|
||||
|
||||
@Override
|
||||
public Thread newThread(Runnable r) {
|
||||
Thread t = Executors.defaultThreadFactory().newThread(r);
|
||||
t.setDaemon(true);
|
||||
t.setName("LocalCandidateExecutor-" + counter.getAndIncrement());
|
||||
return t;
|
||||
}
|
||||
});
|
||||
executor = MoreExecutors.listeningDecorator(exec);
|
||||
|
||||
init();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int maxConcurrentTasks() {
|
||||
return maxConcurrentTasks;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ListenableFuture<OptimizationResult> execute(Candidate candidate, DataProvider dataProvider,
|
||||
ScoreFunction scoreFunction) {
|
||||
return execute(Collections.singletonList(candidate), dataProvider, scoreFunction).get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates, DataProvider dataProvider,
|
||||
ScoreFunction scoreFunction) {
|
||||
List<ListenableFuture<OptimizationResult>> list = new ArrayList<>(candidates.size());
|
||||
for (Candidate candidate : candidates) {
|
||||
Callable<OptimizationResult> task =
|
||||
taskCreator.create(candidate, dataProvider, scoreFunction, statusListeners, this);
|
||||
list.add(executor.submit(task));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ListenableFuture<OptimizationResult> execute(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties, ScoreFunction scoreFunction) {
|
||||
return execute(Collections.singletonList(candidate), dataSource, dataSourceProperties, scoreFunction).get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates, Class<? extends DataSource> dataSource, Properties dataSourceProperties, ScoreFunction scoreFunction) {
|
||||
List<ListenableFuture<OptimizationResult>> list = new ArrayList<>(candidates.size());
|
||||
for (Candidate candidate : candidates) {
|
||||
Callable<OptimizationResult> task = taskCreator.create(candidate, dataSource, dataSourceProperties, scoreFunction, statusListeners, this);
|
||||
list.add(executor.submit(task));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shutdown(boolean awaitTermination) {
|
||||
if(awaitTermination){
|
||||
try {
|
||||
executor.shutdown();
|
||||
executor.awaitTermination(shutdownMaxWaitMS, TimeUnit.MILLISECONDS);
|
||||
} catch (InterruptedException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else {
|
||||
executor.shutdownNow();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,56 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.runner.listener;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
|
||||
/**
|
||||
* BaseStatusListener: implements all methods of {@link StatusListener} as no-op.
|
||||
* Users can extend this and override only the methods actually required
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public abstract class BaseStatusListener implements StatusListener{
|
||||
@Override
|
||||
public void onInitialization(IOptimizationRunner runner) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onShutdown(IOptimizationRunner runner) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onRunnerStatusChange(IOptimizationRunner runner) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner, OptimizationResult result) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration) {
|
||||
//No op
|
||||
}
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.runner.listener;
|
||||
|
||||
/**
|
||||
* Created by Alex on 20/07/2017.
|
||||
*/
|
||||
public enum StatusChangeType {
|
||||
|
||||
CandidateCompleted, CandidateFailed, CandidateNewScheduled, CandidateNewBestScore
|
||||
|
||||
}
|
|
@ -1,62 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.runner.listener;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
|
||||
/**
|
||||
* The status Listener interface is used to inspect/track the status of execution, both for individual candidates,
|
||||
* and for the optimisation runner overall.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public interface StatusListener {
|
||||
|
||||
/** Called when optimization runner starts execution */
|
||||
void onInitialization(IOptimizationRunner runner);
|
||||
|
||||
/** Called when optimization runner terminates */
|
||||
void onShutdown(IOptimizationRunner runner);
|
||||
|
||||
/** Called when any of the summary stats change, for the optimization runner:
|
||||
* number scheduled, number completed, number failed, best score, etc. */
|
||||
void onRunnerStatusChange(IOptimizationRunner runner);
|
||||
|
||||
/**
|
||||
* Called when the status of the candidate is change. For example created, completed, failed.
|
||||
*
|
||||
* @param candidateInfo Candidate information
|
||||
* @param runner Optimisation runner calling this method
|
||||
* @param result Optimisation result. Maybe null.
|
||||
*/
|
||||
void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner, OptimizationResult result);
|
||||
|
||||
/**
|
||||
* This method may be called by tasks as they are executing. The intent of this method is to report partial results,
|
||||
* such as different stages of learning, or scores/evaluations so far
|
||||
*
|
||||
* @param candidateInfo Candidate information
|
||||
* @param candidate Current candidate value/configuration
|
||||
* @param iteration Current iteration number
|
||||
*/
|
||||
void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration);
|
||||
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.runner.listener.impl;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
||||
|
||||
/**
|
||||
* Created by Alex on 20/07/2017.
|
||||
*/
|
||||
@Slf4j
|
||||
public class LoggingStatusListener implements StatusListener {
|
||||
|
||||
|
||||
@Override
|
||||
public void onInitialization(IOptimizationRunner runner) {
|
||||
log.info("Optimization runner: initialized");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onShutdown(IOptimizationRunner runner) {
|
||||
log.info("Optimization runner: shut down");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onRunnerStatusChange(IOptimizationRunner runner) {
|
||||
log.info("Optimization runner: status change");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner,
|
||||
OptimizationResult result) {
|
||||
log.info("Candidate status change: {}", candidateInfo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration) {
|
||||
log.info("Candidate iteration #{} - {}", iteration, candidate);
|
||||
}
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.apache.commons.codec.binary.Base64;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||
import org.nd4j.shade.jackson.core.JsonParser;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
||||
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
|
||||
/**
|
||||
* A custom deserializer to be used in conjunction with {@link FixedValueSerializer}
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class FixedValueDeserializer extends JsonDeserializer<FixedValue> {
|
||||
@Override
|
||||
public FixedValue deserialize(JsonParser p, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
|
||||
JsonNode node = p.getCodec().readTree(p);
|
||||
String className = node.get("@valueclass").asText();
|
||||
Class<?> c;
|
||||
try {
|
||||
c = Class.forName(className);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
if(node.has("value")){
|
||||
//Number, String, Enum
|
||||
JsonNode valueNode = node.get("value");
|
||||
Object o = new ObjectMapper().treeToValue(valueNode, c);
|
||||
return new FixedValue<>(o);
|
||||
} else {
|
||||
//Everything else
|
||||
JsonNode valueNode = node.get("data");
|
||||
String data = valueNode.asText();
|
||||
|
||||
byte[] b = new Base64().decode(data);
|
||||
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(b));
|
||||
try {
|
||||
Object o = ois.readObject();
|
||||
return new FixedValue<>(o);
|
||||
} catch (Throwable t) {
|
||||
throw new RuntimeException(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.apache.commons.net.util.Base64;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||
import org.nd4j.shade.jackson.core.JsonGenerator;
|
||||
import org.nd4j.shade.jackson.core.type.WritableTypeId;
|
||||
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
||||
import org.nd4j.shade.jackson.databind.jsontype.TypeSerializer;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectOutputStream;
|
||||
|
||||
import static org.nd4j.shade.jackson.core.JsonToken.START_OBJECT;
|
||||
|
||||
/**
|
||||
* A custom serializer to handle arbitrary object types
|
||||
* Uses standard JSON where safe (number, string, enumerations) or Java object serialization (bytes -> base64)
|
||||
* The latter is not an ideal approach, but Jackson doesn't support serialization/deserialization of arbitrary
|
||||
* objects very well
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class FixedValueSerializer extends JsonSerializer<FixedValue> {
|
||||
@Override
|
||||
public void serialize(FixedValue fixedValue, JsonGenerator j, SerializerProvider serializerProvider) throws IOException {
|
||||
Object o = fixedValue.getValue();
|
||||
|
||||
j.writeStringField("@valueclass", o.getClass().getName());
|
||||
if(o instanceof Number || o instanceof String || o instanceof Enum){
|
||||
j.writeObjectField("value", o);
|
||||
} else {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
ObjectOutputStream oos = new ObjectOutputStream(baos);
|
||||
oos.writeObject(o);
|
||||
baos.close();
|
||||
byte[] b = baos.toByteArray();
|
||||
String base64 = new Base64().encodeToString(b);
|
||||
j.writeStringField("data", base64);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void serializeWithType(FixedValue value, JsonGenerator gen, SerializerProvider serializers, TypeSerializer typeSer) throws IOException {
|
||||
WritableTypeId typeId = typeSer.typeId(value, START_OBJECT);
|
||||
typeSer.writeTypePrefix(gen, typeId);
|
||||
serialize(value, gen, serializers);
|
||||
typeSer.writeTypeSuffix(gen, typeId);
|
||||
}
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.apache.commons.math3.distribution.*;
|
||||
import org.nd4j.shade.jackson.core.JsonParser;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
||||
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Custom Jackson deserializer for integer distributions
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class IntegerDistributionDeserializer extends JsonDeserializer<IntegerDistribution> {
|
||||
|
||||
@Override
|
||||
public IntegerDistribution deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
|
||||
JsonNode node = p.getCodec().readTree(p);
|
||||
String simpleName = node.get("distribution").asText();
|
||||
|
||||
switch (simpleName) {
|
||||
case "BinomialDistribution":
|
||||
return new BinomialDistribution(node.get("trials").asInt(), node.get("p").asDouble());
|
||||
case "GeometricDistribution":
|
||||
return new GeometricDistribution(node.get("p").asDouble());
|
||||
case "HypergeometricDistribution":
|
||||
return new HypergeometricDistribution(node.get("populationSize").asInt(),
|
||||
node.get("numberOfSuccesses").asInt(), node.get("sampleSize").asInt());
|
||||
case "PascalDistribution":
|
||||
return new PascalDistribution(node.get("r").asInt(), node.get("p").asDouble());
|
||||
case "PoissonDistribution":
|
||||
return new PoissonDistribution(node.get("p").asDouble());
|
||||
case "UniformIntegerDistribution":
|
||||
return new UniformIntegerDistribution(node.get("lower").asInt(), node.get("upper").asInt());
|
||||
case "ZipfDistribution":
|
||||
return new ZipfDistribution(node.get("numElements").asInt(), node.get("exponent").asDouble());
|
||||
default:
|
||||
throw new RuntimeException("Unknown or not supported distribution: " + simpleName);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.apache.commons.math3.distribution.*;
|
||||
import org.nd4j.shade.jackson.core.JsonGenerator;
|
||||
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Custom Jackson serializer for integer distributions
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class IntegerDistributionSerializer extends JsonSerializer<IntegerDistribution> {
|
||||
@Override
|
||||
public void serialize(IntegerDistribution d, JsonGenerator j, SerializerProvider serializerProvider)
|
||||
throws IOException {
|
||||
Class<?> c = d.getClass();
|
||||
String s = c.getSimpleName();
|
||||
|
||||
j.writeStartObject();
|
||||
j.writeStringField("distribution", s);
|
||||
|
||||
if (c == BinomialDistribution.class) {
|
||||
BinomialDistribution bd = (BinomialDistribution) d;
|
||||
j.writeNumberField("trials", bd.getNumberOfTrials());
|
||||
j.writeNumberField("p", bd.getProbabilityOfSuccess());
|
||||
} else if (c == GeometricDistribution.class) {
|
||||
GeometricDistribution gd = (GeometricDistribution) d;
|
||||
j.writeNumberField("p", gd.getProbabilityOfSuccess());
|
||||
} else if (c == HypergeometricDistribution.class) {
|
||||
HypergeometricDistribution hd = (HypergeometricDistribution) d;
|
||||
j.writeNumberField("populationSize", hd.getPopulationSize());
|
||||
j.writeNumberField("numberOfSuccesses", hd.getNumberOfSuccesses());
|
||||
j.writeNumberField("sampleSize", hd.getSampleSize());
|
||||
} else if (c == PascalDistribution.class) {
|
||||
PascalDistribution pd = (PascalDistribution) d;
|
||||
j.writeNumberField("r", pd.getNumberOfSuccesses());
|
||||
j.writeNumberField("p", pd.getProbabilityOfSuccess());
|
||||
} else if (c == PoissonDistribution.class) {
|
||||
PoissonDistribution pd = (PoissonDistribution) d;
|
||||
j.writeNumberField("p", pd.getMean());
|
||||
} else if (c == UniformIntegerDistribution.class) {
|
||||
UniformIntegerDistribution ud = (UniformIntegerDistribution) d;
|
||||
j.writeNumberField("lower", ud.getSupportLowerBound());
|
||||
j.writeNumberField("upper", ud.getSupportUpperBound());
|
||||
} else if (c == ZipfDistribution.class) {
|
||||
ZipfDistribution zd = (ZipfDistribution) d;
|
||||
j.writeNumberField("numElements", zd.getNumberOfElements());
|
||||
j.writeNumberField("exponent", zd.getExponent());
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unknown or not supported IntegerDistribution: " + c);
|
||||
}
|
||||
|
||||
j.writeEndObject();
|
||||
}
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
||||
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
||||
|
||||
/**
|
||||
* Created by Alex on 16/11/2016.
|
||||
*/
|
||||
public class JsonMapper {
|
||||
|
||||
private static ObjectMapper mapper;
|
||||
private static ObjectMapper yamlMapper;
|
||||
|
||||
static {
|
||||
mapper = new ObjectMapper();
|
||||
mapper.registerModule(new JodaModule());
|
||||
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||
mapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY);
|
||||
yamlMapper = new ObjectMapper(new YAMLFactory());
|
||||
yamlMapper.registerModule(new JodaModule());
|
||||
yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||
yamlMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||
yamlMapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
yamlMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
yamlMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
yamlMapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||
}
|
||||
|
||||
private JsonMapper() {}
|
||||
|
||||
|
||||
/**
|
||||
* Return the yaml mapper
|
||||
* @return
|
||||
*/
|
||||
public static ObjectMapper getYamlMapper() {
|
||||
return yamlMapper;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a json mapper
|
||||
* @return
|
||||
*/
|
||||
public static ObjectMapper getMapper() {
|
||||
return mapper;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.apache.commons.math3.distribution.*;
|
||||
import org.deeplearning4j.arbiter.optimize.distribution.LogUniformDistribution;
|
||||
import org.nd4j.shade.jackson.core.JsonParser;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
||||
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Created by Alex on 14/02/2017.
|
||||
*/
|
||||
public class RealDistributionDeserializer extends JsonDeserializer<RealDistribution> {
|
||||
|
||||
@Override
|
||||
public RealDistribution deserialize(JsonParser p, DeserializationContext ctxt)
|
||||
throws IOException, JsonProcessingException {
|
||||
JsonNode node = p.getCodec().readTree(p);
|
||||
String simpleName = node.get("distribution").asText();
|
||||
|
||||
switch (simpleName) {
|
||||
case "BetaDistribution":
|
||||
return new BetaDistribution(node.get("alpha").asDouble(), node.get("beta").asDouble());
|
||||
case "CauchyDistribution":
|
||||
return new CauchyDistribution(node.get("median").asDouble(), node.get("scale").asDouble());
|
||||
case "ChiSquaredDistribution":
|
||||
return new ChiSquaredDistribution(node.get("dof").asDouble());
|
||||
case "ExponentialDistribution":
|
||||
return new ExponentialDistribution(node.get("mean").asDouble());
|
||||
case "FDistribution":
|
||||
return new FDistribution(node.get("numeratorDof").asDouble(), node.get("denominatorDof").asDouble());
|
||||
case "GammaDistribution":
|
||||
return new GammaDistribution(node.get("shape").asDouble(), node.get("scale").asDouble());
|
||||
case "LevyDistribution":
|
||||
return new LevyDistribution(node.get("mu").asDouble(), node.get("c").asDouble());
|
||||
case "LogNormalDistribution":
|
||||
return new LogNormalDistribution(node.get("scale").asDouble(), node.get("shape").asDouble());
|
||||
case "NormalDistribution":
|
||||
return new NormalDistribution(node.get("mean").asDouble(), node.get("stdev").asDouble());
|
||||
case "ParetoDistribution":
|
||||
return new ParetoDistribution(node.get("scale").asDouble(), node.get("shape").asDouble());
|
||||
case "TDistribution":
|
||||
return new TDistribution(node.get("dof").asDouble());
|
||||
case "TriangularDistribution":
|
||||
return new TriangularDistribution(node.get("a").asDouble(), node.get("b").asDouble(),
|
||||
node.get("c").asDouble());
|
||||
case "UniformRealDistribution":
|
||||
return new UniformRealDistribution(node.get("lower").asDouble(), node.get("upper").asDouble());
|
||||
case "WeibullDistribution":
|
||||
return new WeibullDistribution(node.get("alpha").asDouble(), node.get("beta").asDouble());
|
||||
case "LogUniformDistribution":
|
||||
return new LogUniformDistribution(node.get("min").asDouble(), node.get("max").asDouble());
|
||||
default:
|
||||
throw new RuntimeException("Unknown or not supported distribution: " + simpleName);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
|
@ -1,109 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.apache.commons.math3.distribution.*;
|
||||
import org.deeplearning4j.arbiter.optimize.distribution.LogUniformDistribution;
|
||||
import org.nd4j.shade.jackson.core.JsonGenerator;
|
||||
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Custom JSON serializer for Apache commons RealDistribution instances.
|
||||
* The custom serializer is set up to use the built-in c
|
||||
*/
|
||||
public class RealDistributionSerializer extends JsonSerializer<RealDistribution> {
|
||||
|
||||
@Override
|
||||
public void serialize(RealDistribution d, JsonGenerator j, SerializerProvider serializerProvider)
|
||||
throws IOException {
|
||||
Class<?> c = d.getClass();
|
||||
String s = c.getSimpleName();
|
||||
|
||||
j.writeStartObject();
|
||||
j.writeStringField("distribution", s);
|
||||
|
||||
|
||||
if (c == BetaDistribution.class) {
|
||||
BetaDistribution bd = (BetaDistribution) d;
|
||||
j.writeNumberField("alpha", bd.getAlpha());
|
||||
j.writeNumberField("beta", bd.getBeta());
|
||||
} else if (c == CauchyDistribution.class) {
|
||||
CauchyDistribution cd = (CauchyDistribution) d;
|
||||
j.writeNumberField("median", cd.getMedian());
|
||||
j.writeNumberField("scale", cd.getScale());
|
||||
} else if (c == ChiSquaredDistribution.class) {
|
||||
ChiSquaredDistribution cd = (ChiSquaredDistribution) d;
|
||||
j.writeNumberField("dof", cd.getDegreesOfFreedom());
|
||||
} else if (c == ExponentialDistribution.class) {
|
||||
ExponentialDistribution ed = (ExponentialDistribution) d;
|
||||
j.writeNumberField("mean", ed.getMean());
|
||||
} else if (c == FDistribution.class) {
|
||||
FDistribution fd = (FDistribution) d;
|
||||
j.writeNumberField("numeratorDof", fd.getNumeratorDegreesOfFreedom());
|
||||
j.writeNumberField("denominatorDof", fd.getDenominatorDegreesOfFreedom());
|
||||
} else if (c == GammaDistribution.class) {
|
||||
GammaDistribution gd = (GammaDistribution) d;
|
||||
j.writeNumberField("shape", gd.getShape());
|
||||
j.writeNumberField("scale", gd.getScale());
|
||||
} else if (c == LevyDistribution.class) {
|
||||
LevyDistribution ld = (LevyDistribution) d;
|
||||
j.writeNumberField("mu", ld.getLocation());
|
||||
j.writeNumberField("c", ld.getScale());
|
||||
} else if (c == LogNormalDistribution.class) {
|
||||
LogNormalDistribution ln = (LogNormalDistribution) d;
|
||||
j.writeNumberField("scale", ln.getScale());
|
||||
j.writeNumberField("shape", ln.getShape());
|
||||
} else if (c == NormalDistribution.class) {
|
||||
NormalDistribution nd = (NormalDistribution) d;
|
||||
j.writeNumberField("mean", nd.getMean());
|
||||
j.writeNumberField("stdev", nd.getStandardDeviation());
|
||||
} else if (c == ParetoDistribution.class) {
|
||||
ParetoDistribution pd = (ParetoDistribution) d;
|
||||
j.writeNumberField("scale", pd.getScale());
|
||||
j.writeNumberField("shape", pd.getShape());
|
||||
} else if (c == TDistribution.class) {
|
||||
TDistribution td = (TDistribution) d;
|
||||
j.writeNumberField("dof", td.getDegreesOfFreedom());
|
||||
} else if (c == TriangularDistribution.class) {
|
||||
TriangularDistribution td = (TriangularDistribution) d;
|
||||
j.writeNumberField("a", td.getSupportLowerBound());
|
||||
j.writeNumberField("b", td.getMode());
|
||||
j.writeNumberField("c", td.getSupportUpperBound());
|
||||
} else if (c == UniformRealDistribution.class) {
|
||||
UniformRealDistribution u = (UniformRealDistribution) d;
|
||||
j.writeNumberField("lower", u.getSupportLowerBound());
|
||||
j.writeNumberField("upper", u.getSupportUpperBound());
|
||||
} else if (c == WeibullDistribution.class) {
|
||||
WeibullDistribution wb = (WeibullDistribution) d;
|
||||
j.writeNumberField("alpha", wb.getShape());
|
||||
j.writeNumberField("beta", wb.getScale());
|
||||
} else if (c == LogUniformDistribution.class){
|
||||
LogUniformDistribution lud = (LogUniformDistribution) d;
|
||||
j.writeNumberField("min", lud.getMin());
|
||||
j.writeNumberField("max", lud.getMax());
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unknown or not supported RealDistribution: " + d.getClass());
|
||||
}
|
||||
|
||||
j.writeEndObject();
|
||||
}
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
||||
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
||||
|
||||
/**
|
||||
* Created by Alex on 16/11/2016.
|
||||
*/
|
||||
public class YamlMapper {
|
||||
|
||||
private static final ObjectMapper mapper;
|
||||
|
||||
static {
|
||||
mapper = new ObjectMapper(new YAMLFactory());
|
||||
mapper.registerModule(new JodaModule());
|
||||
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||
}
|
||||
|
||||
|
||||
private YamlMapper() {}
|
||||
|
||||
public static ObjectMapper getMapper() {
|
||||
return mapper;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,235 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.util;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.*;
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.net.URL;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipFile;
|
||||
|
||||
/**
|
||||
* Simple utility class used to get access to files at the classpath, or packed into jar.
|
||||
* Based on Spring ClassPathResource implementation + jar internals access implemented.
|
||||
*
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class ClassPathResource {
|
||||
|
||||
private String resourceName;
|
||||
|
||||
private static Logger log = LoggerFactory.getLogger(ClassPathResource.class);
|
||||
|
||||
/**
|
||||
* Builds new ClassPathResource object
|
||||
*
|
||||
* @param resourceName String name of resource, to be retrieved
|
||||
*/
|
||||
public ClassPathResource(String resourceName) {
|
||||
if (resourceName == null)
|
||||
throw new IllegalStateException("Resource name can't be null");
|
||||
this.resourceName = resourceName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns URL of the requested resource
|
||||
*
|
||||
* @return URL of the resource, if it's available in current Jar
|
||||
*/
|
||||
private URL getUrl() {
|
||||
ClassLoader loader = null;
|
||||
try {
|
||||
loader = Thread.currentThread().getContextClassLoader();
|
||||
} catch (Exception e) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
if (loader == null) {
|
||||
loader = ClassPathResource.class.getClassLoader();
|
||||
}
|
||||
|
||||
URL url = loader.getResource(this.resourceName);
|
||||
if (url == null) {
|
||||
// try to check for mis-used starting slash
|
||||
// TODO: see TODO below
|
||||
if (this.resourceName.startsWith("/")) {
|
||||
url = loader.getResource(this.resourceName.replaceFirst("[\\\\/]", ""));
|
||||
if (url != null)
|
||||
return url;
|
||||
} else {
|
||||
// try to add slash, to make clear it's not an issue
|
||||
// TODO: change this mechanic to actual path purifier
|
||||
url = loader.getResource("/" + this.resourceName);
|
||||
if (url != null)
|
||||
return url;
|
||||
}
|
||||
throw new IllegalStateException("Resource '" + this.resourceName + "' cannot be found.");
|
||||
}
|
||||
return url;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns requested ClassPathResource as File object
|
||||
*
|
||||
* Please note: if this method called from compiled jar, temporary file will be created to provide File access
|
||||
*
|
||||
* @return File requested at constructor call
|
||||
* @throws FileNotFoundException
|
||||
*/
|
||||
public File getFile() throws FileNotFoundException {
|
||||
URL url = this.getUrl();
|
||||
|
||||
if (isJarURL(url)) {
|
||||
/*
|
||||
This is actually request for file, that's packed into jar. Probably the current one, but that doesn't matters.
|
||||
*/
|
||||
try {
|
||||
url = extractActualUrl(url);
|
||||
File file = File.createTempFile("canova_temp", "file");
|
||||
file.deleteOnExit();
|
||||
|
||||
ZipFile zipFile = new ZipFile(url.getFile());
|
||||
ZipEntry entry = zipFile.getEntry(this.resourceName);
|
||||
if (entry == null) {
|
||||
if (this.resourceName.startsWith("/")) {
|
||||
entry = zipFile.getEntry(this.resourceName.replaceFirst("/", ""));
|
||||
if (entry == null) {
|
||||
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
|
||||
}
|
||||
} else
|
||||
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
|
||||
}
|
||||
|
||||
long size = entry.getSize();
|
||||
|
||||
InputStream stream = zipFile.getInputStream(entry);
|
||||
FileOutputStream outputStream = new FileOutputStream(file);
|
||||
byte[] array = new byte[1024];
|
||||
int rd = 0;
|
||||
long bytesRead = 0;
|
||||
do {
|
||||
rd = stream.read(array);
|
||||
outputStream.write(array, 0, rd);
|
||||
bytesRead += rd;
|
||||
} while (bytesRead < size);
|
||||
|
||||
outputStream.flush();
|
||||
outputStream.close();
|
||||
|
||||
stream.close();
|
||||
zipFile.close();
|
||||
|
||||
return file;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
} else {
|
||||
/*
|
||||
It's something in the actual underlying filesystem, so we can just go for it
|
||||
*/
|
||||
|
||||
try {
|
||||
URI uri = new URI(url.toString().replaceAll(" ", "%20"));
|
||||
return new File(uri.getSchemeSpecificPart());
|
||||
} catch (URISyntaxException e) {
|
||||
return new File(url.getFile());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks, if proposed URL is packed into archive.
|
||||
*
|
||||
* @param url URL to be checked
|
||||
* @return True, if URL is archive entry, False otherwise
|
||||
*/
|
||||
private boolean isJarURL(URL url) {
|
||||
String protocol = url.getProtocol();
|
||||
return "jar".equals(protocol) || "zip".equals(protocol) || "wsjar".equals(protocol)
|
||||
|| "code-source".equals(protocol) && url.getPath().contains("!/");
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts parent Jar URL from original ClassPath entry URL.
|
||||
*
|
||||
* @param jarUrl Original URL of the resource
|
||||
* @return URL of the Jar file, containing requested resource
|
||||
* @throws MalformedURLException
|
||||
*/
|
||||
private URL extractActualUrl(URL jarUrl) throws MalformedURLException {
|
||||
String urlFile = jarUrl.getFile();
|
||||
int separatorIndex = urlFile.indexOf("!/");
|
||||
if (separatorIndex != -1) {
|
||||
String jarFile = urlFile.substring(0, separatorIndex);
|
||||
|
||||
try {
|
||||
return new URL(jarFile);
|
||||
} catch (MalformedURLException var5) {
|
||||
if (!jarFile.startsWith("/")) {
|
||||
jarFile = "/" + jarFile;
|
||||
}
|
||||
|
||||
return new URL("file:" + jarFile);
|
||||
}
|
||||
} else {
|
||||
return jarUrl;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns requested ClassPathResource as InputStream object
|
||||
*
|
||||
* @return File requested at constructor call
|
||||
* @throws FileNotFoundException
|
||||
*/
|
||||
public InputStream getInputStream() throws FileNotFoundException {
|
||||
URL url = this.getUrl();
|
||||
if (isJarURL(url)) {
|
||||
try {
|
||||
url = extractActualUrl(url);
|
||||
ZipFile zipFile = new ZipFile(url.getFile());
|
||||
ZipEntry entry = zipFile.getEntry(this.resourceName);
|
||||
|
||||
if (entry == null) {
|
||||
if (this.resourceName.startsWith("/")) {
|
||||
entry = zipFile.getEntry(this.resourceName.replaceFirst("/", ""));
|
||||
if (entry == null) {
|
||||
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
|
||||
}
|
||||
} else
|
||||
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
|
||||
}
|
||||
|
||||
return zipFile.getInputStream(entry);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else {
|
||||
File srcFile = this.getFile();
|
||||
return new FileInputStream(srcFile);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.util;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
|
||||
public class CollectionUtils {
|
||||
|
||||
/**
|
||||
* Count the number of unique values in a collection
|
||||
*/
|
||||
public static int countUnique(Collection<?> collection) {
|
||||
HashSet<Object> set = new HashSet<>(collection);
|
||||
return set.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a list containing only unique values in a collection
|
||||
*/
|
||||
public static <T> List<T> getUnique(Collection<T> collection) {
|
||||
HashSet<T> set = new HashSet<>();
|
||||
List<T> out = new ArrayList<>();
|
||||
for (T t : collection) {
|
||||
if (!set.contains(t)) {
|
||||
out.add(t);
|
||||
set.add(t);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.util;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Created by Alex on 29/06/2017.
|
||||
*/
|
||||
public class LeafUtils {
|
||||
|
||||
private LeafUtils() {}
|
||||
|
||||
/**
|
||||
* Returns a list of unique objects, not using the .equals() method, but rather using ==
|
||||
*
|
||||
* @param allLeaves Leaf values to process
|
||||
* @return A list of unique parameter space values
|
||||
*/
|
||||
public static List<ParameterSpace> getUniqueObjects(List<ParameterSpace> allLeaves) {
|
||||
List<ParameterSpace> unique = new ArrayList<>();
|
||||
for (ParameterSpace p : allLeaves) {
|
||||
//This isn't especially efficient, but small number of parameters in general means it's fine
|
||||
boolean found = false;
|
||||
for (ParameterSpace q : unique) {
|
||||
if (p == q) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
unique.add(p);
|
||||
}
|
||||
}
|
||||
|
||||
return unique;
|
||||
}
|
||||
|
||||
/**
|
||||
* Count the number of unique parameters in the specified leaf nodes
|
||||
*
|
||||
* @param allLeaves Leaf values to count the parameters fore
|
||||
* @return Number of parameters for all unique objects
|
||||
*/
|
||||
public static int countUniqueParameters(List<ParameterSpace> allLeaves) {
|
||||
List<ParameterSpace> unique = getUniqueObjects(allLeaves);
|
||||
int count = 0;
|
||||
for (ParameterSpace ps : unique) {
|
||||
if (!ps.isLeaf()) {
|
||||
throw new IllegalStateException("Method should only be used with leaf nodes");
|
||||
}
|
||||
count += ps.numParameters();
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,63 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.util;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class ObjectUtils {
|
||||
|
||||
private ObjectUtils() {}
|
||||
|
||||
/**
|
||||
* Get the string representation of the object. Arrays, including primitive arrays, are printed using
|
||||
* Arrays.toString(...) methods.
|
||||
*
|
||||
* @param v Value to convert to a string
|
||||
* @return String representation
|
||||
*/
|
||||
public static String valueToString(Object v) {
|
||||
if (v.getClass().isArray()) {
|
||||
if (v.getClass().getComponentType().isPrimitive()) {
|
||||
Class<?> c = v.getClass().getComponentType();
|
||||
if (c == int.class) {
|
||||
return Arrays.toString((int[]) v);
|
||||
} else if (c == double.class) {
|
||||
return Arrays.toString((double[]) v);
|
||||
} else if (c == float.class) {
|
||||
return Arrays.toString((float[]) v);
|
||||
} else if (c == long.class) {
|
||||
return Arrays.toString((long[]) v);
|
||||
} else if (c == byte.class) {
|
||||
return Arrays.toString((byte[]) v);
|
||||
} else if (c == short.class) {
|
||||
return Arrays.toString((short[]) v);
|
||||
} else {
|
||||
return v.toString();
|
||||
}
|
||||
} else {
|
||||
return Arrays.toString((Object[]) v);
|
||||
}
|
||||
} else {
|
||||
return v.toString();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
|
||||
* extends BaseDl4jTest - either directly or indirectly.
|
||||
* Other than a small set of exceptions, all tests must extend this
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
|
||||
@Slf4j
|
||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
||||
|
||||
@Override
|
||||
protected Set<Class<?>> getExclusions() {
|
||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getPackageName() {
|
||||
return "org.deeplearning4j.arbiter.optimize";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> getBaseClass() {
|
||||
return BaseDL4JTest.class;
|
||||
}
|
||||
}
|
|
@ -1,158 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import org.deeplearning4j.arbiter.optimize.api.*;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
public class BraninFunction {
|
||||
public static class BraninSpace extends AbstractParameterSpace<BraninConfig> {
|
||||
private int[] indices;
|
||||
private ParameterSpace<Double> first = new ContinuousParameterSpace(-5, 10);
|
||||
private ParameterSpace<Double> second = new ContinuousParameterSpace(0, 15);
|
||||
|
||||
@Override
|
||||
public BraninConfig getValue(double[] parameterValues) {
|
||||
double f = first.getValue(parameterValues);
|
||||
double s = second.getValue(parameterValues);
|
||||
return new BraninConfig(f, s); //-5 to +10 and 0 to 15
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numParameters() {
|
||||
return 2;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParameterSpace> collectLeaves() {
|
||||
List<ParameterSpace> list = new ArrayList<>();
|
||||
list.addAll(first.collectLeaves());
|
||||
list.addAll(second.collectLeaves());
|
||||
return list;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLeaf() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIndices(int... indices) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
public static class BraninConfig implements Serializable {
|
||||
private double x1;
|
||||
private double x2;
|
||||
}
|
||||
|
||||
public static class BraninScoreFunction implements ScoreFunction {
|
||||
private static final double a = 1.0;
|
||||
private static final double b = 5.1 / (4.0 * Math.PI * Math.PI);
|
||||
private static final double c = 5.0 / Math.PI;
|
||||
private static final double r = 6.0;
|
||||
private static final double s = 10.0;
|
||||
private static final double t = 1.0 / (8.0 * Math.PI);
|
||||
|
||||
@Override
|
||||
public double score(Object m, DataProvider data, Map<String, Object> dataParameters) {
|
||||
BraninConfig model = (BraninConfig) m;
|
||||
double x1 = model.getX1();
|
||||
double x2 = model.getX2();
|
||||
|
||||
return a * Math.pow(x2 - b * x1 * x1 + c * x1 - r, 2.0) + s * (1 - t) * Math.cos(x1) + s;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean minimize() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Class<?>> getSupportedModelTypes() {
|
||||
return Collections.<Class<?>>singletonList(BraninConfig.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Class<?>> getSupportedDataTypes() {
|
||||
return Collections.<Class<?>>singletonList(Object.class);
|
||||
}
|
||||
}
|
||||
|
||||
public static class BraninTaskCreator implements TaskCreator {
|
||||
@Override
|
||||
public Callable<OptimizationResult> create(final Candidate c, DataProvider dataProvider,
|
||||
final ScoreFunction scoreFunction, final List<StatusListener> statusListeners,
|
||||
IOptimizationRunner runner) {
|
||||
|
||||
return new Callable<OptimizationResult>() {
|
||||
@Override
|
||||
public OptimizationResult call() throws Exception {
|
||||
|
||||
BraninConfig candidate = (BraninConfig) c.getValue();
|
||||
|
||||
double score = scoreFunction.score(candidate, null, (Map) null);
|
||||
// System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
|
||||
|
||||
Thread.sleep(20);
|
||||
|
||||
if (statusListeners != null) {
|
||||
for (StatusListener sl : statusListeners) {
|
||||
sl.onCandidateIteration(null, null, 0);
|
||||
}
|
||||
}
|
||||
|
||||
CandidateInfo ci = new CandidateInfo(-1, CandidateStatus.Complete, score,
|
||||
System.currentTimeMillis(), null, null, null, null);
|
||||
|
||||
return new OptimizationResult(c, score, c.getIndex(), null, ci, null);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Callable<OptimizationResult> create(Candidate candidate, Class<? extends DataSource> dataSource,
|
||||
Properties dataSourceProperties, ScoreFunction scoreFunction,
|
||||
List<StatusListener> statusListeners, IOptimizationRunner runner) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,120 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
||||
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
|
||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.GeneticSearchCandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusListener;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TestGeneticSearch extends BaseDL4JTest {
|
||||
public class TestSelectionOperator extends SelectionOperator {
|
||||
|
||||
@Override
|
||||
public double[] buildNextGenes() {
|
||||
throw new GeneticGenerationException("Forced exception to test exception handling.");
|
||||
}
|
||||
}
|
||||
|
||||
public class TestTerminationCondition implements TerminationCondition {
|
||||
|
||||
public boolean hasAFailedCandidate = false;
|
||||
public int evalCount = 0;
|
||||
|
||||
@Override
|
||||
public void initialize(IOptimizationRunner optimizationRunner) {}
|
||||
|
||||
@Override
|
||||
public boolean terminate(IOptimizationRunner optimizationRunner) {
|
||||
if (++evalCount == 50) {
|
||||
// Generator did not handle GeneticGenerationException
|
||||
return true;
|
||||
}
|
||||
|
||||
for (CandidateInfo candidateInfo : optimizationRunner.getCandidateStatus()) {
|
||||
if (candidateInfo.getCandidateStatus() == CandidateStatus.Failed) {
|
||||
hasAFailedCandidate = true;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void GeneticSearchCandidateGenerator_getCandidate_ShouldGenerateCandidates() throws Exception {
|
||||
|
||||
ScoreFunction scoreFunction = new BraninFunction.BraninScoreFunction();
|
||||
|
||||
//Define configuration:
|
||||
CandidateGenerator candidateGenerator =
|
||||
new GeneticSearchCandidateGenerator.Builder(new BraninFunction.BraninSpace(), scoreFunction)
|
||||
.build();
|
||||
|
||||
TestTerminationCondition testTerminationCondition = new TestTerminationCondition();
|
||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).scoreFunction(scoreFunction)
|
||||
.terminationConditions(new MaxCandidatesCondition(50), testTerminationCondition).build();
|
||||
|
||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator());
|
||||
|
||||
runner.addListeners(new LoggingStatusListener());
|
||||
runner.execute();
|
||||
|
||||
Assert.assertFalse(testTerminationCondition.hasAFailedCandidate);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void GeneticSearchCandidateGenerator_getCandidate_GeneticExceptionShouldMarkCandidateAsFailed() {
|
||||
|
||||
ScoreFunction scoreFunction = new BraninFunction.BraninScoreFunction();
|
||||
|
||||
//Define configuration:
|
||||
CandidateGenerator candidateGenerator =
|
||||
new GeneticSearchCandidateGenerator.Builder(new BraninFunction.BraninSpace(), scoreFunction)
|
||||
.selectionOperator(new TestSelectionOperator()).build();
|
||||
|
||||
TestTerminationCondition testTerminationCondition = new TestTerminationCondition();
|
||||
|
||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).scoreFunction(scoreFunction)
|
||||
.terminationConditions(testTerminationCondition).build();
|
||||
|
||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator());
|
||||
|
||||
runner.addListeners(new LoggingStatusListener());
|
||||
runner.execute();
|
||||
|
||||
Assert.assertTrue(testTerminationCondition.hasAFailedCandidate);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,106 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class TestGridSearch extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testIndexing() {
|
||||
int[] nValues = {2, 3};
|
||||
int prod = 2 * 3;
|
||||
double[][] expVals = new double[][] {{0.0, 0.0}, {1.0, 0.0}, {0.0, 0.5}, {1.0, 0.5}, {0.0, 1.0}, {1.0, 1.0}};
|
||||
for (int i = 0; i < prod; i++) {
|
||||
double[] out = GridSearchCandidateGenerator.indexToValues(nValues, i, prod);
|
||||
double[] exp = expVals[i];
|
||||
assertArrayEquals(exp, out, 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGeneration() throws Exception {
|
||||
Map<String, Object> commands = new HashMap<>();
|
||||
commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>());
|
||||
|
||||
//Define configuration:
|
||||
CandidateGenerator candidateGenerator = new GridSearchCandidateGenerator(new BraninFunction.BraninSpace(), 4,
|
||||
GridSearchCandidateGenerator.Mode.Sequential, commands);
|
||||
|
||||
//Check sequential:
|
||||
double[] expValuesFirst = {-5, 0, 5, 10}; //Range: -5 to +10, with 4 values
|
||||
double[] expValuesSecond = {0, 5, 10, 15}; //Range: 0 to +15, with 4 values
|
||||
for (int i = 0; i < 4 * 4; i++) {
|
||||
BraninFunction.BraninConfig conf = (BraninFunction.BraninConfig) candidateGenerator.getCandidate().getValue();
|
||||
double expF = expValuesFirst[i % 4]; //Changes most rapidly
|
||||
double expS = expValuesSecond[i / 4];
|
||||
|
||||
double actF = conf.getX1();
|
||||
double actS = conf.getX2();
|
||||
|
||||
assertEquals(expF, actF, 1e-4);
|
||||
assertEquals(expS, actS, 1e-4);
|
||||
}
|
||||
|
||||
//Check random order. specifically: check that all values are generated, in some order
|
||||
double[][] orderedOutput = new double[16][2];
|
||||
for (int i = 0; i < expValuesFirst.length; i++) {
|
||||
for (int j = 0; j < expValuesSecond.length; j++) {
|
||||
orderedOutput[4 * j + i][0] = expValuesFirst[i];
|
||||
orderedOutput[4 * j + i][1] = expValuesSecond[j];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
candidateGenerator = new GridSearchCandidateGenerator(new BraninFunction.BraninSpace(), 4,
|
||||
GridSearchCandidateGenerator.Mode.RandomOrder, commands);
|
||||
boolean[] seen = new boolean[16];
|
||||
int seenCount = 0;
|
||||
for (int i = 0; i < 4 * 4; i++) {
|
||||
assertTrue(candidateGenerator.hasMoreCandidates());
|
||||
BraninFunction.BraninConfig config = (BraninFunction.BraninConfig) candidateGenerator.getCandidate().getValue();
|
||||
double x1 = config.getX1();
|
||||
double x2 = config.getX2();
|
||||
//Work out which of the values this is...
|
||||
boolean matched = false;
|
||||
for (int j = 0; j < 16; j++) {
|
||||
if (Math.abs(orderedOutput[j][0] - x1) < 1e-5 && Math.abs(orderedOutput[j][1] - x2) < 1e-5) {
|
||||
matched = true;
|
||||
if (seen[j])
|
||||
fail("Same candidate generated multiple times");
|
||||
seen[j] = true;
|
||||
seenCount++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assertTrue("Candidate " + x1 + ", " + x2 + " not found; invalid?", matched);
|
||||
}
|
||||
assertFalse(candidateGenerator.hasMoreCandidates());
|
||||
assertEquals(16, seenCount);
|
||||
}
|
||||
}
|
|
@ -1,124 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import org.apache.commons.math3.distribution.LogNormalDistribution;
|
||||
import org.apache.commons.math3.distribution.NormalDistribution;
|
||||
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.BooleanSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
||||
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
|
||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
/**
|
||||
* Created by Alex on 02/02/2017.
|
||||
*/
|
||||
public class TestJson extends BaseDL4JTest {
|
||||
|
||||
protected static ObjectMapper getObjectMapper(JsonFactory factory) {
|
||||
ObjectMapper om = new ObjectMapper(factory);
|
||||
om.registerModule(new JodaModule());
|
||||
om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||
om.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||
om.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
om.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||
return om;
|
||||
}
|
||||
|
||||
private static ObjectMapper jsonMapper = getObjectMapper(new JsonFactory());
|
||||
private static ObjectMapper yamlMapper = getObjectMapper(new YAMLFactory());
|
||||
|
||||
|
||||
@Test
|
||||
public void testParameterSpaceJson() throws Exception {
|
||||
|
||||
List<ParameterSpace<?>> l = new ArrayList<>();
|
||||
l.add(new FixedValue<>(1.0));
|
||||
l.add(new FixedValue<>(1));
|
||||
l.add(new FixedValue<>("string"));
|
||||
l.add(new ContinuousParameterSpace(-1, 1));
|
||||
l.add(new ContinuousParameterSpace(new LogNormalDistribution(1, 1)));
|
||||
l.add(new ContinuousParameterSpace(new NormalDistribution(2, 0.01)));
|
||||
l.add(new DiscreteParameterSpace<>(1, 5, 7));
|
||||
l.add(new DiscreteParameterSpace<>("first", "second", "third"));
|
||||
l.add(new IntegerParameterSpace(0, 10));
|
||||
l.add(new IntegerParameterSpace(new UniformIntegerDistribution(0, 50)));
|
||||
l.add(new BooleanSpace());
|
||||
|
||||
for (ParameterSpace<?> ps : l) {
|
||||
String strJson = jsonMapper.writeValueAsString(ps);
|
||||
String strYaml = yamlMapper.writeValueAsString(ps);
|
||||
|
||||
ParameterSpace<?> fromJson = jsonMapper.readValue(strJson, ParameterSpace.class);
|
||||
ParameterSpace<?> fromYaml = yamlMapper.readValue(strYaml, ParameterSpace.class);
|
||||
|
||||
assertEquals(ps, fromJson);
|
||||
assertEquals(ps, fromYaml);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCandidateGeneratorJson() throws Exception {
|
||||
Map<String, Object> commands = new HashMap<>();
|
||||
commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>());
|
||||
|
||||
List<CandidateGenerator> l = new ArrayList<>();
|
||||
l.add(new GridSearchCandidateGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), 10,
|
||||
GridSearchCandidateGenerator.Mode.Sequential, commands));
|
||||
l.add(new GridSearchCandidateGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), 10,
|
||||
GridSearchCandidateGenerator.Mode.RandomOrder, commands));
|
||||
l.add(new RandomSearchGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), commands));
|
||||
|
||||
for (CandidateGenerator cg : l) {
|
||||
String strJson = jsonMapper.writeValueAsString(cg);
|
||||
String strYaml = yamlMapper.writeValueAsString(cg);
|
||||
|
||||
CandidateGenerator fromJson = jsonMapper.readValue(strJson, CandidateGenerator.class);
|
||||
CandidateGenerator fromYaml = yamlMapper.readValue(strYaml, CandidateGenerator.class);
|
||||
|
||||
assertEquals(cg, fromJson);
|
||||
assertEquals(cg, fromYaml);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,63 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusListener;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
*
|
||||
* Test random search on the Branin Function:
|
||||
* http://www.sfu.ca/~ssurjano/branin.html
|
||||
*/
|
||||
public class TestRandomSearch extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void test() throws Exception {
|
||||
Map<String, Object> commands = new HashMap<>();
|
||||
commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>());
|
||||
|
||||
//Define configuration:
|
||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(new BraninFunction.BraninSpace(), commands);
|
||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).scoreFunction(new BraninFunction.BraninScoreFunction())
|
||||
.terminationConditions(new MaxCandidatesCondition(50)).build();
|
||||
|
||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator());
|
||||
|
||||
runner.addListeners(new LoggingStatusListener());
|
||||
runner.execute();
|
||||
|
||||
|
||||
// System.out.println("----- Complete -----");
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,72 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class TestLogUniform extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testSimple(){
|
||||
|
||||
double min = 0.5;
|
||||
double max = 3;
|
||||
|
||||
double logMin = Math.log(min);
|
||||
double logMax = Math.log(max);
|
||||
|
||||
RealDistribution rd = new LogUniformDistribution(min, max);
|
||||
|
||||
for(double d = 0.1; d<= 3.5; d+= 0.1){
|
||||
double density = rd.density(d);
|
||||
double cumulative = rd.cumulativeProbability(d);
|
||||
double dExp;
|
||||
double cumExp;
|
||||
if(d < min){
|
||||
dExp = 0;
|
||||
cumExp = 0;
|
||||
} else if( d > max){
|
||||
dExp = 0;
|
||||
cumExp = 1;
|
||||
} else {
|
||||
dExp = 1.0 / (d * (logMax-logMin));
|
||||
cumExp = (Math.log(d) - logMin) / (logMax - logMin);
|
||||
}
|
||||
|
||||
assertTrue(dExp >= 0);
|
||||
assertTrue(cumExp >= 0);
|
||||
assertTrue(cumExp <= 1.0);
|
||||
assertEquals(dExp, density, 1e-5);
|
||||
assertEquals(cumExp, cumulative, 1e-5);
|
||||
}
|
||||
|
||||
rd.reseedRandomGenerator(12345);
|
||||
for( int i=0; i<100; i++ ){
|
||||
double d = rd.sample();
|
||||
assertTrue(d >= min);
|
||||
assertTrue(d <= max);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
||||
public class TestCrossoverOperator extends CrossoverOperator {
|
||||
|
||||
private final CrossoverResult[] results;
|
||||
private int resultIdx = 0;
|
||||
|
||||
public PopulationModel getPopulationModel() {
|
||||
return populationModel;
|
||||
}
|
||||
|
||||
public TestCrossoverOperator(CrossoverResult[] results) {
|
||||
this.results = results;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CrossoverResult crossover() {
|
||||
return results[resultIdx++];
|
||||
}
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.MutationOperator;
|
||||
|
||||
public class TestMutationOperator implements MutationOperator {
|
||||
|
||||
private final boolean[] results;
|
||||
private int resultIdx = 0;
|
||||
|
||||
public TestMutationOperator(boolean[] results) {
|
||||
this.results = results;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean mutate(double[] genes) {
|
||||
return results[resultIdx++];
|
||||
}
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class TestParentSelection extends TwoParentSelection {
|
||||
|
||||
public boolean hasBeenInitialized;
|
||||
|
||||
private final double[][] parents;
|
||||
|
||||
public TestParentSelection(double[][] parents) {
|
||||
this.parents = parents;
|
||||
}
|
||||
|
||||
public TestParentSelection() {
|
||||
this(null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initializeInstance(List<Chromosome> population) {
|
||||
super.initializeInstance(population);
|
||||
hasBeenInitialized = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[][] selectParents() {
|
||||
return parents;
|
||||
}
|
||||
|
||||
public List<Chromosome> getPopulation() {
|
||||
return population;
|
||||
}
|
||||
}
|
|
@ -1,32 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic;
|
||||
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class TestPopulationInitializer implements PopulationInitializer {
|
||||
@Override
|
||||
public List<Chromosome> getInitializedPopulation(int size) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
}
|
|
@ -1,90 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
|
||||
public class TestRandomGenerator implements RandomGenerator {
|
||||
private final int[] intRandomNumbers;
|
||||
private int currentIntIdx = 0;
|
||||
private final double[] doubleRandomNumbers;
|
||||
private int currentDoubleIdx = 0;
|
||||
|
||||
|
||||
public TestRandomGenerator(int[] intRandomNumbers, double[] doubleRandomNumbers) {
|
||||
this.intRandomNumbers = intRandomNumbers;
|
||||
this.doubleRandomNumbers = doubleRandomNumbers;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setSeed(int i) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setSeed(int[] ints) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setSeed(long l) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void nextBytes(byte[] bytes) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextInt() {
|
||||
return intRandomNumbers[currentIntIdx++];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextInt(int i) {
|
||||
return intRandomNumbers[currentIntIdx++];
|
||||
}
|
||||
|
||||
@Override
|
||||
public long nextLong() {
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nextBoolean() {
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public float nextFloat() {
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double nextDouble() {
|
||||
return doubleRandomNumbers[currentDoubleIdx++];
|
||||
}
|
||||
|
||||
@Override
|
||||
public double nextGaussian() {
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class ArithmeticCrossoverTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() {
|
||||
double[][] parents = new double[2][];
|
||||
parents[0] = new double[] {1.0};
|
||||
parents[1] = new double[] {2.0};
|
||||
|
||||
TestParentSelection parentSelection = new TestParentSelection(parents);
|
||||
|
||||
RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0});
|
||||
|
||||
ArithmeticCrossover sut =
|
||||
new ArithmeticCrossover.Builder().parentSelection(parentSelection).randomGenerator(rng).build();
|
||||
CrossoverResult result = sut.crossover();
|
||||
|
||||
Assert.assertFalse(result.isModified());
|
||||
Assert.assertEquals(1, result.getGenes().length);
|
||||
Assert.assertEquals(1.0, result.getGenes()[0], 0.001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void ArithmeticCrossover_Crossover_WithinCrossoverRate_ShouldReturnLinearCombination() {
|
||||
double[][] parents = new double[2][];
|
||||
parents[0] = new double[] {1.0};
|
||||
parents[1] = new double[] {2.0};
|
||||
|
||||
TestParentSelection parentSelection = new TestParentSelection(parents);
|
||||
|
||||
RandomGenerator rng = new TestRandomGenerator(null, new double[] {0.1, 0.1});
|
||||
|
||||
ArithmeticCrossover sut =
|
||||
new ArithmeticCrossover.Builder().parentSelection(parentSelection).randomGenerator(rng).build();
|
||||
CrossoverResult result = sut.crossover();
|
||||
|
||||
Assert.assertTrue(result.isModified());
|
||||
Assert.assertEquals(1, result.getGenes().length);
|
||||
Assert.assertEquals(1.9, result.getGenes()[0], 0.001);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class CrossoverOperatorTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException {
|
||||
TestCrossoverOperator sut = new TestCrossoverOperator(null);
|
||||
|
||||
PopulationInitializer populationInitializer = new TestPopulationInitializer();
|
||||
|
||||
PopulationModel populationModel =
|
||||
new PopulationModel.Builder().populationInitializer(populationInitializer).build();
|
||||
sut.initializeInstance(populationModel);
|
||||
|
||||
Assert.assertSame(populationModel, sut.getPopulationModel());
|
||||
|
||||
|
||||
}
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* * Copyright (c) 2021 Deeplearning4j Contributors
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Deque;
|
||||
|
||||
public class CrossoverPointsGeneratorTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void CrossoverPointsGenerator_FixedNumberCrossovers() {
|
||||
RandomGenerator rng = new TestRandomGenerator(new int[] {0}, null);
|
||||
CrossoverPointsGenerator sut = new CrossoverPointsGenerator(10, 2, 2, rng);
|
||||
|
||||
Deque<Integer> result = sut.getCrossoverPoints();
|
||||
|
||||
Assert.assertEquals(3, result.size());
|
||||
int a = result.pop();
|
||||
int b = result.pop();
|
||||
int c = result.pop();
|
||||
Assert.assertTrue(a < b);
|
||||
Assert.assertTrue(b < c);
|
||||
Assert.assertEquals(Integer.MAX_VALUE, c);
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue