Update copyrights remove attic and relocate elsewhere

master
agibsonccc 2021-02-09 13:16:31 +09:00
parent 5bd386a4f9
commit 46dbd0b203
1164 changed files with 7970 additions and 116171 deletions

View File

@ -1,45 +0,0 @@
# Arbiter
A tool dedicated to tuning (hyperparameter optimization) of machine learning models. Part of the DL4J Suite of Machine Learning / Deep Learning tools for the enterprise.
## Modules
Arbiter contains the following modules:
- arbiter-core: Defines the API and core functionality, and also contains functionality for the Arbiter UI
- arbiter-deeplearning4j: For hyperparameter optimization of DL4J models (MultiLayerNetwork and ComputationGraph networks)
## Hyperparameter Optimization Functionality
The open-source version of Arbiter currently defines two methods of hyperparameter optimization:
- Grid search
- Random search
For optimization of complex models such as neural networks (those with more than a few hyperparameters), random search is superior to grid search, though Bayesian hyperparameter optimization schemes
For a comparison of random and grid search methods, see [Random Search for Hyper-parameter Optimization (Bergstra and Bengio, 2012)](http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf).
### Core Concepts and Classes in Arbiter for Hyperparameter Optimization
In order to conduct hyperparameter optimization in Arbiter, it is necessary for the user to understand and define the following:
- **Parameter Space**: A ```ParameterSpace<P>``` specifies the type and allowable values of hyperparameters for a model configuration of type ```P```. For example, ```P``` could be a MultiLayerConfiguration for DL4J
- **Candidate Generator**: A ```CandidateGenerator<C>``` is used to generate candidate models configurations of some type ```C```. The following implementations are defined in arbiter-core:
- ```RandomSearchCandidateGenerator```
- ```GridSearchCandidateGenerator```
- **Score Function**: A ```ScoreFunction<M,D>``` is used to score a model of type ```M``` given data of type ```D```. For example, in DL4J a score function might be used to calculate the classification accuracy from a DataSetIterator
- A key concept here is that they score is a single numerical (double precision) value that we either want to minimize or maximize - this is the goal of hyperparameter optimization
- **Termination Conditions**: One or more ```TerminationCondition``` instances must be provided to the ```OptimizationConfiguration```. ```TerminationCondition``` instances are used to control when hyperparameter optimization should be stopped. Some built-in termination conditions:
- ```MaxCandidatesCondition```: Terminate if more than the specified number of candidate hyperparameter configurations have been executed
- ```MaxTimeCondition```: Terminate after a specified amount of time has elapsed since starting the optimization
- **Result Saver**: The ```ResultSaver<C,M,A>``` interface is used to specify how the results of each hyperparameter optimization run should be saved. For example, whether saving should be done to local disk, to a database, to HDFS, or simply stored in memory.
- Note that ```ResultSaver.saveModel``` method returns a ```ResultReference``` object, which provides a mechanism for re-loading both the model and score from wherever it may be saved.
- **Optimization Configuration**: An ```OptimizationConfiguration<C,M,D,A>``` ties together the above configuration options in a fluent (builder) pattern.
- **Candidate Executor**: The ```CandidateExecutor<C,M,D,A>``` interface provides a layer of abstraction between the configuration and execution of each instance of learning. Currently, the only option is the ```LocalCandidateExecutor```, which is used to execute learning on a single machine (in the current JVM). In principle, other execution methods (for example, on Spark or cloud computing machines) could be implemented.
- **Optimization Runner**: The ```OptimizationRunner``` uses an ```OptimizationConfiguration``` and a ```CandidateExecutor``` to actually run the optimization, and save the results.
### Optimization of DeepLearning4J Models
(This section: forthcoming)

View File

@ -1,89 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>arbiter-core</artifactId>
<name>arbiter-core</name>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<exclusions>
<exclusion>
<groupId>com.google.code.findbugs</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>joda-time</groupId>
<artifactId>joda-time</artifactId>
</dependency>
<!-- ND4J Shaded Jackson Dependency -->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>jackson</artifactId>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>guava</artifactId>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
<version>${commons-codec.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,95 +0,0 @@
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<assembly>
<id>bin</id>
<!-- START SNIPPET: formats -->
<formats>
<format>tar.gz</format>
<!--
<format>tar.bz2</format>
<format>zip</format>
-->
</formats>
<!-- END SNIPPET: formats -->
<dependencySets>
<dependencySet>
<outputDirectory>lib</outputDirectory>
<includes>
<include>*:jar:*</include>
</includes>
<excludes>
<exclude>*:sources</exclude>
</excludes>
</dependencySet>
</dependencySets>
<!-- START SNIPPET: fileSets -->
<fileSets>
<fileSet>
<includes>
<include>readme.txt</include>
</includes>
</fileSet>
<fileSet>
<directory>src/main/resources/bin/</directory>
<outputDirectory>bin</outputDirectory>
<includes>
<include>arbiter</include>
</includes>
<lineEnding>unix</lineEnding>
<fileMode>0755</fileMode>
</fileSet>
<fileSet>
<directory>examples</directory>
<outputDirectory>examples</outputDirectory>
<!--
<lineEnding>unix</lineEnding>
https://stackoverflow.com/questions/2958282/stranges-files-in-my-assembly-since-switching-to-lineendingunix-lineending
-->
</fileSet>
<!--
<fileSet>
<directory>src/bin</directory>
<outputDirectory>bin</outputDirectory>
<includes>
<include>hello</include>
</includes>
<lineEnding>unix</lineEnding>
<fileMode>0755</fileMode>
</fileSet>
-->
<fileSet>
<directory>target</directory>
<outputDirectory>./</outputDirectory>
<includes>
<include>*.jar</include>
</includes>
</fileSet>
</fileSets>
<!-- END SNIPPET: fileSets -->
</assembly>

View File

@ -1,76 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* Created by Alex on 23/07/2017.
*/
public abstract class AbstractParameterSpace<T> implements ParameterSpace<T> {
@Override
public Map<String, ParameterSpace> getNestedSpaces() {
Map<String, ParameterSpace> m = new LinkedHashMap<>();
//Need to manually build and walk the class hierarchy...
Class<?> currClass = this.getClass();
List<Class<?>> classHierarchy = new ArrayList<>();
while (currClass != Object.class) {
classHierarchy.add(currClass);
currClass = currClass.getSuperclass();
}
for (int i = classHierarchy.size() - 1; i >= 0; i--) {
//Use reflection here to avoid a mass of boilerplate code...
Field[] allFields = classHierarchy.get(i).getDeclaredFields();
for (Field f : allFields) {
String name = f.getName();
Class<?> fieldClass = f.getType();
boolean isParamSpacefield = ParameterSpace.class.isAssignableFrom(fieldClass);
if (!isParamSpacefield) {
continue;
}
f.setAccessible(true);
ParameterSpace<?> p;
try {
p = (ParameterSpace<?>) f.get(this);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
if (p != null) {
m.put(name, p);
}
}
}
return m;
}
}

View File

@ -1,59 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.deeplearning4j.arbiter.optimize.generator.util.SerializedSupplier;
import org.nd4j.common.function.Supplier;
import java.io.Serializable;
import java.util.Map;
/**
* Candidate: a proposed hyperparameter configuration.
* Also includes a map for data parameters, to configure things like data preprocessing, etc.
*/
@Data
@AllArgsConstructor
public class Candidate<C> implements Serializable {
private Supplier<C> supplier;
private int index;
private double[] flatParameters;
private Map<String, Object> dataParameters;
private Exception exception;
public Candidate(C value, int index, double[] flatParameters, Map<String,Object> dataParameters, Exception e) {
this(new SerializedSupplier<C>(value), index, flatParameters, dataParameters, e);
}
public Candidate(C value, int index, double[] flatParameters) {
this(new SerializedSupplier<C>(value), index, flatParameters);
}
public Candidate(Supplier<C> value, int index, double[] flatParameters) {
this(value, index, flatParameters, null, null);
}
public C getValue(){
return supplier.get();
}
}

View File

@ -1,67 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
/**
* A CandidateGenerator proposes candidates (i.e., hyperparameter configurations) for evaluation.
* This abstraction allows for different ways of generating the next configuration to test; for example,
* random search, grid search, Bayesian optimization methods, etc.
*
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface CandidateGenerator {
/**
* Is this candidate generator able to generate more candidates? This will always return true in some
* cases, but some search strategies have a limit (grid search, for example)
*/
boolean hasMoreCandidates();
/**
* Generate a candidate hyperparameter configuration
*/
Candidate getCandidate();
/**
* Report results for the candidate generator.
*
* @param result The results to report
*/
void reportResults(OptimizationResult result);
/**
* @return Get the parameter space for this candidate generator
*/
ParameterSpace<?> getParameterSpace();
/**
* @param rngSeed Set the random number generator seed for the candidate generator
*/
void setRngSeed(long rngSeed);
/**
* @return The type (class) of the generated candidates
*/
Class<?> getCandidateType();
}

View File

@ -1,62 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api;
import lombok.Data;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.Serializable;
/**
* An optimization result represents the results of an optimization run, including the candidate configuration, the
* trained model, the score for that model, and index of the model
*
* @author Alex Black
*/
@Data
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@JsonIgnoreProperties({"resultReference"})
public class OptimizationResult implements Serializable {
@JsonProperty
private Candidate candidate;
@JsonProperty
private Double score;
@JsonProperty
private int index;
@JsonProperty
private Object modelSpecificResults;
@JsonProperty
private CandidateInfo candidateInfo;
private ResultReference resultReference;
public OptimizationResult(Candidate candidate, Double score, int index, Object modelSpecificResults,
CandidateInfo candidateInfo, ResultReference resultReference) {
this.candidate = candidate;
this.score = score;
this.index = index;
this.modelSpecificResults = modelSpecificResults;
this.candidateInfo = candidateInfo;
this.resultReference = resultReference;
}
}

View File

@ -1,83 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.util.List;
import java.util.Map;
/**
* ParameterSpace: defines the acceptable ranges of values a given parameter may take.
* Note that parameter spaces can be simple (like {@code ParameterSpace<Double>}) or complicated, including
* multiple nested ParameterSpaces
*
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ParameterSpace<P> {
/**
* Generate a candidate given a set of values. These values are then mapped to a specific candidate, using some
* mapping function (such as the prior probability distribution)
*
* @param parameterValues A set of values, each in the range [0,1], of length {@link #numParameters()}
*/
P getValue(double[] parameterValues);
/**
* Get the total number of parameters (hyperparameters) to be optimized. This includes optional parameters from
* different parameter subpaces. (Thus, not every parameter may be used in every candidate)
*
* @return Number of hyperparameters to be optimized
*/
int numParameters();
/**
* Collect a list of parameters, recursively. Note that leaf parameters are parameters that do not have any
* nested parameter spaces
*/
List<ParameterSpace> collectLeaves();
/**
* Get a list of nested parameter spaces by name. Note that the returned parameter spaces may in turn have further
* nested parameter spaces. The map should be empty for leaf parameter spaces
*
* @return A map of nested parameter spaces
*/
Map<String, ParameterSpace> getNestedSpaces();
/**
* Is this ParameterSpace a leaf? (i.e., does it contain other ParameterSpaces internally?)
*/
@JsonIgnore
boolean isLeaf();
/**
* For leaf ParameterSpaces: set the indices of the leaf ParameterSpace.
* Expects input of length {@link #numParameters()}. Throws exception if {@link #isLeaf()} is false.
*
* @param indices Indices to set. Length should equal {@link #numParameters()}
*/
void setIndices(int... indices);
}

View File

@ -1,64 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.Callable;
/**
* The TaskCreator is used to take a candidate configuration, data provider and score function, and create something
* that can be executed as a Callable
*
* @author Alex Black
*/
public interface TaskCreator {
/**
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
*
* @param candidate Candidate (model) configuration to be trained
* @param dataProvider DataProvider, for the data
* @param scoreFunction Score function to be used to evaluate the model
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
* @return A callable that returns an OptimizationResult, once optimization is complete
*/
@Deprecated
Callable<OptimizationResult> create(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction,
List<StatusListener> statusListeners, IOptimizationRunner runner);
/**
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
*
* @param candidate Candidate (model) configuration to be trained
* @param dataSource Data source
* @param dataSourceProperties Properties (may be null) for the data source
* @param scoreFunction Score function to be used to evaluate the model
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
* @return A callable that returns an OptimizationResult, once optimization is complete
*/
Callable<OptimizationResult> create(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties,
ScoreFunction scoreFunction, List<StatusListener> statusListeners, IOptimizationRunner runner);
}

View File

@ -1,45 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api;
import java.util.HashMap;
import java.util.Map;
public class TaskCreatorProvider {
private static Map<Class<? extends ParameterSpace>, Class<? extends TaskCreator>> map = new HashMap<>();
public synchronized static TaskCreator defaultTaskCreatorFor(Class<? extends ParameterSpace> paramSpaceClass){
Class<? extends TaskCreator> c = map.get(paramSpaceClass);
try {
if(c == null){
return null;
}
return c.newInstance();
} catch (Exception e){
throw new RuntimeException("Could not create new instance of task creator class: " + c + " - missing no-arg constructor?", e);
}
}
public synchronized static void registerDefaultTaskCreatorClass(Class<? extends ParameterSpace> spaceClass,
Class<? extends TaskCreator> creatorClass){
map.put(spaceClass, creatorClass);
}
}

View File

@ -1,84 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.adapter;
import lombok.AllArgsConstructor;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* An abstract class used for adapting one type into another. Subclasses of this need to merely implement 2 simple methods
*
* @param <F> Type to convert from
* @param <T> Type to convert to
* @author Alex Black
*/
@AllArgsConstructor
public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
protected abstract T convertValue(F from);
protected abstract ParameterSpace<F> underlying();
protected abstract String underlyingName();
@Override
public T getValue(double[] parameterValues) {
return convertValue(underlying().getValue(parameterValues));
}
@Override
public int numParameters() {
return underlying().numParameters();
}
@Override
public List<ParameterSpace> collectLeaves() {
ParameterSpace p = underlying();
if(p.isLeaf()){
return Collections.singletonList(p);
}
return underlying().collectLeaves();
}
@Override
public Map<String, ParameterSpace> getNestedSpaces() {
return Collections.singletonMap(underlyingName(), (ParameterSpace)underlying());
}
@Override
public boolean isLeaf() {
return false; //Underlying may be a leaf, however
}
@Override
public void setIndices(int... indices) {
underlying().setIndices(indices);
}
@Override
public String toString() {
return underlying().toString();
}
}

View File

@ -1,56 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.data;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.Serializable;
import java.util.Map;
/**
* DataProvider interface abstracts out the providing of data
* @deprecated Use {@link DataSource}
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@Deprecated
public interface DataProvider extends Serializable {
/**
* Get training data given some parameters for the data.
* Data parameters map is used to specify things like batch
* size data preprocessing
*
* @param dataParameters Parameters for data. May be null or empty for default data
* @return training data
*/
Object trainData(Map<String, Object> dataParameters);
/**
* Get training data given some parameters for the data. Data parameters map is used to specify things like batch
* size data preprocessing
*
* @param dataParameters Parameters for data. May be null or empty for default data
* @return training data
*/
Object testData(Map<String, Object> dataParameters);
Class<?> getDataType();
}

View File

@ -1,91 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.data;
import lombok.Data;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
import java.util.Map;
/**
* This is a {@link DataProvider} for
* an {@link DataSetIteratorFactory} which
* based on a key of {@link DataSetIteratorFactoryProvider#FACTORY_KEY}
* will create {@link org.nd4j.linalg.dataset.api.iterator.DataSetIterator}
* for use with arbiter.
*
* This {@link DataProvider} is mainly meant for use for command line driven
* applications.
*
* @author Adam Gibson
*/
@Data
public class DataSetIteratorFactoryProvider implements DataProvider {
public final static String FACTORY_KEY = "org.deeplearning4j.arbiter.data.data.factory";
/**
* Get training data given some parameters for the data.
* Data parameters map is used to specify things like batch
* size data preprocessing
*
* @param dataParameters Parameters for data. May be null or empty for default data
* @return training data
*/
@Override
public DataSetIteratorFactory trainData(Map<String, Object> dataParameters) {
return create(dataParameters);
}
/**
* Get training data given some parameters for the data. Data parameters map
* is used to specify things like batch
* size data preprocessing
*
* @param dataParameters Parameters for data. May be null or empty for default data
* @return training data
*/
@Override
public DataSetIteratorFactory testData(Map<String, Object> dataParameters) {
return create(dataParameters);
}
@Override
public Class<?> getDataType() {
return DataSetIteratorFactory.class;
}
private DataSetIteratorFactory create(Map<String, Object> dataParameters) {
if (dataParameters == null)
throw new IllegalArgumentException(
"Data parameters is null. Please specify a class name to create a dataset iterator.");
if (!dataParameters.containsKey(FACTORY_KEY))
throw new IllegalArgumentException(
"No data set iterator factory class found. Please specify a class name with key "
+ FACTORY_KEY);
String value = dataParameters.get(FACTORY_KEY).toString();
try {
Class<? extends DataSetIteratorFactory> clazz =
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
return clazz.newInstance();
} catch (Exception e) {
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
}
}
}

View File

@ -1,59 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.data;
import java.io.Serializable;
import java.util.Properties;
/**
* DataSource: defines where the data should come from for training and testing.
* Note that implementations must have a no-argument contsructor
*
* @author Alex Black
*/
public interface DataSource extends Serializable {
/**
* Configure the current data source with the specified properties
* Note: These properties are fixed for the training instance, and are optionally provided by the user
* at the configuration stage.
* The properties could be anything - and are usually specific to each DataSource implementation.
* For example, values such as batch size could be set using these properties
* @param properties Properties to apply to the data source instance
*/
void configure(Properties properties);
/**
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
*/
Object trainData();
/**
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
*/
Object testData();
/**
* The type of data returned by {@link #trainData()} and {@link #testData()}.
* Usually DataSetIterator or MultiDataSetIterator
* @return Class of the objects returned by trainData and testData
*/
Class<?> getDataType();
}

View File

@ -1,42 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.evaluation;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import java.io.Serializable;
import java.util.List;
/**
* ModelEvaluator: Used to conduct additional evaluation.
* For example, this may be classification performance on a test set or similar
*/
public interface ModelEvaluator extends Serializable {
Object evaluateModel(Object model, DataProvider dataProvider);
/**
* @return The model types supported by this class
*/
List<Class<?>> getSupportedModelTypes();
/**
* @return The datatypes supported by this class
*/
List<Class<?>> getSupportedDataTypes();
}

View File

@ -1,65 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.saving;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
/**
* A simple class to store optimization results in-memory.
* Not recommended for large (or a large number of) models.
*/
@NoArgsConstructor
public class InMemoryResultSaver implements ResultSaver {
@Override
public ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException {
return new InMemoryResult(result, modelResult);
}
@Override
public List<Class<?>> getSupportedCandidateTypes() {
return Collections.<Class<?>>singletonList(Object.class);
}
@Override
public List<Class<?>> getSupportedModelTypes() {
return Collections.<Class<?>>singletonList(Object.class);
}
@AllArgsConstructor
private static class InMemoryResult implements ResultReference {
private OptimizationResult result;
private Object modelResult;
@Override
public OptimizationResult getResult() throws IOException {
return result;
}
@Override
public Object getResultModel() throws IOException {
return modelResult;
}
}
}

View File

@ -1,39 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.saving;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.IOException;
/**
* Idea: We can't store all results in memory in general (might have thousands of candidates with millions of
* parameters each)
* So instead: return a reference to the saved result. Idea is that the result may be saved to disk or a database,
* and we can easily load it back into memory (if/when required) using the getResult() method
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ResultReference {
OptimizationResult getResult() throws IOException;
Object getResultModel() throws IOException;
}

View File

@ -1,59 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.saving;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.IOException;
import java.util.List;
/**
* The ResultSaver interface provides a means of saving models in such a way that they can be loaded back into memory later,
* regardless of where/how they are saved.
*
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ResultSaver {
/**
* Save the model (including configuration and any additional evaluation/results)
*
* @param result Optimization result for the model to save
* @param modelResult Model result to save
* @return ResultReference, such that the result can be loaded back into memory
* @throws IOException If IO error occurs during model saving
*/
ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException;
/**
* @return The candidate types supported by this class
*/
List<Class<?>> getSupportedCandidateTypes();
/**
* @return The model types supported by this class
*/
List<Class<?>> getSupportedModelTypes();
}

View File

@ -1,77 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.score;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Properties;
/**
* ScoreFunction defines the objective of hyperparameter optimization.
* Specifically, it is used to calculate a score for a given model, relative to the data set provided
* in the configuration.
*
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ScoreFunction extends Serializable {
/**
* Calculate and return the score, for the given model and data provider
*
* @param model Model to score
* @param dataProvider Data provider - data to use
* @param dataParameters Parameters for data
* @return Calculated score
*/
double score(Object model, DataProvider dataProvider, Map<String, Object> dataParameters);
/**
* Calculate and return the score, for the given model and data provider
*
* @param model Model to score
* @param dataSource Data source
* @param dataSourceProperties data source properties
* @return Calculated score
*/
double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties);
/**
* Should this score function be minimized or maximized?
*
* @return true if score should be minimized, false if score should be maximized
*/
boolean minimize();
/**
* @return The model types supported by this class
*/
List<Class<?>> getSupportedModelTypes();
/**
* @return The data types supported by this class
*/
List<Class<?>> getSupportedDataTypes();
}

View File

@ -1,52 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.termination;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.nd4j.shade.jackson.annotation.JsonProperty;
/**
* Terminate hyperparameter search when the number of candidates exceeds a specified value.
* Note that this is counted as number of completed candidates, plus number of failed candidates.
*/
@AllArgsConstructor
@NoArgsConstructor
@Data
public class MaxCandidatesCondition implements TerminationCondition {
@JsonProperty
private int maxCandidates;
@Override
public void initialize(IOptimizationRunner optimizationRunner) {
//No op
}
@Override
public boolean terminate(IOptimizationRunner optimizationRunner) {
return optimizationRunner.numCandidatesCompleted() + optimizationRunner.numCandidatesFailed() >= maxCandidates;
}
@Override
public String toString() {
return "MaxCandidatesCondition(" + maxCandidates + ")";
}
}

View File

@ -1,83 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.termination;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.concurrent.TimeUnit;
/**
* Terminate hyperparameter optimization after
* a fixed amount of time has passed
* @author Alex Black
*/
@NoArgsConstructor
@Data
public class MaxTimeCondition implements TerminationCondition {
private static final DateTimeFormatter formatter = DateTimeFormat.forPattern("dd-MMM HH:mm ZZ");
private long duration;
private TimeUnit timeUnit;
private long startTime;
private long endTime;
private MaxTimeCondition(@JsonProperty("duration") long duration, @JsonProperty("timeUnit") TimeUnit timeUnit,
@JsonProperty("startTime") long startTime, @JsonProperty("endTime") long endTime) {
this.duration = duration;
this.timeUnit = timeUnit;
this.startTime = startTime;
this.endTime = endTime;
}
/**
* @param duration Duration of time
* @param timeUnit Unit that the duration is specified in
*/
public MaxTimeCondition(long duration, TimeUnit timeUnit) {
this.duration = duration;
this.timeUnit = timeUnit;
}
@Override
public void initialize(IOptimizationRunner optimizationRunner) {
startTime = System.currentTimeMillis();
this.endTime = startTime + timeUnit.toMillis(duration);
}
@Override
public boolean terminate(IOptimizationRunner optimizationRunner) {
return System.currentTimeMillis() >= endTime;
}
@Override
public String toString() {
if (startTime > 0) {
return "MaxTimeCondition(" + duration + "," + timeUnit + ",start=\"" + formatter.print(startTime)
+ "\",end=\"" + formatter.print(endTime) + "\")";
} else {
return "MaxTimeCondition(" + duration + "," + timeUnit + "\")";
}
}
}

View File

@ -1,47 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.api.termination;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
/**
* Global termination condition for conducting hyperparameter optimization.
* Termination conditions are used to determine if/when the optimization should stop.
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@JsonInclude(JsonInclude.Include.NON_NULL)
public interface TerminationCondition {
/**
* Initialize the termination condition (such as starting timers, etc).
*/
void initialize(IOptimizationRunner optimizationRunner);
/**
* Determine whether optimization should be terminated
*
* @param optimizationRunner Optimization runner
* @return true if learning should be terminated, false otherwise
*/
boolean terminate(IOptimizationRunner optimizationRunner);
}

View File

@ -1,223 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.config;
import lombok.*;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
/**
* OptimizationConfiguration ties together all of the various
* components (such as data, score functions, result saving etc)
* required to execute hyperparameter optimization.
*
* @author Alex Black
*/
@Data
@NoArgsConstructor
@EqualsAndHashCode(exclude = {"dataProvider", "terminationConditions", "candidateGenerator", "resultSaver"})
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public class OptimizationConfiguration {
@JsonSerialize
private DataProvider dataProvider;
@JsonSerialize
private Class<? extends DataSource> dataSource;
@JsonSerialize
private Properties dataSourceProperties;
@JsonSerialize
private CandidateGenerator candidateGenerator;
@JsonSerialize
private ResultSaver resultSaver;
@JsonSerialize
private ScoreFunction scoreFunction;
@JsonSerialize
private List<TerminationCondition> terminationConditions;
@JsonSerialize
private Long rngSeed;
@Getter
@Setter
private long executionStartTime;
private OptimizationConfiguration(Builder builder) {
this.dataProvider = builder.dataProvider;
this.dataSource = builder.dataSource;
this.dataSourceProperties = builder.dataSourceProperties;
this.candidateGenerator = builder.candidateGenerator;
this.resultSaver = builder.resultSaver;
this.scoreFunction = builder.scoreFunction;
this.terminationConditions = builder.terminationConditions;
this.rngSeed = builder.rngSeed;
if (rngSeed != null)
candidateGenerator.setRngSeed(rngSeed);
//Validate the configuration: data types, score types, etc
//TODO
//Validate that the dataSource has a no-arg constructor
if(dataSource != null){
try{
dataSource.getConstructor();
} catch (NoSuchMethodException e){
throw new IllegalStateException("Data source class " + dataSource.getName() + " does not have a public no-argument constructor");
}
}
}
public static class Builder {
private DataProvider dataProvider;
private Class<? extends DataSource> dataSource;
private Properties dataSourceProperties;
private CandidateGenerator candidateGenerator;
private ResultSaver resultSaver;
private ScoreFunction scoreFunction;
private List<TerminationCondition> terminationConditions;
private Long rngSeed;
/**
* @deprecated Use {@link #dataSource(Class, Properties)}
*/
@Deprecated
public Builder dataProvider(DataProvider dataProvider) {
this.dataProvider = dataProvider;
return this;
}
/**
* DataSource: defines where the data should come from for training and testing.
* Note that implementations must have a no-argument constructor
* @param dataSource Class for the data source
* @param dataSourceProperties May be null. Properties for configuring the data source
*/
public Builder dataSource(Class<? extends DataSource> dataSource, Properties dataSourceProperties){
this.dataSource = dataSource;
this.dataSourceProperties = dataSourceProperties;
return this;
}
public Builder candidateGenerator(CandidateGenerator candidateGenerator) {
this.candidateGenerator = candidateGenerator;
return this;
}
public Builder modelSaver(ResultSaver resultSaver) {
this.resultSaver = resultSaver;
return this;
}
public Builder scoreFunction(ScoreFunction scoreFunction) {
this.scoreFunction = scoreFunction;
return this;
}
/**
* Termination conditions to use
* @param conditions
* @return
*/
public Builder terminationConditions(TerminationCondition... conditions) {
terminationConditions = Arrays.asList(conditions);
return this;
}
public Builder terminationConditions(List<TerminationCondition> terminationConditions) {
this.terminationConditions = terminationConditions;
return this;
}
public Builder rngSeed(long rngSeed) {
this.rngSeed = rngSeed;
return this;
}
public OptimizationConfiguration build() {
return new OptimizationConfiguration(this);
}
}
/**
* Create an optimization configuration from the json
* @param json the json to create the config from
* For type definitions
* @see OptimizationConfiguration
*/
public static OptimizationConfiguration fromYaml(String json) {
try {
return JsonMapper.getYamlMapper().readValue(json, OptimizationConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Create an optimization configuration from the json
* @param json the json to create the config from
* @see OptimizationConfiguration
*/
public static OptimizationConfiguration fromJson(String json) {
try {
return JsonMapper.getMapper().readValue(json, OptimizationConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Return a json configuration of this optimization configuration
*
* @return
*/
public String toJson() {
try {
return JsonMapper.getMapper().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Return a yaml configuration of this optimization configuration
*
* @return
*/
public String toYaml() {
try {
return JsonMapper.getYamlMapper().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -1,99 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.distribution;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException;
import java.util.Arrays;
/**
* Degenerate distribution: i.e., integer "distribution" that is just a fixed value
*/
public class DegenerateIntegerDistribution implements IntegerDistribution {
private int value;
public DegenerateIntegerDistribution(int value) {
this.value = value;
}
@Override
public double probability(int x) {
return (x == value ? 1.0 : 0.0);
}
@Override
public double cumulativeProbability(int x) {
return (x >= value ? 1.0 : 0.0);
}
@Override
public double cumulativeProbability(int x0, int x1) throws NumberIsTooLargeException {
return (value >= x0 && value <= x1 ? 1.0 : 0.0);
}
@Override
public int inverseCumulativeProbability(double p) throws OutOfRangeException {
throw new UnsupportedOperationException();
}
@Override
public double getNumericalMean() {
return value;
}
@Override
public double getNumericalVariance() {
return 0;
}
@Override
public int getSupportLowerBound() {
return value;
}
@Override
public int getSupportUpperBound() {
return value;
}
@Override
public boolean isSupportConnected() {
return true;
}
@Override
public void reseedRandomGenerator(long seed) {
//no op
}
@Override
public int sample() {
return value;
}
@Override
public int[] sample(int sampleSize) {
int[] out = new int[sampleSize];
Arrays.fill(out, value);
return out;
}
}

View File

@ -1,151 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.distribution;
import org.apache.commons.math3.distribution.*;
/**
* Distribution utils for Apache Commons math distributions - which don't provide equals, hashcode, toString methods,
* don't implement serializable etc.
* Which makes unit testing etc quite difficult.
*
* @author Alex Black
*/
public class DistributionUtils {
private DistributionUtils() {}
public static boolean distributionsEqual(RealDistribution a, RealDistribution b) {
if (a.getClass() != b.getClass())
return false;
Class<?> c = a.getClass();
if (c == BetaDistribution.class) {
BetaDistribution ba = (BetaDistribution) a;
BetaDistribution bb = (BetaDistribution) b;
return ba.getAlpha() == bb.getAlpha() && ba.getBeta() == bb.getBeta();
} else if (c == CauchyDistribution.class) {
CauchyDistribution ca = (CauchyDistribution) a;
CauchyDistribution cb = (CauchyDistribution) b;
return ca.getMedian() == cb.getMedian() && ca.getScale() == cb.getScale();
} else if (c == ChiSquaredDistribution.class) {
ChiSquaredDistribution ca = (ChiSquaredDistribution) a;
ChiSquaredDistribution cb = (ChiSquaredDistribution) b;
return ca.getDegreesOfFreedom() == cb.getDegreesOfFreedom();
} else if (c == ExponentialDistribution.class) {
ExponentialDistribution ea = (ExponentialDistribution) a;
ExponentialDistribution eb = (ExponentialDistribution) b;
return ea.getMean() == eb.getMean();
} else if (c == FDistribution.class) {
FDistribution fa = (FDistribution) a;
FDistribution fb = (FDistribution) b;
return fa.getNumeratorDegreesOfFreedom() == fb.getNumeratorDegreesOfFreedom()
&& fa.getDenominatorDegreesOfFreedom() == fb.getDenominatorDegreesOfFreedom();
} else if (c == GammaDistribution.class) {
GammaDistribution ga = (GammaDistribution) a;
GammaDistribution gb = (GammaDistribution) b;
return ga.getShape() == gb.getShape() && ga.getScale() == gb.getScale();
} else if (c == LevyDistribution.class) {
LevyDistribution la = (LevyDistribution) a;
LevyDistribution lb = (LevyDistribution) b;
return la.getLocation() == lb.getLocation() && la.getScale() == lb.getScale();
} else if (c == LogNormalDistribution.class) {
LogNormalDistribution la = (LogNormalDistribution) a;
LogNormalDistribution lb = (LogNormalDistribution) b;
return la.getScale() == lb.getScale() && la.getShape() == lb.getShape();
} else if (c == NormalDistribution.class) {
NormalDistribution na = (NormalDistribution) a;
NormalDistribution nb = (NormalDistribution) b;
return na.getMean() == nb.getMean() && na.getStandardDeviation() == nb.getStandardDeviation();
} else if (c == ParetoDistribution.class) {
ParetoDistribution pa = (ParetoDistribution) a;
ParetoDistribution pb = (ParetoDistribution) b;
return pa.getScale() == pb.getScale() && pa.getShape() == pb.getShape();
} else if (c == TDistribution.class) {
TDistribution ta = (TDistribution) a;
TDistribution tb = (TDistribution) b;
return ta.getDegreesOfFreedom() == tb.getDegreesOfFreedom();
} else if (c == TriangularDistribution.class) {
TriangularDistribution ta = (TriangularDistribution) a;
TriangularDistribution tb = (TriangularDistribution) b;
return ta.getSupportLowerBound() == tb.getSupportLowerBound()
&& ta.getSupportUpperBound() == tb.getSupportUpperBound() && ta.getMode() == tb.getMode();
} else if (c == UniformRealDistribution.class) {
UniformRealDistribution ua = (UniformRealDistribution) a;
UniformRealDistribution ub = (UniformRealDistribution) b;
return ua.getSupportLowerBound() == ub.getSupportLowerBound()
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
} else if (c == WeibullDistribution.class) {
WeibullDistribution wa = (WeibullDistribution) a;
WeibullDistribution wb = (WeibullDistribution) b;
return wa.getShape() == wb.getShape() && wa.getScale() == wb.getScale();
} else if (c == LogUniformDistribution.class ){
LogUniformDistribution lu_a = (LogUniformDistribution)a;
LogUniformDistribution lu_b = (LogUniformDistribution)b;
return lu_a.getMin() == lu_b.getMin() && lu_a.getMax() == lu_b.getMax();
} else {
throw new UnsupportedOperationException("Unknown or not supported RealDistribution: " + c);
}
}
public static boolean distributionEquals(IntegerDistribution a, IntegerDistribution b) {
if (a.getClass() != b.getClass())
return false;
Class<?> c = a.getClass();
if (c == BinomialDistribution.class) {
BinomialDistribution ba = (BinomialDistribution) a;
BinomialDistribution bb = (BinomialDistribution) b;
return ba.getNumberOfTrials() == bb.getNumberOfTrials()
&& ba.getProbabilityOfSuccess() == bb.getProbabilityOfSuccess();
} else if (c == GeometricDistribution.class) {
GeometricDistribution ga = (GeometricDistribution) a;
GeometricDistribution gb = (GeometricDistribution) b;
return ga.getProbabilityOfSuccess() == gb.getProbabilityOfSuccess();
} else if (c == HypergeometricDistribution.class) {
HypergeometricDistribution ha = (HypergeometricDistribution) a;
HypergeometricDistribution hb = (HypergeometricDistribution) b;
return ha.getPopulationSize() == hb.getPopulationSize()
&& ha.getNumberOfSuccesses() == hb.getNumberOfSuccesses()
&& ha.getSampleSize() == hb.getSampleSize();
} else if (c == PascalDistribution.class) {
PascalDistribution pa = (PascalDistribution) a;
PascalDistribution pb = (PascalDistribution) b;
return pa.getNumberOfSuccesses() == pb.getNumberOfSuccesses()
&& pa.getProbabilityOfSuccess() == pb.getProbabilityOfSuccess();
} else if (c == PoissonDistribution.class) {
PoissonDistribution pa = (PoissonDistribution) a;
PoissonDistribution pb = (PoissonDistribution) b;
return pa.getMean() == pb.getMean();
} else if (c == UniformIntegerDistribution.class) {
UniformIntegerDistribution ua = (UniformIntegerDistribution) a;
UniformIntegerDistribution ub = (UniformIntegerDistribution) b;
return ua.getSupportUpperBound() == ub.getSupportUpperBound()
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
} else if (c == ZipfDistribution.class) {
ZipfDistribution za = (ZipfDistribution) a;
ZipfDistribution zb = (ZipfDistribution) b;
return za.getNumberOfElements() == zb.getNumberOfElements() && za.getExponent() == zb.getNumberOfElements();
} else {
throw new UnsupportedOperationException("Unknown or not supported IntegerDistribution: " + c);
}
}
}

View File

@ -1,157 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.distribution;
import org.nd4j.shade.guava.base.Preconditions;
import lombok.Getter;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException;
import java.util.Random;
/**
* Log uniform distribution, with support in range [min, max] for min > 0
*
* Reference: <a href="https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php">https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php</a>
*
* @author Alex Black
*/
public class LogUniformDistribution implements RealDistribution {
@Getter private final double min;
@Getter private final double max;
private final double logMin;
private final double logMax;
private transient Random rng = new Random();
/**
*
* @param min Minimum value
* @param max Maximum value
*/
public LogUniformDistribution(double min, double max) {
Preconditions.checkArgument(min > 0, "Minimum must be > 0. Got: " + min);
Preconditions.checkArgument(max > min, "Maximum must be > min. Got: (min, max)=("
+ min + "," + max + ")");
this.min = min;
this.max = max;
this.logMin = Math.log(min);
this.logMax = Math.log(max);
}
@Override
public double probability(double x) {
if(x < min || x > max){
return 0;
}
return 1.0 / (x * (logMax - logMin));
}
@Override
public double density(double x) {
return probability(x);
}
@Override
public double cumulativeProbability(double x) {
if(x <= min){
return 0.0;
} else if(x >= max){
return 1.0;
}
return (Math.log(x)-logMin)/(logMax-logMin);
}
@Override
public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException {
return cumulativeProbability(x1) - cumulativeProbability(x0);
}
@Override
public double inverseCumulativeProbability(double p) throws OutOfRangeException {
Preconditions.checkArgument(p >= 0 && p <= 1, "Invalid input: " + p);
return Math.exp(p * (logMax-logMin) + logMin);
}
@Override
public double getNumericalMean() {
return (max-min)/(logMax-logMin);
}
@Override
public double getNumericalVariance() {
double d1 = (logMax-logMin)*(max*max - min*min) - 2*(max-min)*(max-min);
return d1 / (2*Math.pow(logMax-logMin, 2.0));
}
@Override
public double getSupportLowerBound() {
return min;
}
@Override
public double getSupportUpperBound() {
return max;
}
@Override
public boolean isSupportLowerBoundInclusive() {
return true;
}
@Override
public boolean isSupportUpperBoundInclusive() {
return true;
}
@Override
public boolean isSupportConnected() {
return true;
}
@Override
public void reseedRandomGenerator(long seed) {
rng.setSeed(seed);
}
@Override
public double sample() {
return inverseCumulativeProbability(rng.nextDouble());
}
@Override
public double[] sample(int sampleSize) {
double[] d = new double[sampleSize];
for( int i=0; i<sampleSize; i++ ){
d[i] = sample();
}
return d;
}
@Override
public String toString(){
return "LogUniformDistribution(min=" + min + ",max=" + max + ")";
}
}

View File

@ -1,93 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.util.LeafUtils;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
/**
* BaseCandidateGenerator: abstract class upon which {@link RandomSearchGenerator},
* {@link GridSearchCandidateGenerator} and {@link GeneticSearchCandidateGenerator}
* are built.
*
* @param <T> Type of candidates to generate
*/
@Data
@EqualsAndHashCode(exclude = {"rng", "candidateCounter"})
public abstract class BaseCandidateGenerator<T> implements CandidateGenerator {
protected ParameterSpace<T> parameterSpace;
protected AtomicInteger candidateCounter = new AtomicInteger(0);
protected SynchronizedRandomGenerator rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
protected Map<String, Object> dataParameters;
protected boolean initDone = false;
public BaseCandidateGenerator(ParameterSpace<T> parameterSpace, Map<String, Object> dataParameters,
boolean initDone) {
this.parameterSpace = parameterSpace;
this.dataParameters = dataParameters;
this.initDone = initDone;
}
protected void initialize() {
if(!initDone) {
//First: collect leaf parameter spaces objects and remove duplicates
List<ParameterSpace> noDuplicatesList = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
//Second: assign each a number
int i = 0;
for (ParameterSpace ps : noDuplicatesList) {
int np = ps.numParameters();
if (np == 1) {
ps.setIndices(i++);
} else {
int[] values = new int[np];
for (int j = 0; j < np; j++)
values[j] = i++;
ps.setIndices(values);
}
}
initDone = true;
}
}
@Override
public ParameterSpace<T> getParameterSpace() {
return parameterSpace;
}
@Override
public void reportResults(OptimizationResult result) {
//No op
}
@Override
public void setRngSeed(long rngSeed) {
rng.setSeed(rngSeed);
}
}

View File

@ -1,189 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.EmptyPopulationInitializer;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.GeneticSelectionOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator;
import java.util.Map;
/**
* Uses a genetic algorithm to generate candidates.
*
* @author Alexandre Boulanger
*/
@Slf4j
public class GeneticSearchCandidateGenerator extends BaseCandidateGenerator {
@Getter
protected final PopulationModel populationModel;
protected final ChromosomeFactory chromosomeFactory;
protected final SelectionOperator selectionOperator;
protected boolean hasMoreCandidates = true;
public static class Builder {
protected final ParameterSpace<?> parameterSpace;
protected Map<String, Object> dataParameters;
protected boolean initDone;
protected boolean minimizeScore;
protected PopulationModel populationModel;
protected ChromosomeFactory chromosomeFactory;
protected SelectionOperator selectionOperator;
/**
* @param parameterSpace ParameterSpace from which to generate candidates
* @param scoreFunction The score function that will be used in the OptimizationConfiguration
*/
public Builder(ParameterSpace<?> parameterSpace, ScoreFunction scoreFunction) {
this.parameterSpace = parameterSpace;
this.minimizeScore = scoreFunction.minimize();
}
/**
* @param populationModel The PopulationModel instance to use.
*/
public Builder populationModel(PopulationModel populationModel) {
this.populationModel = populationModel;
return this;
}
/**
* @param selectionOperator The SelectionOperator to use. Default is GeneticSelectionOperator
*/
public Builder selectionOperator(SelectionOperator selectionOperator) {
this.selectionOperator = selectionOperator;
return this;
}
public Builder dataParameters(Map<String, Object> dataParameters) {
this.dataParameters = dataParameters;
return this;
}
public GeneticSearchCandidateGenerator.Builder initDone(boolean initDone) {
this.initDone = initDone;
return this;
}
/**
* @param chromosomeFactory The ChromosomeFactory to use
*/
public Builder chromosomeFactory(ChromosomeFactory chromosomeFactory) {
this.chromosomeFactory = chromosomeFactory;
return this;
}
public GeneticSearchCandidateGenerator build() {
if (populationModel == null) {
PopulationInitializer defaultPopulationInitializer = new EmptyPopulationInitializer();
populationModel = new PopulationModel.Builder().populationInitializer(defaultPopulationInitializer)
.build();
}
if (chromosomeFactory == null) {
chromosomeFactory = new ChromosomeFactory();
}
if (selectionOperator == null) {
selectionOperator = new GeneticSelectionOperator.Builder().build();
}
return new GeneticSearchCandidateGenerator(this);
}
}
private GeneticSearchCandidateGenerator(Builder builder) {
super(builder.parameterSpace, builder.dataParameters, builder.initDone);
initialize();
chromosomeFactory = builder.chromosomeFactory;
populationModel = builder.populationModel;
selectionOperator = builder.selectionOperator;
chromosomeFactory.initializeInstance(builder.parameterSpace.numParameters());
populationModel.initializeInstance(builder.minimizeScore);
selectionOperator.initializeInstance(populationModel, chromosomeFactory);
}
@Override
public boolean hasMoreCandidates() {
return hasMoreCandidates;
}
@Override
public Candidate getCandidate() {
double[] values = null;
Object value = null;
Exception e = null;
try {
values = selectionOperator.buildNextGenes();
value = parameterSpace.getValue(values);
} catch (GeneticGenerationException e2) {
log.warn("Error generating candidate", e2);
e = e2;
hasMoreCandidates = false;
} catch (Exception e2) {
log.warn("Error getting configuration for candidate", e2);
e = e2;
}
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
}
@Override
public Class<?> getCandidateType() {
return null;
}
@Override
public String toString() {
return "GeneticSearchCandidateGenerator";
}
@Override
public void reportResults(OptimizationResult result) {
if (result.getScore() == null) {
return;
}
Chromosome newChromosome = chromosomeFactory.createChromosome(result.getCandidate().getFlatParameters(),
result.getScore());
populationModel.add(newChromosome);
}
}

View File

@ -1,234 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.math3.random.RandomAdaptor;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.deeplearning4j.arbiter.util.LeafUtils;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.*;
import java.util.concurrent.ConcurrentLinkedQueue;
/**
* GridSearchCandidateGenerator: generates candidates in an exhaustive grid search manner.<br>
* Note that:<br>
* - For discrete parameters: the grid size (# values to check per hyperparameter) is equal to the number of values for
* that hyperparameter<br>
* - For integer parameters: the grid size is equal to {@code min(discretizationCount,max-min+1)}. Some integer ranges can
* be large, and we don't necessarily want to exhaustively search them. {@code discretizationCount} is a constructor argument<br>
* - For continuous parameters: the grid size is equal to {@code discretizationCount}.<br>
* In all cases, the minimum, maximum and gridSize-2 values between the min/max will be generated.<br>
* Also note that: if a probability distribution is provided for continuous hyperparameters, this will be taken into account
* when generating candidates. This allows the grid for a hyperparameter to be non-linear: i.e., for example, linear in log space
*
* @author Alex Black
*/
@Slf4j
@EqualsAndHashCode(exclude = {"order"}, callSuper = true)
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
/**
* In what order should candidates be generated?<br>
* <b>Sequential</b>: generate candidates in order. The first hyperparameter will be changed most rapidly, and the last
* will be changed least rapidly.<br>
* <b>RandomOrder</b>: generate candidates in a random order<br>
* In both cases, the same candidates will be generated; only the order of generation is different
*/
public enum Mode {
Sequential, RandomOrder
}
private final int discretizationCount;
private final Mode mode;
private int[] numValuesPerParam;
@Getter
private int totalNumCandidates;
private Queue<Integer> order;
/**
* @param parameterSpace ParameterSpace from which to generate candidates
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
* do [0.0, 0.5, 1.0]. Note that if all values
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
* in which candidates should be generated.
*/
public GridSearchCandidateGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
@JsonProperty("discretizationCount") int discretizationCount, @JsonProperty("mode") Mode mode,
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
@JsonProperty("initDone") boolean initDone) {
super(parameterSpace, dataParameters, initDone);
this.discretizationCount = discretizationCount;
this.mode = mode;
initialize();
}
/**
* @param parameterSpace ParameterSpace from which to generate candidates
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
* do [0.0, 0.5, 1.0]. Note that if all values
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
* in which candidates should be generated.
*/
public GridSearchCandidateGenerator(ParameterSpace<?> parameterSpace, int discretizationCount, Mode mode,
Map<String, Object> dataParameters){
this(parameterSpace, discretizationCount, mode, dataParameters, false);
}
@Override
protected void initialize() {
super.initialize();
List<ParameterSpace> leaves = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
int nParams = leaves.size();
//Work out for each parameter: is it continuous or discrete?
// for grid search: discrete values are grid-searchable as-is
// continuous values: discretize using 'discretizationCount' bins
// integer values: use min(max-min+1, discretizationCount) values. i.e., discretize if necessary
numValuesPerParam = new int[nParams];
long searchSize = 1;
for (int i = 0; i < nParams; i++) {
ParameterSpace ps = leaves.get(i);
if (ps instanceof DiscreteParameterSpace) {
DiscreteParameterSpace dps = (DiscreteParameterSpace) ps;
numValuesPerParam[i] = dps.numValues();
} else if (ps instanceof IntegerParameterSpace) {
IntegerParameterSpace ips = (IntegerParameterSpace) ps;
int min = ips.getMin();
int max = ips.getMax();
//Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000)
numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount);
} else if (ps instanceof FixedValue){
numValuesPerParam[i] = 1;
} else {
numValuesPerParam[i] = discretizationCount;
}
searchSize *= numValuesPerParam[i];
}
if (searchSize >= Integer.MAX_VALUE)
throw new IllegalStateException("Invalid search: cannot process search with " + searchSize
+ " candidates > Integer.MAX_VALUE"); //TODO find a more reasonable upper bound?
order = new ConcurrentLinkedQueue<>();
totalNumCandidates = (int) searchSize;
switch (mode) {
case Sequential:
for (int i = 0; i < totalNumCandidates; i++) {
order.add(i);
}
break;
case RandomOrder:
List<Integer> tempList = new ArrayList<>(totalNumCandidates);
for (int i = 0; i < totalNumCandidates; i++) {
tempList.add(i);
}
Collections.shuffle(tempList, new RandomAdaptor(rng));
order.addAll(tempList);
break;
default:
throw new RuntimeException();
}
}
@Override
public boolean hasMoreCandidates() {
return !order.isEmpty();
}
@Override
public Candidate getCandidate() {
int next = order.remove();
//Next: max integer (candidate number) to values
double[] values = indexToValues(numValuesPerParam, next, totalNumCandidates);
Object value = null;
Exception e = null;
try {
value = parameterSpace.getValue(values);
} catch (Exception e2) {
log.warn("Error getting configuration for candidate", e2);
e = e2;
}
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
}
@Override
public Class<?> getCandidateType() {
return null;
}
public static double[] indexToValues(int[] numValuesPerParam, int candidateIdx, int product) {
//How? first map to index of num possible values. Then: to double values in range 0 to 1
// 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc
//Based on: Nd4j Shape.ind2sub
int countNon1 = 0;
for( int i : numValuesPerParam)
if(i > 1)
countNon1++;
int denom = product;
int num = candidateIdx;
int[] index = new int[numValuesPerParam.length];
for (int i = index.length - 1; i >= 0; i--) {
denom /= numValuesPerParam[i];
index[i] = num / denom;
num %= denom;
}
//Now: convert indexes to values in range [0,1]
//min value -> 0
//max value -> 1
double[] out = new double[countNon1];
int outIdx = 0;
for (int i = 0; i < numValuesPerParam.length; i++) {
if (numValuesPerParam[i] > 1){
out[outIdx++] = index[i] / ((double) (numValuesPerParam[i] - 1));
}
}
return out;
}
@Override
public String toString() {
return "GridSearchCandidateGenerator(mode=" + mode + ")";
}
}

View File

@ -1,95 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Map;
/**
* RandomSearchGenerator: generates candidates at random.<br>
* Note: if a probability distribution is provided for continuous hyperparameters,
* this will be taken into account
* when generating candidates. This allows the search to be weighted more towards
* certain values according to a probability
* density. For example: generate samples for learning rate according to log uniform distribution
*
* @author Alex Black
*/
@Slf4j
@EqualsAndHashCode(callSuper = true)
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
public class RandomSearchGenerator extends BaseCandidateGenerator {
@JsonCreator
public RandomSearchGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
@JsonProperty("initDone") boolean initDone) {
super(parameterSpace, dataParameters, initDone);
initialize();
}
public RandomSearchGenerator(ParameterSpace<?> parameterSpace, Map<String,Object> dataParameters){
this(parameterSpace, dataParameters, false);
}
public RandomSearchGenerator(ParameterSpace<?> parameterSpace){
this(parameterSpace, null, false);
}
@Override
public boolean hasMoreCandidates() {
return true;
}
@Override
public Candidate getCandidate() {
double[] randomValues = new double[parameterSpace.numParameters()];
for (int i = 0; i < randomValues.length; i++)
randomValues[i] = rng.nextDouble();
Object value = null;
Exception e = null;
try {
value = parameterSpace.getValue(randomValues);
} catch (Exception e2) {
log.warn("Error getting configuration for candidate", e2);
e = e2;
}
return new Candidate(value, candidateCounter.getAndIncrement(), randomValues, dataParameters, e);
}
@Override
public Class<?> getCandidateType() {
return null;
}
@Override
public String toString() {
return "RandomSearchGenerator";
}
}

View File

@ -1,44 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic;
import lombok.Data;
/**
* Candidates are stored as Chromosome in the population model
*
* @author Alexandre Boulanger
*/
@Data
public class Chromosome {
/**
* The fitness score of the genes.
*/
protected final double fitness;
/**
* The genes.
*/
protected final double[] genes;
public Chromosome(double[] genes, double fitness) {
this.genes = genes;
this.fitness = fitness;
}
}

View File

@ -1,53 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic;
/**
* A factory that builds new chromosomes. Used by the GeneticSearchCandidateGenerator.
*
* @author Alexandre Boulanger
*/
public class ChromosomeFactory {
private int chromosomeLength;
/**
* Called by the GeneticSearchCandidateGenerator.
*/
public void initializeInstance(int chromosomeLength) {
this.chromosomeLength = chromosomeLength;
}
/**
* Create a new instance of a Chromosome
*
* @param genes The genes
* @param fitness The fitness score
* @return A new instance of Chromosome
*/
public Chromosome createChromosome(double[] genes, double fitness) {
return new Chromosome(genes, fitness);
}
/**
* @return The number of genes in a chromosome
*/
public int getChromosomeLength() {
return chromosomeLength;
}
}

View File

@ -1,122 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import org.nd4j.common.base.Preconditions;
/**
* A crossover operator that linearly combines the genes of two parents. <br>
* When a crossover is generated (with a of probability <i>crossover rate</i>), each genes is a linear combination of the corresponding genes of the parents.
* <p>
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i>
*
* @author Alexandre Boulanger
*/
public class ArithmeticCrossover extends TwoParentsCrossoverOperator {
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
private final double crossoverRate;
private final RandomGenerator rng;
public static class Builder {
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
private RandomGenerator rng;
private TwoParentSelection parentSelection;
/**
* The probability that the operator generates a crossover (default 0.85).
*
* @param rate A value between 0.0 and 1.0
*/
public Builder crossoverRate(double rate) {
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
this.crossoverRate = rate;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
/**
* The parent selection behavior. Default is random parent selection.
*
* @param parentSelection An instance of TwoParentSelection
*/
public Builder parentSelection(TwoParentSelection parentSelection) {
this.parentSelection = parentSelection;
return this;
}
public ArithmeticCrossover build() {
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
if (parentSelection == null) {
parentSelection = new RandomTwoParentSelection();
}
return new ArithmeticCrossover(this);
}
}
private ArithmeticCrossover(ArithmeticCrossover.Builder builder) {
super(builder.parentSelection);
this.crossoverRate = builder.crossoverRate;
this.rng = builder.rng;
}
/**
* Has a probability <i>crossoverRate</i> of performing the crossover where each gene is a linear combination of:<br>
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i><br>
* Otherwise, returns the genes of a random parent.
*
* @return The crossover result. See {@link CrossoverResult}.
*/
@Override
public CrossoverResult crossover() {
double[][] parents = parentSelection.selectParents();
double[] offspringValues = new double[parents[0].length];
if (rng.nextDouble() < crossoverRate) {
for (int i = 0; i < offspringValues.length; ++i) {
double t = rng.nextDouble();
offspringValues[i] = t * parents[0][i] + (1.0 - t) * parents[1][i];
}
return new CrossoverResult(true, offspringValues);
}
return new CrossoverResult(false, parents[0]);
}
}

View File

@ -1,47 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
/**
* Abstract class for all crossover operators
*
* @author Alexandre Boulanger
*/
public abstract class CrossoverOperator {
protected PopulationModel populationModel;
/**
* Will be called by the selection operator once the population model is instantiated.
*/
public void initializeInstance(PopulationModel populationModel) {
this.populationModel = populationModel;
}
/**
* Performs the crossover
*
* @return The crossover result. See {@link CrossoverResult}.
*/
public abstract CrossoverResult crossover();
}

View File

@ -1,45 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import lombok.Data;
/**
* Returned by a crossover operator
*
* @author Alexandre Boulanger
*/
@Data
public class CrossoverResult {
/**
* If false, there was no crossover and the operator simply returned the genes of a random parent.
* If true, the genes are the result of a crossover.
*/
private final boolean isModified;
/**
* The genes returned by the operator.
*/
private final double[] genes;
public CrossoverResult(boolean isModified, double[] genes) {
this.isModified = isModified;
this.genes = genes;
}
}

View File

@ -1,180 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
import org.nd4j.common.base.Preconditions;
import java.util.Deque;
/**
* The K-Point crossover will select at random multiple crossover points.<br>
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched.
*/
public class KPointCrossover extends TwoParentsCrossoverOperator {
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
private static final int DEFAULT_MIN_CROSSOVER = 1;
private static final int DEFAULT_MAX_CROSSOVER = 4;
private final double crossoverRate;
private final int minCrossovers;
private final int maxCrossovers;
private final RandomGenerator rng;
public static class Builder {
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
private int minCrossovers = DEFAULT_MIN_CROSSOVER;
private int maxCrossovers = DEFAULT_MAX_CROSSOVER;
private RandomGenerator rng;
private TwoParentSelection parentSelection;
/**
* The probability that the operator generates a crossover (default 0.85).
*
* @param rate A value between 0.0 and 1.0
*/
public Builder crossoverRate(double rate) {
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
this.crossoverRate = rate;
return this;
}
/**
* The number of crossovers points (default is min 1, max 4)
*
* @param min The minimum number
* @param max The maximum number
*/
public Builder numCrossovers(int min, int max) {
Preconditions.checkState(max >= 0 && min >= 0, "Min and max must be positive");
Preconditions.checkState(max >= min, "Max must be greater or equal to min");
this.minCrossovers = min;
this.maxCrossovers = max;
return this;
}
/**
* Use a fixed number of crossover points
*
* @param num The number of crossovers
*/
public Builder numCrossovers(int num) {
Preconditions.checkState(num >= 0, "Num must be positive");
this.minCrossovers = num;
this.maxCrossovers = num;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
/**
* The parent selection behavior. Default is random parent selection.
*
* @param parentSelection An instance of TwoParentSelection
*/
public Builder parentSelection(TwoParentSelection parentSelection) {
this.parentSelection = parentSelection;
return this;
}
public KPointCrossover build() {
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
if (parentSelection == null) {
parentSelection = new RandomTwoParentSelection();
}
return new KPointCrossover(this);
}
}
private CrossoverPointsGenerator crossoverPointsGenerator;
private KPointCrossover(KPointCrossover.Builder builder) {
super(builder.parentSelection);
this.crossoverRate = builder.crossoverRate;
this.maxCrossovers = builder.maxCrossovers;
this.minCrossovers = builder.minCrossovers;
this.rng = builder.rng;
}
private CrossoverPointsGenerator getCrossoverPointsGenerator(int chromosomeLength) {
if (crossoverPointsGenerator == null) {
crossoverPointsGenerator =
new CrossoverPointsGenerator(chromosomeLength, minCrossovers, maxCrossovers, rng);
}
return crossoverPointsGenerator;
}
/**
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select at random multiple crossover points.<br>
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched. <br>
* Otherwise, returns the genes of a random parent.
*
* @return The crossover result. See {@link CrossoverResult}.
*/
@Override
public CrossoverResult crossover() {
double[][] parents = parentSelection.selectParents();
boolean isModified = false;
double[] resultGenes = parents[0];
if (rng.nextDouble() < crossoverRate) {
// Select crossover points
Deque<Integer> crossoverPoints = getCrossoverPointsGenerator(parents[0].length).getCrossoverPoints();
// Crossover
resultGenes = new double[parents[0].length];
int currentParent = 0;
int nextCrossover = crossoverPoints.pop();
for (int i = 0; i < resultGenes.length; ++i) {
if (i == nextCrossover) {
currentParent = currentParent == 0 ? 1 : 0;
nextCrossover = crossoverPoints.pop();
}
resultGenes[i] = parents[currentParent][i];
}
isModified = true;
}
return new CrossoverResult(isModified, resultGenes);
}
}

View File

@ -1,125 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import org.nd4j.common.base.Preconditions;
/**
* The single point crossover will select a random point where every genes before that point comes from one parent
* and after which every genes comes from the other parent.
*
* @author Alexandre Boulanger
*/
public class SinglePointCrossover extends TwoParentsCrossoverOperator {
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
private final RandomGenerator rng;
private final double crossoverRate;
public static class Builder {
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
private RandomGenerator rng;
private TwoParentSelection parentSelection;
/**
* The probability that the operator generates a crossover (default 0.85).
*
* @param rate A value between 0.0 and 1.0
*/
public Builder crossoverRate(double rate) {
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
this.crossoverRate = rate;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
/**
* The parent selection behavior. Default is random parent selection.
*
* @param parentSelection An instance of TwoParentSelection
*/
public Builder parentSelection(TwoParentSelection parentSelection) {
this.parentSelection = parentSelection;
return this;
}
public SinglePointCrossover build() {
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
if (parentSelection == null) {
parentSelection = new RandomTwoParentSelection();
}
return new SinglePointCrossover(this);
}
}
private SinglePointCrossover(SinglePointCrossover.Builder builder) {
super(builder.parentSelection);
this.crossoverRate = builder.crossoverRate;
this.rng = builder.rng;
}
/**
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select a random crossover point.<br>
* Each gene before this point comes from one of the two parents and each gene at or after this point comes from the other parent.
* Otherwise, returns the genes of a random parent.
*
* @return The crossover result. See {@link CrossoverResult}.
*/
public CrossoverResult crossover() {
double[][] parents = parentSelection.selectParents();
boolean isModified = false;
double[] resultGenes = parents[0];
if (rng.nextDouble() < crossoverRate) {
int chromosomeLength = parents[0].length;
// Crossover
resultGenes = new double[chromosomeLength];
int crossoverPoint = rng.nextInt(chromosomeLength);
for (int i = 0; i < resultGenes.length; ++i) {
resultGenes[i] = ((i < crossoverPoint) ? parents[0] : parents[1])[i];
}
isModified = true;
}
return new CrossoverResult(isModified, resultGenes);
}
}

View File

@ -1,48 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
/**
* Abstract class for all crossover operators that applies to two parents.
*
* @author Alexandre Boulanger
*/
public abstract class TwoParentsCrossoverOperator extends CrossoverOperator {
protected final TwoParentSelection parentSelection;
/**
* @param parentSelection A parent selection that selects two parents.
*/
protected TwoParentsCrossoverOperator(TwoParentSelection parentSelection) {
this.parentSelection = parentSelection;
}
/**
* Will be called by the selection operator once the population model is instantiated.
*/
@Override
public void initializeInstance(PopulationModel populationModel) {
super.initializeInstance(populationModel);
parentSelection.initializeInstance(populationModel.getPopulation());
}
}

View File

@ -1,138 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import org.nd4j.common.base.Preconditions;
/**
* The uniform crossover will, for each gene, randomly select the parent that donates the gene.
*
* @author Alexandre Boulanger
*/
public class UniformCrossover extends TwoParentsCrossoverOperator {
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
private static final double DEFAULT_PARENT_BIAS_FACTOR = 0.5;
private final double crossoverRate;
private final double parentBiasFactor;
private final RandomGenerator rng;
public static class Builder {
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
private double parentBiasFactor = DEFAULT_PARENT_BIAS_FACTOR;
private RandomGenerator rng;
private TwoParentSelection parentSelection;
/**
* The probability that the operator generates a crossover (default 0.85).
*
* @param rate A value between 0.0 and 1.0
*/
public Builder crossoverRate(double rate) {
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
this.crossoverRate = rate;
return this;
}
/**
* A factor that will introduce a bias in the parent selection.<br>
*
* @param factor In the range [0, 1]. 0 will only select the first parent while 1 only select the second one. The default is 0.5; no bias.
*/
public Builder parentBiasFactor(double factor) {
Preconditions.checkState(factor >= 0.0 && factor <= 1.0, "Factor must be between 0.0 and 1.0, got %s",
factor);
this.parentBiasFactor = factor;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
/**
* The parent selection behavior. Default is random parent selection.
*
* @param parentSelection An instance of TwoParentSelection
*/
public Builder parentSelection(TwoParentSelection parentSelection) {
this.parentSelection = parentSelection;
return this;
}
public UniformCrossover build() {
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
if (parentSelection == null) {
parentSelection = new RandomTwoParentSelection();
}
return new UniformCrossover(this);
}
}
private UniformCrossover(UniformCrossover.Builder builder) {
super(builder.parentSelection);
this.crossoverRate = builder.crossoverRate;
this.parentBiasFactor = builder.parentBiasFactor;
this.rng = builder.rng;
}
/**
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select randomly which parent donates the gene.<br>
* One of the parent may be favored if the bias is different than 0.5
* Otherwise, returns the genes of a random parent.
*
* @return The crossover result. See {@link CrossoverResult}.
*/
@Override
public CrossoverResult crossover() {
// select the parents
double[][] parents = parentSelection.selectParents();
double[] resultGenes = parents[0];
boolean isModified = false;
if (rng.nextDouble() < crossoverRate) {
// Crossover
resultGenes = new double[parents[0].length];
for (int i = 0; i < resultGenes.length; ++i) {
resultGenes[i] = ((rng.nextDouble() < parentBiasFactor) ? parents[0] : parents[1])[i];
}
isModified = true;
}
return new CrossoverResult(isModified, resultGenes);
}
}

View File

@ -1,46 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import java.util.List;
/**
* Abstract class for all parent selection behaviors
*
* @author Alexandre Boulanger
*/
public abstract class ParentSelection {
protected List<Chromosome> population;
/**
* Will be called by the crossover operator once the population model is instantiated.
*/
public void initializeInstance(List<Chromosome> population) {
this.population = population;
}
/**
* Performs the parent selection
*
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
*/
public abstract double[][] selectParents();
}

View File

@ -1,67 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
/**
* A parent selection behavior that returns two random parents.
*
* @author Alexandre Boulanger
*/
public class RandomTwoParentSelection extends TwoParentSelection {
private final RandomGenerator rng;
public RandomTwoParentSelection() {
this(new SynchronizedRandomGenerator(new JDKRandomGenerator()));
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public RandomTwoParentSelection(RandomGenerator rng) {
this.rng = rng;
}
/**
* Selects two random parents
*
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
*/
@Override
public double[][] selectParents() {
double[][] parents = new double[2][];
int parent1Idx = rng.nextInt(population.size());
int parent2Idx;
do {
parent2Idx = rng.nextInt(population.size());
} while (parent1Idx == parent2Idx);
parents[0] = population.get(parent1Idx).getGenes();
parents[1] = population.get(parent2Idx).getGenes();
return parents;
}
}

View File

@ -1,27 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
/**
* Abstract class for all parent selection behaviors that selects two parents.
*
* @author Alexandre Boulanger
*/
public abstract class TwoParentSelection extends ParentSelection {
}

View File

@ -1,70 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
import java.util.*;
/**
* A helper class used by {@link KPointCrossover} to generate the crossover points
*
* @author Alexandre Boulanger
*/
public class CrossoverPointsGenerator {
private final int minCrossovers;
private final int maxCrossovers;
private final RandomGenerator rng;
private List<Integer> parameterIndexes;
/**
* Constructor
*
* @param chromosomeLength The number of genes
* @param minCrossovers The minimum number of crossover points to generate
* @param maxCrossovers The maximum number of crossover points to generate
* @param rng A RandomGenerator instance
*/
public CrossoverPointsGenerator(int chromosomeLength, int minCrossovers, int maxCrossovers, RandomGenerator rng) {
this.minCrossovers = minCrossovers;
this.maxCrossovers = maxCrossovers;
this.rng = rng;
parameterIndexes = new ArrayList<Integer>();
for (int i = 0; i < chromosomeLength; ++i) {
parameterIndexes.add(i);
}
}
/**
* Generate a list of crossover points.
*
* @return An ordered list of crossover point indexes and with Integer.MAX_VALUE as the last element
*/
public Deque<Integer> getCrossoverPoints() {
Collections.shuffle(parameterIndexes);
List<Integer> crossoverPointLists =
parameterIndexes.subList(0, rng.nextInt(maxCrossovers - minCrossovers) + minCrossovers);
Collections.sort(crossoverPointLists);
Deque<Integer> crossoverPoints = new ArrayDeque<Integer>(crossoverPointLists);
crossoverPoints.add(Integer.MAX_VALUE);
return crossoverPoints;
}
}

View File

@ -1,43 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
/**
* The cull operator will remove from the population the least desirables chromosomes.
*
* @author Alexandre Boulanger
*/
public interface CullOperator {
/**
* Will be called by the population model once created.
*/
void initializeInstance(PopulationModel populationModel);
/**
* Cull the population to the culled size.
*/
void cullPopulation();
/**
* @return The target population size after culling.
*/
int getCulledSize();
}

View File

@ -1,52 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
/**
* An elitist cull operator that discards the chromosomes with the worst fitness while keeping the best ones.
*
* @author Alexandre Boulanger
*/
public class LeastFitCullOperator extends RatioCullOperator {
/**
* The default cull ratio is 1/3.
*/
public LeastFitCullOperator() {
super();
}
/**
* @param cullRatio The ratio of the maximum population size to be culled.<br>
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
*/
public LeastFitCullOperator(double cullRatio) {
super(cullRatio);
}
/**
* Will discard the chromosomes with the worst fitness until the population size fall back at the culled size.
*/
@Override
public void cullPopulation() {
while (population.size() > culledSize) {
population.remove(population.size() - 1);
}
}
}

View File

@ -1,72 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import org.nd4j.common.base.Preconditions;
import java.util.List;
/**
* An abstract base for cull operators that culls back the population to a ratio of its maximum size.
*
* @author Alexandre Boulanger
*/
public abstract class RatioCullOperator implements CullOperator {
private static final double DEFAULT_CULL_RATIO = 1.0 / 3.0;
protected int culledSize;
protected List<Chromosome> population;
protected final double cullRatio;
/**
* @param cullRatio The ratio of the maximum population size to be culled.<br>
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
*/
public RatioCullOperator(double cullRatio) {
Preconditions.checkState(cullRatio >= 0.0 && cullRatio <= 1.0, "Cull ratio must be between 0.0 and 1.0, got %s",
cullRatio);
this.cullRatio = cullRatio;
}
/**
* The default cull ratio is 1/3
*/
public RatioCullOperator() {
this(DEFAULT_CULL_RATIO);
}
/**
* Will be called by the population model once created.
*/
public void initializeInstance(PopulationModel populationModel) {
this.population = populationModel.getPopulation();
culledSize = (int) (populationModel.getPopulationSize() * (1.0 - cullRatio) + 0.5);
}
/**
* @return The target population size after culling.
*/
@Override
public int getCulledSize() {
return culledSize;
}
}

View File

@ -1,25 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions;
public class GeneticGenerationException extends RuntimeException {
public GeneticGenerationException(String message) {
super(message);
}
}

View File

@ -1,35 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
/**
* The mutation operator will apply a mutation to the given genes.
*
* @author Alexandre Boulanger
*/
public interface MutationOperator {
/**
* Performs a mutation.
*
* @param genes The genes to be mutated
* @return True if the genes were mutated, otherwise false.
*/
boolean mutate(double[] genes);
}

View File

@ -1,95 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.nd4j.common.base.Preconditions;
/**
* A mutation operator where each gene has a chance of being mutated with a <i>mutation rate</i> probability.
*
* @author Alexandre Boulanger
*/
public class RandomMutationOperator implements MutationOperator {
private static final double DEFAULT_MUTATION_RATE = 0.005;
private final double mutationRate;
private final RandomGenerator rng;
public static class Builder {
private double mutationRate = DEFAULT_MUTATION_RATE;
private RandomGenerator rng;
/**
* Each gene will have this probability of being mutated.
*
* @param rate The mutation rate. (default 0.005)
*/
public Builder mutationRate(double rate) {
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
this.mutationRate = rate;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
public RandomMutationOperator build() {
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
return new RandomMutationOperator(this);
}
}
private RandomMutationOperator(RandomMutationOperator.Builder builder) {
this.mutationRate = builder.mutationRate;
this.rng = builder.rng;
}
/**
* Performs the mutation. Each gene has a <i>mutation rate</i> probability of being mutated.
*
* @param genes The genes to be mutated
* @return True if the genes were mutated, otherwise false.
*/
@Override
public boolean mutate(double[] genes) {
boolean hasMutation = false;
for (int i = 0; i < genes.length; ++i) {
if (rng.nextDouble() < mutationRate) {
genes[i] = rng.nextDouble();
hasMutation = true;
}
}
return hasMutation;
}
}

View File

@ -1,43 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import java.util.ArrayList;
import java.util.List;
/**
* A population initializer that build an empty population.
*
* @author Alexandre Boulanger
*/
public class EmptyPopulationInitializer implements PopulationInitializer {
/**
* Initialize an empty population
*
* @param size The maximum size of the population.
* @return The initialized population.
*/
@Override
public List<Chromosome> getInitializedPopulation(int size) {
return new ArrayList<>(size);
}
}

View File

@ -1,38 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import java.util.List;
/**
* An initializer that construct the population used by the population model.
*
* @author Alexandre Boulanger
*/
public interface PopulationInitializer {
/**
* Called by the population model to construct the population
*
* @param size The maximum size of the population
* @return An initialized population
*/
List<Chromosome> getInitializedPopulation(int size);
}

View File

@ -1,37 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import java.util.List;
/**
* A listener that is called when the population changes.
*
* @author Alexandre Boulanger
*/
public interface PopulationListener {
/**
* Called after the population has changed.
*
* @param population The population after it has changed.
*/
void onChanged(List<Chromosome> population);
}

View File

@ -1,184 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
import lombok.Getter;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* The population model handles all aspects of the population (initialization, additions and culling)
*
* @author Alexandre Boulanger
*/
public class PopulationModel {
private static final int DEFAULT_POPULATION_SIZE = 30;
private final CullOperator cullOperator;
private final List<PopulationListener> populationListeners = new ArrayList<>();
private Comparator<Chromosome> chromosomeComparator;
/**
* The maximum population size
*/
@Getter
private final int populationSize;
/**
* The population
*/
@Getter
public final List<Chromosome> population;
/**
* A comparator used when higher fitness value is better
*/
public static class MaximizeScoreComparator implements Comparator<Chromosome> {
@Override
public int compare(Chromosome lhs, Chromosome rhs) {
return -Double.compare(lhs.getFitness(), rhs.getFitness());
}
}
/**
* A comparator used when lower fitness value is better
*/
public static class MinimizeScoreComparator implements Comparator<Chromosome> {
@Override
public int compare(Chromosome lhs, Chromosome rhs) {
return Double.compare(lhs.getFitness(), rhs.getFitness());
}
}
public static class Builder {
private int populationSize = DEFAULT_POPULATION_SIZE;
private PopulationInitializer populationInitializer;
private CullOperator cullOperator;
/**
* Use an alternate population initialization behavior. Default is empty population.
*
* @param populationInitializer An instance of PopulationInitializer
*/
public Builder populationInitializer(PopulationInitializer populationInitializer) {
this.populationInitializer = populationInitializer;
return this;
}
/**
* The maximum population size. <br>
* If using a ratio based culling, using a population with culled size of around 1.5 to 2 times the number of genes generally gives good results.
* (e.g. For a chromosome having 10 genes, the culled size should be between 15 and 20. And with a cull ratio of 1/3 we should set the population size to 23 to 30. (15 / (1 - 1/3)), rounded up)
*
* @param size The maximum size of the population
*/
public Builder populationSize(int size) {
populationSize = size;
return this;
}
/**
* Use an alternate cull operator behavior. Default is least fit culling.
*
* @param cullOperator An instance of a CullOperator
*/
public Builder cullOperator(CullOperator cullOperator) {
this.cullOperator = cullOperator;
return this;
}
public PopulationModel build() {
if (cullOperator == null) {
cullOperator = new LeastFitCullOperator();
}
if (populationInitializer == null) {
populationInitializer = new EmptyPopulationInitializer();
}
return new PopulationModel(this);
}
}
public PopulationModel(PopulationModel.Builder builder) {
populationSize = builder.populationSize;
population = new ArrayList<>(builder.populationSize);
PopulationInitializer populationInitializer = builder.populationInitializer;
List<Chromosome> initializedPopulation = populationInitializer.getInitializedPopulation(populationSize);
population.clear();
population.addAll(initializedPopulation);
cullOperator = builder.cullOperator;
cullOperator.initializeInstance(this);
}
/**
* Called by the GeneticSearchCandidateGenerator
*/
public void initializeInstance(boolean minimizeScore) {
chromosomeComparator = minimizeScore ? new MinimizeScoreComparator() : new MaximizeScoreComparator();
}
/**
* Add a PopulationListener to the list of change listeners
* @param listener A PopulationListener instance
*/
public void addListener(PopulationListener listener) {
populationListeners.add(listener);
}
/**
* Add a Chromosome to the population and call the PopulationListeners. Culling may be triggered.
*
* @param element The chromosome to be added
*/
public void add(Chromosome element) {
if (population.size() == populationSize) {
cullOperator.cullPopulation();
}
population.add(element);
Collections.sort(population, chromosomeComparator);
triggerPopulationChangedListeners(population);
}
/**
* @return Return false when the population is below the culled size, otherwise true. <br>
* Used by the selection operator to know if the population is still too small and should generate random genes.
*/
public boolean isReadyToBreed() {
return population.size() >= cullOperator.getCulledSize();
}
private void triggerPopulationChangedListeners(List<Chromosome> population) {
for (PopulationListener listener : populationListeners) {
listener.onChanged(population);
}
}
}

View File

@ -1,199 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.MutationOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import java.util.Arrays;
/**
* A selection operator that will generate random genes initially. Once the population has reached the culled size,
* will start to generate offsprings of parents selected in the population.
*
* @author Alexandre Boulanger
*/
public class GeneticSelectionOperator extends SelectionOperator {
private final static int PREVIOUS_GENES_TO_KEEP = 100;
private final static int MAX_NUM_GENERATION_ATTEMPTS = 1024;
private final CrossoverOperator crossoverOperator;
private final MutationOperator mutationOperator;
private final RandomGenerator rng;
private double[][] previousGenes = new double[PREVIOUS_GENES_TO_KEEP][];
private int previousGenesIdx = 0;
public static class Builder {
private ChromosomeFactory chromosomeFactory;
private PopulationModel populationModel;
private CrossoverOperator crossoverOperator;
private MutationOperator mutationOperator;
private RandomGenerator rng;
/**
* Use an alternate crossover behavior. Default is SinglePointCrossover.
*
* @param crossoverOperator An instance of CrossoverOperator
*/
public Builder crossoverOperator(CrossoverOperator crossoverOperator) {
this.crossoverOperator = crossoverOperator;
return this;
}
/**
* Use an alternate mutation behavior. Default is RandomMutationOperator.
*
* @param mutationOperator An instance of MutationOperator
*/
public Builder mutationOperator(MutationOperator mutationOperator) {
this.mutationOperator = mutationOperator;
return this;
}
/**
* Use a supplied RandomGenerator
*
* @param rng An instance of RandomGenerator
*/
public Builder randomGenerator(RandomGenerator rng) {
this.rng = rng;
return this;
}
public GeneticSelectionOperator build() {
if (crossoverOperator == null) {
crossoverOperator = new SinglePointCrossover.Builder().build();
}
if (mutationOperator == null) {
mutationOperator = new RandomMutationOperator.Builder().build();
}
if (rng == null) {
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
}
return new GeneticSelectionOperator(crossoverOperator, mutationOperator, rng);
}
}
private GeneticSelectionOperator(CrossoverOperator crossoverOperator, MutationOperator mutationOperator,
RandomGenerator rng) {
this.crossoverOperator = crossoverOperator;
this.mutationOperator = mutationOperator;
this.rng = rng;
}
/**
* Called by GeneticSearchCandidateGenerator
*/
@Override
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
super.initializeInstance(populationModel, chromosomeFactory);
crossoverOperator.initializeInstance(populationModel);
}
/**
* Build a new set of genes. Has two distinct modes of operation
* <ul>
* <li>Before the population has reached the culled size: will return a random set of genes.</li>
* <li>After: Parents will be selected among the population, a crossover will be applied followed by a mutation.</li>
* </ul>
* @return Returns the generated set of genes
* @throws GeneticGenerationException If buildNextGenes() can't generate a set that has not already been tried,
* or if the crossover and the mutation operators can't generate a set,
* this exception is thrown.
*/
@Override
public double[] buildNextGenes() {
double[] result;
boolean hasAlreadyBeenTried;
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
do {
if (populationModel.isReadyToBreed()) {
result = buildOffspring();
} else {
result = buildRandomGenes();
}
hasAlreadyBeenTried = hasAlreadyBeenTried(result);
if (hasAlreadyBeenTried && --attemptsRemaining == 0) {
throw new GeneticGenerationException("Failed to generate a set of genes not already tried.");
}
} while (hasAlreadyBeenTried);
previousGenes[previousGenesIdx] = result;
previousGenesIdx = ++previousGenesIdx % previousGenes.length;
return result;
}
private boolean hasAlreadyBeenTried(double[] genes) {
for (int i = 0; i < previousGenes.length; ++i) {
double[] current = previousGenes[i];
if (current != null && Arrays.equals(current, genes)) {
return true;
}
}
return false;
}
private double[] buildOffspring() {
double[] offspringValues;
boolean isModified;
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
do {
CrossoverResult crossoverResult = crossoverOperator.crossover();
offspringValues = crossoverResult.getGenes();
isModified = crossoverResult.isModified();
isModified |= mutationOperator.mutate(offspringValues);
if (!isModified && --attemptsRemaining == 0) {
throw new GeneticGenerationException(
String.format("Crossover and mutation operators failed to generate a new set of genes after %s attempts.",
MAX_NUM_GENERATION_ATTEMPTS));
}
} while (!isModified);
return offspringValues;
}
private double[] buildRandomGenes() {
double[] randomValues = new double[chromosomeFactory.getChromosomeLength()];
for (int i = 0; i < randomValues.length; ++i) {
randomValues[i] = rng.nextDouble();
}
return randomValues;
}
}

View File

@ -1,46 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
/**
* An abstract class for all selection operators. Used by the GeneticSearchCandidateGenerator to generate new candidates.
*
* @author Alexandre Boulanger
*/
public abstract class SelectionOperator {
protected PopulationModel populationModel;
protected ChromosomeFactory chromosomeFactory;
/**
* Called by GeneticSearchCandidateGenerator
*/
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
this.populationModel = populationModel;
this.chromosomeFactory = chromosomeFactory;
}
/**
* Generate a new set of genes.
*/
public abstract double[] buildNextGenes();
}

View File

@ -1,48 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.generator.util;
import org.nd4j.common.function.Supplier;
import java.io.*;
public class SerializedSupplier<T> implements Serializable, Supplier<T> {
private byte[] asBytes;
public SerializedSupplier(T obj){
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
oos.writeObject(obj);
oos.flush();
oos.close();
asBytes = baos.toByteArray();
} catch (Exception e){
throw new RuntimeException("Error serializing object - must be serializable",e);
}
}
@Override
public T get() {
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(asBytes))){
return (T)ois.readObject();
} catch (Exception e){
throw new RuntimeException("Error deserializing object",e);
}
}
}

View File

@ -1,78 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.parameter;
import lombok.EqualsAndHashCode;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* BooleanParameterSpace is a {@code ParameterSpace<Boolean>}; Defines {True, False} as a parameter space
* If argument to setValue is less than or equal to 0.5 it will return True else False
*
* @author susaneraly
*/
@EqualsAndHashCode
public class BooleanSpace implements ParameterSpace<Boolean> {
private int index = -1;
@Override
public Boolean getValue(double[] input) {
if (index == -1) {
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
}
if (input[index] <= 0.5) return Boolean.TRUE;
else return Boolean.FALSE;
}
@Override
public int numParameters() {
return 1;
}
@Override
public List<ParameterSpace> collectLeaves() {
return Collections.singletonList((ParameterSpace) this);
}
@Override
public Map<String, ParameterSpace> getNestedSpaces() {
return Collections.emptyMap();
}
@Override
public boolean isLeaf() {
return true;
}
@Override
public void setIndices(int... indices) {
if (indices == null || indices.length != 1)
throw new IllegalArgumentException("Invalid index");
this.index = indices[0];
}
@Override
public String toString() {
return "BooleanSpace()";
}
}

View File

@ -1,92 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.parameter;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueDeserializer;
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueSerializer;
import org.deeplearning4j.arbiter.util.ObjectUtils;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* FixedValue is a ParameterSpace that defines only a single fixed value
*
* @param <T> Type of (fixed) value
*/
@EqualsAndHashCode
@JsonSerialize(using = FixedValueSerializer.class)
@JsonDeserialize(using = FixedValueDeserializer.class)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public class FixedValue<T> implements ParameterSpace<T> {
@Getter
private Object value;
private int index;
@JsonCreator
public FixedValue(@JsonProperty("value") T value) {
this.value = value;
}
@Override
public String toString() {
return "FixedValue(" + ObjectUtils.valueToString(value) + ")";
}
@Override
public T getValue(double[] input) {
return (T) value;
}
@Override
public int numParameters() {
return 0;
}
@Override
public List<ParameterSpace> collectLeaves() {
return Collections.emptyList();
}
@Override
public Map<String, ParameterSpace> getNestedSpaces() {
return Collections.emptyMap();
}
@Override
public boolean isLeaf() {
return true;
}
@Override
public void setIndices(int... indices) {
if (indices != null && indices.length != 0)
throw new IllegalArgumentException(
"Invalid call: FixedValue ParameterSpace " + "should not be given an index (0 params)");
}
}

View File

@ -1,137 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.parameter.continuous;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.distribution.DistributionUtils;
import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionDeserializer;
import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionSerializer;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* ContinuousParametSpace is a {@code ParameterSpace<Double>} that (optionally) takes an Apache Commons
* {@link RealDistribution} when used for random sampling (such as in a RandomSearchCandidateGenerator)
*
* @author Alex Black
*/
public class ContinuousParameterSpace implements ParameterSpace<Double> {
//Need to use custom serializers/deserializers for commons RealDistribution instances
@JsonSerialize(using = RealDistributionSerializer.class)
@JsonDeserialize(using = RealDistributionDeserializer.class)
private RealDistribution distribution;
private int index = -1;
/**
* ContinuousParameterSpace with uniform distribution between the minimum and maximum values
*
* @param min Minimum value that can be generated
* @param max Maximum value that can be generated
*/
public ContinuousParameterSpace(double min, double max) {
this(new UniformRealDistribution(min, max));
}
/**
* ConditiousParameterSpcae wiht a specified probability distribution. The provided distribution defines the min/max
* values, and (for random search, etc) will be used when generating random values
*
* @param distribution Distribution to sample from
*/
public ContinuousParameterSpace(@JsonProperty("distribution") RealDistribution distribution) {
this.distribution = distribution;
}
@Override
public Double getValue(double[] input) {
if (index == -1) {
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
}
return distribution.inverseCumulativeProbability(input[index]);
}
@Override
public int numParameters() {
return 1;
}
@Override
public List<ParameterSpace> collectLeaves() {
return Collections.singletonList((ParameterSpace) this);
}
@Override
public Map<String, ParameterSpace> getNestedSpaces() {
return Collections.emptyMap();
}
@Override
public boolean isLeaf() {
return true;
}
@Override
public void setIndices(int... indices) {
if (indices == null || indices.length != 1) {
throw new IllegalArgumentException("Invalid index");
}
this.index = indices[0];
}
@Override
public String toString() {
if (distribution instanceof UniformRealDistribution) {
return "ContinuousParameterSpace(min=" + distribution.getSupportLowerBound() + ",max="
+ distribution.getSupportUpperBound() + ")";
} else {
return "ContinuousParameterSpace(" + distribution + ")";
}
}
public boolean equals(Object o) {
if (o == this)
return true;
if (!(o instanceof ContinuousParameterSpace))
return false;
final ContinuousParameterSpace other = (ContinuousParameterSpace) o;
if (distribution == null ? other.distribution != null
: !DistributionUtils.distributionsEqual(distribution, other.distribution))
return false;
return this.index == other.index;
}
public int hashCode() {
final int PRIME = 59;
int result = 1;
result = result * PRIME + (distribution == null ? 43 : distribution.getClass().hashCode());
result = result * PRIME + this.index;
return result;
}
}

View File

@ -1,114 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.parameter.discrete;
import lombok.EqualsAndHashCode;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.util.ObjectUtils;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import java.util.*;
/**
* A DiscreteParameterSpace is used for a set of un-ordered values
*
* @param <P> Parameter type
* @author Alex Black
*/
@EqualsAndHashCode
public class DiscreteParameterSpace<P> implements ParameterSpace<P> {
@JsonSerialize
private List<P> values;
private int index = -1;
public DiscreteParameterSpace(@JsonProperty("values") P... values) {
if (values != null)
this.values = Arrays.asList(values);
}
public DiscreteParameterSpace(Collection<P> values) {
this.values = new ArrayList<>(values);
}
public int numValues() {
return values.size();
}
@Override
public P getValue(double[] input) {
if (index == -1) {
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
}
if (values == null)
throw new IllegalStateException("Values are null.");
//Map a value in range [0,1] to one of the list of values
//First value: [0,width], second: (width,2*width], third: (3*width,4*width] etc
int size = values.size();
if (size == 1)
return values.get(0);
double width = 1.0 / size;
int val = (int) (input[index] / width);
return values.get(Math.min(val, size - 1));
}
@Override
public int numParameters() {
return 1;
}
@Override
public List<ParameterSpace> collectLeaves() {
return Collections.singletonList((ParameterSpace) this);
}
@Override
public Map<String, ParameterSpace> getNestedSpaces() {
return Collections.emptyMap();
}
@Override
public boolean isLeaf() {
return true;
}
@Override
public void setIndices(int... indices) {
if (indices == null || indices.length != 1) {
throw new IllegalArgumentException("Invalid index");
}
this.index = indices[0];
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("DiscreteParameterSpace(");
int n = values.size();
for (int i = 0; i < n; i++) {
P value = values.get(i);
sb.append(ObjectUtils.valueToString(value));
sb.append((i == n - 1 ? ")" : ","));
}
return sb.toString();
}
}

View File

@ -1,151 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.parameter.integer;
import lombok.NoArgsConstructor;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.distribution.DistributionUtils;
import org.deeplearning4j.arbiter.optimize.serde.jackson.IntegerDistributionDeserializer;
import org.deeplearning4j.arbiter.optimize.serde.jackson.IntegerDistributionSerializer;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* IntegerParameterSpace is a {@code ParameterSpace<Integer>}; i.e., defines an ordered space of integers between
* some minimum and maximum value
*
* @author Alex Black
*/
@JsonIgnoreProperties({"min", "max"})
@NoArgsConstructor
public class IntegerParameterSpace implements ParameterSpace<Integer> {
@JsonSerialize(using = IntegerDistributionSerializer.class)
@JsonDeserialize(using = IntegerDistributionDeserializer.class)
private IntegerDistribution distribution;
private int index = -1;
/**
* Create an IntegerParameterSpace with a uniform distribution between the specified min/max (inclusive)
*
* @param min Min value, inclusive
* @param max Max value, inclusive
*/
public IntegerParameterSpace(int min, int max) {
this(new UniformIntegerDistribution(min, max));
}
/**
* Crate an IntegerParametSpace from the given IntegerDistribution
*
* @param distribution Distribution to use
*/
@JsonCreator
public IntegerParameterSpace(@JsonProperty("distribution") IntegerDistribution distribution) {
this.distribution = distribution;
}
public int getMin() {
return distribution.getSupportLowerBound();
}
public int getMax() {
return distribution.getSupportUpperBound();
}
@Override
public Integer getValue(double[] input) {
if (index == -1) {
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
}
return distribution.inverseCumulativeProbability(input[index]);
}
@Override
public int numParameters() {
return 1;
}
@Override
public List<ParameterSpace> collectLeaves() {
return Collections.singletonList((ParameterSpace) this);
}
@Override
public Map<String, ParameterSpace> getNestedSpaces() {
return Collections.emptyMap();
}
@Override
public boolean isLeaf() {
return true;
}
@Override
public void setIndices(int... indices) {
if (indices == null || indices.length != 1)
throw new IllegalArgumentException("Invalid index");
this.index = indices[0];
}
@Override
public String toString() {
if (distribution instanceof UniformIntegerDistribution) {
return "IntegerParameterSpace(min=" + distribution.getSupportLowerBound() + ",max="
+ distribution.getSupportUpperBound() + ")";
} else {
return "IntegerParameterSpace(" + distribution + ")";
}
}
public boolean equals(Object o) {
if (o == this)
return true;
if (!(o instanceof IntegerParameterSpace))
return false;
final IntegerParameterSpace other = (IntegerParameterSpace) o;
if (!other.canEqual(this))
return false;
if (distribution == null ? other.distribution != null
: !DistributionUtils.distributionEquals(distribution, other.distribution))
return false;
return this.index == other.index;
}
public int hashCode() {
final int PRIME = 59;
int result = 1;
result = result * PRIME + (distribution == null ? 43 : distribution.getClass().hashCode());
result = result * PRIME + this.index;
return result;
}
protected boolean canEqual(Object other) {
return other instanceof IntegerParameterSpace;
}
}

View File

@ -1,71 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.parameter.math;
import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import java.util.List;
/**
* A simple parameter space that implements scalar mathematical operations on another parameter space. This allows you
* to do things like Y = X * 2, where X is a parameter space. For example, a layer size hyperparameter could be set
* using this to 2x the size of the previous layer
*
* @param <T> Type of the parameter space
* @author Alex Black
*/
public class MathOp<T extends Number> extends AbstractParameterSpace<T> {
private ParameterSpace<T> parameterSpace;
private Op op;
private T scalar;
public MathOp(ParameterSpace<T> parameterSpace, Op op, T scalar){
this.parameterSpace = parameterSpace;
this.op = op;
this.scalar = scalar;
}
@Override
public T getValue(double[] parameterValues) {
T u = parameterSpace.getValue(parameterValues);
return op.doOp(u, scalar);
}
@Override
public int numParameters() {
return parameterSpace.numParameters();
}
@Override
public List<ParameterSpace> collectLeaves() {
return parameterSpace.collectLeaves();
}
@Override
public boolean isLeaf() {
return false;
}
@Override
public void setIndices(int... indices) {
parameterSpace.setIndices(indices);
}
}

View File

@ -1,78 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.parameter.math;
public enum Op {
ADD, SUB, MUL, DIV;
//Package private
<T extends Number> T doOp(T first, T second){
if(first instanceof Integer || first instanceof Long){
long result;
switch (this){
case ADD:
result = Long.valueOf(first.longValue() + second.longValue());
break;
case SUB:
result = Long.valueOf(first.longValue() - second.longValue());
break;
case MUL:
result = Long.valueOf(first.longValue() * second.longValue());
break;
case DIV:
result = Long.valueOf(first.longValue() / second.longValue());
break;
default:
throw new UnsupportedOperationException("Unknown op: " + this);
}
if(first instanceof Long){
return (T)Long.valueOf(result);
} else {
return (T)Integer.valueOf((int)result);
}
} else if(first instanceof Double || first instanceof Float){
double result;
switch (this){
case ADD:
result = Double.valueOf(first.doubleValue() + second.doubleValue());
break;
case SUB:
result = Double.valueOf(first.doubleValue() - second.doubleValue());
break;
case MUL:
result = Double.valueOf(first.doubleValue() * second.doubleValue());
break;
case DIV:
result = Double.valueOf(first.doubleValue() / second.doubleValue());
break;
default:
throw new UnsupportedOperationException("Unknown op: " + this);
}
if(first instanceof Double){
return (T)Double.valueOf(result);
} else {
return (T)Float.valueOf((float)result);
}
} else {
throw new UnsupportedOperationException("Not supported type: only Integer, Long, Double, Float supported" +
" here. Got type: " + first.getClass());
}
}
}

View File

@ -1,81 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.parameter.math;
import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* A simple parameter space that implements pairwise mathematical operations on another parameter space. This allows you
* to do things like Z = X + Y, where X and Y are parameter spaces.
*
* @param <T> Type of the parameter space
* @author Alex Black
*/
public class PairMathOp<T extends Number> extends AbstractParameterSpace<T> {
private ParameterSpace<T> first;
private ParameterSpace<T> second;
private Op op;
public PairMathOp(ParameterSpace<T> first, ParameterSpace<T> second, Op op){
this.first = first;
this.second = second;
this.op = op;
}
@Override
public T getValue(double[] parameterValues) {
T f = first.getValue(parameterValues);
T s = second.getValue(parameterValues);
return op.doOp(f, s);
}
@Override
public int numParameters() {
return first.numParameters() + second.numParameters();
}
@Override
public List<ParameterSpace> collectLeaves() {
List<ParameterSpace> l = new ArrayList<>();
l.addAll(first.collectLeaves());
l.addAll(second.collectLeaves());
return l;
}
@Override
public boolean isLeaf() {
return false;
}
@Override
public void setIndices(int... indices) {
int n1 = first.numParameters();
int n2 = second.numParameters();
int[] s1 = Arrays.copyOfRange(indices, 0, n1);
int[] s2 = Arrays.copyOfRange(indices, n1, n1+n2);
first.setIndices(s1);
second.setIndices(s2);
}
}

View File

@ -1,381 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.runner;
import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* BaseOptimization runner: responsible for scheduling tasks, saving results using the result saver, etc.
*
* @author Alex Black
*/
@Slf4j
public abstract class BaseOptimizationRunner implements IOptimizationRunner {
private static final int POLLING_FREQUENCY = 1;
private static final TimeUnit POLLING_FREQUENCY_UNIT = TimeUnit.SECONDS;
protected OptimizationConfiguration config;
protected Queue<Future<OptimizationResult>> queuedFutures = new ConcurrentLinkedQueue<>();
protected BlockingQueue<Future<OptimizationResult>> completedFutures = new LinkedBlockingQueue<>();
protected AtomicInteger totalCandidateCount = new AtomicInteger();
protected AtomicInteger numCandidatesCompleted = new AtomicInteger();
protected AtomicInteger numCandidatesFailed = new AtomicInteger();
protected Double bestScore = null;
protected Long bestScoreTime = null;
protected AtomicInteger bestScoreCandidateIndex = new AtomicInteger(-1);
protected List<ResultReference> allResults = new ArrayList<>();
protected Map<Integer, CandidateInfo> currentStatus = new ConcurrentHashMap<>(); //TODO: better design possible?
protected ExecutorService futureListenerExecutor;
protected List<StatusListener> statusListeners = new ArrayList<>();
protected BaseOptimizationRunner(OptimizationConfiguration config) {
this.config = config;
if (config.getTerminationConditions() == null || config.getTerminationConditions().size() == 0) {
throw new IllegalArgumentException("Cannot create BaseOptimizationRunner without TerminationConditions ("
+ "termination conditions are null or empty)");
}
}
protected void init() {
futureListenerExecutor = Executors.newFixedThreadPool(maxConcurrentTasks(), new ThreadFactory() {
private AtomicLong counter = new AtomicLong(0);
@Override
public Thread newThread(Runnable r) {
Thread t = Executors.defaultThreadFactory().newThread(r);
t.setDaemon(true);
t.setName("ArbiterOptimizationRunner-" + counter.getAndIncrement());
return t;
}
});
}
/**
*
*/
@Override
public void execute() {
log.info("{}: execution started", this.getClass().getSimpleName());
config.setExecutionStartTime(System.currentTimeMillis());
for (StatusListener listener : statusListeners) {
listener.onInitialization(this);
}
//Initialize termination conditions (start timers, etc)
for (TerminationCondition c : config.getTerminationConditions()) {
c.initialize(this);
}
//Queue initial tasks:
List<Future<OptimizationResult>> tempList = new ArrayList<>(100);
while (true) {
//Otherwise: add tasks if required
Future<OptimizationResult> future = null;
try {
future = completedFutures.poll(POLLING_FREQUENCY, POLLING_FREQUENCY_UNIT);
} catch (InterruptedException e) {
//No op?
}
if (future != null) {
tempList.add(future);
}
completedFutures.drainTo(tempList);
//Process results (if any)
for (Future<OptimizationResult> f : tempList) {
queuedFutures.remove(f);
processReturnedTask(f);
}
if (tempList.size() > 0) {
for (StatusListener sl : statusListeners) {
sl.onRunnerStatusChange(this);
}
}
tempList.clear();
//Check termination conditions:
if (terminate()) {
shutdown(true);
break;
}
//Add additional tasks
while (config.getCandidateGenerator().hasMoreCandidates() && queuedFutures.size() < maxConcurrentTasks()) {
Candidate candidate = config.getCandidateGenerator().getCandidate();
CandidateInfo status;
if (candidate.getException() != null) {
//Failed on generation...
status = processFailedCandidates(candidate);
} else {
long created = System.currentTimeMillis();
ListenableFuture<OptimizationResult> f;
if(config.getDataSource() != null){
f = execute(candidate, config.getDataSource(), config.getDataSourceProperties(), config.getScoreFunction());
} else {
f = execute(candidate, config.getDataProvider(), config.getScoreFunction());
}
f.addListener(new OnCompletionListener(f), futureListenerExecutor);
queuedFutures.add(f);
totalCandidateCount.getAndIncrement();
status = new CandidateInfo(candidate.getIndex(), CandidateStatus.Created, null,
created, null, null, candidate.getFlatParameters(), null);
currentStatus.put(candidate.getIndex(), status);
}
for (StatusListener listener : statusListeners) {
listener.onCandidateStatusChange(status, this, null);
}
}
}
//Process any final (completed) tasks:
completedFutures.drainTo(tempList);
for (Future<OptimizationResult> f : tempList) {
queuedFutures.remove(f);
processReturnedTask(f);
}
tempList.clear();
log.info("Optimization runner: execution complete");
for (StatusListener listener : statusListeners) {
listener.onShutdown(this);
}
}
private CandidateInfo processFailedCandidates(Candidate<?> candidate) {
//In case the candidate fails during the creation of the candidate
long time = System.currentTimeMillis();
String stackTrace = ExceptionUtils.getStackTrace(candidate.getException());
CandidateInfo newStatus = new CandidateInfo(candidate.getIndex(), CandidateStatus.Failed, null, time, time,
time, candidate.getFlatParameters(), stackTrace);
currentStatus.put(candidate.getIndex(), newStatus);
return newStatus;
}
/**
* Process returned task (either completed or failed
*/
private void processReturnedTask(Future<OptimizationResult> future) {
long currentTime = System.currentTimeMillis();
OptimizationResult result;
try {
result = future.get(100, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
throw new RuntimeException("Unexpected InterruptedException thrown for task", e);
} catch (ExecutionException e) {
//Note that most of the time, an OptimizationResult is returned even for an exception
//This is just to handle any that are missed there (or, by implementations that don't properly do this)
log.warn("Task failed", e);
numCandidatesFailed.getAndIncrement();
return;
} catch (TimeoutException e) {
throw new RuntimeException(e); //TODO
}
//Update internal status:
CandidateInfo status = currentStatus.get(result.getIndex());
CandidateInfo newStatus = new CandidateInfo(result.getIndex(), result.getCandidateInfo().getCandidateStatus(),
result.getScore(), status.getCreatedTime(), result.getCandidateInfo().getStartTime(),
currentTime, status.getFlatParams(), result.getCandidateInfo().getExceptionStackTrace());
currentStatus.put(result.getIndex(), newStatus);
//Listeners (on complete, etc) should be executed in underlying task
if (result.getCandidateInfo().getCandidateStatus() == CandidateStatus.Failed) {
log.info("Task {} failed during execution: {}", result.getIndex(), result.getCandidateInfo().getExceptionStackTrace());
numCandidatesFailed.getAndIncrement();
} else {
//Report completion to candidate generator
config.getCandidateGenerator().reportResults(result);
Double score = result.getScore();
log.info("Completed task {}, score = {}", result.getIndex(), result.getScore());
boolean minimize = config.getScoreFunction().minimize();
if (score != null && (bestScore == null
|| ((minimize && score < bestScore) || (!minimize && score > bestScore)))) {
if (bestScore == null) {
log.info("New best score: {} (first completed model)", score);
} else {
int idx = result.getIndex();
int lastBestIdx = bestScoreCandidateIndex.get();
log.info("New best score: {}, model {} (prev={}, model {})", score, idx, bestScore, lastBestIdx);
}
bestScore = score;
bestScoreTime = System.currentTimeMillis();
bestScoreCandidateIndex.set(result.getIndex());
}
numCandidatesCompleted.getAndIncrement();
//Model saving is done in the optimization tasks, to avoid CUDA threading issues
ResultReference resultReference = result.getResultReference();
if (resultReference != null)
allResults.add(resultReference);
}
}
@Override
public int numCandidatesTotal() {
return totalCandidateCount.get();
}
@Override
public int numCandidatesCompleted() {
return numCandidatesCompleted.get();
}
@Override
public int numCandidatesFailed() {
return numCandidatesFailed.get();
}
@Override
public int numCandidatesQueued() {
return queuedFutures.size();
}
@Override
public Double bestScore() {
return bestScore;
}
@Override
public Long bestScoreTime() {
return bestScoreTime;
}
@Override
public int bestScoreCandidateIndex() {
return bestScoreCandidateIndex.get();
}
@Override
public List<ResultReference> getResults() {
return new ArrayList<>(allResults);
}
@Override
public OptimizationConfiguration getConfiguration() {
return config;
}
@Override
public void addListeners(StatusListener... listeners) {
for (StatusListener l : listeners) {
if (!statusListeners.contains(l)) {
statusListeners.add(l);
}
}
}
@Override
public void removeListeners(StatusListener... listeners) {
for (StatusListener l : listeners) {
statusListeners.remove(l);
}
}
@Override
public void removeAllListeners() {
statusListeners.clear();
}
@Override
public List<CandidateInfo> getCandidateStatus() {
return new ArrayList<>(currentStatus.values());
}
private boolean terminate() {
for (TerminationCondition c : config.getTerminationConditions()) {
if (c.terminate(this)) {
log.info("BaseOptimizationRunner global termination condition hit: {}", c);
return true;
}
}
return false;
}
@AllArgsConstructor
@Data
private class FutureDetails {
private final Future<OptimizationResult> future;
private final long startTime;
private final int index;
}
@AllArgsConstructor
private class OnCompletionListener implements Runnable {
private Future<OptimizationResult> future;
@Override
public void run() {
completedFutures.add(future);
}
}
protected abstract int maxConcurrentTasks();
@Deprecated
protected abstract ListenableFuture<OptimizationResult> execute(Candidate candidate, DataProvider dataProvider,
ScoreFunction scoreFunction);
@Deprecated
protected abstract List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates,
DataProvider dataProvider, ScoreFunction scoreFunction);
protected abstract ListenableFuture<OptimizationResult> execute(Candidate candidate, Class<? extends DataSource> dataSource,
Properties dataSourceProperties, ScoreFunction scoreFunction);
protected abstract List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates, Class<? extends DataSource> dataSource,
Properties dataSourceProperties, ScoreFunction scoreFunction);
}

View File

@ -1,43 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.runner;
import lombok.AllArgsConstructor;
import lombok.Data;
/**
* Simple helper class to store status of a candidate that is/has been/will be executed
*/
@AllArgsConstructor
@Data
public class CandidateInfo {
public CandidateInfo() {
//No arg constructor for Jackson
}
private int index;
private CandidateStatus candidateStatus;
private Double score;
private long createdTime;
private Long startTime;
private Long endTime;
private double[] flatParams; //Same as parameters in Candidate class
private String exceptionStackTrace;
}

View File

@ -1,26 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.runner;
/**
* Status for candidates
*/
public enum CandidateStatus {
Created, Running, Complete, Failed, Cancelled
}

View File

@ -1,69 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.runner;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.util.List;
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface IOptimizationRunner {
void execute();
/** Total number of candidates: created (scheduled), completed and failed */
int numCandidatesTotal();
int numCandidatesCompleted();
int numCandidatesFailed();
/** Number of candidates running or queued */
int numCandidatesQueued();
/** Best score found so far */
Double bestScore();
/** Time that the best score was found at, or 0 if no jobs have completed successfully */
Long bestScoreTime();
/** Index of the best scoring candidate, or -1 if no candidate has scored yet*/
int bestScoreCandidateIndex();
List<ResultReference> getResults();
OptimizationConfiguration getConfiguration();
void addListeners(StatusListener... listeners);
void removeListeners(StatusListener... listeners);
void removeAllListeners();
List<CandidateInfo> getCandidateStatus();
/**
* @param awaitCompletion If true: await completion of currently scheduled tasks. If false: shutdown immediately,
* cancelling any currently executing tasks
*/
void shutdown(boolean awaitCompletion);
}

View File

@ -1,152 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.runner;
import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
import org.nd4j.shade.guava.util.concurrent.ListeningExecutorService;
import org.nd4j.shade.guava.util.concurrent.MoreExecutors;
import lombok.Setter;
import org.deeplearning4j.arbiter.optimize.api.*;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;
/**
* LocalOptimizationRunner: execute hyperparameter optimization
* locally (on current machine, in current JVM).
*
* @author Alex Black
*/
public class LocalOptimizationRunner extends BaseOptimizationRunner {
public static final int DEFAULT_MAX_CONCURRENT_TASKS = 1;
private final int maxConcurrentTasks;
private TaskCreator taskCreator;
private ListeningExecutorService executor;
@Setter
private long shutdownMaxWaitMS = 2L * 24 * 60 * 60 * 1000;
public LocalOptimizationRunner(OptimizationConfiguration config){
this(config, null);
}
public LocalOptimizationRunner(OptimizationConfiguration config, TaskCreator taskCreator) {
this(DEFAULT_MAX_CONCURRENT_TASKS, config, taskCreator);
}
public LocalOptimizationRunner(int maxConcurrentTasks, OptimizationConfiguration config){
this(maxConcurrentTasks, config, null);
}
public LocalOptimizationRunner(int maxConcurrentTasks, OptimizationConfiguration config, TaskCreator taskCreator) {
super(config);
if (maxConcurrentTasks <= 0)
throw new IllegalArgumentException("maxConcurrentTasks must be > 0 (got: " + maxConcurrentTasks + ")");
this.maxConcurrentTasks = maxConcurrentTasks;
if(taskCreator == null){
Class<? extends ParameterSpace> psClass = config.getCandidateGenerator().getParameterSpace().getClass();
taskCreator = TaskCreatorProvider.defaultTaskCreatorFor(psClass);
if(taskCreator == null){
throw new IllegalStateException("No TaskCreator was provided and a default TaskCreator cannot be " +
"inferred for ParameterSpace class " + psClass.getName() + ". Please provide a TaskCreator " +
"via the LocalOptimizationRunner constructor");
}
}
this.taskCreator = taskCreator;
ExecutorService exec = Executors.newFixedThreadPool(maxConcurrentTasks, new ThreadFactory() {
private AtomicLong counter = new AtomicLong(0);
@Override
public Thread newThread(Runnable r) {
Thread t = Executors.defaultThreadFactory().newThread(r);
t.setDaemon(true);
t.setName("LocalCandidateExecutor-" + counter.getAndIncrement());
return t;
}
});
executor = MoreExecutors.listeningDecorator(exec);
init();
}
@Override
protected int maxConcurrentTasks() {
return maxConcurrentTasks;
}
@Override
protected ListenableFuture<OptimizationResult> execute(Candidate candidate, DataProvider dataProvider,
ScoreFunction scoreFunction) {
return execute(Collections.singletonList(candidate), dataProvider, scoreFunction).get(0);
}
@Override
protected List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates, DataProvider dataProvider,
ScoreFunction scoreFunction) {
List<ListenableFuture<OptimizationResult>> list = new ArrayList<>(candidates.size());
for (Candidate candidate : candidates) {
Callable<OptimizationResult> task =
taskCreator.create(candidate, dataProvider, scoreFunction, statusListeners, this);
list.add(executor.submit(task));
}
return list;
}
@Override
protected ListenableFuture<OptimizationResult> execute(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties, ScoreFunction scoreFunction) {
return execute(Collections.singletonList(candidate), dataSource, dataSourceProperties, scoreFunction).get(0);
}
@Override
protected List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates, Class<? extends DataSource> dataSource, Properties dataSourceProperties, ScoreFunction scoreFunction) {
List<ListenableFuture<OptimizationResult>> list = new ArrayList<>(candidates.size());
for (Candidate candidate : candidates) {
Callable<OptimizationResult> task = taskCreator.create(candidate, dataSource, dataSourceProperties, scoreFunction, statusListeners, this);
list.add(executor.submit(task));
}
return list;
}
@Override
public void shutdown(boolean awaitTermination) {
if(awaitTermination){
try {
executor.shutdown();
executor.awaitTermination(shutdownMaxWaitMS, TimeUnit.MILLISECONDS);
} catch (InterruptedException e){
throw new RuntimeException(e);
}
} else {
executor.shutdownNow();
}
}
}

View File

@ -1,56 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.runner.listener;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
/**
* BaseStatusListener: implements all methods of {@link StatusListener} as no-op.
* Users can extend this and override only the methods actually required
*
* @author Alex Black
*/
public abstract class BaseStatusListener implements StatusListener{
@Override
public void onInitialization(IOptimizationRunner runner) {
//No op
}
@Override
public void onShutdown(IOptimizationRunner runner) {
//No op
}
@Override
public void onRunnerStatusChange(IOptimizationRunner runner) {
//No op
}
@Override
public void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner, OptimizationResult result) {
//No op
}
@Override
public void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration) {
//No op
}
}

View File

@ -1,28 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.runner.listener;
/**
* Created by Alex on 20/07/2017.
*/
public enum StatusChangeType {
CandidateCompleted, CandidateFailed, CandidateNewScheduled, CandidateNewBestScore
}

View File

@ -1,62 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.runner.listener;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
/**
* The status Listener interface is used to inspect/track the status of execution, both for individual candidates,
* and for the optimisation runner overall.
*
* @author Alex Black
*/
public interface StatusListener {
/** Called when optimization runner starts execution */
void onInitialization(IOptimizationRunner runner);
/** Called when optimization runner terminates */
void onShutdown(IOptimizationRunner runner);
/** Called when any of the summary stats change, for the optimization runner:
* number scheduled, number completed, number failed, best score, etc. */
void onRunnerStatusChange(IOptimizationRunner runner);
/**
* Called when the status of the candidate is change. For example created, completed, failed.
*
* @param candidateInfo Candidate information
* @param runner Optimisation runner calling this method
* @param result Optimisation result. Maybe null.
*/
void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner, OptimizationResult result);
/**
* This method may be called by tasks as they are executing. The intent of this method is to report partial results,
* such as different stages of learning, or scores/evaluations so far
*
* @param candidateInfo Candidate information
* @param candidate Current candidate value/configuration
* @param iteration Current iteration number
*/
void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration);
}

View File

@ -1,59 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.runner.listener.impl;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
/**
* Created by Alex on 20/07/2017.
*/
@Slf4j
public class LoggingStatusListener implements StatusListener {
@Override
public void onInitialization(IOptimizationRunner runner) {
log.info("Optimization runner: initialized");
}
@Override
public void onShutdown(IOptimizationRunner runner) {
log.info("Optimization runner: shut down");
}
@Override
public void onRunnerStatusChange(IOptimizationRunner runner) {
log.info("Optimization runner: status change");
}
@Override
public void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner,
OptimizationResult result) {
log.info("Candidate status change: {}", candidateInfo);
}
@Override
public void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration) {
log.info("Candidate iteration #{} - {}", iteration, candidate);
}
}

View File

@ -1,70 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.apache.commons.codec.binary.Base64;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
/**
* A custom deserializer to be used in conjunction with {@link FixedValueSerializer}
* @author Alex Black
*/
public class FixedValueDeserializer extends JsonDeserializer<FixedValue> {
@Override
public FixedValue deserialize(JsonParser p, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
JsonNode node = p.getCodec().readTree(p);
String className = node.get("@valueclass").asText();
Class<?> c;
try {
c = Class.forName(className);
} catch (Exception e) {
throw new RuntimeException(e);
}
if(node.has("value")){
//Number, String, Enum
JsonNode valueNode = node.get("value");
Object o = new ObjectMapper().treeToValue(valueNode, c);
return new FixedValue<>(o);
} else {
//Everything else
JsonNode valueNode = node.get("data");
String data = valueNode.asText();
byte[] b = new Base64().decode(data);
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(b));
try {
Object o = ois.readObject();
return new FixedValue<>(o);
} catch (Throwable t) {
throw new RuntimeException(t);
}
}
}
}

View File

@ -1,69 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.apache.commons.net.util.Base64;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.core.type.WritableTypeId;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;
import org.nd4j.shade.jackson.databind.jsontype.TypeSerializer;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import static org.nd4j.shade.jackson.core.JsonToken.START_OBJECT;
/**
* A custom serializer to handle arbitrary object types
* Uses standard JSON where safe (number, string, enumerations) or Java object serialization (bytes -> base64)
* The latter is not an ideal approach, but Jackson doesn't support serialization/deserialization of arbitrary
* objects very well
*
* @author Alex Black
*/
public class FixedValueSerializer extends JsonSerializer<FixedValue> {
@Override
public void serialize(FixedValue fixedValue, JsonGenerator j, SerializerProvider serializerProvider) throws IOException {
Object o = fixedValue.getValue();
j.writeStringField("@valueclass", o.getClass().getName());
if(o instanceof Number || o instanceof String || o instanceof Enum){
j.writeObjectField("value", o);
} else {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(o);
baos.close();
byte[] b = baos.toByteArray();
String base64 = new Base64().encodeToString(b);
j.writeStringField("data", base64);
}
}
@Override
public void serializeWithType(FixedValue value, JsonGenerator gen, SerializerProvider serializers, TypeSerializer typeSer) throws IOException {
WritableTypeId typeId = typeSer.typeId(value, START_OBJECT);
typeSer.writeTypePrefix(gen, typeId);
serialize(value, gen, serializers);
typeSer.writeTypeSuffix(gen, typeId);
}
}

View File

@ -1,61 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.apache.commons.math3.distribution.*;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import java.io.IOException;
/**
* Custom Jackson deserializer for integer distributions
*
* @author Alex Black
*/
public class IntegerDistributionDeserializer extends JsonDeserializer<IntegerDistribution> {
@Override
public IntegerDistribution deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
JsonNode node = p.getCodec().readTree(p);
String simpleName = node.get("distribution").asText();
switch (simpleName) {
case "BinomialDistribution":
return new BinomialDistribution(node.get("trials").asInt(), node.get("p").asDouble());
case "GeometricDistribution":
return new GeometricDistribution(node.get("p").asDouble());
case "HypergeometricDistribution":
return new HypergeometricDistribution(node.get("populationSize").asInt(),
node.get("numberOfSuccesses").asInt(), node.get("sampleSize").asInt());
case "PascalDistribution":
return new PascalDistribution(node.get("r").asInt(), node.get("p").asDouble());
case "PoissonDistribution":
return new PoissonDistribution(node.get("p").asDouble());
case "UniformIntegerDistribution":
return new UniformIntegerDistribution(node.get("lower").asInt(), node.get("upper").asInt());
case "ZipfDistribution":
return new ZipfDistribution(node.get("numElements").asInt(), node.get("exponent").asDouble());
default:
throw new RuntimeException("Unknown or not supported distribution: " + simpleName);
}
}
}

View File

@ -1,76 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.apache.commons.math3.distribution.*;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;
import java.io.IOException;
/**
* Custom Jackson serializer for integer distributions
*
* @author Alex Black
*/
public class IntegerDistributionSerializer extends JsonSerializer<IntegerDistribution> {
@Override
public void serialize(IntegerDistribution d, JsonGenerator j, SerializerProvider serializerProvider)
throws IOException {
Class<?> c = d.getClass();
String s = c.getSimpleName();
j.writeStartObject();
j.writeStringField("distribution", s);
if (c == BinomialDistribution.class) {
BinomialDistribution bd = (BinomialDistribution) d;
j.writeNumberField("trials", bd.getNumberOfTrials());
j.writeNumberField("p", bd.getProbabilityOfSuccess());
} else if (c == GeometricDistribution.class) {
GeometricDistribution gd = (GeometricDistribution) d;
j.writeNumberField("p", gd.getProbabilityOfSuccess());
} else if (c == HypergeometricDistribution.class) {
HypergeometricDistribution hd = (HypergeometricDistribution) d;
j.writeNumberField("populationSize", hd.getPopulationSize());
j.writeNumberField("numberOfSuccesses", hd.getNumberOfSuccesses());
j.writeNumberField("sampleSize", hd.getSampleSize());
} else if (c == PascalDistribution.class) {
PascalDistribution pd = (PascalDistribution) d;
j.writeNumberField("r", pd.getNumberOfSuccesses());
j.writeNumberField("p", pd.getProbabilityOfSuccess());
} else if (c == PoissonDistribution.class) {
PoissonDistribution pd = (PoissonDistribution) d;
j.writeNumberField("p", pd.getMean());
} else if (c == UniformIntegerDistribution.class) {
UniformIntegerDistribution ud = (UniformIntegerDistribution) d;
j.writeNumberField("lower", ud.getSupportLowerBound());
j.writeNumberField("upper", ud.getSupportUpperBound());
} else if (c == ZipfDistribution.class) {
ZipfDistribution zd = (ZipfDistribution) d;
j.writeNumberField("numElements", zd.getNumberOfElements());
j.writeNumberField("exponent", zd.getExponent());
} else {
throw new UnsupportedOperationException("Unknown or not supported IntegerDistribution: " + c);
}
j.writeEndObject();
}
}

View File

@ -1,76 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
/**
* Created by Alex on 16/11/2016.
*/
public class JsonMapper {
private static ObjectMapper mapper;
private static ObjectMapper yamlMapper;
static {
mapper = new ObjectMapper();
mapper.registerModule(new JodaModule());
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
mapper.enable(SerializationFeature.INDENT_OUTPUT);
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
mapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY);
yamlMapper = new ObjectMapper(new YAMLFactory());
yamlMapper.registerModule(new JodaModule());
yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
yamlMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
yamlMapper.enable(SerializationFeature.INDENT_OUTPUT);
yamlMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
yamlMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
yamlMapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
}
private JsonMapper() {}
/**
* Return the yaml mapper
* @return
*/
public static ObjectMapper getYamlMapper() {
return yamlMapper;
}
/**
* Return a json mapper
* @return
*/
public static ObjectMapper getMapper() {
return mapper;
}
}

View File

@ -1,80 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.apache.commons.math3.distribution.*;
import org.deeplearning4j.arbiter.optimize.distribution.LogUniformDistribution;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import java.io.IOException;
/**
* Created by Alex on 14/02/2017.
*/
public class RealDistributionDeserializer extends JsonDeserializer<RealDistribution> {
@Override
public RealDistribution deserialize(JsonParser p, DeserializationContext ctxt)
throws IOException, JsonProcessingException {
JsonNode node = p.getCodec().readTree(p);
String simpleName = node.get("distribution").asText();
switch (simpleName) {
case "BetaDistribution":
return new BetaDistribution(node.get("alpha").asDouble(), node.get("beta").asDouble());
case "CauchyDistribution":
return new CauchyDistribution(node.get("median").asDouble(), node.get("scale").asDouble());
case "ChiSquaredDistribution":
return new ChiSquaredDistribution(node.get("dof").asDouble());
case "ExponentialDistribution":
return new ExponentialDistribution(node.get("mean").asDouble());
case "FDistribution":
return new FDistribution(node.get("numeratorDof").asDouble(), node.get("denominatorDof").asDouble());
case "GammaDistribution":
return new GammaDistribution(node.get("shape").asDouble(), node.get("scale").asDouble());
case "LevyDistribution":
return new LevyDistribution(node.get("mu").asDouble(), node.get("c").asDouble());
case "LogNormalDistribution":
return new LogNormalDistribution(node.get("scale").asDouble(), node.get("shape").asDouble());
case "NormalDistribution":
return new NormalDistribution(node.get("mean").asDouble(), node.get("stdev").asDouble());
case "ParetoDistribution":
return new ParetoDistribution(node.get("scale").asDouble(), node.get("shape").asDouble());
case "TDistribution":
return new TDistribution(node.get("dof").asDouble());
case "TriangularDistribution":
return new TriangularDistribution(node.get("a").asDouble(), node.get("b").asDouble(),
node.get("c").asDouble());
case "UniformRealDistribution":
return new UniformRealDistribution(node.get("lower").asDouble(), node.get("upper").asDouble());
case "WeibullDistribution":
return new WeibullDistribution(node.get("alpha").asDouble(), node.get("beta").asDouble());
case "LogUniformDistribution":
return new LogUniformDistribution(node.get("min").asDouble(), node.get("max").asDouble());
default:
throw new RuntimeException("Unknown or not supported distribution: " + simpleName);
}
}
}

View File

@ -1,109 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.apache.commons.math3.distribution.*;
import org.deeplearning4j.arbiter.optimize.distribution.LogUniformDistribution;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;
import java.io.IOException;
/**
* Custom JSON serializer for Apache commons RealDistribution instances.
* The custom serializer is set up to use the built-in c
*/
public class RealDistributionSerializer extends JsonSerializer<RealDistribution> {
@Override
public void serialize(RealDistribution d, JsonGenerator j, SerializerProvider serializerProvider)
throws IOException {
Class<?> c = d.getClass();
String s = c.getSimpleName();
j.writeStartObject();
j.writeStringField("distribution", s);
if (c == BetaDistribution.class) {
BetaDistribution bd = (BetaDistribution) d;
j.writeNumberField("alpha", bd.getAlpha());
j.writeNumberField("beta", bd.getBeta());
} else if (c == CauchyDistribution.class) {
CauchyDistribution cd = (CauchyDistribution) d;
j.writeNumberField("median", cd.getMedian());
j.writeNumberField("scale", cd.getScale());
} else if (c == ChiSquaredDistribution.class) {
ChiSquaredDistribution cd = (ChiSquaredDistribution) d;
j.writeNumberField("dof", cd.getDegreesOfFreedom());
} else if (c == ExponentialDistribution.class) {
ExponentialDistribution ed = (ExponentialDistribution) d;
j.writeNumberField("mean", ed.getMean());
} else if (c == FDistribution.class) {
FDistribution fd = (FDistribution) d;
j.writeNumberField("numeratorDof", fd.getNumeratorDegreesOfFreedom());
j.writeNumberField("denominatorDof", fd.getDenominatorDegreesOfFreedom());
} else if (c == GammaDistribution.class) {
GammaDistribution gd = (GammaDistribution) d;
j.writeNumberField("shape", gd.getShape());
j.writeNumberField("scale", gd.getScale());
} else if (c == LevyDistribution.class) {
LevyDistribution ld = (LevyDistribution) d;
j.writeNumberField("mu", ld.getLocation());
j.writeNumberField("c", ld.getScale());
} else if (c == LogNormalDistribution.class) {
LogNormalDistribution ln = (LogNormalDistribution) d;
j.writeNumberField("scale", ln.getScale());
j.writeNumberField("shape", ln.getShape());
} else if (c == NormalDistribution.class) {
NormalDistribution nd = (NormalDistribution) d;
j.writeNumberField("mean", nd.getMean());
j.writeNumberField("stdev", nd.getStandardDeviation());
} else if (c == ParetoDistribution.class) {
ParetoDistribution pd = (ParetoDistribution) d;
j.writeNumberField("scale", pd.getScale());
j.writeNumberField("shape", pd.getShape());
} else if (c == TDistribution.class) {
TDistribution td = (TDistribution) d;
j.writeNumberField("dof", td.getDegreesOfFreedom());
} else if (c == TriangularDistribution.class) {
TriangularDistribution td = (TriangularDistribution) d;
j.writeNumberField("a", td.getSupportLowerBound());
j.writeNumberField("b", td.getMode());
j.writeNumberField("c", td.getSupportUpperBound());
} else if (c == UniformRealDistribution.class) {
UniformRealDistribution u = (UniformRealDistribution) d;
j.writeNumberField("lower", u.getSupportLowerBound());
j.writeNumberField("upper", u.getSupportUpperBound());
} else if (c == WeibullDistribution.class) {
WeibullDistribution wb = (WeibullDistribution) d;
j.writeNumberField("alpha", wb.getShape());
j.writeNumberField("beta", wb.getScale());
} else if (c == LogUniformDistribution.class){
LogUniformDistribution lud = (LogUniformDistribution) d;
j.writeNumberField("min", lud.getMin());
j.writeNumberField("max", lud.getMax());
} else {
throw new UnsupportedOperationException("Unknown or not supported RealDistribution: " + d.getClass());
}
j.writeEndObject();
}
}

View File

@ -1,54 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
/**
* Created by Alex on 16/11/2016.
*/
public class YamlMapper {
private static final ObjectMapper mapper;
static {
mapper = new ObjectMapper(new YAMLFactory());
mapper.registerModule(new JodaModule());
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
mapper.enable(SerializationFeature.INDENT_OUTPUT);
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
}
private YamlMapper() {}
public static ObjectMapper getMapper() {
return mapper;
}
}

View File

@ -1,235 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
/**
* Simple utility class used to get access to files at the classpath, or packed into jar.
* Based on Spring ClassPathResource implementation + jar internals access implemented.
*
*
* @author raver119@gmail.com
*/
public class ClassPathResource {
private String resourceName;
private static Logger log = LoggerFactory.getLogger(ClassPathResource.class);
/**
* Builds new ClassPathResource object
*
* @param resourceName String name of resource, to be retrieved
*/
public ClassPathResource(String resourceName) {
if (resourceName == null)
throw new IllegalStateException("Resource name can't be null");
this.resourceName = resourceName;
}
/**
* Returns URL of the requested resource
*
* @return URL of the resource, if it's available in current Jar
*/
private URL getUrl() {
ClassLoader loader = null;
try {
loader = Thread.currentThread().getContextClassLoader();
} catch (Exception e) {
// do nothing
}
if (loader == null) {
loader = ClassPathResource.class.getClassLoader();
}
URL url = loader.getResource(this.resourceName);
if (url == null) {
// try to check for mis-used starting slash
// TODO: see TODO below
if (this.resourceName.startsWith("/")) {
url = loader.getResource(this.resourceName.replaceFirst("[\\\\/]", ""));
if (url != null)
return url;
} else {
// try to add slash, to make clear it's not an issue
// TODO: change this mechanic to actual path purifier
url = loader.getResource("/" + this.resourceName);
if (url != null)
return url;
}
throw new IllegalStateException("Resource '" + this.resourceName + "' cannot be found.");
}
return url;
}
/**
* Returns requested ClassPathResource as File object
*
* Please note: if this method called from compiled jar, temporary file will be created to provide File access
*
* @return File requested at constructor call
* @throws FileNotFoundException
*/
public File getFile() throws FileNotFoundException {
URL url = this.getUrl();
if (isJarURL(url)) {
/*
This is actually request for file, that's packed into jar. Probably the current one, but that doesn't matters.
*/
try {
url = extractActualUrl(url);
File file = File.createTempFile("canova_temp", "file");
file.deleteOnExit();
ZipFile zipFile = new ZipFile(url.getFile());
ZipEntry entry = zipFile.getEntry(this.resourceName);
if (entry == null) {
if (this.resourceName.startsWith("/")) {
entry = zipFile.getEntry(this.resourceName.replaceFirst("/", ""));
if (entry == null) {
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
}
} else
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
}
long size = entry.getSize();
InputStream stream = zipFile.getInputStream(entry);
FileOutputStream outputStream = new FileOutputStream(file);
byte[] array = new byte[1024];
int rd = 0;
long bytesRead = 0;
do {
rd = stream.read(array);
outputStream.write(array, 0, rd);
bytesRead += rd;
} while (bytesRead < size);
outputStream.flush();
outputStream.close();
stream.close();
zipFile.close();
return file;
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
/*
It's something in the actual underlying filesystem, so we can just go for it
*/
try {
URI uri = new URI(url.toString().replaceAll(" ", "%20"));
return new File(uri.getSchemeSpecificPart());
} catch (URISyntaxException e) {
return new File(url.getFile());
}
}
}
/**
* Checks, if proposed URL is packed into archive.
*
* @param url URL to be checked
* @return True, if URL is archive entry, False otherwise
*/
private boolean isJarURL(URL url) {
String protocol = url.getProtocol();
return "jar".equals(protocol) || "zip".equals(protocol) || "wsjar".equals(protocol)
|| "code-source".equals(protocol) && url.getPath().contains("!/");
}
/**
* Extracts parent Jar URL from original ClassPath entry URL.
*
* @param jarUrl Original URL of the resource
* @return URL of the Jar file, containing requested resource
* @throws MalformedURLException
*/
private URL extractActualUrl(URL jarUrl) throws MalformedURLException {
String urlFile = jarUrl.getFile();
int separatorIndex = urlFile.indexOf("!/");
if (separatorIndex != -1) {
String jarFile = urlFile.substring(0, separatorIndex);
try {
return new URL(jarFile);
} catch (MalformedURLException var5) {
if (!jarFile.startsWith("/")) {
jarFile = "/" + jarFile;
}
return new URL("file:" + jarFile);
}
} else {
return jarUrl;
}
}
/**
* Returns requested ClassPathResource as InputStream object
*
* @return File requested at constructor call
* @throws FileNotFoundException
*/
public InputStream getInputStream() throws FileNotFoundException {
URL url = this.getUrl();
if (isJarURL(url)) {
try {
url = extractActualUrl(url);
ZipFile zipFile = new ZipFile(url.getFile());
ZipEntry entry = zipFile.getEntry(this.resourceName);
if (entry == null) {
if (this.resourceName.startsWith("/")) {
entry = zipFile.getEntry(this.resourceName.replaceFirst("/", ""));
if (entry == null) {
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
}
} else
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
}
return zipFile.getInputStream(entry);
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
File srcFile = this.getFile();
return new FileInputStream(srcFile);
}
}
}

View File

@ -1,51 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.util;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
public class CollectionUtils {
/**
* Count the number of unique values in a collection
*/
public static int countUnique(Collection<?> collection) {
HashSet<Object> set = new HashSet<>(collection);
return set.size();
}
/**
* Returns a list containing only unique values in a collection
*/
public static <T> List<T> getUnique(Collection<T> collection) {
HashSet<T> set = new HashSet<>();
List<T> out = new ArrayList<>();
for (T t : collection) {
if (!set.contains(t)) {
out.add(t);
set.add(t);
}
}
return out;
}
}

View File

@ -1,76 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.util;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import java.util.ArrayList;
import java.util.List;
/**
* Created by Alex on 29/06/2017.
*/
public class LeafUtils {
private LeafUtils() {}
/**
* Returns a list of unique objects, not using the .equals() method, but rather using ==
*
* @param allLeaves Leaf values to process
* @return A list of unique parameter space values
*/
public static List<ParameterSpace> getUniqueObjects(List<ParameterSpace> allLeaves) {
List<ParameterSpace> unique = new ArrayList<>();
for (ParameterSpace p : allLeaves) {
//This isn't especially efficient, but small number of parameters in general means it's fine
boolean found = false;
for (ParameterSpace q : unique) {
if (p == q) {
found = true;
break;
}
}
if (!found) {
unique.add(p);
}
}
return unique;
}
/**
* Count the number of unique parameters in the specified leaf nodes
*
* @param allLeaves Leaf values to count the parameters fore
* @return Number of parameters for all unique objects
*/
public static int countUniqueParameters(List<ParameterSpace> allLeaves) {
List<ParameterSpace> unique = getUniqueObjects(allLeaves);
int count = 0;
for (ParameterSpace ps : unique) {
if (!ps.isLeaf()) {
throw new IllegalStateException("Method should only be used with leaf nodes");
}
count += ps.numParameters();
}
return count;
}
}

View File

@ -1,63 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.util;
import java.util.Arrays;
/**
* @author Alex Black
*/
public class ObjectUtils {
private ObjectUtils() {}
/**
* Get the string representation of the object. Arrays, including primitive arrays, are printed using
* Arrays.toString(...) methods.
*
* @param v Value to convert to a string
* @return String representation
*/
public static String valueToString(Object v) {
if (v.getClass().isArray()) {
if (v.getClass().getComponentType().isPrimitive()) {
Class<?> c = v.getClass().getComponentType();
if (c == int.class) {
return Arrays.toString((int[]) v);
} else if (c == double.class) {
return Arrays.toString((double[]) v);
} else if (c == float.class) {
return Arrays.toString((float[]) v);
} else if (c == long.class) {
return Arrays.toString((long[]) v);
} else if (c == byte.class) {
return Arrays.toString((byte[]) v);
} else if (c == short.class) {
return Arrays.toString((short[]) v);
} else {
return v.toString();
}
} else {
return Arrays.toString((Object[]) v);
}
} else {
return v.toString();
}
}
}

View File

@ -1,51 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.deeplearning4j.arbiter.optimize";
}
@Override
protected Class<?> getBaseClass() {
return BaseDL4JTest.class;
}
}

View File

@ -1,158 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.deeplearning4j.arbiter.optimize.api.*;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.Callable;
public class BraninFunction {
public static class BraninSpace extends AbstractParameterSpace<BraninConfig> {
private int[] indices;
private ParameterSpace<Double> first = new ContinuousParameterSpace(-5, 10);
private ParameterSpace<Double> second = new ContinuousParameterSpace(0, 15);
@Override
public BraninConfig getValue(double[] parameterValues) {
double f = first.getValue(parameterValues);
double s = second.getValue(parameterValues);
return new BraninConfig(f, s); //-5 to +10 and 0 to 15
}
@Override
public int numParameters() {
return 2;
}
@Override
public List<ParameterSpace> collectLeaves() {
List<ParameterSpace> list = new ArrayList<>();
list.addAll(first.collectLeaves());
list.addAll(second.collectLeaves());
return list;
}
@Override
public boolean isLeaf() {
return false;
}
@Override
public void setIndices(int... indices) {
throw new UnsupportedOperationException();
}
}
@AllArgsConstructor
@Data
public static class BraninConfig implements Serializable {
private double x1;
private double x2;
}
public static class BraninScoreFunction implements ScoreFunction {
private static final double a = 1.0;
private static final double b = 5.1 / (4.0 * Math.PI * Math.PI);
private static final double c = 5.0 / Math.PI;
private static final double r = 6.0;
private static final double s = 10.0;
private static final double t = 1.0 / (8.0 * Math.PI);
@Override
public double score(Object m, DataProvider data, Map<String, Object> dataParameters) {
BraninConfig model = (BraninConfig) m;
double x1 = model.getX1();
double x2 = model.getX2();
return a * Math.pow(x2 - b * x1 * x1 + c * x1 - r, 2.0) + s * (1 - t) * Math.cos(x1) + s;
}
@Override
public double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties) {
throw new UnsupportedOperationException();
}
@Override
public boolean minimize() {
return true;
}
@Override
public List<Class<?>> getSupportedModelTypes() {
return Collections.<Class<?>>singletonList(BraninConfig.class);
}
@Override
public List<Class<?>> getSupportedDataTypes() {
return Collections.<Class<?>>singletonList(Object.class);
}
}
public static class BraninTaskCreator implements TaskCreator {
@Override
public Callable<OptimizationResult> create(final Candidate c, DataProvider dataProvider,
final ScoreFunction scoreFunction, final List<StatusListener> statusListeners,
IOptimizationRunner runner) {
return new Callable<OptimizationResult>() {
@Override
public OptimizationResult call() throws Exception {
BraninConfig candidate = (BraninConfig) c.getValue();
double score = scoreFunction.score(candidate, null, (Map) null);
// System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
Thread.sleep(20);
if (statusListeners != null) {
for (StatusListener sl : statusListeners) {
sl.onCandidateIteration(null, null, 0);
}
}
CandidateInfo ci = new CandidateInfo(-1, CandidateStatus.Complete, score,
System.currentTimeMillis(), null, null, null, null);
return new OptimizationResult(c, score, c.getIndex(), null, ci, null);
}
};
}
@Override
public Callable<OptimizationResult> create(Candidate candidate, Class<? extends DataSource> dataSource,
Properties dataSourceProperties, ScoreFunction scoreFunction,
List<StatusListener> statusListeners, IOptimizationRunner runner) {
throw new UnsupportedOperationException();
}
}
}

View File

@ -1,120 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.generator.GeneticSearchCandidateGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator;
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusListener;
import org.junit.Assert;
import org.junit.Test;
public class TestGeneticSearch extends BaseDL4JTest {
public class TestSelectionOperator extends SelectionOperator {
@Override
public double[] buildNextGenes() {
throw new GeneticGenerationException("Forced exception to test exception handling.");
}
}
public class TestTerminationCondition implements TerminationCondition {
public boolean hasAFailedCandidate = false;
public int evalCount = 0;
@Override
public void initialize(IOptimizationRunner optimizationRunner) {}
@Override
public boolean terminate(IOptimizationRunner optimizationRunner) {
if (++evalCount == 50) {
// Generator did not handle GeneticGenerationException
return true;
}
for (CandidateInfo candidateInfo : optimizationRunner.getCandidateStatus()) {
if (candidateInfo.getCandidateStatus() == CandidateStatus.Failed) {
hasAFailedCandidate = true;
return true;
}
}
return false;
}
}
@Test
public void GeneticSearchCandidateGenerator_getCandidate_ShouldGenerateCandidates() throws Exception {
ScoreFunction scoreFunction = new BraninFunction.BraninScoreFunction();
//Define configuration:
CandidateGenerator candidateGenerator =
new GeneticSearchCandidateGenerator.Builder(new BraninFunction.BraninSpace(), scoreFunction)
.build();
TestTerminationCondition testTerminationCondition = new TestTerminationCondition();
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).scoreFunction(scoreFunction)
.terminationConditions(new MaxCandidatesCondition(50), testTerminationCondition).build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator());
runner.addListeners(new LoggingStatusListener());
runner.execute();
Assert.assertFalse(testTerminationCondition.hasAFailedCandidate);
}
@Test
public void GeneticSearchCandidateGenerator_getCandidate_GeneticExceptionShouldMarkCandidateAsFailed() {
ScoreFunction scoreFunction = new BraninFunction.BraninScoreFunction();
//Define configuration:
CandidateGenerator candidateGenerator =
new GeneticSearchCandidateGenerator.Builder(new BraninFunction.BraninSpace(), scoreFunction)
.selectionOperator(new TestSelectionOperator()).build();
TestTerminationCondition testTerminationCondition = new TestTerminationCondition();
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).scoreFunction(scoreFunction)
.terminationConditions(testTerminationCondition).build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator());
runner.addListeners(new LoggingStatusListener());
runner.execute();
Assert.assertTrue(testTerminationCondition.hasAFailedCandidate);
}
}

View File

@ -1,106 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
import org.junit.Test;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.*;
public class TestGridSearch extends BaseDL4JTest {
@Test
public void testIndexing() {
int[] nValues = {2, 3};
int prod = 2 * 3;
double[][] expVals = new double[][] {{0.0, 0.0}, {1.0, 0.0}, {0.0, 0.5}, {1.0, 0.5}, {0.0, 1.0}, {1.0, 1.0}};
for (int i = 0; i < prod; i++) {
double[] out = GridSearchCandidateGenerator.indexToValues(nValues, i, prod);
double[] exp = expVals[i];
assertArrayEquals(exp, out, 1e-4);
}
}
@Test
public void testGeneration() throws Exception {
Map<String, Object> commands = new HashMap<>();
commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>());
//Define configuration:
CandidateGenerator candidateGenerator = new GridSearchCandidateGenerator(new BraninFunction.BraninSpace(), 4,
GridSearchCandidateGenerator.Mode.Sequential, commands);
//Check sequential:
double[] expValuesFirst = {-5, 0, 5, 10}; //Range: -5 to +10, with 4 values
double[] expValuesSecond = {0, 5, 10, 15}; //Range: 0 to +15, with 4 values
for (int i = 0; i < 4 * 4; i++) {
BraninFunction.BraninConfig conf = (BraninFunction.BraninConfig) candidateGenerator.getCandidate().getValue();
double expF = expValuesFirst[i % 4]; //Changes most rapidly
double expS = expValuesSecond[i / 4];
double actF = conf.getX1();
double actS = conf.getX2();
assertEquals(expF, actF, 1e-4);
assertEquals(expS, actS, 1e-4);
}
//Check random order. specifically: check that all values are generated, in some order
double[][] orderedOutput = new double[16][2];
for (int i = 0; i < expValuesFirst.length; i++) {
for (int j = 0; j < expValuesSecond.length; j++) {
orderedOutput[4 * j + i][0] = expValuesFirst[i];
orderedOutput[4 * j + i][1] = expValuesSecond[j];
}
}
candidateGenerator = new GridSearchCandidateGenerator(new BraninFunction.BraninSpace(), 4,
GridSearchCandidateGenerator.Mode.RandomOrder, commands);
boolean[] seen = new boolean[16];
int seenCount = 0;
for (int i = 0; i < 4 * 4; i++) {
assertTrue(candidateGenerator.hasMoreCandidates());
BraninFunction.BraninConfig config = (BraninFunction.BraninConfig) candidateGenerator.getCandidate().getValue();
double x1 = config.getX1();
double x2 = config.getX2();
//Work out which of the values this is...
boolean matched = false;
for (int j = 0; j < 16; j++) {
if (Math.abs(orderedOutput[j][0] - x1) < 1e-5 && Math.abs(orderedOutput[j][1] - x2) < 1e-5) {
matched = true;
if (seen[j])
fail("Same candidate generated multiple times");
seen[j] = true;
seenCount++;
break;
}
}
assertTrue("Candidate " + x1 + ", " + x2 + " not found; invalid?", matched);
}
assertFalse(candidateGenerator.hasMoreCandidates());
assertEquals(16, seenCount);
}
}

View File

@ -1,124 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize;
import org.apache.commons.math3.distribution.LogNormalDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
import org.deeplearning4j.arbiter.optimize.parameter.BooleanSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.junit.Test;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 02/02/2017.
*/
public class TestJson extends BaseDL4JTest {
protected static ObjectMapper getObjectMapper(JsonFactory factory) {
ObjectMapper om = new ObjectMapper(factory);
om.registerModule(new JodaModule());
om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
om.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
om.enable(SerializationFeature.INDENT_OUTPUT);
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
om.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
return om;
}
private static ObjectMapper jsonMapper = getObjectMapper(new JsonFactory());
private static ObjectMapper yamlMapper = getObjectMapper(new YAMLFactory());
@Test
public void testParameterSpaceJson() throws Exception {
List<ParameterSpace<?>> l = new ArrayList<>();
l.add(new FixedValue<>(1.0));
l.add(new FixedValue<>(1));
l.add(new FixedValue<>("string"));
l.add(new ContinuousParameterSpace(-1, 1));
l.add(new ContinuousParameterSpace(new LogNormalDistribution(1, 1)));
l.add(new ContinuousParameterSpace(new NormalDistribution(2, 0.01)));
l.add(new DiscreteParameterSpace<>(1, 5, 7));
l.add(new DiscreteParameterSpace<>("first", "second", "third"));
l.add(new IntegerParameterSpace(0, 10));
l.add(new IntegerParameterSpace(new UniformIntegerDistribution(0, 50)));
l.add(new BooleanSpace());
for (ParameterSpace<?> ps : l) {
String strJson = jsonMapper.writeValueAsString(ps);
String strYaml = yamlMapper.writeValueAsString(ps);
ParameterSpace<?> fromJson = jsonMapper.readValue(strJson, ParameterSpace.class);
ParameterSpace<?> fromYaml = yamlMapper.readValue(strYaml, ParameterSpace.class);
assertEquals(ps, fromJson);
assertEquals(ps, fromYaml);
}
}
@Test
public void testCandidateGeneratorJson() throws Exception {
Map<String, Object> commands = new HashMap<>();
commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>());
List<CandidateGenerator> l = new ArrayList<>();
l.add(new GridSearchCandidateGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), 10,
GridSearchCandidateGenerator.Mode.Sequential, commands));
l.add(new GridSearchCandidateGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), 10,
GridSearchCandidateGenerator.Mode.RandomOrder, commands));
l.add(new RandomSearchGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), commands));
for (CandidateGenerator cg : l) {
String strJson = jsonMapper.writeValueAsString(cg);
String strYaml = yamlMapper.writeValueAsString(cg);
CandidateGenerator fromJson = jsonMapper.readValue(strJson, CandidateGenerator.class);
CandidateGenerator fromYaml = yamlMapper.readValue(strYaml, CandidateGenerator.class);
assertEquals(cg, fromJson);
assertEquals(cg, fromYaml);
}
}
}

View File

@ -1,63 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusListener;
import org.junit.Test;
import java.util.HashMap;
import java.util.Map;
/**
*
* Test random search on the Branin Function:
* http://www.sfu.ca/~ssurjano/branin.html
*/
public class TestRandomSearch extends BaseDL4JTest {
@Test
public void test() throws Exception {
Map<String, Object> commands = new HashMap<>();
commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>());
//Define configuration:
CandidateGenerator candidateGenerator = new RandomSearchGenerator(new BraninFunction.BraninSpace(), commands);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).scoreFunction(new BraninFunction.BraninScoreFunction())
.terminationConditions(new MaxCandidatesCondition(50)).build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator());
runner.addListeners(new LoggingStatusListener());
runner.execute();
// System.out.println("----- Complete -----");
}
}

View File

@ -1,72 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.distribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class TestLogUniform extends BaseDL4JTest {
@Test
public void testSimple(){
double min = 0.5;
double max = 3;
double logMin = Math.log(min);
double logMax = Math.log(max);
RealDistribution rd = new LogUniformDistribution(min, max);
for(double d = 0.1; d<= 3.5; d+= 0.1){
double density = rd.density(d);
double cumulative = rd.cumulativeProbability(d);
double dExp;
double cumExp;
if(d < min){
dExp = 0;
cumExp = 0;
} else if( d > max){
dExp = 0;
cumExp = 1;
} else {
dExp = 1.0 / (d * (logMax-logMin));
cumExp = (Math.log(d) - logMin) / (logMax - logMin);
}
assertTrue(dExp >= 0);
assertTrue(cumExp >= 0);
assertTrue(cumExp <= 1.0);
assertEquals(dExp, density, 1e-5);
assertEquals(cumExp, cumulative, 1e-5);
}
rd.reseedRandomGenerator(12345);
for( int i=0; i<100; i++ ){
double d = rd.sample();
assertTrue(d >= min);
assertTrue(d <= max);
}
}
}

View File

@ -1,42 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.genetic;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
public class TestCrossoverOperator extends CrossoverOperator {
private final CrossoverResult[] results;
private int resultIdx = 0;
public PopulationModel getPopulationModel() {
return populationModel;
}
public TestCrossoverOperator(CrossoverResult[] results) {
this.results = results;
}
@Override
public CrossoverResult crossover() {
return results[resultIdx++];
}
}

View File

@ -1,36 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.genetic;
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.MutationOperator;
public class TestMutationOperator implements MutationOperator {
private final boolean[] results;
private int resultIdx = 0;
public TestMutationOperator(boolean[] results) {
this.results = results;
}
@Override
public boolean mutate(double[] genes) {
return results[resultIdx++];
}
}

View File

@ -1,54 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.genetic;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import java.util.List;
public class TestParentSelection extends TwoParentSelection {
public boolean hasBeenInitialized;
private final double[][] parents;
public TestParentSelection(double[][] parents) {
this.parents = parents;
}
public TestParentSelection() {
this(null);
}
@Override
public void initializeInstance(List<Chromosome> population) {
super.initializeInstance(population);
hasBeenInitialized = true;
}
@Override
public double[][] selectParents() {
return parents;
}
public List<Chromosome> getPopulation() {
return population;
}
}

View File

@ -1,32 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.genetic;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
import java.util.ArrayList;
import java.util.List;
public class TestPopulationInitializer implements PopulationInitializer {
@Override
public List<Chromosome> getInitializedPopulation(int size) {
return new ArrayList<>();
}
}

View File

@ -1,90 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.genetic;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.math3.random.RandomGenerator;
public class TestRandomGenerator implements RandomGenerator {
private final int[] intRandomNumbers;
private int currentIntIdx = 0;
private final double[] doubleRandomNumbers;
private int currentDoubleIdx = 0;
public TestRandomGenerator(int[] intRandomNumbers, double[] doubleRandomNumbers) {
this.intRandomNumbers = intRandomNumbers;
this.doubleRandomNumbers = doubleRandomNumbers;
}
@Override
public void setSeed(int i) {
}
@Override
public void setSeed(int[] ints) {
}
@Override
public void setSeed(long l) {
}
@Override
public void nextBytes(byte[] bytes) {
}
@Override
public int nextInt() {
return intRandomNumbers[currentIntIdx++];
}
@Override
public int nextInt(int i) {
return intRandomNumbers[currentIntIdx++];
}
@Override
public long nextLong() {
throw new NotImplementedException("Not implemented");
}
@Override
public boolean nextBoolean() {
throw new NotImplementedException("Not implemented");
}
@Override
public float nextFloat() {
throw new NotImplementedException("Not implemented");
}
@Override
public double nextDouble() {
return doubleRandomNumbers[currentDoubleIdx++];
}
@Override
public double nextGaussian() {
throw new NotImplementedException("Not implemented");
}
}

View File

@ -1,70 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
public class ArithmeticCrossoverTests extends BaseDL4JTest {
@Test
public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() {
double[][] parents = new double[2][];
parents[0] = new double[] {1.0};
parents[1] = new double[] {2.0};
TestParentSelection parentSelection = new TestParentSelection(parents);
RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0});
ArithmeticCrossover sut =
new ArithmeticCrossover.Builder().parentSelection(parentSelection).randomGenerator(rng).build();
CrossoverResult result = sut.crossover();
Assert.assertFalse(result.isModified());
Assert.assertEquals(1, result.getGenes().length);
Assert.assertEquals(1.0, result.getGenes()[0], 0.001);
}
@Test
public void ArithmeticCrossover_Crossover_WithinCrossoverRate_ShouldReturnLinearCombination() {
double[][] parents = new double[2][];
parents[0] = new double[] {1.0};
parents[1] = new double[] {2.0};
TestParentSelection parentSelection = new TestParentSelection(parents);
RandomGenerator rng = new TestRandomGenerator(null, new double[] {0.1, 0.1});
ArithmeticCrossover sut =
new ArithmeticCrossover.Builder().parentSelection(parentSelection).randomGenerator(rng).build();
CrossoverResult result = sut.crossover();
Assert.assertTrue(result.isModified());
Assert.assertEquals(1, result.getGenes().length);
Assert.assertEquals(1.9, result.getGenes()[0], 0.001);
}
}

View File

@ -1,45 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator;
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert;
import org.junit.Test;
public class CrossoverOperatorTests extends BaseDL4JTest {
@Test
public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException {
TestCrossoverOperator sut = new TestCrossoverOperator(null);
PopulationInitializer populationInitializer = new TestPopulationInitializer();
PopulationModel populationModel =
new PopulationModel.Builder().populationInitializer(populationInitializer).build();
sut.initializeInstance(populationModel);
Assert.assertSame(populationModel, sut.getPopulationModel());
}
}

View File

@ -1,47 +0,0 @@
/*
* ******************************************************************************
* * Copyright (c) 2021 Deeplearning4j Contributors
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
import java.util.Deque;
public class CrossoverPointsGeneratorTests extends BaseDL4JTest {
@Test
public void CrossoverPointsGenerator_FixedNumberCrossovers() {
RandomGenerator rng = new TestRandomGenerator(new int[] {0}, null);
CrossoverPointsGenerator sut = new CrossoverPointsGenerator(10, 2, 2, rng);
Deque<Integer> result = sut.getCrossoverPoints();
Assert.assertEquals(3, result.size());
int a = result.pop();
int b = result.pop();
int c = result.pop();
Assert.assertTrue(a < b);
Assert.assertTrue(b < c);
Assert.assertEquals(Integer.MAX_VALUE, c);
}
}

Some files were not shown because too many files have changed in this diff Show More